From b6575b22c8b4ded6b1b31cdd4781d6fe60e5bf53 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 4 Oct 2022 17:07:28 +0200 Subject: [PATCH 001/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .../wsl-x86_64-docker-cpu-build.jenkinsfile | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 .jenkins/wsl-x86_64-docker-cpu-build.jenkinsfile diff --git a/.jenkins/wsl-x86_64-docker-cpu-build.jenkinsfile b/.jenkins/wsl-x86_64-docker-cpu-build.jenkinsfile new file mode 100644 index 000000000..d02007c75 --- /dev/null +++ b/.jenkins/wsl-x86_64-docker-cpu-build.jenkinsfile @@ -0,0 +1,58 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +pipeline { + agent { + dockerfile { + filename 'Dockerfile' + dir '.docker' + label 'WSL-docker' + //additionalBuildArgs '--build-arg version=1.0.2' + //args '--gpus all' + } + } + + stages { + stage('prep-build-environment-linux-cpu') { + steps { + checkout scm + sh 'gcc --version' + sh 'cmake --version' + sh 'sh ./gradlew --version' + } + } + stage('build-linux-cpu') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cpu \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' + } + } + } +} From 62060d81316620eca0206c69bcf82c6e8f297dcc Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 4 Oct 2022 20:06:21 +0200 Subject: [PATCH 002/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .../linux-x86_64-docker-cpu-build.jenkinsfile | 32 ++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile b/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile index 46b32531b..64cfec3cc 100644 --- a/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile @@ -47,7 +47,37 @@ pipeline { steps { withGradle { - sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cpu \ + sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cpu \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' + } + } + stage('test-linux-cpu') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' + } + } + stage('publish-linux-cpu') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew publish --stacktrace -PCAVIS_CHIP=cpu \ -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' } From 2d9558af6b266fefaf3f2cf1939ea5631a3e7bf8 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 4 Oct 2022 20:07:17 +0200 Subject: [PATCH 003/126] Additional integration tests Signed-off-by: brian --- brutex-extended-tests/build.gradle | 8 ++ .../IntegrationTestBaselineGenerator.java | 0 .../integration/IntegrationTestRunner.java | 9 +- .../integration/IntegrationTestsDL4J.java | 0 .../integration/IntegrationTestsSameDiff.java | 0 .../deeplearning4j/integration/ModelType.java | 0 .../deeplearning4j/integration/TestCase.java | 0 .../deeplearning4j/integration/TestUtils.java | 0 .../testcases/dl4j/CNN1DTestCases.java | 0 .../testcases/dl4j/CNN2DTestCases.java | 0 .../testcases/dl4j/CNN3DTestCases.java | 0 .../testcases/dl4j/MLPTestCases.java | 0 .../testcases/dl4j/RNNTestCases.java | 0 .../testcases/dl4j/UnsupervisedTestCases.java | 0 .../dl4j/misc/CharacterIterator.java | 0 .../testcases/samediff/SameDiffCNNCases.java | 0 .../samediff/SameDiffMLPTestCases.java | 0 .../samediff/SameDiffRNNTestCases.java | 2 +- .../util/CountingMultiDataSetIterator.java | 0 deeplearning4j/dl4j-integration-tests/pom.xml | 106 ------------------ .../dl4j-integration-tests/readme.md | 63 ----------- .../src/test/resources/logback-test.xml | 54 --------- 22 files changed, 13 insertions(+), 229 deletions(-) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java (99%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/ModelType.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/TestCase.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/TestUtils.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java (100%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java (99%) rename {deeplearning4j/dl4j-integration-tests => brutex-extended-tests}/src/test/java/org/deeplearning4j/integration/util/CountingMultiDataSetIterator.java (100%) delete mode 100644 deeplearning4j/dl4j-integration-tests/pom.xml delete mode 100644 deeplearning4j/dl4j-integration-tests/readme.md delete mode 100644 deeplearning4j/dl4j-integration-tests/src/test/resources/logback-test.xml diff --git a/brutex-extended-tests/build.gradle b/brutex-extended-tests/build.gradle index d21da53de..db9ea1cae 100644 --- a/brutex-extended-tests/build.gradle +++ b/brutex-extended-tests/build.gradle @@ -37,6 +37,7 @@ dependencies { implementation projects.cavisDatavec.cavisDatavecApi implementation projects.cavisDatavec.cavisDatavecSpark.cavisDatavecSparkCore implementation projects.cavisDnn.cavisDnnCommon + implementation projects.cavisDnn.cavisDnnCommonTests implementation projects.cavisDnn.cavisDnnApi implementation "org.slf4j:slf4j-api" implementation "org.apache.hadoop:hadoop-client" @@ -47,6 +48,8 @@ dependencies { testImplementation "org.apache.spark:spark-sql_${scalaVersion}" testCompileOnly "org.scala-lang:scala-library" + implementation "it.unimi.dsi:fastutil-core:8.5.8" + implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkCore implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkParameterserver implementation projects.cavisDnn.cavisDnnNnParent.cavisDnnNnCore @@ -56,8 +59,13 @@ dependencies { implementation projects.cavisUi.cavisUiModel implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerCore implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerNode + implementation projects.cavisDnn.cavisDnnData.cavisDnnDataDatasets implementation projects.cavisDnn.cavisDnnData.cavisDnnDataDatavecIterators + implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators implementation projects.cavisDatavec.cavisDatavecData.cavisDatavecDataImage + implementation projects.cavisDnn.cavisDnnParallelwrapper + + implementation projects.cavisZoo.cavisZooModels } test { diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java similarity index 99% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index 43c112d0a..e32f65a67 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -21,6 +21,8 @@ package org.deeplearning4j.integration; +import com.google.common.collect.ImmutableSet; +import com.google.common.reflect.ClassPath; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; @@ -44,13 +46,14 @@ import org.deeplearning4j.optimize.listeners.CollectScoresListener; import org.deeplearning4j.parallelism.ParallelInference; import org.deeplearning4j.parallelism.inference.InferenceMode; import org.deeplearning4j.util.ModelSerializer; - import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.VariableType; import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.common.base.Preconditions; +import org.nd4j.common.primitives.Pair; +import org.nd4j.common.resources.Resources; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.*; import org.nd4j.evaluation.regression.RegressionEvaluation; @@ -66,10 +69,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.ops.transforms.Transforms; -import org.nd4j.common.primitives.Pair; -import org.nd4j.common.resources.Resources; -import com.google.common.collect.ImmutableSet; -import com.google.common.reflect.ClassPath; import java.io.*; import java.lang.reflect.Modifier; diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/ModelType.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/ModelType.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/ModelType.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/ModelType.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestCase.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestCase.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestCase.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestCase.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffMLPTestCases.java diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java similarity index 99% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java index e3daa2126..579aff0a4 100644 --- a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffRNNTestCases.java @@ -63,7 +63,7 @@ import java.util.Map; public class SameDiffRNNTestCases { public static TestCase getRnnCsvSequenceClassificationTestCase1() { - return new SameDiffRNNTestCases.RnnCsvSequenceClassificationTestCase1(); + return new RnnCsvSequenceClassificationTestCase1(); } protected static class RnnCsvSequenceClassificationTestCase1 extends TestCase { diff --git a/deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/util/CountingMultiDataSetIterator.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/util/CountingMultiDataSetIterator.java similarity index 100% rename from deeplearning4j/dl4j-integration-tests/src/test/java/org/deeplearning4j/integration/util/CountingMultiDataSetIterator.java rename to brutex-extended-tests/src/test/java/org/deeplearning4j/integration/util/CountingMultiDataSetIterator.java diff --git a/deeplearning4j/dl4j-integration-tests/pom.xml b/deeplearning4j/dl4j-integration-tests/pom.xml deleted file mode 100644 index bc595b5e7..000000000 --- a/deeplearning4j/dl4j-integration-tests/pom.xml +++ /dev/null @@ -1,106 +0,0 @@ - - - - - - - net.brutex.ai - deeplearning4j-parent - 1.0.0-SNAPSHOT - - - 4.0.0 - - dl4j-integration-tests - - - - - org.slf4j - slf4j-api - - - net.brutex.ai - nd4j-api - ${project.version} - - - net.brutex.ai - deeplearning4j-core - ${project.version} - - - net.brutex.ai - deeplearning4j-zoo - ${project.version} - - - net.brutex.ai - deeplearning4j-parallel-wrapper - ${project.version} - - - ch.qos.logback - logback-classic - test - - - net.brutex.ai - deeplearning4j-common-tests - ${project.version} - test - - - net.brutex.ai - nd4j-common - ${project.version} - test - - - - - - - - - org.apache.maven.plugins - maven-install-plugin - 2.5.2 - - - default-install - none - - - - - - org.apache.maven.plugins - maven-deploy-plugin - - true - - - - - \ No newline at end of file diff --git a/deeplearning4j/dl4j-integration-tests/readme.md b/deeplearning4j/dl4j-integration-tests/readme.md deleted file mode 100644 index e0b9697e1..000000000 --- a/deeplearning4j/dl4j-integration-tests/readme.md +++ /dev/null @@ -1,63 +0,0 @@ - -#DL4J and SameDiff Integration Tests - -These tests are designed to check a number of aspects of DL4J and SameDiff: -1. Predictions (i.e., network output) -2. Training (training curves, parameters, gradient calculation) -3. Evaluation (accuracy, etc) -4. Model serialization (saving + loading models) -5. Overfitting sanity checks (make sure we can overfit a single example) -6. Data pipelines -7. Parallel Wrapper -8. Validating conditions that should always hold (frozen layer params don't change, for example) - - -They are designed for the following purposes: -1. Detecting regressions: i.e., new commit changed or broke previously working functionality -2. Detecting integration issues - i.e., issues that show up only when components are used together (but not in isolation in unit test) -3. Detecting significant differences between CPU and CUDA backends -4. Validating implementation via sanity checks on training - i.e., can we overfit a single example? -5. Checking networks and data pipelines on real-world scale data and nets -6. Operating as fully automated pre-release checks (replacing manual sanity checks) - -## Main Classes - -Explanation of the main classes: -* **IntegrationTestBaselineGenerator**: Run *manually* to generate and save "expected results" for comparing in the future. - Output goes to dl4j-test-resources, for saving/uploading. -* **IntegrationTestRunner**: Actually runs the tests, and compares the output/result to those generated by the baseline generator -* **TestCase**: integration tests extend this -* **testcases/\*.java**: the actual integration test definitions -* **IntegrationTestsDL4J**: entry point for running the DL4J integration tests -* **IntegrationTestsSameDiff**: entry point for running the SameDiff integration tests - -## Types of Test Components - -The integration tests are set up to be able to run multiple types of tests on each network configuration. - -Networks may be pretrained (from model zoo) or randomly initialized (from specified configuration). - -Specifically, test cases can be run with any subset of the following components to be tested, by setting TestCase.XYZ boolean options to true or false: - -1. **testPredictions**: Testing output (predictions) on some specified data vs. saved/known good arrays -2. **testGradients**: Testing gradients on some specified data vs. saved/known good arrays -3. **testPretrain**: Test layerwise pretraining parameters and training curves -4. **testTrainingCurves**: Train, and check score vs. iteration -5. **testParamsPostTraining**: validate params match post training -6. **testEvaluation**: test the evaluation performance (post training, if 4 or 5 are true) -7. **testParallelInference**: validate that single net and parallel inference results match -8. **testOverfitting**: sanity check - try to overfit a single example - -See TestCase.java for more details. - - -## Adding a New Integration Test - -The process to add a new test is simple: -1. Add a method that creates and returns a TestCase object (example: testcases/MLPTestCases.getMLPMnist()) -2. Add it as a unit test to IntegrationTests class (example: IntegrationTestsDL4J.testMLPMnist()) -3. Run IntegrationTestBaselineGenerator with the new test case, to generate and save the "known good" results. -4. Run the new integration test to make sure it passes, on both CPU and CUDA backends -5. Commit the generated test resources from step 3 to dl4j-test-resources repo - -Note that IntegrationTestBaselineGenerator assumes you have the dl4j-test-resources cloned parallel to the DL4J mono-repo. \ No newline at end of file diff --git a/deeplearning4j/dl4j-integration-tests/src/test/resources/logback-test.xml b/deeplearning4j/dl4j-integration-tests/src/test/resources/logback-test.xml deleted file mode 100644 index 6be67561e..000000000 --- a/deeplearning4j/dl4j-integration-tests/src/test/resources/logback-test.xml +++ /dev/null @@ -1,54 +0,0 @@ - - - - - - logs/application.log - - %logger{15} - %message%n%xException{5} - - - - - - - %logger{15} - %message%n%xException{5} - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file From 7362ea278bd34753b7e7a32578ca3f9bf6d31f24 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 4 Oct 2022 20:16:44 +0200 Subject: [PATCH 004/126] Fix compiler warning: comparison between signed and unsigned integer expressions Signed-off-by: brian --- .../src/main/cpp/blas/helpers/helper_generator.h | 2 +- .../cavis-native-lib/src/main/cpp/blas/legacy/NativeOps.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/helpers/helper_generator.h b/cavis-native/cavis-native-lib/src/main/cpp/blas/helpers/helper_generator.h index 6fd265f11..098bcf7d7 100644 --- a/cavis-native/cavis-native-lib/src/main/cpp/blas/helpers/helper_generator.h +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/helpers/helper_generator.h @@ -610,7 +610,7 @@ namespace sd { state[0] = seedConv(this->seed); state[1] = seedConv(this->seed * 119 + 3); - int fd = 3 + 3; + //int fd = 3 + 3; for (Nd4jLong i = 0; i < limit; i++) { buffer[i] = next64(); diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/legacy/NativeOps.h b/cavis-native/cavis-native-lib/src/main/cpp/blas/legacy/NativeOps.h index 6bc1f4fe1..f32182e54 100644 --- a/cavis-native/cavis-native-lib/src/main/cpp/blas/legacy/NativeOps.h +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/legacy/NativeOps.h @@ -1113,7 +1113,7 @@ static Nd4jPointer _numpyHeaderForNd4j(Nd4jPointer data,const Nd4jPointer shapeB auto npHeader = cnpy::createNpyHeader(data,npShape,rank,wordSize); char *ret = new char[npHeader.size() + 1]; int count = 0; - for(int i = 0; i < npHeader.size(); i++) { + for(int i = 0; (size_t)i < npHeader.size(); i++) { ret[count] = npHeader[i]; count++; } From 37e1e606033777fc81b0001f001e2d893ff600ae Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 5 Oct 2022 13:31:33 +0200 Subject: [PATCH 005/126] Additional integration tests Signed-off-by: brian --- cavis-zoo/cavis-zoo-models/build.gradle | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cavis-zoo/cavis-zoo-models/build.gradle b/cavis-zoo/cavis-zoo-models/build.gradle index 56b674d37..7cdf4a532 100644 --- a/cavis-zoo/cavis-zoo-models/build.gradle +++ b/cavis-zoo/cavis-zoo-models/build.gradle @@ -38,4 +38,6 @@ dependencies { implementation "com.fasterxml.jackson.core:jackson-databind" testImplementation "org.bytedeco:opencv" testImplementation "org.bytedeco:javacv" + + testImplementation "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT" } \ No newline at end of file From 7bcfa76df457857fd777d39e7947da0aef26cb95 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 5 Oct 2022 13:46:28 +0200 Subject: [PATCH 006/126] Fix compiler warnings Signed-off-by: brian --- .../main/java/org/deeplearning4j/eval/BaseEvaluation.java | 2 +- .../nn/conf/layers/objdetect/Yolo2OutputLayer.java | 2 ++ .../deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java | 6 ++---- .../deeplearning4j/nn/conf/memory/LayerMemoryReport.java | 6 ++---- 4 files changed, 7 insertions(+), 9 deletions(-) diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/BaseEvaluation.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/BaseEvaluation.java index 0a872ef72..e7318364d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/BaseEvaluation.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/BaseEvaluation.java @@ -38,7 +38,7 @@ import com.fasterxml.jackson.databind.module.SimpleModule; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; @Deprecated -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = false) public abstract class BaseEvaluation extends org.nd4j.evaluation.BaseEvaluation { @Getter diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java index ec93eee55..1229e8cfd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers.objdetect; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import org.deeplearning4j.nn.api.Layer; @@ -49,6 +50,7 @@ import java.util.List; import java.util.Map; @Data +@EqualsAndHashCode(callSuper = false) public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { private double lambdaCoord; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java index 438c98ad8..7cbebeaf2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java @@ -20,10 +20,7 @@ package org.deeplearning4j.nn.conf.layers.recurrent; -import lombok.Data; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; +import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -40,6 +37,7 @@ import java.util.Collection; import java.util.Map; @Data +@EqualsAndHashCode(callSuper = false) public class SimpleRnn extends BaseRecurrentLayer { private boolean hasLayerNorm = false; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/LayerMemoryReport.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/LayerMemoryReport.java index 33c0c3b6d..28725679b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/LayerMemoryReport.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/LayerMemoryReport.java @@ -20,10 +20,7 @@ package org.deeplearning4j.nn.conf.memory; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.NonNull; +import lombok.*; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.inputs.InputType; import org.nd4j.linalg.api.buffer.DataBuffer; @@ -35,6 +32,7 @@ import java.util.Map; @Data @AllArgsConstructor @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class LayerMemoryReport extends MemoryReport { private String layerName; From 6856b154b1885d8da6b8df20f048e33165bb681b Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 6 Oct 2022 13:22:06 +0200 Subject: [PATCH 007/126] More test fixes Signed-off-by: brian --- .../aeron/ipc/AeronNDArraySubscriber.java | 2 +- .../nd4j/aeron/ipc/LargeNdArrayIpcTest.java | 36 ++++++++++++------- cavis-nd4j/cavis-nd4j-common/build.gradle | 2 ++ .../nd4j/common/io/ClassPathResourceTest.java | 7 ++-- .../ParameterServerSubscriber.java | 2 +- nd4j/pom.xml | 35 +++++++++--------- 6 files changed, 50 insertions(+), 34 deletions(-) diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArraySubscriber.java b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArraySubscriber.java index 204893e97..0b554093a 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArraySubscriber.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArraySubscriber.java @@ -51,7 +51,7 @@ public class AeronNDArraySubscriber implements AutoCloseable { // Create a context, needed for client connection to media driver // A separate media driver process need to run prior to running this application private Aeron.Context ctx; - private AtomicBoolean running = new AtomicBoolean(true); + private AtomicBoolean running; private final AtomicBoolean init = new AtomicBoolean(false); private NDArrayCallback ndArrayCallback; diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java b/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java index 85ec9f01b..de0f74356 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java @@ -72,18 +72,23 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { INDArray arr = Nd4j.ones(length); AeronNDArrayPublisher publisher; ctx = new Aeron.Context() - .driverTimeoutMs(10000).availableImageHandler(AeronUtil::printAvailableImage) - .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()) - .errorHandler(err -> err.printStackTrace()); + .driverTimeoutMs(10000) + .availableImageHandler(AeronUtil::printAvailableImage) + .unavailableImageHandler(AeronUtil::printUnavailableImage) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()) + .errorHandler(err -> err.printStackTrace()); final AtomicBoolean running = new AtomicBoolean(true); Aeron aeron = Aeron.connect(ctx); int numSubscribers = 1; AeronNDArraySubscriber[] subscribers = new AeronNDArraySubscriber[numSubscribers]; for (int i = 0; i < numSubscribers; i++) { - AeronNDArraySubscriber subscriber = AeronNDArraySubscriber.builder().streamId(streamId).ctx(getContext()) - .channel(channel).aeron(aeron).running(running).ndArrayCallback(new NDArrayCallback() { + AeronNDArraySubscriber subscriber = AeronNDArraySubscriber.builder() + .streamId(streamId).ctx(getContext()) + .channel(channel) + .aeron(aeron) + .running(running) + .ndArrayCallback(new NDArrayCallback() { /** * A listener for ndarray message * @@ -110,7 +115,8 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { try { subscriber.launch(); } catch (Exception e) { - log.error("",e); + System.out.println(e.getMessage()); + e.printStackTrace(); } }); @@ -122,17 +128,23 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { Thread.sleep(1000); - publisher = AeronNDArrayPublisher.builder().publishRetryTimeOut(3000).streamId(streamId).channel(channel) - .aeron(aeron).build(); + publisher = AeronNDArrayPublisher.builder() + .publishRetryTimeOut(3000) + .streamId(streamId) + .channel(channel) + .aeron(aeron) + .build(); - for (int i = 0; i < 1 && running.get(); i++) { - log.info("About to send array."); + for (int i = 0; i < 1; i++) { + System.out.println("About to send array."); publisher.publish(arr); - log.info("Sent array"); + System.out.println("Sent array"); } + Thread.sleep( 5000); + for (int i = 0; i < numSubscribers; i++) CloseHelper.close(subscribers[i]); CloseHelper.close(aeron); diff --git a/cavis-nd4j/cavis-nd4j-common/build.gradle b/cavis-nd4j/cavis-nd4j-common/build.gradle index e801f3b5c..b58257047 100644 --- a/cavis-nd4j/cavis-nd4j-common/build.gradle +++ b/cavis-nd4j/cavis-nd4j-common/build.gradle @@ -18,6 +18,7 @@ * ***************************************************************************** * */ +apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" dependencies { implementation "com.fasterxml.jackson.core:jackson-databind" @@ -29,4 +30,5 @@ dependencies { implementation "org.apache.commons:commons-compress" implementation "commons-codec:commons-codec" testImplementation projects.cavisNd4j.cavisNd4jCommonTests + testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT" } \ No newline at end of file diff --git a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java index 9e1591f82..3e5f8bf72 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java +++ b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java @@ -24,9 +24,11 @@ package org.nd4j.common.io; import org.apache.commons.io.FileUtils; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.io.ClassPathResource; import java.io.File; +import java.nio.file.Path; import java.util.UUID; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -36,13 +38,12 @@ public class ClassPathResourceTest { @Test - public void testDirExtractingIntelliJ() throws Exception { + public void testDirExtractingIntelliJ(@TempDir Path tempDir) throws Exception { //https://github.com/deeplearning4j/deeplearning4j/issues/6483 ClassPathResource cpr = new ClassPathResource("somedir"); - File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID().toString()); - FileUtils.forceMkdir(f); + File f = tempDir.toFile(); cpr.copyDirectory(f); File[] files = f.listFiles(); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java index 6aaca2e49..b32e97147 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java @@ -325,7 +325,7 @@ public class ParameterServerSubscriber implements AutoCloseable { int tries=0; while (!subscriber.launched() && tries<12) { tries++; - LockSupport.parkNanos(100000); + Thread.sleep(1000); } if(!subscriber.launched()) { throw new Exception("Subscriber did not start in time."); diff --git a/nd4j/pom.xml b/nd4j/pom.xml index 8dfe09d06..33f87c4b3 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -1,22 +1,23 @@ Date: Fri, 7 Oct 2022 10:49:08 +0200 Subject: [PATCH 008/126] More test fixes Signed-off-by: brian --- brutex-extended-tests/build.gradle | 2 + .../integration/IntegrationTestRunner.java | 4 ++ .../integration/IntegrationTestsDL4J.java | 8 +--- .../integration/IntegrationTestsSameDiff.java | 6 --- cavis-common-platform/build.gradle | 1 + .../image/loader/NativeImageLoader.java | 41 ++++++++++--------- .../recordreader/BaseImageRecordReader.java | 7 +++- .../java/org/deeplearning4j/BaseDL4JTest.java | 7 ---- .../earlystopping/TestEarlyStopping.java | 2 + .../custom/testclasses/CustomActivation.java | 2 +- .../testlayers/SameDiffDenseVertex.java | 2 + .../build.gradle | 1 + .../RemoteParameterServerClientTests.java | 6 ++- .../ParameterServerSubscriber.java | 2 +- 14 files changed, 49 insertions(+), 42 deletions(-) diff --git a/brutex-extended-tests/build.gradle b/brutex-extended-tests/build.gradle index db9ea1cae..bd53f61bd 100644 --- a/brutex-extended-tests/build.gradle +++ b/brutex-extended-tests/build.gradle @@ -66,6 +66,8 @@ dependencies { implementation projects.cavisDnn.cavisDnnParallelwrapper implementation projects.cavisZoo.cavisZooModels + + testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT" } test { diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index e32f65a67..29e80ce99 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -73,6 +73,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import java.io.*; import java.lang.reflect.Modifier; import java.nio.charset.StandardCharsets; +import java.nio.file.Path; import java.util.*; import java.util.concurrent.atomic.AtomicInteger; @@ -154,6 +155,9 @@ public class IntegrationTestRunner { evaluationClassesSeen = new HashMap<>(); } + public static void runTest(TestCase tc, Path testDir) throws Exception { + runTest(tc, testDir.toFile()); + } public static void runTest(TestCase tc, File testDir) throws Exception { BaseDL4JTest.skipUnlessIntegrationTests(); //Tests will ONLY be run if integration test profile is enabled. //This could alternatively be done via maven surefire configuration diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java index ebf4a9442..6b50f265e 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsDL4J.java @@ -28,18 +28,14 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import java.io.File; +import java.nio.file.Path; ////@Ignore("AB - 2019/05/27 - Integration tests need to be updated") public class IntegrationTestsDL4J extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 300_000L; - } - @TempDir - public File testDir; + public Path testDir; @AfterAll public static void afterClass(){ diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java index cc86b5cb3..5eb1ce856 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestsSameDiff.java @@ -30,12 +30,6 @@ import java.io.File; public class IntegrationTestsSameDiff extends BaseDL4JTest { - - @Override - public long getTimeoutMilliseconds() { - return 300_000L; - } - @TempDir public File testDir; diff --git a/cavis-common-platform/build.gradle b/cavis-common-platform/build.gradle index 05368a76b..a6202c6a8 100644 --- a/cavis-common-platform/build.gradle +++ b/cavis-common-platform/build.gradle @@ -65,6 +65,7 @@ dependencies { /*Logging*/ api 'org.slf4j:slf4j-api:1.7.30' + api 'org.slf4j:slf4j-simple:1.7.25' api "org.apache.logging.log4j:log4j-core:2.17.0" api "ch.qos.logback:logback-classic:1.2.3" diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index cf3e4abbe..4db72decd 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -48,6 +48,11 @@ import static org.bytedeco.opencv.global.opencv_core.*; import static org.bytedeco.opencv.global.opencv_imgcodecs.*; import static org.bytedeco.opencv.global.opencv_imgproc.*; +/** + * Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp + * + * @author saudet + */ public class NativeImageLoader extends BaseImageLoader { private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024; private byte[] buffer = null; @@ -57,14 +62,16 @@ public class NativeImageLoader extends BaseImageLoader { "png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM", "PNG", "TIF", "TIFF", "EXR", "WEBP"}; - protected OpenCVFrameConverter.ToMat converter = new OpenCVFrameConverter.ToMat(); + protected OpenCVFrameConverter.ToMat converter; boolean direct = !Loader.getPlatform().startsWith("android"); /** * Loads images with no scaling or conversion. */ - public NativeImageLoader() {} + public NativeImageLoader() { + this.converter = new OpenCVFrameConverter.ToMat(); + } /** * Instantiate an image with the given @@ -74,6 +81,7 @@ public class NativeImageLoader extends BaseImageLoader { */ public NativeImageLoader(long height, long width) { + this(); this.height = height; this.width = width; } @@ -87,8 +95,7 @@ public class NativeImageLoader extends BaseImageLoader { * @param channels the number of channels for the image* */ public NativeImageLoader(long height, long width, long channels) { - this.height = height; - this.width = width; + this(height, width); this.channels = channels; } @@ -132,12 +139,9 @@ public class NativeImageLoader extends BaseImageLoader { } protected NativeImageLoader(NativeImageLoader other) { - this.height = other.height; - this.width = other.width; - this.channels = other.channels; + this(other.height, other.width, other.channels, other.multiPageMode); this.centerCropIfNeeded = other.centerCropIfNeeded; this.imageTransform = other.imageTransform; - this.multiPageMode = other.multiPageMode; } @Override @@ -297,7 +301,7 @@ public class NativeImageLoader extends BaseImageLoader { private Mat streamToMat(InputStream is) throws IOException { if(buffer == null){ buffer = IOUtils.toByteArray(is); - if(buffer.length <= 0){ + if(buffer.length == 0){ throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); } bufferMat = new Mat(buffer); @@ -545,10 +549,15 @@ public class NativeImageLoader extends BaseImageLoader { } public void asMatrixView(InputStream is, INDArray view) throws IOException { - Mat mat = streamToMat(is); - Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); + throw new RuntimeException("Not implemented"); + + } + + public void asMatrixView(String filename, INDArray view) throws IOException { + Mat image = imread(filename,IMREAD_ANYDEPTH | IMREAD_ANYCOLOR ); + //Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); if (image == null || image.empty()) { - PIX pix = pixReadMem(mat.data(), mat.cols()); + PIX pix = pixReadMem(image.data(), image.cols()); if (pix == null) { throw new IOException("Could not decode image from input stream"); } @@ -561,14 +570,8 @@ public class NativeImageLoader extends BaseImageLoader { image.deallocate(); } - public void asMatrixView(String filename, INDArray view) throws IOException { - asMatrixView(new File(filename), view); - } - public void asMatrixView(File f, INDArray view) throws IOException { - try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - asMatrixView(bis, view); - } + asMatrixView(f.getAbsolutePath(), view); } public void asMatrixView(Mat image, INDArray view) throws IOException { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index 86a6a59c1..079049257 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -53,6 +53,10 @@ import java.io.*; import java.net.URI; import java.util.*; +/** +* Base class for the image record reader +* +*/ @Slf4j public abstract class BaseImageRecordReader extends BaseRecordReader { protected boolean finishedInputStreamSplit; @@ -344,7 +348,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { ((NativeImageLoader) imageLoader).asMatrixView(currBatch.get(i), features.tensorAlongDimension(i, 1, 2, 3)); } catch (Exception e) { - System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath()); + System.out.println("Image file failed during load: " + currBatch.get(i).getAbsolutePath() + "\n" + e.getMessage()); + e.printStackTrace(); throw new RuntimeException(e); } } diff --git a/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index aca151aa2..0dcd1fe08 100644 --- a/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -58,13 +58,6 @@ public abstract class BaseDL4JTest { return DEFAULT_THREADS; } - /** - * Override this method to set the default timeout for methods in the test class - */ - public long getTimeoutMilliseconds(){ - return 90_000; - } - /** * Override this to set the profiling mode for the tests defined in the child class */ diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java index b4e790ea1..2774e9961 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java @@ -21,6 +21,7 @@ package org.deeplearning4j.earlystopping; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; @@ -817,6 +818,7 @@ public class TestEarlyStopping extends BaseDL4JTest { } @Data + @EqualsAndHashCode(callSuper = false) public static class TestListener extends BaseTrainingListener { private int countEpochStart = 0; private int countEpochEnd = 0; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java index e73a6fea6..f88e76a17 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomActivation.java @@ -26,7 +26,7 @@ import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = false) public class CustomActivation extends BaseActivationFunction implements IActivation { @Override public INDArray getActivation(INDArray in, boolean training) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java index 1f1c632e9..da674ea7c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.layers.samediff.testlayers; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams; @@ -37,6 +38,7 @@ import java.util.Map; @NoArgsConstructor @Data +@EqualsAndHashCode(callSuper = false) public class SameDiffDenseVertex extends SameDiffVertex { private int nIn; diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/build.gradle b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/build.gradle index af5e0aa84..99b6d1866 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/build.gradle +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/build.gradle @@ -27,6 +27,7 @@ dependencies { implementation "com.fasterxml.jackson.core:jackson-core" implementation "com.fasterxml.jackson.core:jackson-databind" implementation "org.slf4j:slf4j-api" + implementation "org.slf4j:slf4j-simple" implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerModel implementation projects.cavisNd4j.cavisNd4jAeron implementation projects.cavisDnn.cavisDnnApi diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java index 62175a7fa..005443fe3 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java @@ -56,7 +56,11 @@ public class RemoteParameterServerClientTests extends BaseND4JTest { new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true) .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy()) - .senderIdleStrategy(new BusySpinIdleStrategy()); + .senderIdleStrategy(new BusySpinIdleStrategy()) + .driverTimeoutMs(1000*1000 *1000) + .clientLivenessTimeoutNs(1000*1000*1000) + .timerIntervalNs( 1000 * 1000); + mediaDriver = MediaDriver.launchEmbedded(ctx); aeron = Aeron.connect(getContext()); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java index b32e97147..59bce0ad0 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/ParameterServerSubscriber.java @@ -325,7 +325,7 @@ public class ParameterServerSubscriber implements AutoCloseable { int tries=0; while (!subscriber.launched() && tries<12) { tries++; - Thread.sleep(1000); + Thread.sleep(2000); } if(!subscriber.launched()) { throw new Exception("Subscriber did not start in time."); From b8a21bc99173921141fde18cac6aa87e55033005 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 7 Oct 2022 12:28:58 +0200 Subject: [PATCH 009/126] More test fixes Signed-off-by: brian --- .../ndarray/NDArrayColumnsMathOpTransform.java | 2 ++ .../ndarray/NDArrayMathFunctionTransform.java | 2 ++ .../ndarray/NDArrayScalarOpTransform.java | 2 ++ .../StringListToIndicesNDArrayTransform.java | 2 ++ .../image/transform/LargestBlobCropTransform.java | 2 ++ .../image/transform/PipelineImageTransform.java | 2 ++ .../image/transform/RandomCropTransform.java | 2 ++ .../autodiff/functions/DifferentialFunction.java | 5 +---- .../org/nd4j/autodiff/samediff/SDVariable.java | 14 ++------------ .../nd4j/autodiff/samediff/internal/Variable.java | 1 + .../java/org/nd4j/autodiff/util/SameDiffUtils.java | 2 +- .../classification/EvaluationCalibration.java | 2 +- .../java/org/nd4j/evaluation/curves/Histogram.java | 2 ++ .../deallocation/DeallocatableReference.java | 3 +++ .../nd4j/linalg/api/ops/BaseIndexAccumulation.java | 2 ++ .../org/nd4j/linalg/api/ops/DynamicCustomOp.java | 3 +++ .../org/nd4j/linalg/api/ops/custom/Flatten.java | 2 ++ .../ops/impl/controlflow/compat/BaseCompatOp.java | 3 +++ .../api/ops/impl/controlflow/compat/Enter.java | 2 ++ .../api/ops/impl/controlflow/compat/While.java | 2 ++ .../linalg/api/ops/impl/indexaccum/FirstIndex.java | 2 ++ .../linalg/api/ops/impl/indexaccum/LastIndex.java | 2 ++ .../api/ops/impl/indexaccum/custom/ArgAmax.java | 2 ++ .../api/ops/impl/indexaccum/custom/ArgAmin.java | 2 ++ .../api/ops/impl/indexaccum/custom/ArgMax.java | 2 ++ .../api/ops/impl/indexaccum/custom/ArgMin.java | 2 ++ .../layers/convolution/config/Conv1DConfig.java | 7 +++---- .../layers/convolution/config/Conv2DConfig.java | 2 ++ .../layers/convolution/config/Conv3DConfig.java | 2 ++ .../layers/convolution/config/DeConv2DConfig.java | 2 ++ .../layers/convolution/config/DeConv3DConfig.java | 2 ++ .../config/LocalResponseNormalizationConfig.java | 2 ++ .../layers/convolution/config/Pooling2DConfig.java | 2 ++ .../layers/convolution/config/Pooling3DConfig.java | 2 ++ .../org/nd4j/linalg/api/ops/impl/reduce/Mmul.java | 2 +- .../nd4j/linalg/api/ops/impl/reduce/MmulBp.java | 2 +- .../api/ops/impl/reduce/custom/BatchMmul.java | 2 +- .../nd4j/linalg/dataset/BalanceMinibatches.java | 5 +++++ .../org/nd4j/linalg/learning/config/AdaMax.java | 6 ++++-- .../java/org/nd4j/linalg/profiler/OpProfiler.java | 14 -------------- .../cavis-dnn-data-datavec-iterators/build.gradle | 1 + .../nn/modelimport/keras/layers/TFOpLayerImpl.java | 2 ++ .../keras/layers/core/KerasMasking.java | 2 ++ .../modelimport/keras/layers/core/KerasMerge.java | 2 ++ .../deeplearning4j/clustering/vptree/VPTree.java | 4 +++- .../org/deeplearning4j/eval/curves/Histogram.java | 2 ++ .../paramavg/ParameterAveragingTrainingMaster.java | 3 ++- .../training/SharedTrainingMaster.java | 3 ++- .../client/ParameterServerClient.java | 8 ++++---- .../zoo/model/TextGenerationLSTM.java | 2 +- 50 files changed, 101 insertions(+), 49 deletions(-) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java index 8be648fa0..ddbf34503 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java @@ -21,6 +21,7 @@ package org.datavec.api.transform.ndarray; import lombok.Data; +import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.MathOp; import org.datavec.api.transform.metadata.ColumnMetaData; @@ -36,6 +37,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Arrays; @Data +@EqualsAndHashCode(callSuper = false) public class NDArrayColumnsMathOpTransform extends BaseColumnsMathOpTransform { public NDArrayColumnsMathOpTransform(@JsonProperty("newColumnName") String newColumnName, diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.java index ca5b7921c..74a91332c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.java @@ -21,6 +21,7 @@ package org.datavec.api.transform.ndarray; import lombok.Data; +import lombok.EqualsAndHashCode; import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; @@ -32,6 +33,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import com.fasterxml.jackson.annotation.JsonProperty; @Data +@EqualsAndHashCode(callSuper = false) public class NDArrayMathFunctionTransform extends BaseColumnTransform { //Can't guarantee that the writable won't be re-used, for example in different Spark ops on the same RDD diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.java index 4f4dcc4c3..708625c8a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.java @@ -21,6 +21,7 @@ package org.datavec.api.transform.ndarray; import lombok.Data; +import lombok.EqualsAndHashCode; import org.datavec.api.transform.MathOp; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData; @@ -33,6 +34,7 @@ import org.nd4j.linalg.ops.transforms.Transforms; import com.fasterxml.jackson.annotation.JsonProperty; @Data +@EqualsAndHashCode(callSuper = false) public class NDArrayScalarOpTransform extends BaseColumnTransform { private final MathOp mathOp; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java index 19a03ce83..2d31b5c05 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToIndicesNDArrayTransform.java @@ -21,6 +21,7 @@ package org.datavec.api.transform.transform.string; import lombok.Data; +import lombok.EqualsAndHashCode; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import com.fasterxml.jackson.annotation.JsonProperty; @@ -31,6 +32,7 @@ import java.util.Collections; import java.util.List; @Data +@EqualsAndHashCode(callSuper = false) public class StringListToIndicesNDArrayTransform extends StringListToCountsNDArrayTransform { /** * @param columnName The name of the column to convert diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/LargestBlobCropTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/LargestBlobCropTransform.java index 0a12ee7f7..d823e5736 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/LargestBlobCropTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/LargestBlobCropTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; import org.nd4j.linalg.factory.Nd4j; @@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*; import static org.bytedeco.opencv.global.opencv_imgproc.*; @Data +@EqualsAndHashCode(callSuper = false) public class LargestBlobCropTransform extends BaseImageTransform { protected org.nd4j.linalg.api.rng.Random rng; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/PipelineImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/PipelineImageTransform.java index 84c4191be..524cf0a13 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/PipelineImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/PipelineImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NonNull; import org.datavec.image.data.ImageWritable; @@ -32,6 +33,7 @@ import java.util.*; import org.bytedeco.opencv.opencv_core.*; @Data +@EqualsAndHashCode(callSuper = false) public class PipelineImageTransform extends BaseImageTransform { protected List> imageTransforms; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/RandomCropTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/RandomCropTransform.java index 8d76ecb2c..86013c239 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/RandomCropTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/RandomCropTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; import org.nd4j.linalg.factory.Nd4j; @@ -35,6 +36,7 @@ import org.bytedeco.opencv.opencv_core.*; @JsonIgnoreProperties({"rng", "converter"}) @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class RandomCropTransform extends BaseImageTransform { protected int outputHeight; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 3ab872d91..c7920422f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -20,11 +20,8 @@ package org.nd4j.autodiff.functions; -import lombok.Data; -import lombok.Getter; -import lombok.Setter; +import lombok.*; import lombok.extern.slf4j.Slf4j; -import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index bfbd49ff0..609a1ceba 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -46,7 +46,9 @@ public class SDVariable implements Serializable { protected SameDiff sameDiff; @Getter + @Setter protected String varName; + @Getter @Setter protected VariableType variableType; @@ -83,18 +85,6 @@ public class SDVariable implements Serializable { return varName; } - public void setVarName(String varName) { - this.varName = varName; - } - - /** - * @deprecated Use {@link #name()} - */ - @Deprecated - public String getVarName(){ - return name(); - } - /** * Returns true if this variable is a place holder * @return diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java index 9f0bb48f8..dbc36b239 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/Variable.java @@ -39,5 +39,6 @@ public class Variable { protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of protected List controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution protected SDVariable gradient; //Variable corresponding to the gradient of this variable + @Builder.Default protected int variableIndex = -1; } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java index 560717ad8..a5e3f9b66 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java @@ -76,7 +76,7 @@ public class SameDiffUtils { public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map externalGradients, SDVariable... inputs) { Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + - " be specified when using external errors: got %s", inputs); + " be specified when using external errors: got %s", (Object) inputs); ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients); fn.outputVariable(); return fn; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java index e154fad61..670309928 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java @@ -49,7 +49,7 @@ import java.io.Serializable; import java.util.List; @Getter -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = false) public class EvaluationCalibration extends BaseEvaluation { public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/curves/Histogram.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/curves/Histogram.java index afaf32b32..f3510d615 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/curves/Histogram.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/curves/Histogram.java @@ -22,8 +22,10 @@ package org.nd4j.evaluation.curves; import lombok.Data; import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.EqualsAndHashCode; @Data +@EqualsAndHashCode(callSuper = false) public class Histogram extends BaseHistogram { private final String title; private final double lower; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatableReference.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatableReference.java index 24d0fe424..2c5cf41ac 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatableReference.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatableReference.java @@ -21,6 +21,8 @@ package org.nd4j.linalg.api.memory.deallocation; import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.memory.Deallocatable; import org.nd4j.linalg.api.memory.Deallocator; @@ -28,6 +30,7 @@ import java.lang.ref.ReferenceQueue; import java.lang.ref.WeakReference; @Data +@EqualsAndHashCode(callSuper = false) public class DeallocatableReference extends WeakReference { private String id; private Deallocator deallocator; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 7ae3ade6c..bcaa443d9 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -36,6 +37,7 @@ import java.util.List; @Slf4j @Data +@EqualsAndHashCode(callSuper = false) public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccumulation { protected boolean keepDims = false; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 794b00eb0..36e6982a0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -44,6 +44,9 @@ import java.lang.reflect.Array; import java.util.*; @Slf4j +@Builder +@AllArgsConstructor +@EqualsAndHashCode(callSuper = true) public class DynamicCustomOp extends DifferentialFunction implements CustomOp { private String opName; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java index 7c1fa9526..7981803e8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.custom; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; @@ -35,6 +36,7 @@ import java.util.List; @Data @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class Flatten extends DynamicCustomOp { private int order; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java index b6664aeab..0e71c1db2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java @@ -21,6 +21,8 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; import java.util.List; + +import lombok.EqualsAndHashCode; import lombok.NonNull; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; @@ -39,6 +41,7 @@ import org.tensorflow.framework.NodeDef; import java.util.HashMap; import java.util.Map; +@EqualsAndHashCode(callSuper = false) public abstract class BaseCompatOp extends DynamicCustomOp { protected String frameName; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java index 2b24f6944..f88dd0f76 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Enter.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -37,6 +38,7 @@ import java.util.Map; @Data @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class Enter extends BaseCompatOp { protected boolean isConstant; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/While.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/While.java index 518662ba8..b1f1b4b5c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/While.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/While.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.impl.controlflow.compat; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -36,6 +37,7 @@ import java.util.Map; @Data @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class While extends BaseCompatOp { protected boolean isConstant; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java index 8c7b7739b..fe08a1299 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/FirstIndex.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; @@ -36,6 +37,7 @@ import java.util.List; @Data @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class FirstIndex extends BaseIndexAccumulation { protected Condition condition; protected double compare; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java index 7b0e47dce..305b859c7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/LastIndex.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; @@ -38,6 +39,7 @@ import java.util.Map; @Data @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class LastIndex extends BaseIndexAccumulation { protected Condition condition; protected double compare; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmax.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmax.java index 228996407..84da21d67 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmax.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmax.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum.custom; import lombok.Data; +import lombok.EqualsAndHashCode; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; @@ -38,6 +39,7 @@ import java.util.List; import java.util.Map; @Data +@EqualsAndHashCode(callSuper = false) public class ArgAmax extends DynamicCustomOp { protected boolean keepDims = false; private int[] dimensions; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmin.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmin.java index 16db04b2c..38562baee 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmin.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmin.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum.custom; import lombok.Data; +import lombok.EqualsAndHashCode; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; @@ -38,6 +39,7 @@ import java.util.List; import java.util.Map; @Data +@EqualsAndHashCode(callSuper = false) public class ArgAmin extends DynamicCustomOp { protected boolean keepDims = false; private int[] dimensions; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java index e2a7438bd..557b1f21f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum.custom; import lombok.Data; +import lombok.EqualsAndHashCode; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; @@ -37,6 +38,7 @@ import java.util.List; import java.util.Map; @Data +@EqualsAndHashCode(callSuper = false) public class ArgMax extends DynamicCustomOp { protected boolean keepDims = false; private int[] dimensions; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java index 00445ee87..65d6ec1f3 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.api.ops.impl.indexaccum.custom; import lombok.Data; +import lombok.EqualsAndHashCode; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; @@ -37,6 +38,7 @@ import java.util.List; import java.util.Map; @Data +@EqualsAndHashCode(callSuper = false) public class ArgMin extends DynamicCustomOp { protected boolean keepDims = false; private int[] dimensions; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java index d86ed9f48..833aa5263 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv1DConfig.java @@ -22,16 +22,15 @@ package org.nd4j.linalg.api.ops.impl.layers.convolution.config; import java.util.LinkedHashMap; import java.util.Map; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.NonNull; + +import lombok.*; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class Conv1DConfig extends BaseConvolutionConfig { public static final String NCW = "NCW"; public static final String NWC = "NWC"; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java index 7981da12f..66e870027 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java @@ -24,6 +24,7 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.common.base.Preconditions; import org.nd4j.enums.WeightsFormat; @@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class Conv2DConfig extends BaseConvolutionConfig { public static final String NCHW = "NCHW"; public static final String NHWC = "NHWC"; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv3DConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv3DConfig.java index 1607d6ccd..574868013 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv3DConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv3DConfig.java @@ -25,6 +25,7 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.util.ConvConfigUtil; @@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class Conv3DConfig extends BaseConvolutionConfig { public static final String NDHWC = "NDHWC"; public static final String NCDHW = "NCDHW"; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv2DConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv2DConfig.java index 0b57e5cd5..bf9a041b9 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv2DConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv2DConfig.java @@ -24,6 +24,7 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.util.ConvConfigUtil; @@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class DeConv2DConfig extends BaseConvolutionConfig { public static final String NCHW = "NCHW"; public static final String NHWC = "NHWC"; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv3DConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv3DConfig.java index 041cda3d2..6c793e7d0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv3DConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/DeConv3DConfig.java @@ -24,6 +24,7 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.util.ConvConfigUtil; @@ -31,6 +32,7 @@ import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class DeConv3DConfig extends BaseConvolutionConfig { public static final String NCDHW = "NCDHW"; public static final String NDHWC = "NDHWC"; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/LocalResponseNormalizationConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/LocalResponseNormalizationConfig.java index e7edfcf4e..400ddd733 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/LocalResponseNormalizationConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/LocalResponseNormalizationConfig.java @@ -24,12 +24,14 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class LocalResponseNormalizationConfig extends BaseConvolutionConfig { private double alpha, beta, bias; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling2DConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling2DConfig.java index 1bf19aad2..bab13ae06 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling2DConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling2DConfig.java @@ -24,6 +24,7 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D.Divisor; @@ -33,6 +34,7 @@ import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class Pooling2DConfig extends BaseConvolutionConfig { @Builder.Default private long kH = -1, kW = -1; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling3DConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling3DConfig.java index 007c89e89..1ed04bcaf 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling3DConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Pooling3DConfig.java @@ -24,6 +24,7 @@ import java.util.LinkedHashMap; import java.util.Map; import lombok.Builder; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling3D.Pooling3DType; @@ -32,6 +33,7 @@ import org.nd4j.linalg.util.ConvConfigUtil; @Data @Builder @NoArgsConstructor +@EqualsAndHashCode(callSuper = false) public class Pooling3DConfig extends BaseConvolutionConfig { @Builder.Default private long kD = -1, kW = -1, kH = -1; // kernel @Builder.Default private long sD = 1, sW = 1, sH = 1; // strides diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index b37be57c0..be291a9c3 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -39,7 +39,7 @@ import org.tensorflow.framework.NodeDef; import java.lang.reflect.Field; import java.util.*; -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = false) public class Mmul extends DynamicCustomOp { protected MMulTranspose mt; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/MmulBp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/MmulBp.java index 58ed6f027..4471244ef 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/MmulBp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/MmulBp.java @@ -32,7 +32,7 @@ import org.nd4j.common.util.ArrayUtil; import java.util.List; -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = false) public class MmulBp extends DynamicCustomOp { protected MMulTranspose mt; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java index ea312036d..a7d63e4bd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/BatchMmul.java @@ -32,7 +32,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.*; -@EqualsAndHashCode +@EqualsAndHashCode(callSuper = false) public class BatchMmul extends DynamicCustomOp { protected int transposeA; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java index 0e400c756..b77a29eb4 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/BalanceMinibatches.java @@ -39,10 +39,15 @@ import java.util.Map; public class BalanceMinibatches { private DataSetIterator dataSetIterator; private int numLabels; + @Builder.Default private Map> paths = Maps.newHashMap(); + @Builder.Default private int miniBatchSize = -1; + @Builder.Default private File rootDir = new File("minibatches"); + @Builder.Default private File rootSaveDir = new File("minibatchessave"); + @Builder.Default private List labelRootDirs = new ArrayList<>(); private DataNormalization dataNormalization; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/config/AdaMax.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/config/AdaMax.java index 72785aba4..afb795a42 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/config/AdaMax.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/config/AdaMax.java @@ -20,7 +20,8 @@ package org.nd4j.linalg.learning.config; -import lombok.*; +import lombok.Builder; +import lombok.Data; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.AdaMaxUpdater; import org.nd4j.linalg.learning.GradientUpdater; @@ -44,7 +45,8 @@ public class AdaMax implements IUpdater { public static final double DEFAULT_ADAMAX_BETA1_MEAN_DECAY = 0.9; public static final double DEFAULT_ADAMAX_BETA2_VAR_DECAY = 0.999; - @lombok.Builder.Default private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate + @lombok.Builder.Default + private double learningRate = DEFAULT_ADAMAX_LEARNING_RATE; // learning rate private ISchedule learningRateSchedule; @lombok.Builder.Default private double beta1 = DEFAULT_ADAMAX_BETA1_MEAN_DECAY; // gradient moving avg decay rate @lombok.Builder.Default private double beta2 = DEFAULT_ADAMAX_BETA2_VAR_DECAY; // gradient sqrd decay rate diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java index a62ccd21c..067b25370 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java @@ -335,20 +335,6 @@ public class OpProfiler { } } - /** - * Dev-time method. - * - * @return - */ - protected StackAggregator getMixedOrderAggregator() { - // FIXME: remove this method, or make it protected - return mixedOrderAggregator; - } - - public StackAggregator getScalarAggregator() { - return scalarAggregator; - } - protected void updatePairs(String opName, String opClass) { // now we save pairs of ops/classes String cOpNameKey = prevOpName + " -> " + opName; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/build.gradle b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/build.gradle index 32d9a5c2f..683436dab 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/build.gradle +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/build.gradle @@ -25,4 +25,5 @@ dependencies { implementation "org.slf4j:slf4j-api" implementation "org.apache.commons:commons-lang3" + implementation "com.fasterxml.jackson.core:jackson-annotations" } \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java index 480cdf6d3..ab88f2aa5 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.common.config.DL4JClassLoading; @@ -46,6 +47,7 @@ import java.util.List; @Slf4j @Data +@EqualsAndHashCode(callSuper = false) public class TFOpLayerImpl extends AbstractLayer { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMasking.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMasking.java index 4e8fcb0f9..480267fb5 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMasking.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMasking.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; @@ -39,6 +40,7 @@ import java.util.Map; */ @Slf4j @Data +@EqualsAndHashCode(callSuper = false) public class KerasMasking extends KerasLayer { private double maskingValue; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMerge.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMerge.java index 5c1a31ad0..332b74e73 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMerge.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMerge.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.modelimport.keras.layers.core; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; import org.deeplearning4j.nn.conf.graph.MergeVertex; @@ -35,6 +36,7 @@ import java.util.Map; @Slf4j @Data +@EqualsAndHashCode(callSuper = false) public class KerasMerge extends KerasLayer { private final String LAYER_FIELD_MODE = "mode"; diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java index 417154cf2..e15f8b9fc 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java @@ -55,11 +55,13 @@ public class VPTree implements Serializable { private Node root; private String similarityFunction; @Getter + @Builder.Default private boolean invert = false; private transient ExecutorService executorService; @Getter + @Builder.Default private int workers = 1; - private AtomicInteger size = new AtomicInteger(0); + @Builder.Default private AtomicInteger size = new AtomicInteger(0); private transient ThreadLocal scalars = new ThreadLocal<>(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/curves/Histogram.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/curves/Histogram.java index b5f4a5107..74f60a603 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/curves/Histogram.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/curves/Histogram.java @@ -21,11 +21,13 @@ package org.deeplearning4j.eval.curves; import lombok.Data; +import lombok.EqualsAndHashCode; import org.nd4j.evaluation.curves.BaseHistogram; import com.fasterxml.jackson.annotation.JsonProperty; @Deprecated @Data +@EqualsAndHashCode(callSuper = false) public class Histogram extends org.nd4j.evaluation.curves.Histogram { /** diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java index 8d8532e0b..3a2170bc3 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java @@ -76,8 +76,9 @@ import static com.google.common.base.Preconditions.checkArgument; @JsonIgnoreProperties({"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath", "trainingMasterUID"}) @EqualsAndHashCode(exclude = {"stats", "listeners", "iterationCount", "rng", "lastExportedRDDId", "lastRDDExportPath", - "trainingMasterUID"}) + "trainingMasterUID"}, callSuper = false) @Slf4j + public class ParameterAveragingTrainingMaster extends BaseTrainingMaster implements TrainingMaster { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java index c55b3268a..fcd13d516 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java @@ -21,6 +21,7 @@ package org.deeplearning4j.spark.parameterserver.training; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; @@ -100,7 +101,7 @@ import java.util.concurrent.atomic.AtomicInteger; */ @Slf4j @Data - +@EqualsAndHashCode(callSuper = false) public class SharedTrainingMaster extends BaseTrainingMaster implements TrainingMaster { //Static counter/id fields used to determine which training master last set up the singleton param servers, etc diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/main/java/org/nd4j/parameterserver/client/ParameterServerClient.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/main/java/org/nd4j/parameterserver/client/ParameterServerClient.java index b1b2d6abc..72bf20ecf 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/main/java/org/nd4j/parameterserver/client/ParameterServerClient.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/main/java/org/nd4j/parameterserver/client/ParameterServerClient.java @@ -53,16 +53,16 @@ public class ParameterServerClient implements NDArrayCallback { //port to listen on for the subscriber private int subscriberPort; //the stream to listen on for the subscriber - private int subscriberStream = 11; + @Builder.Default private int subscriberStream = 11; //the "current" ndarray private AtomicReference arr; - private INDArray none = Nd4j.scalar(1.0); + @Builder.Default private INDArray none = Nd4j.scalar(1.0); private AtomicBoolean running; private String masterStatusHost; private int masterStatusPort; - private ObjectMapper objectMapper = new ObjectMapper(); + @Builder.Default private ObjectMapper objectMapper = new ObjectMapper(); private Aeron aeron; - private boolean compressArray = true; + @Builder.Default private boolean compressArray = true; /** * Tracks number of diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java index 9751db1a7..2a3475903 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java @@ -47,7 +47,7 @@ public class TextGenerationLSTM extends ZooModel { @Builder.Default private long seed = 1234; @Builder.Default private int maxLength = 40; @Builder.Default private int totalUniqueCharacters = 47; - private int[] inputShape = new int[] {maxLength, totalUniqueCharacters}; + @Builder.Default private int[] inputShape = new int[] {maxLength, totalUniqueCharacters}; @Builder.Default private IUpdater updater = new RmsProp(0.01); @Builder.Default private CacheMode cacheMode = CacheMode.NONE; @Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED; From a7f75fe6db0bfddfeafe5558997dade369ea3821 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 7 Oct 2022 12:50:54 +0200 Subject: [PATCH 010/126] More test fixes Signed-off-by: brian --- .../java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java index 2a3475903..432c74231 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java @@ -47,7 +47,7 @@ public class TextGenerationLSTM extends ZooModel { @Builder.Default private long seed = 1234; @Builder.Default private int maxLength = 40; @Builder.Default private int totalUniqueCharacters = 47; - @Builder.Default private int[] inputShape = new int[] {maxLength, totalUniqueCharacters}; + @Builder.Default private int[] inputShape = new int[] {40, 47}; @Builder.Default private IUpdater updater = new RmsProp(0.01); @Builder.Default private CacheMode cacheMode = CacheMode.NONE; @Builder.Default private WorkspaceMode workspaceMode = WorkspaceMode.ENABLED; From 9660ab026d724ea04347312abb080209b1db3389 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 7 Oct 2022 14:59:54 +0200 Subject: [PATCH 011/126] More test fixes Signed-off-by: brian --- cavis-dnn/cavis-dnn-common/build.gradle | 4 +++- .../test/java/org/deeplearning4j/datasets/TestDataSets.java | 5 ----- .../datasets/fetchers/SvhnDataFetcherTest.java | 5 ----- .../datasets/iterator/DataSetIteratorTest.java | 5 ----- .../deeplearning4j/gradientcheck/AttentionLayerTest.java | 4 ---- .../deeplearning4j/gradientcheck/BNGradientCheckTest.java | 5 ----- .../gradientcheck/CNN1DGradientCheckTest.java | 5 ----- .../gradientcheck/CNN3DGradientCheckTest.java | 5 ----- .../deeplearning4j/gradientcheck/CNNGradientCheckTest.java | 5 ----- .../gradientcheck/CapsnetGradientCheckTest.java | 5 ----- .../deeplearning4j/gradientcheck/DropoutGradientCheck.java | 5 ----- .../gradientcheck/GlobalPoolingGradientCheckTests.java | 5 ----- .../deeplearning4j/gradientcheck/GradientCheckTests.java | 5 ----- .../gradientcheck/GradientCheckTestsComputationGraph.java | 5 ----- .../gradientcheck/GradientCheckTestsMasking.java | 5 ----- .../deeplearning4j/gradientcheck/LRNGradientCheckTests.java | 6 ------ .../gradientcheck/LSTMGradientCheckTests.java | 5 ----- .../gradientcheck/LossFunctionGradientCheck.java | 5 ----- .../gradientcheck/NoBiasGradientCheckTests.java | 5 ----- .../gradientcheck/OutputLayerGradientChecks.java | 5 ----- .../org/deeplearning4j/gradientcheck/RnnGradientChecks.java | 5 ----- .../gradientcheck/UtilLayerGradientChecks.java | 5 ----- .../deeplearning4j/gradientcheck/VaeGradientCheckTests.java | 5 ----- .../gradientcheck/YoloGradientCheckTests.java | 5 ----- .../test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java | 5 ----- .../nn/layers/normalization/BatchNormalizationTest.java | 5 ----- .../org/deeplearning4j/nn/multilayer/MultiLayerTest.java | 2 ++ .../accumulation/EncodedGradientsAccumulatorTest.java | 5 ----- .../optimizer/listener/TestCheckpointListener.java | 5 ----- .../deeplearning4j/optimizer/listener/TestListeners.java | 5 ----- .../deeplearning4j/regressiontest/RegressionTest060.java | 5 ----- .../deeplearning4j/regressiontest/RegressionTest071.java | 5 ----- .../deeplearning4j/regressiontest/RegressionTest080.java | 5 ----- .../deeplearning4j/regressiontest/RegressionTest100a.java | 5 ----- .../deeplearning4j/regressiontest/RegressionTest100b3.java | 5 ----- .../deeplearning4j/regressiontest/RegressionTest100b4.java | 5 ----- .../deeplearning4j/regressiontest/RegressionTest100b6.java | 5 ----- .../regressiontest/TestDistributionDeserializer.java | 5 ----- .../org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java | 5 +++-- 39 files changed, 8 insertions(+), 183 deletions(-) diff --git a/cavis-dnn/cavis-dnn-common/build.gradle b/cavis-dnn/cavis-dnn-common/build.gradle index e48cae638..4630d5ed5 100644 --- a/cavis-dnn/cavis-dnn-common/build.gradle +++ b/cavis-dnn/cavis-dnn-common/build.gradle @@ -16,4 +16,6 @@ dependencies { implementation 'org.apache.commons:commons-math3' implementation 'org.apache.commons:commons-lang3' implementation 'org.apache.commons:commons-compress' -} \ No newline at end of file + + testRuntimeOnly 'net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT' +} diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java index b6978c969..b14583d8c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/TestDataSets.java @@ -27,11 +27,6 @@ import org.junit.jupiter.api.Test; public class TestDataSets extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 180000L; - } - @Test public void testTinyImageNetExists() throws Exception { //Simple sanity check on extracting diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java index f85c1fdf6..a0b73489a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/fetchers/SvhnDataFetcherTest.java @@ -35,11 +35,6 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue; */ public class SvhnDataFetcherTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 480_000_000L; //Shouldn't take this long but slow download or drive access on CI machines may need extra time. - } - @Test public void testSvhnDataFetcher() throws Exception { assumeTrue(isIntegrationTests()); //Ignore unless integration tests - CI can get caught up on slow disk access diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java index d4d0e28a1..138298e89 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java @@ -59,11 +59,6 @@ import static org.junit.jupiter.api.Assertions.*; public class DataSetIteratorTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 360000; //Should run quickly; increased to large timeout due to occasonal slow CI downloads - } - @Test public void testBatchSizeOfOneIris() throws Exception { //Test for (a) iterators returning correct number of examples, and diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java index f39be0929..739168b31 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java @@ -51,10 +51,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; ////@Ignore public class AttentionLayerTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } @Test public void testSelfAttentionLayer() { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index 6eb8c4e25..3d945b27e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -62,11 +62,6 @@ public class BNGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testGradient2dSimple() { DataNormalization scaler = new NormalizerMinMaxScaler(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index cdd11b6f9..094034320 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -62,11 +62,6 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 180000; - } - @Test public void testCnn1DWithLocallyConnected1D() { Nd4j.getRandom().setSeed(1337); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index f7a9375f8..2c8f4dead 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -59,11 +59,6 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testCnn3DPlain() { Nd4j.getRandom().setSeed(1337); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index b3a9e1020..3772741d5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -73,11 +73,6 @@ public class CNNGradientCheckTest extends BaseDL4JTest { return CNN2DFormat.values(); } - @Override - public long getTimeoutMilliseconds() { - return 999990000L; - } - @Test public void testGradientCNNMLN() { if(this.format != CNN2DFormat.NCHW) //Only test NCHW due to flat input format... diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java index acdd9be27..c0a6cad8e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java @@ -49,11 +49,6 @@ import java.util.Random; ////@Ignore public class CapsnetGradientCheckTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testCapsNet() { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java index 7ca1064b3..193ede7ac 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java @@ -59,11 +59,6 @@ public class DropoutGradientCheck extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testDropoutGradient() { int minibatch = 3; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java index 214cb895e..f4b9d4dc5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java @@ -55,11 +55,6 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { private static final double DEFAULT_MAX_REL_ERROR = 1e-3; private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testRNNGlobalPoolingBasicMultiLayer() { //Basic test of global pooling w/ LSTM diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index 17c43c5c0..cab80a69a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -70,11 +70,6 @@ public class GradientCheckTests extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testMinibatchApplication() { IrisDataSetIterator iter = new IrisDataSetIterator(30, 150); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index bd18f698b..be641898e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -71,11 +71,6 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 999999999L; - } - @Test public void testBasicIris() { Nd4j.getRandom().setSeed(12345); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index b5822bc3d..a444e1146 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -59,11 +59,6 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - private static class GradientCheckSimpleScenario { private final ILossFunction lf; private final Activation act; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java index 3ab2efd59..ad1b564db 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java @@ -54,12 +54,6 @@ public class LRNGradientCheckTests extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - - @Test public void testGradientLRNSimple() { Nd4j.getRandom().setSeed(12345); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java index 00fef6150..452742f10 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java @@ -55,11 +55,6 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testLSTMBasicMultiLayer() { //Basic test of GravesLSTM layer diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index bc85841e3..fe4c1eb3b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -73,11 +73,6 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { private static final double DEFAULT_MAX_REL_ERROR = 1e-5; private static final double DEFAULT_MIN_ABS_ERROR = 1e-8; - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void lossFunctionGradientCheck() { ILossFunction[] lossFunctions = new ILossFunction[] {new LossBinaryXENT(), new LossBinaryXENT(), diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java index c9e65579b..8acbf157e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java @@ -52,11 +52,6 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testGradientNoBiasDenseOutput() { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 12a1340e2..f11daf9ec 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -52,11 +52,6 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testRnnLossLayer() { Nd4j.getRandom().setSeed(12345L); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index d1cbd5955..4555904ca 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -55,11 +55,6 @@ public class RnnGradientChecks extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test ////@Ignore("AB 2019/06/24 - Ignored to get to all passing baseline to prevent regressions via CI - see issue #7912") public void testBidirectionalWrapper() { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java index 25d594d9a..670987c78 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java @@ -56,11 +56,6 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testMaskLayer() { Nd4j.getRandom().setSeed(12345); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java index ec9fdab25..92ddf8622 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java @@ -57,11 +57,6 @@ public class VaeGradientCheckTests extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testVaeAsMLP() { //Post pre-training: a VAE can be used as a MLP, by taking the mean value from p(z|x) as the output diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 85e513076..105fcb284 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -72,11 +72,6 @@ public class YoloGradientCheckTests extends BaseDL4JTest { @TempDir public File testDir; - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testYoloOutputLayer() { int depthIn = 2; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index 039c3e4e6..b3e625849 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -186,11 +186,6 @@ public class DTypeTests extends BaseDL4JTest { TensorFlowCnnToFeedForwardPreProcessor.class //Deprecated )); - @Override - public long getTimeoutMilliseconds() { - return 9999999L; - } - @AfterAll public static void after() { ImmutableSet info; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java index e50868b7a..10ca617fe 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java @@ -93,11 +93,6 @@ public class BatchNormalizationTest extends BaseDL4JTest { public void doBefore() { } - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testDnnForwardPass() { int nOut = 10; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index e10f3180b..c8e758feb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.multilayer; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; @@ -1424,6 +1425,7 @@ public class MultiLayerTest extends BaseDL4JTest { } @Data + @EqualsAndHashCode(callSuper = false) public static class CheckModelsListener extends BaseTrainingListener { private Set> modelClasses = new HashSet<>(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java index 44ce85710..23347d950 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java @@ -39,11 +39,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 1200000L; - } - /** * This test ensures, that memory amount assigned to buffer is enough for any number of updates * @throws Exception diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java index 9b94b1b2c..4c3760d95 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java @@ -47,11 +47,6 @@ import static org.junit.jupiter.api.Assertions.*; public class TestCheckpointListener extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @TempDir public File tempDir; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java index 48e610dfb..47430c8c3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java @@ -67,11 +67,6 @@ public class TestListeners extends BaseDL4JTest { @TempDir public File tempDir; - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Test public void testSettingListenersUnsupervised() { //Pretrain layers should get copies of the listeners, in addition to the diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index da6976b6a..985f347d8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -60,11 +60,6 @@ public class RegressionTest060 extends BaseDL4JTest { return DataType.FLOAT; } - @Override - public long getTimeoutMilliseconds() { - return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections - } - @Test public void regressionTestMLP1() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index e2ef4b233..2a75e7994 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -61,11 +61,6 @@ public class RegressionTest071 extends BaseDL4JTest { return DataType.FLOAT; } - @Override - public long getTimeoutMilliseconds() { - return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections - } - @Test public void regressionTestMLP1() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index b2af73f06..6566f03fe 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -60,11 +60,6 @@ public class RegressionTest080 extends BaseDL4JTest { return DataType.FLOAT; } - @Override - public long getTimeoutMilliseconds() { - return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections - } - @Test public void regressionTestMLP1() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index d2b20bea3..acee54871 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -57,11 +57,6 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j public class RegressionTest100a extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections - } - @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index 8cca8472e..8df2f258b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -54,11 +54,6 @@ import static org.junit.jupiter.api.Assertions.*; public class RegressionTest100b3 extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections - } - @Override public DataType getDataType(){ return DataType.FLOAT; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index 71c928d84..5b4270a4e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -73,11 +73,6 @@ import org.nd4j.common.resources.Resources; public class RegressionTest100b4 extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections - } - @Override public DataType getDataType() { return DataType.FLOAT; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index cbf45e56d..40df45924 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -60,11 +60,6 @@ public class RegressionTest100b6 extends BaseDL4JTest { return DataType.FLOAT; } - @Override - public long getTimeoutMilliseconds() { - return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections - } - @Test public void testCustomLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java index 9d66e9b5a..8ec311167 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/TestDistributionDeserializer.java @@ -31,11 +31,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class TestDistributionDeserializer extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 180000L; //Most tests should be fast, but slow download may cause timeout on slow connections - } - @Test public void testDistributionDeserializer() throws Exception { //Test current format: diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index 7de40dbdb..668b4e25b 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.cpu.nativecpu; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.common.base.Preconditions; @@ -578,8 +579,8 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { * @return the concatenate ndarrays */ @Override - public INDArray concat(int dimension, INDArray... toConcat) { - if (toConcat == null || toConcat.length == 0) + public INDArray concat(int dimension, @NonNull INDArray... toConcat) { + if (toConcat.length == 0) throw new ND4JIllegalStateException("Can't concatenate 0 arrays"); if (toConcat.length == 1) From 6fd0702ea528d7d5ddfec1ffa81c737455638812 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 7 Oct 2022 15:04:30 +0200 Subject: [PATCH 012/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-docker-cpu-build.jenkinsfile | 2 +- .jenkins/linux-x86_64-docker-cuda-build.jenkinsfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile b/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile index 64cfec3cc..5553f8014 100644 --- a/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile @@ -24,7 +24,7 @@ pipeline { dockerfile { filename 'Dockerfile' dir '.docker' - label 'linuxdocker' + label 'linux && docker' //additionalBuildArgs '--build-arg version=1.0.2' //args '--gpus all' } diff --git a/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile b/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile index 940863d71..1ba9af2da 100644 --- a/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile @@ -24,7 +24,7 @@ pipeline { dockerfile { filename 'Dockerfile' dir '.docker' - label 'linuxdocker-cuda' + label 'linux && docker && cuda' //additionalBuildArgs '--build-arg version=1.0.2' args '--gpus all' } From 1aa8a1fbf5d30791e123d5dda33906fdcc0bfc2b Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 7 Oct 2022 15:34:23 +0200 Subject: [PATCH 013/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-cpu-build.jenkinsfile | 82 ++++++ cavis-native/cavis-native-lib/build.gradle | 2 +- gradle/wrapper/gradle-wrapper.properties | 6 +- gradlew | 269 ++++++++++++-------- 4 files changed, 243 insertions(+), 116 deletions(-) create mode 100644 .jenkins/linux-x86_64-cpu-build.jenkinsfile diff --git a/.jenkins/linux-x86_64-cpu-build.jenkinsfile b/.jenkins/linux-x86_64-cpu-build.jenkinsfile new file mode 100644 index 000000000..05da2153d --- /dev/null +++ b/.jenkins/linux-x86_64-cpu-build.jenkinsfile @@ -0,0 +1,82 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +pipeline { + agent { + label 'linux' + } + + stages { + stage('prep-build-environment-linux-cpu') { + steps { + checkout scm + sh 'gcc --version' + sh 'cmake --version' + sh 'sh ./gradlew --version' + } + } + stage('build-linux-cpu') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cpu \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' + } + } + stage('test-linux-cpu') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' + } + } + stage('publish-linux-cpu') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew publish --stacktrace -PCAVIS_CHIP=cpu \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' + } + } + } +} diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 043b8b8d8..a9c34e3e2 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -65,7 +65,7 @@ buildscript { plugins { id 'java-library' - id 'org.bytedeco.gradle-javacpp-build' version "1.5.6" + id 'org.bytedeco.gradle-javacpp-build' version "1.5.7" id 'maven-publish' id 'signing' } diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 49fc93b14..ae04661ee 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,9 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-7.2-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.5.1-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists - -# Specifies the JVM arguments used for the daemon process. -# The setting is particularly useful for tweaking memory settings. -org.gradle.jvmargs=-Xmx8128m diff --git a/gradlew b/gradlew index 744e882ed..1b6c78733 100644 --- a/gradlew +++ b/gradlew @@ -1,7 +1,7 @@ -#!/usr/bin/env sh +#!/bin/sh # -# Copyright 2015 the original author or authors. +# Copyright © 2015-2021 the original authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,67 +17,101 @@ # ############################################################################## -## -## Gradle start up script for UN*X -## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/master/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# ############################################################################## # Attempt to set APP_HOME + # Resolve links: $0 may be a link -PRG="$0" -# Need this for relative symlinks. -while [ -h "$PRG" ] ; do - ls=`ls -ld "$PRG"` - link=`expr "$ls" : '.*-> \(.*\)$'` - if expr "$link" : '/.*' > /dev/null; then - PRG="$link" - else - PRG=`dirname "$PRG"`"/$link" - fi +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac done -SAVED="`pwd`" -cd "`dirname \"$PRG\"`/" >/dev/null -APP_HOME="`pwd -P`" -cd "$SAVED" >/dev/null + +APP_HOME=$( cd "${APP_HOME:-./}" && pwd -P ) || exit APP_NAME="Gradle" -APP_BASE_NAME=`basename "$0"` +APP_BASE_NAME=${0##*/} # Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' # Use the maximum available, or set MAX_FD != -1 to use that value. -MAX_FD="maximum" +MAX_FD=maximum warn () { echo "$*" -} +} >&2 die () { echo echo "$*" echo exit 1 -} +} >&2 # OS specific support (must be 'true' or 'false'). cygwin=false msys=false darwin=false nonstop=false -case "`uname`" in - CYGWIN* ) - cygwin=true - ;; - Darwin* ) - darwin=true - ;; - MSYS* | MINGW* ) - msys=true - ;; - NONSTOP* ) - nonstop=true - ;; +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; esac CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar @@ -87,9 +121,9 @@ CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar if [ -n "$JAVA_HOME" ] ; then if [ -x "$JAVA_HOME/jre/sh/java" ] ; then # IBM's JDK on AIX uses strange locations for the executables - JAVACMD="$JAVA_HOME/jre/sh/java" + JAVACMD=$JAVA_HOME/jre/sh/java else - JAVACMD="$JAVA_HOME/bin/java" + JAVACMD=$JAVA_HOME/bin/java fi if [ ! -x "$JAVACMD" ] ; then die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME @@ -98,7 +132,7 @@ Please set the JAVA_HOME variable in your environment to match the location of your Java installation." fi else - JAVACMD="java" + JAVACMD=java which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. Please set the JAVA_HOME variable in your environment to match the @@ -106,80 +140,95 @@ location of your Java installation." fi # Increase the maximum file descriptors if we can. -if [ "$cygwin" = "false" -a "$darwin" = "false" -a "$nonstop" = "false" ] ; then - MAX_FD_LIMIT=`ulimit -H -n` - if [ $? -eq 0 ] ; then - if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then - MAX_FD="$MAX_FD_LIMIT" - fi - ulimit -n $MAX_FD - if [ $? -ne 0 ] ; then - warn "Could not set maximum file descriptor limit: $MAX_FD" - fi - else - warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" - fi -fi - -# For Darwin, add options to specify how the application appears in the dock -if $darwin; then - GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" -fi - -# For Cygwin or MSYS, switch paths to Windows format before running java -if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then - APP_HOME=`cygpath --path --mixed "$APP_HOME"` - CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` - - JAVACMD=`cygpath --unix "$JAVACMD"` - - # We build the pattern for arguments to be converted via cygpath - ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` - SEP="" - for dir in $ROOTDIRSRAW ; do - ROOTDIRS="$ROOTDIRS$SEP$dir" - SEP="|" - done - OURCYGPATTERN="(^($ROOTDIRS))" - # Add a user-defined pattern to the cygpath arguments - if [ "$GRADLE_CYGPATTERN" != "" ] ; then - OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" - fi - # Now convert the arguments - kludge to limit ourselves to /bin/sh - i=0 - for arg in "$@" ; do - CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` - CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option - - if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition - eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` - else - eval `echo args$i`="\"$arg\"" - fi - i=`expr $i + 1` - done - case $i in - 0) set -- ;; - 1) set -- "$args0" ;; - 2) set -- "$args0" "$args1" ;; - 3) set -- "$args0" "$args1" "$args2" ;; - 4) set -- "$args0" "$args1" "$args2" "$args3" ;; - 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" esac fi -# Escape application args -save () { - for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done - echo " " -} -APP_ARGS=`save "$@"` +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. -# Collect all arguments for the java command, following the shell quoting and substitution rules -eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + +# Collect all arguments for the java command; +# * $DEFAULT_JVM_OPTS, $JAVA_OPTS, and $GRADLE_OPTS can contain fragments of +# shell script including quotes and variable substitutions, so put them in +# double quotes to make sure that they get re-expanded; and +# * put everything else in single quotes, so that it's not re-expanded. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' exec "$JAVACMD" "$@" From 2da0a947507d26eac911956f1184e26a80231ce6 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 7 Oct 2022 16:23:49 +0200 Subject: [PATCH 014/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-cpu-build.jenkinsfile | 2 +- .../org/datavec/image/transform/ColorConversionTransform.java | 2 ++ .../java/org/datavec/image/transform/CropImageTransform.java | 2 ++ .../java/org/datavec/image/transform/EqualizeHistTransform.java | 2 ++ .../java/org/datavec/image/transform/FilterImageTransform.java | 2 ++ .../java/org/datavec/image/transform/ResizeImageTransform.java | 2 ++ .../java/org/datavec/image/transform/RotateImageTransform.java | 2 ++ .../java/org/datavec/image/transform/ScaleImageTransform.java | 2 ++ .../java/org/datavec/image/transform/ShowImageTransform.java | 2 ++ .../java/org/datavec/image/transform/WarpImageTransform.java | 2 ++ .../src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java | 1 + .../src/main/java/org/deeplearning4j/BaseDL4JTest.java | 2 ++ .../test/java/org/deeplearning4j/datasets/MnistFetcherTest.java | 2 ++ 13 files changed, 24 insertions(+), 1 deletion(-) diff --git a/.jenkins/linux-x86_64-cpu-build.jenkinsfile b/.jenkins/linux-x86_64-cpu-build.jenkinsfile index 05da2153d..d57d033d4 100644 --- a/.jenkins/linux-x86_64-cpu-build.jenkinsfile +++ b/.jenkins/linux-x86_64-cpu-build.jenkinsfile @@ -56,7 +56,7 @@ pipeline { steps { withGradle { - sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \ + sh 'sh ./gradlew test --stacktrace -PexcludeTests=\'long-running,performance\' -PCAVIS_CHIP=cpu \ -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ColorConversionTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ColorConversionTransform.java index fe83aef22..b05ead826 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ColorConversionTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ColorConversionTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; import com.fasterxml.jackson.annotation.JsonInclude; @@ -33,6 +34,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class ColorConversionTransform extends BaseImageTransform { /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/CropImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/CropImageTransform.java index 5fe2a3a3b..14ceeeefe 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/CropImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/CropImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; import com.fasterxml.jackson.annotation.JsonInclude; @@ -32,6 +33,7 @@ import org.bytedeco.opencv.opencv_core.*; @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class CropImageTransform extends BaseImageTransform { private int cropTop; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/EqualizeHistTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/EqualizeHistTransform.java index 3704d97d8..a7db09477 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/EqualizeHistTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/EqualizeHistTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; @@ -36,6 +37,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; @JsonIgnoreProperties({"splitChannels", "converter"}) @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class EqualizeHistTransform extends BaseImageTransform { /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/FilterImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/FilterImageTransform.java index 1120bcef6..3f1c4c493 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/FilterImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/FilterImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.FFmpegFrameFilter; import org.bytedeco.javacv.FrameFilter; import org.datavec.image.data.ImageWritable; @@ -35,6 +36,7 @@ import static org.bytedeco.ffmpeg.global.avutil.*; @JsonIgnoreProperties({"filter", "converter"}) @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class FilterImageTransform extends BaseImageTransform { private FFmpegFrameFilter filter; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ResizeImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ResizeImageTransform.java index 565bd0d32..6e28b9e05 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ResizeImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ResizeImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; import com.fasterxml.jackson.annotation.JsonInclude; @@ -34,6 +35,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class ResizeImageTransform extends BaseImageTransform { private int newHeight; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/RotateImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/RotateImageTransform.java index 8be9359ed..d4d55d777 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/RotateImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/RotateImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import lombok.experimental.Accessors; @@ -43,6 +44,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; @JsonIgnoreProperties({"interMode", "borderMode", "borderValue", "converter"}) @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class RotateImageTransform extends BaseImageTransform { private float centerx; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ScaleImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ScaleImageTransform.java index c2c70a874..040c27772 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ScaleImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ScaleImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; import com.fasterxml.jackson.annotation.JsonInclude; @@ -34,6 +35,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class ScaleImageTransform extends BaseImageTransform { private float dx; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ShowImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ShowImageTransform.java index 272e2768d..6bc19cba2 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ShowImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ShowImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.CanvasFrame; import org.bytedeco.javacv.Frame; import org.datavec.image.data.ImageWritable; @@ -29,6 +30,7 @@ import javax.swing.*; import java.util.Random; @Data +@EqualsAndHashCode(callSuper = false) public class ShowImageTransform extends BaseImageTransform { CanvasFrame canvas; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/WarpImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/WarpImageTransform.java index f55369a6e..aeed3ce97 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/WarpImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/WarpImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import lombok.experimental.Accessors; @@ -43,6 +44,7 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; @JsonIgnoreProperties({"interMode", "borderMode", "borderValue", "converter"}) @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class WarpImageTransform extends BaseImageTransform { private float[] deltas; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 36e6982a0..14b736898 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -52,6 +52,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { private String opName; @Builder.Default protected List inputArguments = new ArrayList<>(); + @Builder.Default protected List outputArguments = new ArrayList<>(); diff --git a/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index 0dcd1fe08..cfaae7561 100644 --- a/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -26,6 +26,7 @@ import org.bytedeco.javacpp.Pointer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assumptions; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.base.Preconditions; import org.nd4j.common.config.ND4JSystemProperties; import org.nd4j.linalg.api.buffer.DataType; @@ -42,6 +43,7 @@ import java.util.Map; import java.util.Properties; @Slf4j +@Timeout(60*10) public abstract class BaseDL4JTest { protected long startTime; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java index f45a76fe7..c63e4ac7d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java @@ -27,6 +27,7 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; @@ -44,6 +45,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @org.junit.jupiter.api.Timeout(300) +@TestInstance(TestInstance.Lifecycle.PER_CLASS) public class MnistFetcherTest extends BaseDL4JTest { @TempDir From 098fcf48701f0ed56770199f23a0d6c9dab4fecf Mon Sep 17 00:00:00 2001 From: brian Date: Sat, 8 Oct 2022 13:40:37 +0200 Subject: [PATCH 015/126] More test fixes Signed-off-by: brian --- .../solver/accumulation/SmartFancyBlockingQueueTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java index 5d713ca59..78dbb6d14 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java @@ -39,7 +39,7 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j ////@Ignore("AB 2019/05/21 - Failing (stuck, causing timeouts) - Issue #7657") -@Timeout(120000L) +//@Timeout(120000L) public class SmartFancyBlockingQueueTest extends BaseDL4JTest { @Test From 6cb5d30284ee984debcac095a3df5541783516ae Mon Sep 17 00:00:00 2001 From: brian Date: Sun, 9 Oct 2022 09:16:03 +0200 Subject: [PATCH 016/126] More test fixes Signed-off-by: brian --- .../cavis-datavec-data-geo/build.gradle | 1 + .../org/datavec/image/loader/LFWLoader.java | 2 +- .../image/loader/NativeImageLoader.java | 42 ++++++++++++++++--- .../image/transform/BoxImageTransform.java | 2 + .../image/transform/FlipImageTransform.java | 2 + .../image/transform/MultiImageTransform.java | 2 + .../image/loader/TestNativeImageLoader.java | 36 ---------------- 7 files changed, 44 insertions(+), 43 deletions(-) diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/build.gradle b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/build.gradle index 49822b112..36788ea03 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/build.gradle +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/build.gradle @@ -28,4 +28,5 @@ dependencies { implementation "commons-io:commons-io" testImplementation projects.cavisNd4j.cavisNd4jCommonTests + testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT" } \ No newline at end of file diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java index 5f78f1611..dc75e7e1c 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java @@ -151,7 +151,7 @@ public class LFWLoader extends BaseImageLoader implements Serializable { } FileSplit fileSplit = new FileSplit(fullDir, ALLOWED_FORMATS, rng); BalancedPathFilter pathFilter = new BalancedPathFilter(rng, ALLOWED_FORMATS, labelGenerator, numExamples, - numLabels, 0, batchSize, null); + numLabels, 0, batchSize, (String) null); inputSplit = fileSplit.sample(pathFilter, numExamples * splitTrainTest, numExamples * (1 - splitTrainTest)); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index 4db72decd..303ca742f 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -256,8 +256,27 @@ public class NativeImageLoader extends BaseImageLoader { @Override public INDArray asMatrix(File f, boolean nchw) throws IOException { - try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - return asMatrix(bis, nchw); + Mat mat = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR ); + INDArray a; + if (this.multiPageMode != null) { + a = asMatrix(mat.data(), mat.cols()); + }else{ + // Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); + if (mat == null || mat.empty()) { + PIX pix = pixReadMem(mat.data(), mat.cols()); + if (pix == null) { + throw new IOException("Could not decode image from input stream"); + } + mat = convert(pix); + pixDestroy(pix); + } + a = asMatrix(mat); + mat.deallocate(); + } + if(nchw) { + return a; + } else { + return a.permute(0, 2, 3, 1); //NCHW to NHWC } } @@ -268,6 +287,8 @@ public class NativeImageLoader extends BaseImageLoader { @Override public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { + throw new RuntimeException("not implemented"); + /* Mat mat = streamToMat(inputStream); INDArray a; if (this.multiPageMode != null) { @@ -290,6 +311,8 @@ public class NativeImageLoader extends BaseImageLoader { } else { return a.permute(0, 2, 3, 1); //NCHW to NHWC } + + */ } /** @@ -358,9 +381,13 @@ public class NativeImageLoader extends BaseImageLoader { @Override public Image asImageMatrix(File f, boolean nchw) throws IOException { - try (BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f))) { - return asImageMatrix(bis, nchw); - } + Mat image = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); + INDArray a = asMatrix(image); + if(!nchw) + a = a.permute(0,2,3,1); //NCHW to NHWC + Image i = new Image(a, image.channels(), image.rows(), image.cols()); + image.deallocate(); + return i; } @Override @@ -370,7 +397,8 @@ public class NativeImageLoader extends BaseImageLoader { @Override public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { - Mat mat = streamToMat(inputStream); + throw new RuntimeException("Deprecated. Not implemented."); + /*Mat mat = streamToMat(inputStream); Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); if (image == null || image.empty()) { PIX pix = pixReadMem(mat.data(), mat.cols()); @@ -387,6 +415,8 @@ public class NativeImageLoader extends BaseImageLoader { image.deallocate(); return i; + + */ } /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/BoxImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/BoxImageTransform.java index 76dc8a798..97755d82e 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/BoxImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/BoxImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import lombok.experimental.Accessors; @@ -38,6 +39,7 @@ import org.bytedeco.opencv.opencv_core.*; @JsonIgnoreProperties({"borderValue"}) @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class BoxImageTransform extends BaseImageTransform { private int width; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/FlipImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/FlipImageTransform.java index c1e00ce35..d9b827e22 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/FlipImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/FlipImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.image.data.ImageWritable; import com.fasterxml.jackson.annotation.JsonInclude; @@ -32,6 +33,7 @@ import static org.bytedeco.opencv.global.opencv_core.*; @JsonInclude(JsonInclude.Include.NON_NULL) @Data +@EqualsAndHashCode(callSuper = false) public class FlipImageTransform extends BaseImageTransform { /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/MultiImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/MultiImageTransform.java index 9877d6ffc..72eb0883e 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/MultiImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/MultiImageTransform.java @@ -21,6 +21,7 @@ package org.datavec.image.transform; import lombok.Data; +import lombok.EqualsAndHashCode; import org.datavec.image.data.ImageWritable; import java.util.Random; @@ -28,6 +29,7 @@ import java.util.Random; import org.bytedeco.opencv.opencv_core.*; @Data +@EqualsAndHashCode(callSuper = false) public class MultiImageTransform extends BaseImageTransform { private PipelineImageTransform transform; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index 5bc82a570..9550867bb 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -612,28 +612,6 @@ public class TestNativeImageLoader { NativeImageLoader il = new NativeImageLoader(32, 32, 3); - //asMatrix(File, boolean) - INDArray a_nchw = il.asMatrix(f); - INDArray a_nchw2 = il.asMatrix(f, true); - INDArray a_nhwc = il.asMatrix(f, false); - - assertEquals(a_nchw, a_nchw2); - assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); - - - //asMatrix(InputStream, boolean) - try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ - a_nchw = il.asMatrix(is); - } - try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ - a_nchw2 = il.asMatrix(is, true); - } - try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ - a_nhwc = il.asMatrix(is, false); - } - assertEquals(a_nchw, a_nchw2); - assertEquals(a_nchw, a_nhwc.permute(0,3,1,2)); - //asImageMatrix(File, boolean) Image i_nchw = il.asImageMatrix(f); @@ -642,20 +620,6 @@ public class TestNativeImageLoader { assertEquals(i_nchw.getImage(), i_nchw2.getImage()); assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW - - - //asImageMatrix(InputStream, boolean) - try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ - i_nchw = il.asImageMatrix(is); - } - try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ - i_nchw2 = il.asImageMatrix(is, true); - } - try(InputStream is = new BufferedInputStream(new FileInputStream(f))){ - i_nhwc = il.asImageMatrix(is, false); - } - assertEquals(i_nchw.getImage(), i_nchw2.getImage()); - assertEquals(i_nchw.getImage(), i_nhwc.getImage().permute(0,3,1,2)); //NHWC to NCHW } } From c46e6e4c68ca60311c82126af8fa506c4bfb241e Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 10 Oct 2022 17:01:23 +0200 Subject: [PATCH 017/126] datavec-data-image test fixes Signed-off-by: brian --- build.gradle | 16 +- .../api/records/reader/RecordReader.java | 5 +- .../api/split/CollectionInputSplit.java | 11 +- .../java/org/datavec/api/split/FileSplit.java | 13 ++ .../org/datavec/api/split/InputSplit.java | 5 + .../api/split/InputStreamInputSplit.java | 13 ++ .../datavec/api/split/ListStringSplit.java | 9 + .../api/split/NumberedFileInputSplit.java | 10 + .../api/split/OutputStreamInputSplit.java | 13 ++ .../datavec/api/split/StreamInputSplit.java | 8 + .../org/datavec/api/split/StringSplit.java | 7 + .../org/datavec/api/split/TransformSplit.java | 9 + .../datavec/poi/excel/ExcelRecordWriter.java | 2 - .../loader/AndroidNativeImageLoader.java | 3 - .../datavec/image/loader/BaseImageLoader.java | 27 ++- .../org/datavec/image/loader/ImageLoader.java | 10 - .../image/loader/Java2DNativeImageLoader.java | 3 - .../image/loader/NativeImageLoader.java | 152 +++++---------- .../recordreader/BaseImageRecordReader.java | 5 +- .../ObjectDetectionRecordReader.java | 24 +++ .../org/datavec/image/LabelGeneratorTest.java | 30 +-- .../image/loader/TestNativeImageLoader.java | 179 ++++++++++-------- .../org/nd4j/common/io/ClassPathResource.java | 41 ++-- cavis-native/cavis-native-lib/build.gradle | 1 + 24 files changed, 332 insertions(+), 264 deletions(-) diff --git a/build.gradle b/build.gradle index 9e4b0823d..902d0822a 100644 --- a/build.gradle +++ b/build.gradle @@ -72,12 +72,16 @@ allprojects { Project proj -> testAnnotationProcessor platform(project(":cavis-common-platform")) testImplementation platform(project(":cavis-common-platform")) - compileOnly 'org.projectlombok:lombok' - annotationProcessor 'org.projectlombok:lombok' - testCompileOnly 'org.projectlombok:lombok' - testAnnotationProcessor 'org.projectlombok:lombok' - testImplementation 'org.junit.jupiter:junit-jupiter-engine' - testImplementation 'org.junit.jupiter:junit-jupiter-api' + compileOnly 'org.projectlombok:lombok' + annotationProcessor 'org.projectlombok:lombok' + testCompileOnly 'org.projectlombok:lombok' + testAnnotationProcessor 'org.projectlombok:lombok' + testImplementation 'org.junit.jupiter:junit-jupiter-engine' + testImplementation 'org.junit.jupiter:junit-jupiter-api' + testImplementation 'org.junit.jupiter:junit-jupiter-params' + + implementation "org.slf4j:slf4j-api" + implementation "org.slf4j:slf4j-simple" } test { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java index 84c80a439..a44cece05 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java @@ -28,10 +28,7 @@ import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Writable; -import java.io.Closeable; -import java.io.DataInputStream; -import java.io.IOException; -import java.io.Serializable; +import java.io.*; import java.net.URI; import java.util.Collection; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/CollectionInputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/CollectionInputSplit.java index 918076906..760723e09 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/CollectionInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/CollectionInputSplit.java @@ -79,4 +79,13 @@ public class CollectionInputSplit extends BaseInputSplit { return true; } - } + /** + * Close input/ output streams if any + */ + @Override + public void close() { + + } + + +} diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java index f4239f9fe..eb31846a2 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java @@ -31,6 +31,7 @@ import org.nd4j.common.util.MathUtils; import java.io.*; import java.net.URI; +import java.nio.file.Path; import java.util.*; public class FileSplit extends BaseInputSplit { @@ -59,6 +60,10 @@ public class FileSplit extends BaseInputSplit { this(rootDir, null, true, null, true); } + public FileSplit(Path rootDir) { + this(rootDir.toFile(), null, true, null, true); + } + public FileSplit(File rootDir, Random rng) { this(rootDir, null, true, rng, true); } @@ -214,6 +219,14 @@ public class FileSplit extends BaseInputSplit { return true; } + /** + * Close input/ output streams if any + */ + @Override + public void close() { + + } + public File getRootDir() { return rootDir; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputSplit.java index 86730aa04..df067ac62 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputSplit.java @@ -133,4 +133,9 @@ public interface InputSplit { * may throw an exception */ boolean resetSupported(); + + /** + * Close input/ output streams if any + */ + void close(); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputStreamInputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputStreamInputSplit.java index 7e29d6456..7bd514745 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputStreamInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputStreamInputSplit.java @@ -21,6 +21,7 @@ package org.datavec.api.split; import java.io.File; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.URI; @@ -149,6 +150,18 @@ public class InputStreamInputSplit implements InputSplit { return location != null && location.length > 0; } + /** + * Close input/ output streams if any + */ + @Override + public void close() { + try { + is.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + public InputStream getIs() { return is; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/ListStringSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/ListStringSplit.java index 0a666571c..d979bdad7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/ListStringSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/ListStringSplit.java @@ -20,6 +20,7 @@ package org.datavec.api.split; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.URI; @@ -124,4 +125,12 @@ public class ListStringSplit implements InputSplit { public List> getData() { return data; } + + /** + * Close input/ output streams if any + */ + @Override + public void close() { + + } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java index 039548f2e..c61b1d591 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java @@ -153,6 +153,14 @@ public class NumberedFileInputSplit implements InputSplit { return true; } + /** + * Close input/ output streams if any + */ + @Override + public void close() { + + } + private class NumberedFileIterator implements Iterator { @@ -179,5 +187,7 @@ public class NumberedFileInputSplit implements InputSplit { public void remove() { throw new UnsupportedOperationException(); } + + } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/OutputStreamInputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/OutputStreamInputSplit.java index 44af7930a..b0ce4b8ff 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/OutputStreamInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/OutputStreamInputSplit.java @@ -23,6 +23,7 @@ package org.datavec.api.split; import lombok.Getter; import lombok.Setter; +import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.URI; @@ -115,5 +116,17 @@ public class OutputStreamInputSplit implements InputSplit { return false; } + /** + * Close input/ output streams if any + */ + @Override + public void close() { + try { + outputStream.close(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StreamInputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StreamInputSplit.java index 5d4ba6c2e..1f74e11f5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StreamInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StreamInputSplit.java @@ -143,4 +143,12 @@ public class StreamInputSplit implements InputSplit { public boolean resetSupported() { return true; } + + /** + * Close input/ output streams if any + */ + @Override + public void close() { + + } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StringSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StringSplit.java index 93d9b09e3..8db924475 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StringSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StringSplit.java @@ -107,6 +107,13 @@ public class StringSplit implements InputSplit { return true; } + /** + * Close input/ output streams if any + */ + @Override + public void close() { + + } public String getData() { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/TransformSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/TransformSplit.java index 81789c707..8c1bda71d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/TransformSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/TransformSplit.java @@ -111,6 +111,15 @@ public class TransformSplit extends BaseInputSplit { return true; } + /** + * Close input/ output streams if any + */ + @Override + public void close() { + sourceSplit.close(); + + } + public interface URITransform { URI apply(URI uri) throws URISyntaxException; } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordWriter.java index 0b07c1ae7..33e691c57 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordWriter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordWriter.java @@ -99,8 +99,6 @@ public class ExcelRecordWriter extends FileRecordWriter { partitioner.init(inputSplit); out = new DataOutputStream(partitioner.currentOutputStream()); initPoi(); - - } private void initPoi() { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/AndroidNativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/AndroidNativeImageLoader.java index 33bebec54..cfe557579 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/AndroidNativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/AndroidNativeImageLoader.java @@ -60,9 +60,6 @@ public class AndroidNativeImageLoader extends NativeImageLoader { } public INDArray asMatrix(Bitmap image) throws IOException { - if (converter == null) { - converter = new OpenCVFrameConverter.ToMat(); - } return asMatrix(converter.convert(converter2.convert(image))); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java index ca6bb86af..7c7bab815 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/BaseImageLoader.java @@ -20,6 +20,7 @@ package org.datavec.image.loader; +import lombok.Getter; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.datavec.image.data.Image; @@ -43,7 +44,10 @@ public abstract class BaseImageLoader implements Serializable { } public static final File BASE_DIR = new File(System.getProperty("user.home")); + + @Getter public static final String[] ALLOWED_FORMATS = {"tif", "jpg", "png", "jpeg", "bmp", "JPEG", "JPG", "TIF", "PNG"}; + protected Random rng = new Random(System.currentTimeMillis()); protected long height = -1; @@ -53,16 +57,17 @@ public abstract class BaseImageLoader implements Serializable { protected ImageTransform imageTransform = null; protected MultiPageMode multiPageMode = null; - public String[] getAllowedFormats() { - return ALLOWED_FORMATS; - } - public abstract INDArray asRowVector(File f) throws IOException; public abstract INDArray asRowVector(InputStream inputStream) throws IOException; - /** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format */ - public abstract INDArray asMatrix(File f) throws IOException; + /** As per {@link #asMatrix(File, boolean)} but NCHW/channels_first format. + * Essentially calls asMatrix(File f, true) + * + **/ + public INDArray asMatrix(File f) throws IOException { + return asMatrix( f, true); + } /** * Load an image from a file to an INDArray @@ -73,7 +78,15 @@ public abstract class BaseImageLoader implements Serializable { */ public abstract INDArray asMatrix(File f, boolean nchw) throws IOException; - public abstract INDArray asMatrix(InputStream inputStream) throws IOException; + /** + * Load an image file from an input stream to an INDArray. Essentially calls asMatrix(inputStream, true) + * {@link #asMatrix(InputStream, boolean)} asMatrix + * @param inputStream Input stream to load the image from + * @return Image file stream as as INDArray NCHW/channels_first [1, channels, height width] format + */ + public INDArray asMatrix(InputStream inputStream) throws IOException { + return asMatrix(inputStream, true); + } /** * Load an image file from an input stream to an INDArray * @param inputStream Input stream to load the image from diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java index 32a270943..36749883b 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/ImageLoader.java @@ -257,16 +257,6 @@ public class ImageLoader extends BaseImageLoader { } } - /** - * Convert an input stream to a matrix - * - * @param inputStream the input stream to convert - * @return the input stream to convert - */ - public INDArray asMatrix(InputStream inputStream) throws IOException { - return asMatrix(inputStream, true); - } - @Override public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { INDArray ret; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/Java2DNativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/Java2DNativeImageLoader.java index f4da9a65d..c684ff239 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/Java2DNativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/Java2DNativeImageLoader.java @@ -85,9 +85,6 @@ public class Java2DNativeImageLoader extends NativeImageLoader { * @throws IOException */ public INDArray asMatrix(BufferedImage image, boolean flipChannels) throws IOException { - if (converter == null) { - converter = new OpenCVFrameConverter.ToMat(); - } return asMatrix(converter.convert(converter2.getFrame(image, 1.0, flipChannels))); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index 303ca742f..08a886eb9 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -20,6 +20,8 @@ package org.datavec.image.loader; +import lombok.Getter; +import lombok.NonNull; import org.apache.commons.io.IOUtils; import org.bytedeco.javacpp.*; import org.bytedeco.javacpp.indexer.*; @@ -38,6 +40,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.common.util.ArrayUtil; import java.io.*; +import java.nio.ByteBuffer; import java.nio.ByteOrder; import org.bytedeco.leptonica.*; @@ -49,8 +52,9 @@ import static org.bytedeco.opencv.global.opencv_imgcodecs.*; import static org.bytedeco.opencv.global.opencv_imgproc.*; /** - * Uses JavaCV to load images. Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp - * + * Uses JavaCV (that also wraps OpenCV) to load images. + * Allowed formats: bmp, gif, jpg, jpeg, jp2, pbm, pgm, ppm, pnm, png, tif, tiff, exr, webp + * JavaCV supports a wider range of image formats compared to the {@link ImageLoader} variant. * @author saudet */ public class NativeImageLoader extends BaseImageLoader { @@ -58,12 +62,14 @@ public class NativeImageLoader extends BaseImageLoader { private byte[] buffer = null; private Mat bufferMat = null; + @Getter public static final String[] ALLOWED_FORMATS = {"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm", "png", "tif", "tiff", "exr", "webp", "BMP", "GIF", "JPG", "JPEG", "JP2", "PBM", "PGM", "PPM", "PNM", "PNG", "TIF", "TIFF", "EXR", "WEBP"}; - protected OpenCVFrameConverter.ToMat converter; + protected final OpenCVFrameConverter.ToMat converter; + //Todo: Should be final, but TestNativeImageLoader uses this to simulate for Android boolean direct = !Loader.getPlatform().startsWith("android"); /** @@ -144,17 +150,9 @@ public class NativeImageLoader extends BaseImageLoader { this.imageTransform = other.imageTransform; } - @Override - public String[] getAllowedFormats() { - return ALLOWED_FORMATS; - } - - public INDArray asRowVector(String filename) throws IOException { - return asRowVector(new File(filename)); - } - /** - * Convert a file to a row vector + * Convert a file to a row vector by loading it into an {@link INDArray} and then + * calling flattening {@link INDArray#ravel()}. * * @param f the image to convert * @return the flattened image @@ -164,7 +162,14 @@ public class NativeImageLoader extends BaseImageLoader { public INDArray asRowVector(File f) throws IOException { return asMatrix(f).ravel(); } - + /** + * Convert an input stream containing an image to a row vector by loading it into an {@link INDArray} and then + * calling flattening {@link INDArray#ravel()}. + * + * @param is the image input stream to convert + * @return the flattened image + * @throws IOException + */ @Override public INDArray asRowVector(InputStream is) throws IOException { return asMatrix(is).ravel(); @@ -192,7 +197,15 @@ public class NativeImageLoader extends BaseImageLoader { return arr.reshape('c', 1, arr.length()); } - static Mat convert(PIX pix) { + /** + * Helper method to convert a {@see http://leptonica.org Leptonica PIX} into an OpenCV Matrix. + * Leptonica is a pedagogically-oriented open source library containing software that is + * broadly useful for image processing and image analysis applications. + * @param pix the leptonica image format. + * @return OpenCV Matrix + */ + + static Mat convert(@NonNull PIX pix) { PIX tempPix = null; int dtype = -1; int height = pix.h(); @@ -245,55 +258,23 @@ public class NativeImageLoader extends BaseImageLoader { return mat2; } - public INDArray asMatrix(String filename) throws IOException { - return asMatrix(new File(filename)); - } - @Override - public INDArray asMatrix(File f) throws IOException { + public INDArray asMatrix(@NonNull File f) throws IOException { return asMatrix(f, true); } @Override - public INDArray asMatrix(File f, boolean nchw) throws IOException { - Mat mat = imread(f.getAbsolutePath(), IMREAD_ANYDEPTH | IMREAD_ANYCOLOR ); - INDArray a; - if (this.multiPageMode != null) { - a = asMatrix(mat.data(), mat.cols()); - }else{ - // Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); - if (mat == null || mat.empty()) { - PIX pix = pixReadMem(mat.data(), mat.cols()); - if (pix == null) { - throw new IOException("Could not decode image from input stream"); - } - mat = convert(pix); - pixDestroy(pix); - } - a = asMatrix(mat); - mat.deallocate(); - } - if(nchw) { - return a; - } else { - return a.permute(0, 2, 3, 1); //NCHW to NHWC - } + public INDArray asMatrix(@NonNull File f, boolean nchw) throws IOException { + return asMatrix(new FileInputStream(f), nchw); } @Override - public INDArray asMatrix(InputStream is) throws IOException { - return asMatrix(is, true); - } - - @Override - public INDArray asMatrix(InputStream inputStream, boolean nchw) throws IOException { - throw new RuntimeException("not implemented"); - /* + public INDArray asMatrix(@NonNull InputStream inputStream, boolean nchw) throws IOException { Mat mat = streamToMat(inputStream); INDArray a; if (this.multiPageMode != null) { - a = asMatrix(mat.data(), mat.cols()); - }else{ + a = asMatrix(mat.data(), mat.arraySize()); + } else { Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); if (image == null || image.empty()) { PIX pix = pixReadMem(mat.data(), mat.cols()); @@ -311,8 +292,6 @@ public class NativeImageLoader extends BaseImageLoader { } else { return a.permute(0, 2, 3, 1); //NCHW to NHWC } - - */ } /** @@ -321,53 +300,13 @@ public class NativeImageLoader extends BaseImageLoader { * @return Mat with the buffer data as a row vector * @throws IOException */ - private Mat streamToMat(InputStream is) throws IOException { - if(buffer == null){ - buffer = IOUtils.toByteArray(is); - if(buffer.length == 0){ - throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); - } - bufferMat = new Mat(buffer); - return bufferMat; - } else { - int numReadTotal = is.read(buffer); - //Need to know if all data has been read. - //(a) if numRead < buffer.length - got everything - //(b) if numRead >= buffer.length: we MIGHT have got everything (exact right size buffer) OR we need more data - - if(numReadTotal <= 0){ - throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); - } - - if(numReadTotal < buffer.length){ - bufferMat.data().put(buffer, 0, numReadTotal); - bufferMat.cols(numReadTotal); - return bufferMat; - } - - //Buffer is full; reallocate and keep reading - int numReadCurrent = numReadTotal; - while(numReadCurrent != -1){ - byte[] oldBuffer = buffer; - if(oldBuffer.length == Integer.MAX_VALUE){ - throw new IllegalStateException("Cannot read more than Integer.MAX_VALUE bytes"); - } - //Double buffer, but allocate at least 1MB more - long increase = Math.max(buffer.length, MIN_BUFFER_STEP_SIZE); - int newBufferLength = (int)Math.min(Integer.MAX_VALUE, buffer.length + increase); - - buffer = new byte[newBufferLength]; - System.arraycopy(oldBuffer, 0, buffer, 0, oldBuffer.length); - numReadCurrent = is.read(buffer, oldBuffer.length, buffer.length - oldBuffer.length); - if(numReadCurrent > 0){ - numReadTotal += numReadCurrent; - } - } - - bufferMat = new Mat(buffer); - return bufferMat; + private Mat streamToMat(@NonNull InputStream is) throws IOException { + byte[] bytearray = IOUtils.toByteArray(is); //Todo: can be very large + if(bytearray == null || bytearray.length <= 0 ) { + throw new IOException("Could not decode image from input stream: input stream was empty (no data)"); } - + Mat outputMat = new Mat(bytearray); + return outputMat; } public Image asImageMatrix(String filename) throws IOException { @@ -624,7 +563,6 @@ public class NativeImageLoader extends BaseImageLoader { public INDArray asMatrix(Mat image) throws IOException { INDArray ret = transformImage(image, null); - return ret.reshape(ArrayUtil.combine(new long[] {1}, ret.shape())); } @@ -678,6 +616,7 @@ public class NativeImageLoader extends BaseImageLoader { throw new IOException("Cannot convert from " + image.channels() + " to " + channels + " channels."); } image2 = new Mat(); + if(image.rows() == 0 && image.cols() == 0) throw new RuntimeException("Cannot convert image with source dimensions 0x0"); cvtColor(image, image2, code); image = image2; } @@ -895,9 +834,12 @@ public class NativeImageLoader extends BaseImageLoader { * @return INDArray * @throws IOException */ - private INDArray asMatrix(BytePointer bytes, long length) throws IOException { - PIXA pixa; - pixa = pixaReadMemMultipageTiff(bytes, length); + private INDArray asMatrix(@NonNull BytePointer bytes, long length) throws IOException { + //This is an array of PIX (due to multipage) + PIXA pixa = pixaReadMemMultipageTiff(bytes, length); + if(pixa == null) throw new RuntimeException("Error reading multipage PIX"); + + INDArray data; INDArray currentD; INDArrayIndex[] index = null; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index 079049257..f2644b3ac 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -122,7 +122,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { } protected boolean containsFormat(String format) { - for (String format2 : imageLoader.getAllowedFormats()) + for (String format2 : imageLoader.getALLOWED_FORMATS()) if (format.endsWith("." + format2)) return true; return false; @@ -172,6 +172,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { //remove the root directory FileSplit split1 = (FileSplit) split; labels.remove(split1.getRootDir()); + split1.close(); } //To ensure consistent order for label assignment (irrespective of file iteration order), we want to sort the list of labels @@ -405,7 +406,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { @Override public void close() throws IOException { - //No op + this.inputSplit.close(); } @Override diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java index b35926de6..a8f25d876 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java @@ -272,6 +272,30 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { } } + + public List record(URI uri, File f) throws IOException { + invokeListeners(uri); + if (imageLoader == null) { + imageLoader = new NativeImageLoader(height, width, channels, imageTransform); + } + Image image = this.imageLoader.asImageMatrix(f); + if(!nchw) + image.setImage(image.getImage().permute(0,2,3,1)); + Nd4j.getAffinityManager().ensureLocation(image.getImage(), AffinityManager.Location.DEVICE); + + List ret = RecordConverter.toRecord(image.getImage()); + if (appendLabel) { + List imageObjectsForPath = labelProvider.getImageObjectsForPath(uri.getPath()); + int nClasses = labels.size(); + INDArray outLabel = Nd4j.create(1, 4 + nClasses, gridH, gridW); + label(image, imageObjectsForPath, outLabel, 0); + if(!nchw) + outLabel = outLabel.permute(0,2,3,1); //NCHW to NHWC + ret.add(new NDArrayWritable(outLabel)); + } + return ret; + } + @Override public List record(URI uri, DataInputStream dataInputStream) throws IOException { invokeListeners(uri); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java index e539c7fa8..3e27c84e4 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/LabelGeneratorTest.java @@ -21,14 +21,24 @@ package org.datavec.image; import org.apache.commons.io.FileUtils; +import org.apache.commons.io.IOUtils; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.split.FileSplit; import org.datavec.image.recordreader.ImageRecordReader; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.CleanupMode; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.nd4j.common.io.ClassPathResource; import java.io.File; +import java.io.IOException; +import java.nio.file.FileVisitResult; +import java.nio.file.FileVisitor; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.attribute.BasicFileAttributes; import java.util.Arrays; import java.util.List; @@ -37,32 +47,28 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class LabelGeneratorTest { - @TempDir - public File testDir; - @Test - public void testParentPathLabelGenerator() throws Exception { + @ParameterizedTest + @ValueSource(strings = {"m", "m.", "something"}) + public void testParentPathLabelGenerator(String dirPrefix, @TempDir Path testDir) throws Exception { //https://github.com/deeplearning4j/DataVec/issues/273 File orig = new ClassPathResource("datavec-data-image/testimages/class0/0.jpg").getFile(); - for(String dirPrefix : new String[]{"m.", "m"}) { - File f = testDir; int numDirs = 3; int filesPerDir = 4; for (int i = 0; i < numDirs; i++) { - File currentLabelDir = new File(f, dirPrefix + i); - currentLabelDir.mkdirs(); + File currentLabelDir = new File(testDir.toFile(), dirPrefix + i); for (int j = 0; j < filesPerDir; j++) { File f3 = new File(currentLabelDir, "myImg_" + j + ".jpg"); - FileUtils.copyFile(orig, f3); + FileUtils.copyFile(orig, f3); //will create directories as needed assertTrue(f3.exists()); } } ImageRecordReader rr = new ImageRecordReader(28, 28, 1, new ParentPathLabelGenerator()); - rr.initialize(new FileSplit(f)); + rr.initialize(new FileSplit(testDir)); List labelsAct = rr.getLabels(); List labelsExp = Arrays.asList(dirPrefix + "0", dirPrefix + "1", dirPrefix + "2"); @@ -74,7 +80,9 @@ public class LabelGeneratorTest { rr.next(); actCount++; } + rr.close(); assertEquals(expCount, actCount); - } + + } } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java index 9550867bb..391fcd25d 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestNativeImageLoader.java @@ -28,29 +28,36 @@ import org.bytedeco.javacpp.indexer.UByteIndexer; import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.Java2DFrameConverter; import org.bytedeco.javacv.OpenCVFrameConverter; +import org.bytedeco.leptonica.PIX; +import org.bytedeco.leptonica.PIXCMAP; +import org.bytedeco.opencv.opencv_core.Mat; import org.datavec.image.data.Image; import org.datavec.image.data.ImageWritable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.resources.Resources; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.io.ClassPathResource; import java.awt.image.BufferedImage; -import java.io.*; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; import java.lang.reflect.Field; +import java.nio.file.Path; import java.util.Random; +import java.util.stream.IntStream; +import java.util.stream.Stream; -import org.bytedeco.leptonica.*; -import org.bytedeco.opencv.opencv_core.*; import static org.bytedeco.leptonica.global.lept.*; import static org.bytedeco.opencv.global.opencv_core.*; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assertions.*; /** * @@ -61,73 +68,74 @@ public class TestNativeImageLoader { static final long seed = 10; static final Random rng = new Random(seed); - @TempDir - public File testDir; + @ParameterizedTest(name = "#{index} - Run test with arguments WxHxD {0}, {1}, {2}") + @MethodSource("generateDimensions") + public void testConvertPIX(int width, int height, int depth, int matType) { + PIX pix = pixCreate(width, height, depth); + Mat mat = NativeImageLoader.convert(pix); + + assertEquals(width, mat.cols()); + assertEquals(height, mat.rows()); + if(matType==CV_8UC4) matType= CV_8UC1; //this would be for 8 bit 256 gradients + assertEquals(matType, mat.type()); + } + + /** + * Run PIX creation and conversation test with ColorMap + * @param width + * @param height + * @param depth + * @param matType + */ + @ParameterizedTest(name = "#{index} - Run test with arguments WxHxD {0}, {1}, {2}") + @MethodSource("generateDimensions") + public void testConvertPIXCMAP(int width, int height, int depth, int matType) { + // a GIF file, for example + PIX pix = pixCreate(width, height, depth); + PIXCMAP cmap = pixcmapCreateLinear(depth, 256); + pixSetColormap(pix, cmap); + Mat mat = NativeImageLoader.convert(pix); + assertEquals(width, mat.cols()); + assertEquals(height, mat.rows()); + if( matType == CV_8UC1 && depth >= 8 ) matType = CV_8UC4; //change the argument, as this is depth 8 with 256 shades + assertEquals(matType, mat.type()); + } + + // Static stream of arguments + static Stream generateDimensions() { + return Stream.of( + Arguments.arguments(20, 20 ,1, CV_8UC1), + Arguments.arguments(1, 1, 1, CV_8UC1), + Arguments.arguments(1014, 1080, 1, CV_8UC1), + Arguments.arguments(20, 20 ,2, CV_8UC1), + Arguments.arguments(1, 1, 2, CV_8UC1), + Arguments.arguments(1014, 1080, 2, CV_8UC1), + Arguments.arguments(20, 20 ,4, CV_8UC1), + Arguments.arguments(1, 1, 4, CV_8UC1), + Arguments.arguments(1014, 1080, 4, CV_8UC1), + Arguments.arguments(20, 20 ,8, CV_8UC1), + Arguments.arguments(1, 1, 8, CV_8UC1), + Arguments.arguments(1014, 1080, 16, CV_16UC(1)), + Arguments.arguments(1014, 1080, 24, CV_8UC(3)), + Arguments.arguments(1014, 1080, 32, CV_32FC1), + Arguments.arguments(2048, 4096, 32, CV_32FC1), + Arguments.arguments(300, 300, 8, CV_8UC4) + ); + } @Test public void testConvertPix() throws Exception { - PIX pix; - Mat mat; - - pix = pixCreate(11, 22, 1); - mat = NativeImageLoader.convert(pix); - assertEquals(11, mat.cols()); - assertEquals(22, mat.rows()); - assertEquals(CV_8UC1, mat.type()); - - pix = pixCreate(33, 44, 2); - mat = NativeImageLoader.convert(pix); - assertEquals(33, mat.cols()); - assertEquals(44, mat.rows()); - assertEquals(CV_8UC1, mat.type()); - - pix = pixCreate(55, 66, 4); - mat = NativeImageLoader.convert(pix); - assertEquals(55, mat.cols()); - assertEquals(66, mat.rows()); - assertEquals(CV_8UC1, mat.type()); - - pix = pixCreate(77, 88, 8); - mat = NativeImageLoader.convert(pix); - assertEquals(77, mat.cols()); - assertEquals(88, mat.rows()); - assertEquals(CV_8UC1, mat.type()); - - pix = pixCreate(99, 111, 16); - mat = NativeImageLoader.convert(pix); - assertEquals(99, mat.cols()); - assertEquals(111, mat.rows()); - assertEquals(CV_16UC(1), mat.type()); - - pix = pixCreate(222, 333, 24); - mat = NativeImageLoader.convert(pix); - assertEquals(222, mat.cols()); - assertEquals(333, mat.rows()); - assertEquals(CV_8UC(3), mat.type()); - - pix = pixCreate(444, 555, 32); - mat = NativeImageLoader.convert(pix); - assertEquals(444, mat.cols()); - assertEquals(555, mat.rows()); - assertEquals(CV_32FC1, mat.type()); - - // a GIF file, for example - pix = pixCreate(32, 32, 8); - PIXCMAP cmap = pixcmapCreateLinear(8, 256); - pixSetColormap(pix, cmap); - mat = NativeImageLoader.convert(pix); - assertEquals(32, mat.cols()); - assertEquals(32, mat.rows()); - assertEquals(CV_8UC4, mat.type()); int w4 = 100, h4 = 238, ch4 = 1, pages = 1, depth = 1; String path2MitosisFile = "datavec-data-image/testimages2/mitosis.tif"; NativeImageLoader loader5 = new NativeImageLoader(h4, w4, ch4, NativeImageLoader.MultiPageMode.FIRST); INDArray array6 = null; try { - array6 = loader5.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); + File f = new ClassPathResource(path2MitosisFile).getFile(); + assertTrue(!f.isDirectory() && f.canRead()); + array6 = loader5.asMatrix( f ); } catch (IOException e) { - log.error("",e); + System.out.println(e.getMessage()); fail(); } assertEquals(5, array6.rank()); @@ -158,7 +166,7 @@ public class TestNativeImageLoader { NativeImageLoader loader7 = new NativeImageLoader(h4, w4, ch6, NativeImageLoader.MultiPageMode.MINIBATCH); INDArray array8 = null; try { - array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile().getAbsolutePath()); + array8 = loader7.asMatrix(new ClassPathResource(path2MitosisFile).getFile()); } catch (IOException e) { log.error("",e); } @@ -174,7 +182,7 @@ public class TestNativeImageLoader { NativeImageLoader loader8 = new NativeImageLoader(h5, w5, ch6, NativeImageLoader.MultiPageMode.MINIBATCH); INDArray array9 = null; try { - array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile().getAbsolutePath()); + array9 = loader8.asMatrix(new ClassPathResource(braintiff).getFile()); } catch (IOException e) { log.error("",e); fail(); @@ -481,7 +489,7 @@ public class TestNativeImageLoader { int w1 = 33, h1 = 77, ch1 = 1; NativeImageLoader loader1 = new NativeImageLoader(h1, w1, ch1); - INDArray array1 = loader1.asMatrix(f0); + INDArray array1 = loader1.asMatrix(new File(f0)); assertEquals(4, array1.rank()); assertEquals(1, array1.size(0)); assertEquals(1, array1.size(1)); @@ -565,40 +573,43 @@ public class TestNativeImageLoader { @Test - public void testNativeImageLoaderEmptyStreams() throws Exception { - File dir = testDir; + public void testNativeImageLoaderEmptyStreams(@TempDir Path tempDir) throws Exception { + File dir = tempDir.toFile(); File f = new File(dir, "myFile.jpg"); f.createNewFile(); NativeImageLoader nil = new NativeImageLoader(32, 32, 3); - try(InputStream is = new FileInputStream(f)){ - nil.asMatrix(is); + try { + nil.asMatrix(f); fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg.contains("decode image"), msg); } - try(InputStream is = new FileInputStream(f)){ - nil.asImageMatrix(is); + try { + nil.asImageMatrix(f); + fail("Expected exception"); + } catch (IOException e){ + String msg = e.getMessage(); + assertTrue(msg.contains("decode image"), msg); + } catch (RuntimeException e) { + String msg = e.getMessage(); + assertTrue(msg.contains("Cannot convert image with source dimensions 0x0")); + } + + try{ + nil.asRowVector(f); fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); assertTrue(msg.contains("decode image"), msg); } - try(InputStream is = new FileInputStream(f)){ - nil.asRowVector(is); - fail("Expected exception"); - } catch (IOException e){ - String msg = e.getMessage(); - assertTrue(msg.contains("decode image"), msg); - } - - try(InputStream is = new FileInputStream(f)){ + try{ INDArray arr = Nd4j.create(DataType.FLOAT, 1, 3, 32, 32); - nil.asMatrixView(is, arr); + nil.asMatrixView(f, arr); fail("Expected exception"); } catch (IOException e){ String msg = e.getMessage(); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java index cf3d45944..03fa53767 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java @@ -20,6 +20,8 @@ package org.nd4j.common.io; +import lombok.Getter; +import lombok.NonNull; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.IOUtils; @@ -38,43 +40,32 @@ import java.util.zip.ZipFile; public class ClassPathResource extends AbstractFileResolvingResource { - private final String path; - private ClassLoader classLoader; + @Getter private final String path; + @Getter private final ClassLoader classLoader; private Class clazz; - public ClassPathResource(String path) { - this(path, (ClassLoader) null); + public ClassPathResource(@NonNull String path) { + this(path, ND4JClassLoading.getNd4jClassloader()); } - public ClassPathResource(String path, ClassLoader classLoader) { - Assert.notNull(path, "Path must not be null"); + public ClassPathResource(@NonNull String path, @NonNull ClassLoader classLoader) { String pathToUse = StringUtils.cleanPath(path); if (pathToUse.startsWith("/")) { pathToUse = pathToUse.substring(1); } this.path = pathToUse; - this.classLoader = classLoader != null ? classLoader : ND4JClassLoading.getNd4jClassloader(); - } - - public ClassPathResource(String path, Class clazz) { - Assert.notNull(path, "Path must not be null"); - this.path = StringUtils.cleanPath(path); - this.clazz = clazz; - } - - protected ClassPathResource(String path, ClassLoader classLoader, Class clazz) { - this.path = StringUtils.cleanPath(path); this.classLoader = classLoader; + } + + public ClassPathResource(@NonNull String path, @NonNull Class clazz) { + this(path, clazz.getClassLoader()); this.clazz = clazz; } - public final String getPath() { - return this.path; - } - - public final ClassLoader getClassLoader() { - return this.classLoader != null ? this.classLoader : this.clazz.getClassLoader(); + protected ClassPathResource(@NonNull String path, @NonNull ClassLoader classLoader, @NonNull Class clazz) { + this(path, classLoader); + this.clazz = clazz; } /** @@ -133,14 +124,12 @@ public class ClassPathResource extends AbstractFileResolvingResource { } else { tmpFile = Files.createTempFile(FilenameUtils.getName(path), "tmp").toFile(); } - - tmpFile.deleteOnExit(); - BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(tmpFile)); IOUtils.copy(is, bos); bos.flush(); bos.close(); + is.close(); return tmpFile; } diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index a9c34e3e2..c65b768d3 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -222,6 +222,7 @@ chipList.each { thisChip -> if (project.hasProperty("skip-native") && project.getProperty("skip-native").equals("true")) { enabled = false } + dependsOn "processResources" properties = getBuildPlatform( thisChip, it ) From f8067f8f960c4d2da236b45d42ea2e3a6dafc969 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 10 Oct 2022 23:20:18 +0200 Subject: [PATCH 018/126] datavec-data-image test fixes Signed-off-by: brian --- .../api/records/reader/BaseRecordReader.java | 51 +++++++++++++++++++ .../api/records/reader/RecordReader.java | 2 +- .../reader/impl/ComposableRecordReader.java | 7 +-- .../impl/ConcatenatingRecordReader.java | 2 +- .../collection/CollectionRecordReader.java | 5 -- .../CollectionSequenceRecordReader.java | 5 -- .../collection/ListStringRecordReader.java | 18 ------- .../impl/filebatch/FileBatchRecordReader.java | 2 +- .../FileBatchSequenceRecordReader.java | 2 +- .../TransformProcessRecordReader.java | 2 +- .../TransformProcessSequenceRecordReader.java | 2 +- .../java/org/datavec/api/split/FileSplit.java | 8 +-- .../org/datavec/api/split/InputSplit.java | 6 +-- .../org/datavec/api/split/TransformSplit.java | 3 +- .../impl/CSVSequenceRecordReaderTest.java | 4 +- .../datavec/api/split/InputSplitTests.java | 5 ++ .../PartitionerTests.java | 15 ++---- .../recordreader/BaseAudioRecordReader.java | 5 -- .../image/loader/NativeImageLoader.java | 15 ++++-- .../recordreader/BaseImageRecordReader.java | 4 +- .../org/datavec/image/loader/LoaderTests.java | 2 +- .../TestObjectDetectionRecordReader.java | 16 +++--- .../nd4j/linalg/api/ops/DynamicCustomOp.java | 3 -- .../org/nd4j/common/io/ClassPathResource.java | 6 ++- .../cavis-dnn-data-datasets/build.gradle | 2 + 25 files changed, 104 insertions(+), 88 deletions(-) rename cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/{parittion => partition}/PartitionerTests.java (84%) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java index 1014dfce4..9fd71d227 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java @@ -27,6 +27,7 @@ import org.datavec.api.split.streams.FileStreamCreatorFunction; import org.datavec.api.writable.Writable; import org.nd4j.common.function.Function; +import java.io.Closeable; import java.io.IOException; import java.io.InputStream; import java.net.URI; @@ -84,4 +85,54 @@ public abstract class BaseRecordReader implements RecordReader { public List> next(int num) { throw new UnsupportedOperationException(); } + + /** + * Closes this resource, relinquishing any underlying resources. + * This method is invoked automatically on objects managed by the + * {@code try}-with-resources statement. + * + *

While this interface method is declared to throw {@code + * Exception}, implementers are strongly encouraged to + * declare concrete implementations of the {@code close} method to + * throw more specific exceptions, or to throw no exception at all + * if the close operation cannot fail. + * + *

Cases where the close operation may fail require careful + * attention by implementers. It is strongly advised to relinquish + * the underlying resources and to internally mark the + * resource as closed, prior to throwing the exception. The {@code + * close} method is unlikely to be invoked more than once and so + * this ensures that the resources are released in a timely manner. + * Furthermore it reduces problems that could arise when the resource + * wraps, or is wrapped, by another resource. + * + *

Implementers of this interface are also strongly advised + * to not have the {@code close} method throw {@link + * InterruptedException}. + *

+ * This exception interacts with a thread's interrupted status, + * and runtime misbehavior is likely to occur if an {@code + * InterruptedException} is {@linkplain Throwable#addSuppressed + * suppressed}. + *

+ * More generally, if it would cause problems for an + * exception to be suppressed, the {@code AutoCloseable.close} + * method should not throw it. + * + *

Note that unlike the {@link Closeable#close close} + * method of {@link Closeable}, this {@code close} method + * is not required to be idempotent. In other words, + * calling this {@code close} method more than once may have some + * visible side effect, unlike {@code Closeable.close} which is + * required to have no effect if called more than once. + *

+ * However, implementers of this interface are strongly encouraged + * to make their {@code close} methods idempotent. + * + * @throws Exception if this resource cannot be closed + */ + @Override + public void close() throws Exception { + inputSplit.close(); + } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java index a44cece05..14d1da31d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java @@ -33,7 +33,7 @@ import java.net.URI; import java.util.Collection; import java.util.List; -public interface RecordReader extends Closeable, Serializable, Configurable { +public interface RecordReader extends AutoCloseable, Serializable, Configurable { String NAME_SPACE = RecordReader.class.getName(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java index 52035e7f1..96693a1a3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java @@ -82,9 +82,10 @@ public class ComposableRecordReader extends BaseRecordReader { } @Override - public void close() throws IOException { - for (RecordReader reader : readers) - reader.close(); + public void close() throws Exception { + for (RecordReader reader : readers) { + reader.close(); + } } @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java index 14692114e..e01d93ed1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java @@ -80,7 +80,7 @@ public class ConcatenatingRecordReader extends BaseRecordReader { } @Override - public void close() throws IOException { + public void close() throws Exception { for (RecordReader reader : readers) reader.close(); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java index 760ea75cb..a8e02e2c4 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java @@ -68,11 +68,6 @@ public class CollectionRecordReader extends BaseRecordReader { return records.hasNext(); } - @Override - public void close() throws IOException { - - } - @Override public void setConf(Configuration conf) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java index cf60ba546..c87b541f8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java @@ -72,11 +72,6 @@ public class CollectionSequenceRecordReader extends BaseRecordReader implements return records.hasNext(); } - @Override - public void close() throws IOException { - - } - @Override public void setConf(Configuration conf) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java index 921cef917..7c99ca300 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java @@ -150,24 +150,6 @@ public class ListStringRecordReader extends BaseRecordReader { throw new UnsupportedOperationException("Loading from metadata not yet implemented"); } - /** - * Closes this stream and releases any system resources associated - * with it. If the stream is already closed then invoking this - * method has no effect. - *

- *

As noted in {@link AutoCloseable#close()}, cases where the - * close may fail require careful attention. It is strongly advised - * to relinquish the underlying resources and to internally - * mark the {@code Closeable} as closed, prior to throwing - * the {@code IOException}. - * - * @throws IOException if an I/O error occurs - */ - @Override - public void close() throws IOException { - - } - /** * Set the configuration to be used by this object. * diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchRecordReader.java index 219a99870..b827400d6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchRecordReader.java @@ -152,7 +152,7 @@ public class FileBatchRecordReader implements RecordReader { } @Override - public void close() throws IOException { + public void close() throws Exception { recordReader.close(); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchSequenceRecordReader.java index 20bdb91b4..133089920 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchSequenceRecordReader.java @@ -172,7 +172,7 @@ public class FileBatchSequenceRecordReader implements SequenceRecordReader { } @Override - public void close() throws IOException { + public void close() throws Exception { recordReader.close(); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java index 374b54c45..160b2c134 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java @@ -258,7 +258,7 @@ public class TransformProcessRecordReader implements RecordReader { * @throws IOException if an I/O error occurs */ @Override - public void close() throws IOException { + public void close() throws Exception { recordReader.close(); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java index 48cd6e937..7023e70b4 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java @@ -308,7 +308,7 @@ public class TransformProcessSequenceRecordReader implements SequenceRecordReade * @throws IOException if an I/O error occurs */ @Override - public void close() throws IOException { + public void close() throws Exception { sequenceRecordReader.close(); } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java index eb31846a2..97183f346 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java @@ -20,6 +20,7 @@ package org.datavec.api.split; +import lombok.Getter; import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.filefilter.IOFileFilter; import org.apache.commons.io.filefilter.RegexFileFilter; @@ -36,7 +37,7 @@ import java.util.*; public class FileSplit extends BaseInputSplit { - protected File rootDir; + @Getter protected File rootDir; // Use for Collections, pass in list of file type strings protected String[] allowFormat = null; protected boolean recursive = true; @@ -227,11 +228,6 @@ public class FileSplit extends BaseInputSplit { } - - public File getRootDir() { - return rootDir; - } - private List listFiles(File dir, String[] allowedFormats, boolean recursive) { Preconditions.checkState(dir.isDirectory(), "Argument is not a directory: %s", dir); IOFileFilter filter; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputSplit.java index df067ac62..70d6b620b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputSplit.java @@ -26,7 +26,7 @@ import java.io.OutputStream; import java.net.URI; import java.util.Iterator; -public interface InputSplit { +public interface InputSplit extends AutoCloseable { /** @@ -134,8 +134,4 @@ public interface InputSplit { */ boolean resetSupported(); - /** - * Close input/ output streams if any - */ - void close(); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/TransformSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/TransformSplit.java index 8c1bda71d..efb3f043a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/TransformSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/TransformSplit.java @@ -115,9 +115,8 @@ public class TransformSplit extends BaseInputSplit { * Close input/ output streams if any */ @Override - public void close() { + public void close() throws Exception { sourceSplit.close(); - } public interface URITransform { diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java index 8fdce2165..4032b912c 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java @@ -230,8 +230,10 @@ public class CSVSequenceRecordReaderTest extends BaseND4JTest { } + @Override + public void close() throws Exception { - + } } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java index f7c413d34..43d274151 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java @@ -46,6 +46,11 @@ public class InputSplitTests extends BaseND4JTest { @Test public void testSample() throws URISyntaxException { BaseInputSplit split = new BaseInputSplit() { + @Override + public void close() throws Exception { + + } + { String[] paths = {"label0/group1_img.tif", "label1/group1_img.jpg", "label2/group1_img.png", "label3/group1_img.jpeg", "label4/group1_img.bmp", "label5/group1_img.JPEG", diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/partition/PartitionerTests.java similarity index 84% rename from cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java rename to cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/partition/PartitionerTests.java index f8e2c99f1..71211c9cd 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/partition/PartitionerTests.java @@ -18,19 +18,16 @@ * ***************************************************************************** */ -package org.datavec.api.split.parittion; +package org.datavec.api.split.partition; +import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; -import com.google.common.io.Files; import org.datavec.api.conf.Configuration; import org.datavec.api.split.FileSplit; -import org.datavec.api.split.partition.NumberOfRecordsPartitioner; -import org.datavec.api.split.partition.PartitionMetaData; -import org.datavec.api.split.partition.Partitioner; import org.junit.jupiter.api.Test; -import java.io.File; import java.io.OutputStream; +import java.nio.file.Path; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -38,9 +35,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class PartitionerTests extends BaseND4JTest { @Test - public void testRecordsPerFilePartition() { + public void testRecordsPerFilePartition(@TempDir Path tmpDir) { Partitioner partitioner = new NumberOfRecordsPartitioner(); - File tmpDir = Files.createTempDir(); FileSplit fileSplit = new FileSplit(tmpDir); assertTrue(fileSplit.needsBootstrapForWrite()); fileSplit.bootStrapForWrite(); @@ -49,9 +45,8 @@ public class PartitionerTests extends BaseND4JTest { } @Test - public void testInputAddFile() throws Exception { + public void testInputAddFile(@TempDir Path tmpDir) throws Exception { Partitioner partitioner = new NumberOfRecordsPartitioner(); - File tmpDir = Files.createTempDir(); FileSplit fileSplit = new FileSplit(tmpDir); assertTrue(fileSplit.needsBootstrapForWrite()); fileSplit.bootStrapForWrite(); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java index 82c9bb1ce..a98b0d1d5 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java @@ -154,11 +154,6 @@ public abstract class BaseAudioRecordReader extends BaseRecordReader { } - @Override - public void close() throws IOException { - - } - @Override public void setConf(Configuration conf) { this.conf = conf; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index 08a886eb9..bda972a86 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -336,8 +336,7 @@ public class NativeImageLoader extends BaseImageLoader { @Override public Image asImageMatrix(InputStream inputStream, boolean nchw) throws IOException { - throw new RuntimeException("Deprecated. Not implemented."); - /*Mat mat = streamToMat(inputStream); + Mat mat = streamToMat(inputStream); Mat image = imdecode(mat, IMREAD_ANYDEPTH | IMREAD_ANYCOLOR); if (image == null || image.empty()) { PIX pix = pixReadMem(mat.data(), mat.cols()); @@ -348,14 +347,20 @@ public class NativeImageLoader extends BaseImageLoader { pixDestroy(pix); } INDArray a = asMatrix(image); - if(!nchw) - a = a.permute(0,2,3,1); //NCHW to NHWC + if(!nchw) a = swapNCHWtoNHWC(a); Image i = new Image(a, image.channels(), image.rows(), image.cols()); image.deallocate(); return i; + } - */ + /** + * Change channel order from NCHW to NHWC + * @param a input INDArray + * @return swapped INDArray + */ + private INDArray swapNCHWtoNHWC(INDArray a) { + return a.permute(0,2,3,1); //NCHW to NHWC } /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index f2644b3ac..48502f95d 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -405,8 +405,8 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { } @Override - public void close() throws IOException { - this.inputSplit.close(); + public void close() throws Exception { + inputSplit.close(); } @Override diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java index b9b4cf3a2..2ca05dd55 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java @@ -80,7 +80,7 @@ public class LoaderTests { String subDir = "cifar/cifar-10-batches-bin/data_batch_1.bin"; String path = FilenameUtils.concat(System.getProperty("user.home"), subDir); byte[] fullDataExpected = new byte[3073]; - FileInputStream inExpected = new FileInputStream(new File(path)); + FileInputStream inExpected = new FileInputStream(path); inExpected.read(fullDataExpected); byte[] fullDataActual = new byte[3073]; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java index 81d667f78..483be1090 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java @@ -45,6 +45,7 @@ import org.nd4j.common.io.ClassPathResource; import java.io.File; import java.net.URI; +import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -54,18 +55,13 @@ import static org.junit.jupiter.api.Assertions.*; public class TestObjectDetectionRecordReader { - @TempDir - public File testDir; - @Test - public void test() throws Exception { + public void test(@TempDir Path testDir) throws Exception { for(boolean nchw : new boolean[]{true, false}) { ImageObjectLabelProvider lp = new TestImageObjectDetectionLabelProvider(); - File f = testDir; - new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(f); - - String path = new File(f, "000012.jpg").getParent(); + new ClassPathResource("datavec-data-image/objdetect/").copyDirectory(testDir); + Path path = testDir.resolve("000012.jpg").getParent(); int h = 32; int w = 32; @@ -74,7 +70,7 @@ public class TestObjectDetectionRecordReader { int gH = 10; //Enforce consistent iteration order for tests - URI[] u = new FileSplit(new File(path)).locations(); + URI[] u = new FileSplit(path).locations(); Arrays.sort(u); RecordReader rr = new ObjectDetectionRecordReader(h, w, c, gH, gW, nchw, lp); @@ -154,7 +150,7 @@ public class TestObjectDetectionRecordReader { rr.reset(); Record record = rr.nextRecord(); RecordMetaDataImageURI metadata = (RecordMetaDataImageURI) record.getMetaData(); - assertEquals(new File(path, "000012.jpg"), new File(metadata.getURI())); + assertEquals( path.resolve( "000012.jpg").toFile(), new File(metadata.getURI())); assertEquals(3, metadata.getOrigC()); assertEquals((int) origH[0], metadata.getOrigH()); assertEquals((int) origW[0], metadata.getOrigW()); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 14b736898..304f40b99 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -44,9 +44,7 @@ import java.lang.reflect.Array; import java.util.*; @Slf4j -@Builder @AllArgsConstructor -@EqualsAndHashCode(callSuper = true) public class DynamicCustomOp extends DifferentialFunction implements CustomOp { private String opName; @@ -56,7 +54,6 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { @Builder.Default protected List outputArguments = new ArrayList<>(); - @Builder.Default protected List tArguments = new ArrayList<>(); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java index 03fa53767..7c5687ef9 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java @@ -33,6 +33,7 @@ import java.net.MalformedURLException; import java.net.URISyntaxException; import java.net.URL; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.attribute.FileAttribute; import java.util.Enumeration; import java.util.zip.ZipEntry; @@ -143,7 +144,6 @@ public class ClassPathResource extends AbstractFileResolvingResource { public void copyDirectory(File destination) throws IOException { Preconditions.checkState(destination.exists() && destination.isDirectory(), "Destination directory must exist and be a directory: %s", destination); - URL url = this.getUrl(); if (isJarURL(url)) { @@ -180,6 +180,7 @@ public class ClassPathResource extends AbstractFileResolvingResource { try(BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(extractTo))){ InputStream is = getInputStream(name, clazz, classLoader); IOUtils.copy(is, bos); + is.close(); } } } @@ -209,6 +210,9 @@ public class ClassPathResource extends AbstractFileResolvingResource { } } + public void copyDirectory(Path destination) throws IOException { + copyDirectory(destination.toFile()); + } public boolean exists() { URL url; if (this.clazz != null) { diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/build.gradle b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/build.gradle index 6eabedb4f..814edc3b6 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/build.gradle +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/build.gradle @@ -27,4 +27,6 @@ dependencies { implementation projects.cavisDnn.cavisDnnApi implementation projects.cavisDatavec.cavisDatavecApi implementation "commons-io:commons-io" + implementation "com.fasterxml.jackson.core:jackson-annotations" + implementation "com.fasterxml.jackson.core:jackson-core" } \ No newline at end of file From 6960418295e7f5966e3548d186cd9f37dacd248c Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 11 Oct 2022 07:59:22 +0200 Subject: [PATCH 019/126] More test fixes Signed-off-by: brian --- .docker/Dockerfile | 2 ++ .../configurations/Keras2ModelConfigurationTest.java | 7 ++----- .../keras/configurations/KerasModelImportTest.java | 7 ++----- .../modelimport/keras/e2e/KerasModelEndToEndTest.java | 11 ++--------- .../keras/weights/KerasWeightSettingTests.java | 7 ++----- 5 files changed, 10 insertions(+), 24 deletions(-) diff --git a/.docker/Dockerfile b/.docker/Dockerfile index 6184075d9..2e8e9a472 100644 --- a/.docker/Dockerfile +++ b/.docker/Dockerfile @@ -6,3 +6,5 @@ RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2 tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \ ./bootstrap && make && make install +RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf + diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index 41f0cf0c7..a8eab6be6 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -33,6 +33,7 @@ import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceTo import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -48,6 +49,7 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; @Slf4j +@Timeout(300) public class Keras2ModelConfigurationTest extends BaseDL4JTest { ClassLoader classLoader = getClass().getClassLoader(); @@ -231,11 +233,6 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { runModelConfigTest("modelimport/keras/configs/keras2/simple_add_tf_keras_2.json"); } - @Override - public long getTimeoutMilliseconds() { - return 999999999L; - } - @Test public void embeddingConcatTest() throws Exception { runModelConfigTest("/modelimport/keras/configs/keras2/model_concat_embedding_sequences_tf_keras_2.json"); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java index 41fa9edc5..c45b3c52b 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java @@ -32,6 +32,7 @@ import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurat import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.common.resources.Resources; import org.nd4j.linalg.convolution.Convolution; import org.nd4j.linalg.factory.Nd4j; @@ -45,12 +46,8 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; * Test import of Keras models. */ @Slf4j +@Timeout(300) public class KerasModelImportTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 9999999999999L; - } - @Test public void testH5WithoutTensorflowScope() throws Exception { MultiLayerNetwork model = loadModel("modelimport/keras/tfscope/model.h5"); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 65bac76b4..9b6797c06 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -44,10 +44,7 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; import org.deeplearning4j.nn.transferlearning.TransferLearning; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; @@ -78,6 +75,7 @@ import static org.junit.jupiter.api.Assertions.*; * @author dave@skymind.io, Max Pumperla */ @Slf4j +@Timeout(300) public class KerasModelEndToEndTest extends BaseDL4JTest { private static final String GROUP_ATTR_INPUTS = "inputs"; private static final String GROUP_ATTR_OUTPUTS = "outputs"; @@ -93,11 +91,6 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @TempDir public File testDir; - @Override - public long getTimeoutMilliseconds() { - return 900000000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources - } - @Test public void fileNotFoundEndToEnd() throws Exception { String modelPath = "modelimport/keras/examples/foo/bar.h5"; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java index 9de2cb73a..b40eb37c1 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -45,16 +46,12 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j +@Timeout(300) public class KerasWeightSettingTests extends BaseDL4JTest { @TempDir private File testDir; - @Override - public long getTimeoutMilliseconds() { - return 9999999L; - } - @Test public void testSimpleLayersWithWeights() throws Exception { int[] kerasVersions = new int[]{1, 2}; From 21e7f1c8b829695bd1073a1258f3f0c61fd73f2e Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 11 Oct 2022 09:29:26 +0200 Subject: [PATCH 020/126] More test fixes Signed-off-by: brian --- .../api/io/filters/BalancedPathFilter.java | 8 ++++---- .../common/resources/DL4JResources.java | 5 +++++ .../datasets/MnistFetcherTest.java | 20 ++++--------------- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java index 3a58cc3a7..348b4e0fd 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java @@ -87,14 +87,14 @@ public class BalancedPathFilter extends RandomPathFilter { protected boolean acceptLabel(String name) { if (labels == null || labels.length == 0) { - return true; + return false; } for (String label : labels) { if (name.equals(label)) { - return true; + return false; } } - return false; + return true; } @Override @@ -107,7 +107,7 @@ public class BalancedPathFilter extends RandomPathFilter { URI path = paths[i]; Writable label = labelGenerator.getLabelForPath(path); if (!acceptLabel(label.toString())) { - continue; + continue; //we skip label in case it is null, empty or already in the collection } List pathList = labelPaths.get(label); if (pathList == null) { diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/resources/DL4JResources.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/resources/DL4JResources.java index e8f6ecda0..c73955595 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/resources/DL4JResources.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/resources/DL4JResources.java @@ -27,6 +27,7 @@ import org.nd4j.common.base.Preconditions; import java.io.File; import java.net.MalformedURLException; import java.net.URL; +import java.nio.file.Path; public class DL4JResources { @@ -128,6 +129,10 @@ public class DL4JResources { baseDirectory = f; } + public static void setBaseDirectory(@NonNull Path p) { + setBaseDirectory(p.toFile()); + } + /** * @return The base storage directory for DL4J resources */ diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java index c63e4ac7d..1bc48d33e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/MnistFetcherTest.java @@ -24,10 +24,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.base.MnistFetcher; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.*; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; @@ -37,6 +34,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; import java.io.File; +import java.nio.file.Path; import java.util.HashSet; import java.util.Set; @@ -45,24 +43,14 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; @org.junit.jupiter.api.Timeout(300) -@TestInstance(TestInstance.Lifecycle.PER_CLASS) public class MnistFetcherTest extends BaseDL4JTest { @TempDir - public File testDir; - - @BeforeAll - public void setup() throws Exception { - DL4JResources.setBaseDirectory(testDir); - } - - @AfterAll - public void after() { - DL4JResources.resetBaseDirectoryLocation(); - } + public Path testDir; @Test public void testMnist() throws Exception { + DL4JResources.setBaseDirectory(testDir); DataSetIterator iter = new MnistDataSetIterator(32, 60000, false, true, false, -1); int count = 0; while(iter.hasNext()){ From 011ce913c9c781748c4a0624eed97e7a760b34a0 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 11 Oct 2022 10:16:51 +0200 Subject: [PATCH 021/126] More test fixes Signed-off-by: brian --- .../src/test/java/org/deeplearning4j/TsneTest.java | 7 ++----- .../models/paragraphvectors/ParagraphVectorsTest.java | 5 ----- .../deeplearning4j/models/word2vec/Word2VecTestsSmall.java | 6 +----- .../word2vec/iterator/Word2VecDataSetIteratorTest.java | 7 ++----- 4 files changed, 5 insertions(+), 20 deletions(-) diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/TsneTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/TsneTest.java index 7edb8ea3e..20ba99cbe 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/TsneTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/TsneTest.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.conf.WorkspaceMode; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -42,13 +43,9 @@ import java.util.ArrayList; import java.util.List; @Slf4j +@Timeout(300) public class TsneTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 180000L; - } - @TempDir public File testDir; diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java index cac34901f..84869467b 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java @@ -78,11 +78,6 @@ import static org.junit.jupiter.api.Assertions.*; @Timeout(240) public class ParagraphVectorsTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return isIntegrationTests() ? 600_000 : 240_000; - } - @TempDir public File testDir; diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java index 6d7bfaf63..19681185c 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java @@ -60,14 +60,10 @@ import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j +@Timeout(300) public class Word2VecTestsSmall extends BaseDL4JTest { WordVectors word2vec; - @Override - public long getTimeoutMilliseconds() { - return isIntegrationTests() ? 240000 : 60000; - } - @BeforeEach public void setUp() throws Exception { word2vec = WordVectorSerializer.readWord2VecModel(new ClassPathResource("vec.bin").getFile()); diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java index a72f92211..7d806aafb 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java @@ -34,6 +34,7 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.common.resources.Resources; @@ -46,13 +47,9 @@ import java.util.concurrent.atomic.AtomicInteger; import static org.junit.jupiter.api.Assertions.assertArrayEquals; +@Timeout(300) public class Word2VecDataSetIteratorTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 60000L; - } - /** * Basically all we want from this test - being able to finish without exceptions. */ From a2792424593d571e5bf673f52301087f78996240 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 11 Oct 2022 13:52:52 +0200 Subject: [PATCH 022/126] More test fixes Signed-off-by: brian --- .../org/datavec/api/util/files/URIUtil.java | 4 +- .../cavis-datavec-data-arrow}/build.gradle | 0 .../org/datavec/arrow/ArrowConverter.java | 0 .../arrow/recordreader/ArrowRecord.java | 0 .../arrow/recordreader/ArrowRecordReader.java | 0 .../arrow/recordreader/ArrowRecordWriter.java | 0 .../ArrowWritableRecordBatch.java | 0 .../ArrowWritableRecordTimeSeriesBatch.java | 0 .../org/datavec/arrow/ArrowConverterTest.java | 0 .../arrow/AssertTestsExtendBaseClass.java | 0 .../org/datavec/arrow/RecordMapperTest.java | 0 ...rowWritableRecordTimeSeriesBatchTests.java | 0 .../cavis-datavec-local/build.gradle | 2 +- .../common/config/DL4JSystemProperties.java | 2 + cavis-dnn/cavis-dnn-nn/build.gradle | 1 + .../deeplearning4j/nn/layers/HelperUtils.java | 116 ++++++++++++++++++ .../deeplearning4j/nn/layers/LayerHelper.java | 2 + .../nn/layers/mkldnn/BaseMKLDNNHelper.java | 4 + .../layers/mkldnn/MKLDNNBatchNormHelper.java | 5 + .../nn/layers/mkldnn/MKLDNNLSTMHelper.java | 7 ++ ...KLDNNLocalResponseNormalizationHelper.java | 1 + .../layers/recurrent/BidirectionalLayer.java | 5 + .../nn/layers/HelperUtilsTest.java | 31 +---- settings.gradle | 2 +- 24 files changed, 149 insertions(+), 33 deletions(-) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/build.gradle (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/main/java/org/datavec/arrow/ArrowConverter.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/test/java/org/datavec/arrow/ArrowConverterTest.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/test/java/org/datavec/arrow/RecordMapperTest.java (100%) rename cavis-datavec/{cavis-datavec-arrow => cavis-datavec-data/cavis-datavec-data-arrow}/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java (100%) create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/files/URIUtil.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/files/URIUtil.java index 5469476db..82911ebb4 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/files/URIUtil.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/files/URIUtil.java @@ -20,13 +20,15 @@ package org.datavec.api.util.files; +import lombok.NonNull; + import java.io.File; import java.net.URI; import java.net.URISyntaxException; public class URIUtil { - public static URI fileToURI(File f) { + public static URI fileToURI(@NonNull File f) { try { // manually construct URI (this is faster) String sp = slashify(f.getAbsoluteFile().getPath(), false); diff --git a/cavis-datavec/cavis-datavec-arrow/build.gradle b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/build.gradle similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/build.gradle rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/build.gradle diff --git a/cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java diff --git a/cavis-datavec/cavis-datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java similarity index 100% rename from cavis-datavec/cavis-datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java rename to cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java diff --git a/cavis-datavec/cavis-datavec-local/build.gradle b/cavis-datavec/cavis-datavec-local/build.gradle index 153bf0499..b9fb9c07c 100644 --- a/cavis-datavec/cavis-datavec-local/build.gradle +++ b/cavis-datavec/cavis-datavec-local/build.gradle @@ -23,7 +23,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" dependencies { implementation "com.codepoetics:protonpack:1.15" implementation projects.cavisDatavec.cavisDatavecApi - implementation projects.cavisDatavec.cavisDatavecArrow + implementation projects.cavisDatavec.cavisDatavecData.cavisDatavecDataArrow implementation projects.cavisDnn.cavisDnnApi implementation "com.google.guava:guava" diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JSystemProperties.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JSystemProperties.java index 4a8299eb0..494e07c9a 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JSystemProperties.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JSystemProperties.java @@ -40,6 +40,8 @@ public class DL4JSystemProperties { */ public static final String DL4J_RESOURCES_DIR_PROPERTY = "org.deeplearning4j.resources.directory"; + public static final String DISABLE_HELPER_PROPERTY = "org.deeplearning4j.disablehelperloading"; + public static final String HELPER_DISABLE_DEFAULT_VALUE = "false"; /** * Applicability: Numerous modules, including deeplearning4j-datasets and deeplearning4j-zoo
* Description: Used to set the base URL for hosting of resources such as datasets (like MNIST) and pretrained diff --git a/cavis-dnn/cavis-dnn-nn/build.gradle b/cavis-dnn/cavis-dnn-nn/build.gradle index d9792730a..3ffdbee6a 100644 --- a/cavis-dnn/cavis-dnn-nn/build.gradle +++ b/cavis-dnn/cavis-dnn-nn/build.gradle @@ -18,6 +18,7 @@ * ***************************************************************************** * */ +apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" dependencies { implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java new file mode 100644 index 000000000..eb59a2c5f --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java @@ -0,0 +1,116 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ +package org.deeplearning4j.nn.layers; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.common.config.DL4JClassLoading; +import org.nd4j.linalg.factory.Nd4j; + +import static org.deeplearning4j.common.config.DL4JSystemProperties.DISABLE_HELPER_PROPERTY; +import static org.deeplearning4j.common.config.DL4JSystemProperties.HELPER_DISABLE_DEFAULT_VALUE; + +/** + * Simple meta helper util class for instantiating + * platform specific layer helpers that handle interaction with + * lower level libraries like cudnn and onednn. + * + * @author Adam Gibson + */ +@Slf4j +public class HelperUtils { + + + /** + * Creates a {@link LayerHelper} + * for use with platform specific code. + * @param the actual class type to be returned + * @param cudnnHelperClassName the cudnn class name + * @param oneDnnClassName the one dnn class name + * @param layerHelperSuperClass the layer helper super class + * @param layerName the name of the layer to be created + * @param arguments the arguments to be used in creation of the layer + * @return + */ + public static T createHelper(String cudnnHelperClassName, + String oneDnnClassName, + Class layerHelperSuperClass, + String layerName, + Object... arguments) { + + Boolean disabled = Boolean.parseBoolean(System.getProperty(DISABLE_HELPER_PROPERTY,HELPER_DISABLE_DEFAULT_VALUE)); + if(disabled) { + System.out.println("Disabled helper creation, returning null"); + return null; + } + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + LayerHelper helperRet = null; + if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) { + if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) { + log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName); + helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( + cudnnHelperClassName, + (Class) layerHelperSuperClass, + new Object[]{arguments}); + log.debug("Cudnn helper {} successfully initialized",cudnnHelperClassName); + + } + else { + log.warn("Unable to find class {} using the classloader set for Dl4jClassLoading. Trying to use class loader that loaded the class {} instead.",cudnnHelperClassName,layerHelperSuperClass.getName()); + ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader(); + DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass); + try { + helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( + cudnnHelperClassName, + (Class) layerHelperSuperClass, + arguments); + + } catch (Exception e) { + log.warn("Unable to use helper implementation {} for helper type {}, please check your classpath. Falling back to built in normal methods for now.",cudnnHelperClassName,layerHelperSuperClass.getName()); + } + + log.warn("Returning class loader to original one."); + DL4JClassLoading.setDl4jClassloader(classLoader); + + } + + if (helperRet != null && !helperRet.checkSupported()) { + return null; + } + + if(helperRet != null) { + log.debug("{} successfully initialized",cudnnHelperClassName); + } + + } else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) { + helperRet = DL4JClassLoading.createNewInstance( + oneDnnClassName, + arguments); + log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName); + } + + if (helperRet != null && !helperRet.checkSupported()) { + log.debug("Removed helper {} as not supported", helperRet.getClass()); + return null; + } + + return (T) helperRet; + } + +} \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LayerHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LayerHelper.java index fd3bc4ce3..82fc974ac 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LayerHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LayerHelper.java @@ -37,4 +37,6 @@ public interface LayerHelper { */ Map helperMemoryUse(); + boolean checkSupported(); + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/BaseMKLDNNHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/BaseMKLDNNHelper.java index 0e25760b8..a349eef20 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/BaseMKLDNNHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/BaseMKLDNNHelper.java @@ -59,4 +59,8 @@ public class BaseMKLDNNHelper { } } + public boolean checkSupported() { + return mklDnnEnabled(); + } + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java index 27fec5626..388125e82 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java @@ -197,4 +197,9 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { public Map helperMemoryUse() { return Collections.emptyMap(); } + + @Override + public boolean checkSupported() { + return BaseMKLDNNHelper.mklDnnEnabled(); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java index 353d7c664..a8803eda2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java @@ -30,6 +30,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.*; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; @@ -43,6 +44,7 @@ import java.util.List; import java.util.Map; public class MKLDNNLSTMHelper implements LSTMHelper { + public MKLDNNLSTMHelper(DataType dataType) {} @Override public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) { //TODO check other activation functions for MKLDNN @@ -159,6 +161,11 @@ public class MKLDNNLSTMHelper implements LSTMHelper { return Collections.emptyMap(); } + @Override + public boolean checkSupported() { + return BaseMKLDNNHelper.mklDnnEnabled(); + } + private int activationToArg(IActivation a){ //0=tanh, 1=relu, 2=sigmoid, 3=affine, 4=leaky relu, 5= thresholded relu, 6=scaled tanh, 7=hard sigmoid, 8=ELU, 9=softsign, 10=softplus if(a instanceof ActivationTanH) diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.java index d0c9f90ad..c38a81b58 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLocalResponseNormalizationHelper.java @@ -94,4 +94,5 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp public Map helperMemoryUse() { return Collections.emptyMap(); } + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.java index a33baf754..3d2784fa5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalLayer.java @@ -592,6 +592,11 @@ public class BidirectionalLayer implements RecurrentLayer { } return ret; } + + @Override + public boolean checkSupported() { + return true; + } } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/test/java/org/deeplearning4j/nn/layers/HelperUtilsTest.java b/cavis-dnn/cavis-dnn-nn/src/test/java/org/deeplearning4j/nn/layers/HelperUtilsTest.java index 0f034301c..bd05f187f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/test/java/org/deeplearning4j/nn/layers/HelperUtilsTest.java +++ b/cavis-dnn/cavis-dnn-nn/src/test/java/org/deeplearning4j/nn/layers/HelperUtilsTest.java @@ -20,50 +20,21 @@ package org.deeplearning4j.nn.layers; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.ActivationLayer; -import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.conf.layers.*; -import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper; import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper; import org.deeplearning4j.nn.layers.mkldnn.*; import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper; import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper; import org.deeplearning4j.nn.layers.recurrent.LSTMHelper; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.DisplayName; -import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; -import org.nd4j.common.tests.tags.NativeTag; -import org.nd4j.common.tests.tags.TagNames; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.impl.ActivationELU; -import org.nd4j.linalg.activations.impl.ActivationRationalTanh; -import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertNotNull; /** */ @DisplayName("Activation Layer Test") -@NativeTag -@Tag(TagNames.CUSTOM_FUNCTIONALITY) -@Tag(TagNames.DL4J_OLD_API) public class HelperUtilsTest extends BaseDL4JTest { @Override diff --git a/settings.gradle b/settings.gradle index d6355b295..2e4e68cce 100644 --- a/settings.gradle +++ b/settings.gradle @@ -116,8 +116,8 @@ include ':cavis-dnn:cavis-dnn-spark:cavis-dnn-spark-parameterserver' include ':cavis-dnn:cavis-dnn-tsne' include ':cavis-datavec' include ':cavis-datavec:cavis-datavec-api' -include ':cavis-datavec:cavis-datavec-arrow' include ':cavis-datavec:cavis-datavec-data' +include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-arrow' include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-image' include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-audio' include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-codec' From efbb341742a6d0531f74bd05efc04f5d3c3e6ecb Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 11:01:57 +0200 Subject: [PATCH 023/126] More test fixes Signed-off-by: brian --- build.gradle | 5 +++++ build_requirements.md | 5 ++++- cavis-common-platform/build.gradle | 4 ++-- cavis-dnn/cavis-dnn-parallelwrapper/build.gradle | 4 +++- cavis-full/build.gradle | 6 +++++- 5 files changed, 19 insertions(+), 5 deletions(-) diff --git a/build.gradle b/build.gradle index 902d0822a..ab3337562 100644 --- a/build.gradle +++ b/build.gradle @@ -63,6 +63,11 @@ allprojects { Project proj -> plugins.withType(JavaPlugin) { + sourceCompatibility = 11 + targetCompatibility = 1.8 + tasks.withType(JavaCompile) { + options.release = 8 + } dependencies { implementation platform(project(":cavis-common-platform")) diff --git a/build_requirements.md b/build_requirements.md index 50d83268a..db6532203 100644 --- a/build_requirements.md +++ b/build_requirements.md @@ -126,4 +126,7 @@ sudo sh cmake-3.20.4-linux-x86_64.sh --skip-license echo "supersede domain-name-servers 172.31.0.2, 8.8.8.8" | sudo tee -a /etc/dhcp/dhclient.conf echo "nameserver 8.8.8.8" | sudo tee -a /etc/resolv.conf - \ No newline at end of file + # Buildparameter: # + + -P\\ + CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2 \ No newline at end of file diff --git a/cavis-common-platform/build.gradle b/cavis-common-platform/build.gradle index a6202c6a8..7941b39ed 100644 --- a/cavis-common-platform/build.gradle +++ b/cavis-common-platform/build.gradle @@ -64,8 +64,8 @@ dependencies { api "org.projectlombok:lombok:1.18.24" /*Logging*/ - api 'org.slf4j:slf4j-api:1.7.30' - api 'org.slf4j:slf4j-simple:1.7.25' + api 'org.slf4j:slf4j-api:2.0.3' + api 'org.slf4j:slf4j-simple:2.0.3' api "org.apache.logging.log4j:log4j-core:2.17.0" api "ch.qos.logback:logback-classic:1.2.3" diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/build.gradle b/cavis-dnn/cavis-dnn-parallelwrapper/build.gradle index c039ab783..735e02b78 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/build.gradle +++ b/cavis-dnn/cavis-dnn-parallelwrapper/build.gradle @@ -25,6 +25,9 @@ dependencies { implementation 'org.slf4j:slf4j-api' implementation "com.google.guava:guava" + implementation "com.fasterxml.jackson.core:jackson-annotations" + implementation "com.fasterxml.jackson.core:jackson-core" + implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerCore implementation projects.cavisNd4j.cavisNd4jParameterServer.cavisNd4jParameterServerClient implementation projects.cavisDnn.cavisDnnCore @@ -36,7 +39,6 @@ dependencies { testImplementation projects.cavisUi.cavisUiStandalone - testImplementation projects.cavisDnn.cavisDnnCommonTests testImplementation projects.cavisUi.cavisUiModel testImplementation projects.cavisUi.cavisUiVertx diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index a986d6671..68e847fdf 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -12,8 +12,10 @@ configurations.archives.artifacts.with { archives -> dependencies { //Todo clean this api platform(project(":cavis-common-platform")) - api "org.bytedeco:javacpp" + api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" + api 'org.slf4j:slf4j-simple:2.0.3' + api 'org.slf4j:slf4j-api:2.0.3' //api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" rootProject.getAllprojects().each { Project sproj -> @@ -85,3 +87,5 @@ publishing { } } } + + From b08a0ac24bb542a67e890d5cf483fdd44eff059e Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 12:03:08 +0200 Subject: [PATCH 024/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .docker/Dockerfile | 10 +++++++--- build.gradle | 5 +++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/.docker/Dockerfile b/.docker/Dockerfile index 2e8e9a472..483b99544 100644 --- a/.docker/Dockerfile +++ b/.docker/Dockerfile @@ -2,9 +2,13 @@ FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04 RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git -RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \ - tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \ - ./bootstrap && make && make install +#Build cmake version from source \ +#RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \ +# tar -xvf cmake-3.24.2.tar.gz && cd cmake-3.24.2 && \ +# ./bootstrap && make && make install +RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2-linux-x86_64.sh && \ + mkdir /opt/cmake && sh ./cmake-3.24.2-linux-x86_64.sh --skip-license --prefix=/opt/cmake && ln -s /opt/cmake/bin/cmake /usr/bin/cmake && \ + rm cmake-3.24.2-linux-x86_64.sh RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf diff --git a/build.gradle b/build.gradle index ab3337562..cd5911461 100644 --- a/build.gradle +++ b/build.gradle @@ -44,6 +44,7 @@ ext { scalaVersion = "2.12" logger.quiet("Scala main version is set to {}", scalaVersion) + logger.quiet("Running java {}", JavaVersion.current()) } configurations.all { @@ -63,8 +64,8 @@ allprojects { Project proj -> plugins.withType(JavaPlugin) { - sourceCompatibility = 11 - targetCompatibility = 1.8 + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_1_8 tasks.withType(JavaCompile) { options.release = 8 } From b43e5860a99a563773aca7b12995f0936fed0f6e Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 12:07:39 +0200 Subject: [PATCH 025/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .docker/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.docker/Dockerfile b/.docker/Dockerfile index 483b99544..fe174610a 100644 --- a/.docker/Dockerfile +++ b/.docker/Dockerfile @@ -1,6 +1,7 @@ FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04 -RUN apt-get update && \ +RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf && \ + apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git #Build cmake version from source \ #RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \ @@ -10,5 +11,4 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3. mkdir /opt/cmake && sh ./cmake-3.24.2-linux-x86_64.sh --skip-license --prefix=/opt/cmake && ln -s /opt/cmake/bin/cmake /usr/bin/cmake && \ rm cmake-3.24.2-linux-x86_64.sh -RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf From 4f3393ceb46378c8710adb9c67c52c7072ef94f4 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 12:12:40 +0200 Subject: [PATCH 026/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .docker/Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.docker/Dockerfile b/.docker/Dockerfile index fe174610a..4e2c0ece8 100644 --- a/.docker/Dockerfile +++ b/.docker/Dockerfile @@ -1,7 +1,6 @@ FROM nvidia/cuda:11.4.0-cudnn8-devel-ubuntu20.04 -RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf && \ - apt-get update && \ +RUN apt-get update && \ DEBIAN_FRONTEND=noninteractive apt-get install -y openjdk-11-jdk wget build-essential checkinstall zlib1g-dev libssl-dev git #Build cmake version from source \ #RUN wget https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3.24.2.tar.gz && \ From 0e50a1a04cf4d2c111c0e72772a446d4300aab13 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 12:24:19 +0200 Subject: [PATCH 027/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-docker-cuda-build.jenkinsfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile b/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile index 1ba9af2da..b465b0d3f 100644 --- a/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile @@ -26,7 +26,7 @@ pipeline { dir '.docker' label 'linux && docker && cuda' //additionalBuildArgs '--build-arg version=1.0.2' - args '--gpus all' + //args '--gpus all' --needed for test only, you can build without GPU } } @@ -49,7 +49,7 @@ pipeline { steps { withGradle { - sh 'sh ./gradlew publish --stacktrace -x test -PCAVIS_CHIP=cuda \ + sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \ -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' } From 81d49ba1f0c2d2ef8c8c5e789a05b893efb917fd Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 12:40:01 +0200 Subject: [PATCH 028/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-docker-cuda-build.jenkinsfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile b/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile index b465b0d3f..331283531 100644 --- a/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-cuda-build.jenkinsfile @@ -34,7 +34,7 @@ pipeline { stage('prep-build-environment-linux-cuda') { steps { checkout scm - sh 'nvidia-smi' + //sh 'nvidia-smi' sh 'nvcc --version' sh 'gcc --version' sh 'cmake --version' From e3d64a1cac74947c1d02d7d82238542ae0787dd1 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 16:55:13 +0200 Subject: [PATCH 029/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-cpu-build.jenkinsfile | 4 ++-- cavis-native/cavis-native-lib/build.gradle | 7 ++++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.jenkins/linux-x86_64-cpu-build.jenkinsfile b/.jenkins/linux-x86_64-cpu-build.jenkinsfile index d57d033d4..bc4a988b1 100644 --- a/.jenkins/linux-x86_64-cpu-build.jenkinsfile +++ b/.jenkins/linux-x86_64-cpu-build.jenkinsfile @@ -48,7 +48,7 @@ pipeline { //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' } } - stage('test-linux-cpu') { + /*stage('test-linux-cpu') { environment { MAVEN = credentials('Internal Archiva') OSSRH = credentials('OSSRH') @@ -62,7 +62,7 @@ pipeline { } //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' } - } + }*/ stage('publish-linux-cpu') { environment { MAVEN = credentials('Internal Archiva') diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index c65b768d3..85e445b64 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -467,13 +467,13 @@ javadoc { if(! osdetector.os.startsWith("windows")) { tasks.getByName("publish") { - enabled = false + enabled = true } tasks.getByName("generatePomFileForMavenJavaPublication") { - enabled = false + enabled = true } tasks.getByName("publishMavenJavaPublicationToLocalRemoteRepository") { - enabled = false + enabled = true } chipList.each {thisChip -> artifacts { @@ -485,6 +485,7 @@ if(! osdetector.os.startsWith("windows")) { publishing { publications { mavenJava(MavenPublication) { + log.quiet("Adding artifacts from task {} to the publication.", "${thisChip}SupportJar" ) artifact tasks.getByName("${thisChip}SupportJar") } } From 1ee6b7a2312b3c6288c04781775be4633c64f65e Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 19:45:41 +0200 Subject: [PATCH 030/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/cavis-native-lib/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 85e445b64..f5c623d56 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -485,7 +485,7 @@ if(! osdetector.os.startsWith("windows")) { publishing { publications { mavenJava(MavenPublication) { - log.quiet("Adding artifacts from task {} to the publication.", "${thisChip}SupportJar" ) + logger.quiet("Adding artifacts from task {} to the publication.", "${thisChip}SupportJar" ) artifact tasks.getByName("${thisChip}SupportJar") } } From 941275df3a9ec2f2b6e67c4afb287cd84e249f8f Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 20:39:12 +0200 Subject: [PATCH 031/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/cavis-native-lib/build.gradle | 28 ++++++++++++---------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index f5c623d56..4cfef72f6 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -388,7 +388,7 @@ chipList.each { thisChip -> //} } - // Generates jnijavacpp.cpp and jniNativeLibrary.cpp, compiles and links it + // Create Jar with classifier tasks.getByName("${thisChip}SupportJar") { Jar thisTask -> dependsOn "javacpp${thisChip.capitalize()}SupportBuildCompiler" dependsOn "javacpp${thisChip.capitalize()}SupportBuildCommand" @@ -405,7 +405,7 @@ chipList.each { thisChip -> } return exclude } - into "${javacppPlatform}/" //we need it in a platform, that javacpp Loader understands + into "${javacppPlatform}/" //path within jar, we need it in a platform, that javacpp Loader understands } from(sourceSets.getByName("${thisChip}Support").getOutput()) { @@ -466,34 +466,36 @@ javadoc { if(! osdetector.os.startsWith("windows")) { - tasks.getByName("publish") { - enabled = true - } + //tasks.getByName("publish") { + // enabled = false + // } tasks.getByName("generatePomFileForMavenJavaPublication") { enabled = true } tasks.getByName("publishMavenJavaPublicationToLocalRemoteRepository") { enabled = true } - chipList.each {thisChip -> + chipList.each { thisChip -> artifacts { archives tasks.getByName("${thisChip}SupportJar") } } +} - chipList.each { thisChip -> - publishing { - publications { - mavenJava(MavenPublication) { - logger.quiet("Adding artifacts from task {} to the publication.", "${thisChip}SupportJar" ) - artifact tasks.getByName("${thisChip}SupportJar") - } + +chipList.each { thisChip -> + publishing { + publications { + mavenJava(MavenPublication) { + logger.quiet("Adding artifacts from task {} to the publication.", "${thisChip}SupportJar" ) + artifact tasks.getByName("${thisChip}SupportJar") } } } } + if( osdetector.os.startsWith("windows")) { FileCollection collection = layout.files { file("build/libs/").listFiles() } From a9890feb9f230f67bba558bf68cdfe403dd1f9e5 Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 13 Oct 2022 16:53:50 +0200 Subject: [PATCH 032/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-docker-cpu-build.jenkinsfile | 14 +++++++------- cavis-native/cavis-native-lib/build.gradle | 1 - 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile b/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile index 5553f8014..43e77bc0d 100644 --- a/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile @@ -60,14 +60,14 @@ pipeline { OSSRH = credentials('OSSRH') } - steps { - withGradle { - sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \ - -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ - -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' - } + //steps { + // withGradle { + // sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \ + // -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + // -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + // } //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' - } + //} } stage('publish-linux-cpu') { environment { diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 4cfef72f6..84217af15 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -487,7 +487,6 @@ chipList.each { thisChip -> publishing { publications { mavenJava(MavenPublication) { - logger.quiet("Adding artifacts from task {} to the publication.", "${thisChip}SupportJar" ) artifact tasks.getByName("${thisChip}SupportJar") } } From 2a96e7185347825126eaa219b2f9300973be0721 Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 13 Oct 2022 16:55:37 +0200 Subject: [PATCH 033/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-docker-cpu-build.jenkinsfile | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile b/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile index 43e77bc0d..1379d630e 100644 --- a/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-cpu-build.jenkinsfile @@ -60,14 +60,14 @@ pipeline { OSSRH = credentials('OSSRH') } - //steps { - // withGradle { - // sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \ - // -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ - // -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' - // } + steps { + withGradle { + //sh 'sh ./gradlew test --stacktrace -PCAVIS_CHIP=cpu \ + // -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + // -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' - //} + } } stage('publish-linux-cpu') { environment { From e695f1c653dda347183a7b4b18b241045d5930c5 Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 13 Oct 2022 17:52:47 +0200 Subject: [PATCH 034/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- build_requirements.md | 14 +++++++++++++- cavis-full/build.gradle | 3 ++- cavis-native/cavis-native-lib/build.gradle | 3 ++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/build_requirements.md b/build_requirements.md index db6532203..602190b95 100644 --- a/build_requirements.md +++ b/build_requirements.md @@ -129,4 +129,16 @@ echo "nameserver 8.8.8.8" | sudo tee -a /etc/resolv.conf # Buildparameter: # -P\\ - CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2 \ No newline at end of file + CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2 + +# Zeppelin Spark dependencies # +3 + + +To add the dependency to the language models, use the following format in the Dependencies section of the of the Spark Interpreter configuration (Interpreters -> Spark -> Edit -> Dependencies): + +groupId:artifactId:packaging:classifier:version + +In your case it should work with + +edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 \ No newline at end of file diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 68e847fdf..659e119e2 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -16,7 +16,8 @@ dependencies { api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' - //api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 84217af15..10648759d 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -117,7 +117,8 @@ dependencies { api platform(project(':cavis-common-platform')) - api "org.bytedeco:javacpp" + implementation "org.bytedeco:javacpp" + implementation group: "org.bytedeco", name: "javacpp", classifier: "${javacppPlatform}" if(withCuda()) { cudaSupportImplementation platform(project(':cavis-common-platform')) From 47181956818d92dc4c6ba088ec961bcbb1474ea9 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 12:00:34 +0200 Subject: [PATCH 035/126] upgrade versions Signed-off-by: brian --- cavis-common-platform/build.gradle | 30 +++++++++++-------- .../cavis-datavec-python/build.gradle | 4 +-- cavis-full/build.gradle | 5 ++-- 3 files changed, 23 insertions(+), 16 deletions(-) diff --git a/cavis-common-platform/build.gradle b/cavis-common-platform/build.gradle index 7941b39ed..81851f931 100644 --- a/cavis-common-platform/build.gradle +++ b/cavis-common-platform/build.gradle @@ -5,18 +5,22 @@ plugins { ext { scalaVersion = rootProject.ext.scalaVersion + javacppPlatform = osdetector.classifier } - def javacpp = [version: "1.5.6", presetsVersion: "1.5.6"] + def javacpp = [version: "1.5.7", presetsVersion: "1.5.7"] def hdf5 = [version: "1.12.1"] def jackson = [version: "2.13.4"] - def cuda = [version: "11.4"] - def cudnn = [version: "8.2"] - def openblas = [version: "0.3.17"] + def cuda = [version: "11.6"] + def cudnn = [version: "8.3"] + def openblas = [version: "0.3.19"] + def numpy = [version: "1.22.2"] + def tensorflow = [version: "1.15.5"] + def cpython = [version: "3.10.2"] - def javacv = [version:"1.5.6"] - def opencv = [version: "4.5.3"] - def leptonica = [version: "1.81.1"] + def javacv = [version:"1.5.7"] + def opencv = [version: "4.5.5"] + def leptonica = [version: "1.82.0"] def junit = [version: "5.9.1"] def flatbuffers = [version: "1.10.0"] @@ -111,17 +115,19 @@ dependencies { api "org.bytedeco:leptonica-platform:${leptonica.version}-${javacpp.presetsVersion}" api "org.bytedeco:hdf5-platform:${hdf5.version}-${javacpp.presetsVersion}" api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}" - api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:windows-x86_64" - api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:linux-x86_64" + api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:${javacppPlatform}" + //api "org.bytedeco:hdf5:${hdf5.version}-${javacpp.presetsVersion}:linux-x86_64" api "org.bytedeco:cuda:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}" api "org.bytedeco:cuda-platform-redist:${cuda.version}-${cudnn.version}-${javacpp.presetsVersion}" api "org.bytedeco:mkl-dnn:0.21.5-${javacpp.presetsVersion}" - api "org.bytedeco:tensorflow:1.15.5-${javacpp.presetsVersion}" - api "org.bytedeco:cpython:3.9.6-${javacpp.presetsVersion}" - api "org.bytedeco:numpy:1.21.1-${javacpp.presetsVersion}" + api "org.bytedeco:tensorflow:${tensorflow.version}-${javacpp.presetsVersion}" + api "org.bytedeco:cpython:${cpython.version}-${javacpp.presetsVersion}:${javacppPlatform}" + api "org.bytedeco:numpy:${numpy.version}-${javacpp.presetsVersion}:${javacppPlatform}" + //implementation "org.bytedeco:cpython-platform:3.9.6-1.5.6" + //implementation "org.bytedeco:numpy-platform:1.21.1-1.5.6" /* Apache Spark */ api "org.apache.spark:spark-core_${scalaVersion}:${spark.version}" diff --git a/cavis-datavec/cavis-datavec-python/build.gradle b/cavis-datavec/cavis-datavec-python/build.gradle index 2b9292500..0ee4b03dc 100644 --- a/cavis-datavec/cavis-datavec-python/build.gradle +++ b/cavis-datavec/cavis-datavec-python/build.gradle @@ -22,8 +22,8 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" dependencies { implementation 'org.json:json:20190722' - implementation "org.bytedeco:cpython-platform:3.9.6-1.5.6" - implementation "org.bytedeco:numpy-platform:1.21.1-1.5.6" + implementation "org.bytedeco:cpython" + implementation "org.bytedeco:numpy" implementation 'com.google.code.findbugs:jsr305:3.0.2' implementation projects.cavisDatavec.cavisDatavecApi implementation projects.cavisDnn.cavisDnnApi diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 659e119e2..b3d2231e5 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -16,8 +16,9 @@ dependencies { api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' - api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" - api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + //TODO for the two below.. either platform specific uber jars or a single big one with all platforms + //api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") From 5eb3d1c33d8389377b99816df7e119261823bcb7 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 12:11:15 +0200 Subject: [PATCH 036/126] More test fixes Signed-off-by: brian --- ...inux-x86_64-docker-cpu-publish.jenkinsfile | 49 +++++++++++++++++++ .../clustering/kdtree/KDTreeTest.java | 10 +--- .../clustering/kmeans/KMeansTest.java | 7 +-- .../clustering/sptree/SPTreeTest.java | 5 -- .../spark/BaseSparkKryoTest.java | 5 -- .../deeplearning4j/spark/BaseSparkTest.java | 5 -- .../multilayer/TestSparkDl4jMultiLayer.java | 5 -- ...TestSparkMultiLayerParameterAveraging.java | 6 --- .../spark/util/TestRepartitioning.java | 7 +-- .../spark/text/BaseSparkTest.java | 5 -- .../spark/parameterserver/BaseSparkTest.java | 6 --- .../train/GradientSharingTrainingTest.java | 5 -- .../ui/stats/TestTransferStatsCollection.java | 5 -- .../deeplearning4j/ui/TestVertxUIManual.java | 7 +-- .../org/deeplearning4j/zoo/MiscTests.java | 5 -- .../org/deeplearning4j/zoo/TestImageNet.java | 5 -- 16 files changed, 57 insertions(+), 80 deletions(-) create mode 100644 .jenkins/linux-x86_64-docker-cpu-publish.jenkinsfile diff --git a/.jenkins/linux-x86_64-docker-cpu-publish.jenkinsfile b/.jenkins/linux-x86_64-docker-cpu-publish.jenkinsfile new file mode 100644 index 000000000..d185b5946 --- /dev/null +++ b/.jenkins/linux-x86_64-docker-cpu-publish.jenkinsfile @@ -0,0 +1,49 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +pipeline { + agent { + dockerfile { + filename 'Dockerfile' + dir '.docker' + label 'linux && docker' + //additionalBuildArgs '--build-arg version=1.0.2' + //args '--gpus all' + } + } + + stages { + stage('publish-linux-cpu') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew publish -x test -PCAVIS_CHIP=cpu \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + } + } + } +} diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java index cb6b05d89..00beb9e71 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java @@ -19,10 +19,7 @@ package org.deeplearning4j.clustering.kdtree; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.joda.time.Duration; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Tag; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.*; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -44,12 +41,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; /** * Created by agibsonccc on 1/1/15. */ +@Timeout(120) public class KDTreeTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } private KDTree kdTree; diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java index 4b35b9f6a..40683daa9 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java @@ -22,6 +22,7 @@ import org.deeplearning4j.Performance; import org.deeplearning4j.clustering.algorithm.Distance; import org.deeplearning4j.clustering.cluster.*; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -34,15 +35,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; /** * Created by agibsonccc on 7/2/17. */ +@Timeout(120) public class KMeansTest extends BaseDL4JTest { private boolean[] useKMeansPlusPlus = {true, false}; - @Override - public long getTimeoutMilliseconds() { - return 60000L; - } - @Test public void testKMeans() { Nd4j.getRandom().setSeed(7); diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java index 17af2afd4..5973a1f5a 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java @@ -33,11 +33,6 @@ import static org.junit.jupiter.api.Assertions.*; */ public class SPTreeTest extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - @BeforeEach public void setUp() { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java index 9cab73c5e..e60185e46 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java @@ -29,11 +29,6 @@ import java.util.Map; public class BaseSparkKryoTest extends BaseSparkTest { - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - @Override public JavaSparkContext getContext() { if (sc != null) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index e00f8d6d3..5a8ac5d7e 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -56,11 +56,6 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable protected transient DataSet data; protected transient JavaRDD sparkData; - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - @BeforeEach public void before() { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java index c64618557..9c7f783e0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java @@ -53,11 +53,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; @Slf4j public class TestSparkDl4jMultiLayer extends BaseSparkTest { - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - @Override public DataType getDataType() { return DataType.FLOAT; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index bc1ced484..48a30034a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -98,12 +98,6 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { @TempDir public File testDir; - - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - @Override public DataType getDefaultFPDataType() { return DataType.FLOAT; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java index c83282547..77fdff58e 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java @@ -31,6 +31,7 @@ import org.deeplearning4j.spark.impl.common.CountPartitionsFunction; import org.deeplearning4j.spark.impl.repartitioner.DefaultRepartitioner; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import scala.Tuple2; import java.util.ArrayList; @@ -42,13 +43,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +@Timeout(300) public class TestRepartitioning extends BaseSparkTest { - @Override - public long getTimeoutMilliseconds() { - return isIntegrationTests() ? 240000 : 60000; - } - @Test public void testRepartitioning() { if(Platform.isWindows()) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java index d998ddde4..1e647013a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java @@ -35,11 +35,6 @@ import java.util.Map; public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { protected transient JavaSparkContext sc; - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - @BeforeEach public void before() throws Exception { sc = getContext(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java index d110e41bd..7a022a132 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java @@ -54,12 +54,6 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable protected transient DataSet data; protected transient JavaRDD sparkData; - - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - @BeforeEach public void before() { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java index 31cd119d7..f535372b2 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java @@ -78,11 +78,6 @@ public class GradientSharingTrainingTest extends BaseSparkTest { @TempDir public File testDir; - @Override - public long getTimeoutMilliseconds() { - return 180000L; - } - @Test public void trainSanityCheck() throws Exception { diff --git a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java index 5736cdb7a..3cf4ec7d9 100644 --- a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java +++ b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java @@ -40,11 +40,6 @@ import java.io.IOException; public class TestTransferStatsCollection extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 90_000L; - } - @Test public void test() throws IOException { diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java index eb3d19c51..d6f11df5e 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java @@ -41,6 +41,7 @@ import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.common.function.Function; @@ -58,13 +59,9 @@ import static org.junit.jupiter.api.Assertions.*; @Slf4j //@Ignore +@Timeout(600) public class TestVertxUIManual extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 3600_000L; - } - @AfterAll public void shutdownServer() throws InterruptedException { UIServer.getInstance().stop(); diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/MiscTests.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/MiscTests.java index 0bc1572e7..04a023596 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/MiscTests.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/MiscTests.java @@ -37,11 +37,6 @@ import java.io.File; //@Ignore("Times out too often") public class MiscTests extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return Long.MAX_VALUE; - } - @Test public void testTransferVGG() throws Exception { DataSet ds = new DataSet(); diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java index 8be3a65cc..04d1f8fce 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java @@ -57,11 +57,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; //@Ignore("Times out too often") public class TestImageNet extends BaseDL4JTest { - @Override - public long getTimeoutMilliseconds() { - return 90000L; - } - @Override public DataType getDataType(){ return DataType.FLOAT; From 976491ee8608505d6b0045f46d2cf1e9f3394bc8 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 12:32:41 +0200 Subject: [PATCH 037/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-cuda-build.jenkinsfile | 60 ++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 .jenkins/linux-x86_64-cuda-build.jenkinsfile diff --git a/.jenkins/linux-x86_64-cuda-build.jenkinsfile b/.jenkins/linux-x86_64-cuda-build.jenkinsfile new file mode 100644 index 000000000..1b9399028 --- /dev/null +++ b/.jenkins/linux-x86_64-cuda-build.jenkinsfile @@ -0,0 +1,60 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +pipeline { + agent { + dockerfile { + filename 'Dockerfile' + dir '.docker' + label 'linux && cuda' + //additionalBuildArgs '--build-arg version=1.0.2' + //args '--gpus all' --needed for test only, you can build without GPU + } + } + + stages { + stage('prep-build-environment-linux-cuda') { + steps { + checkout scm + //sh 'nvidia-smi' + sh 'nvcc --version' + sh 'gcc --version' + sh 'cmake --version' + sh 'sh ./gradlew --version' + } + } + stage('build-linux-cuda') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' + } + } + } +} From f1a0a66021c2b418c20b5039d514352514e590d4 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 17:07:14 +0200 Subject: [PATCH 038/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index b3d2231e5..62b16a2bd 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -17,8 +17,8 @@ dependencies { api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - //api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" - //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") From 320b4430dd58ce8ee1e867f23fe4ae43bf48a2e1 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 21:12:07 +0200 Subject: [PATCH 039/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 62b16a2bd..1d924339b 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -12,12 +12,13 @@ configurations.archives.artifacts.with { archives -> dependencies { //Todo clean this api platform(project(":cavis-common-platform")) - api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise + //api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> From 97fbd07a0cee2c3fee7fed7b8f62a7b6b124791e Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 21:47:05 +0200 Subject: [PATCH 040/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 1d924339b..b62a7b1ca 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -17,7 +17,7 @@ dependencies { api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86", ext: "jar" api group: "org.bytedeco", name: "javacpp", version: "1.5.7" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" From 05e8a78d5156b85325caa213df70095a24f712a7 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 21:53:22 +0200 Subject: [PATCH 041/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index b62a7b1ca..971d402a5 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -17,8 +17,8 @@ dependencies { api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86", ext: "jar" - api group: "org.bytedeco", name: "javacpp", version: "1.5.7" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> From fd4da57a0daed704764bf5dcb89eb3bf48d62c72 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 21:59:13 +0200 Subject: [PATCH 042/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 971d402a5..15aa6a034 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -17,8 +17,8 @@ dependencies { api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" - api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "" + // api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> From 6174642bfe5ef7cb43d0425ba861162f03d3ffd3 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 22:24:59 +0200 Subject: [PATCH 043/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 15aa6a034..dc3bacb75 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -18,7 +18,7 @@ dependencies { api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms // api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" - api group: "org.bytedeco", name: "javacpp", version: "1.5.7" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7", ext: "jar" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> From 87488dbdd7c9c9fa1b4bd1f3f01ba44d2998fb92 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 23:07:02 +0200 Subject: [PATCH 044/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index dc3bacb75..6e2b426dd 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -17,8 +17,8 @@ dependencies { api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - // api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" - api group: "org.bytedeco", name: "javacpp", version: "1.5.7", ext: "jar" + //api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" + api group: "org.bytedeco", name: "javacpp-platform", version: "1.5.7" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> From 2ca92d343169de815e1bd20c3362dcb1c4d76740 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 23:21:27 +0200 Subject: [PATCH 045/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 6e2b426dd..7b77b9c46 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -17,10 +17,10 @@ dependencies { api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - //api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x64_86" - api group: "org.bytedeco", name: "javacpp-platform", version: "1.5.7" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" + api group: "org.bytedeco", name: "javacpp", version: "1.5.7" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" - + api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") && !sproj.name.equals("Cavis") From 55c9d7d10ce16f05f0f6d27653123a337246796e Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 23:38:41 +0200 Subject: [PATCH 046/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 7b77b9c46..989cfb274 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -19,7 +19,7 @@ dependencies { //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") From df7675fba7dd6271e019aed1a781a3bb87eff67c Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 23:56:48 +0200 Subject: [PATCH 047/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-dnn/cavis-dnn-api/build.gradle | 3 --- 1 file changed, 3 deletions(-) diff --git a/cavis-dnn/cavis-dnn-api/build.gradle b/cavis-dnn/cavis-dnn-api/build.gradle index ffd000c1c..499b6a96a 100644 --- a/cavis-dnn/cavis-dnn-api/build.gradle +++ b/cavis-dnn/cavis-dnn-api/build.gradle @@ -13,9 +13,6 @@ plugins { } apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" -group 'net.brutex' -version '1.0.0-SNAPSHOT' - dependencies { testRuntimeOnly 'net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT' From 28b1df3773b7cc076d2f8fa435d8579a28fcffe7 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 19 Oct 2022 13:31:00 +0200 Subject: [PATCH 048/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 989cfb274..7b77b9c46 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -19,7 +19,7 @@ dependencies { //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") From e5d45fb620f9c8082875158f18578ea091707af8 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 19 Oct 2022 15:59:59 +0200 Subject: [PATCH 049/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/cavis-native-lib/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 10648759d..eb3779c1a 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -84,7 +84,7 @@ chipList.each {thisChip -> } -if(osdetector.os.startsWith("windows")) { +//if(osdetector.os.startsWith("windows")) { sourceSets { main { java { @@ -93,7 +93,7 @@ if(osdetector.os.startsWith("windows")) { } } } -} +//} java { From 0a9a0cdf3c424f73111ad9b6705363b918b21ec6 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 19 Oct 2022 16:19:01 +0200 Subject: [PATCH 050/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 7b77b9c46..016c3ef67 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -20,13 +20,14 @@ dependencies { api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" api group: "org.bytedeco", name: "javacpp", version: "1.5.7" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" - api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' + //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") && !sproj.name.equals("Cavis") && !sproj.name.equals("cavis-datavec") && !sproj.name.equals("cavis-dnn") && !sproj.name.equals("cavis-native") + && !sproj.name.equals("cavis-native-lib") && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { From 234b91a5b1d3534d815080ec3428c3a48f1c2b10 Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 20 Oct 2022 09:44:26 +0200 Subject: [PATCH 051/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 3 +-- cavis-native/cavis-native-lib/build.gradle | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 016c3ef67..2e587fa8e 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -18,7 +18,7 @@ dependencies { api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" - api group: "org.bytedeco", name: "javacpp", version: "1.5.7" + //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> @@ -27,7 +27,6 @@ dependencies { && !sproj.name.equals("cavis-datavec") && !sproj.name.equals("cavis-dnn") && !sproj.name.equals("cavis-native") - && !sproj.name.equals("cavis-native-lib") && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index eb3779c1a..0a638ff15 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -488,6 +488,7 @@ chipList.each { thisChip -> publishing { publications { mavenJava(MavenPublication) { + artifact jar artifact tasks.getByName("${thisChip}SupportJar") } } From 656d36781211c19fa373d4a237d6749e1f0439c5 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 21 Oct 2022 15:19:32 +0200 Subject: [PATCH 052/126] Fix javadoc and cleanup Signed-off-by: brian --- .../test/java/net/brutex/spark/BrianTest.java | 3 +- .../java/net/brutex/spark/BrianTest2.java | 2 +- .../integration/IntegrationTestRunner.java | 8 +- .../testcases/dl4j/CNN1DTestCases.java | 4 +- .../testcases/dl4j/CNN2DTestCases.java | 10 +- .../testcases/dl4j/CNN3DTestCases.java | 4 +- .../testcases/dl4j/RNNTestCases.java | 4 +- .../dl4j/misc/CharacterIterator.java | 17 +- .../testcases/samediff/SameDiffCNNCases.java | 8 +- build_requirements.md | 6 +- .../org/datavec/api/conf/Configuration.java | 4 +- .../api/formats/output/OutputFormat.java | 2 +- .../org/datavec/api/io/BinaryComparable.java | 4 +- .../org/datavec/api/io/DataInputBuffer.java | 2 +- .../org/datavec/api/io/DataOutputBuffer.java | 4 +- .../org/datavec/api/io/RawComparator.java | 2 +- .../datavec/api/io/WritableComparator.java | 2 +- .../org/datavec/api/io/WritableUtils.java | 4 +- .../io/converters/LabelWriterConverter.java | 2 +- .../api/io/labels/PathLabelGenerator.java | 2 +- .../io/serializers/SerializationFactory.java | 4 +- .../java/org/datavec/api/records/Buffer.java | 2 +- .../java/org/datavec/api/records/IOUtils.java | 6 +- .../api/records/reader/RecordReader.java | 2 - .../reader/impl/ComposableRecordReader.java | 2 +- .../impl/ConcatenatingRecordReader.java | 2 +- .../records/reader/impl/FileRecordReader.java | 2 +- .../records/reader/impl/LineRecordReader.java | 8 +- .../CollectionSequenceRecordReader.java | 2 +- .../csv/CSVMultiSequenceRecordReader.java | 6 +- .../csv/CSVNLinesSequenceRecordReader.java | 4 +- .../CSVVariableSlidingWindowRecordReader.java | 10 +- .../impl/csv/SerializableCSVParser.java | 2 +- .../impl/inmemory/InMemoryRecordReader.java | 2 - .../InMemorySequenceRecordReader.java | 2 - .../reader/impl/jackson/FieldSelection.java | 8 +- .../impl/jackson/JacksonLineRecordReader.java | 4 +- .../JacksonLineSequenceRecordReader.java | 4 +- .../impl/jackson/JacksonRecordReader.java | 12 +- .../reader/impl/misc/MatlabRecordReader.java | 2 +- .../impl/misc/SVMLightRecordReader.java | 6 +- .../impl/regex/RegexLineRecordReader.java | 6 +- .../impl/regex/RegexSequenceRecordReader.java | 8 +- .../TransformProcessRecordReader.java | 2 - .../TransformProcessSequenceRecordReader.java | 2 - .../impl/misc/SVMLightRecordWriter.java | 4 +- .../org/datavec/api/split/BaseInputSplit.java | 4 +- .../java/org/datavec/api/split/FileSplit.java | 2 +- .../api/split/InputStreamInputSplit.java | 4 +- .../datavec/api/split/ListStringSplit.java | 2 +- .../api/split/NumberedFileInputSplit.java | 4 +- .../org/datavec/api/split/StringSplit.java | 2 +- .../api/transform/TransformProcess.java | 15 +- .../analysis/columns/NDArrayAnalysis.java | 2 +- .../counter/IntegerAnalysisCounter.java | 4 +- .../counter/NDArrayAnalysisCounter.java | 2 +- .../counter/StringAnalysisCounter.java | 2 +- .../CategoricalHistogramCounter.java | 4 +- .../analysis/json/TDigestDeserializer.java | 4 +- .../analysis/json/TDigestSerializer.java | 2 +- .../bytes/BytesQualityAnalysisState.java | 2 +- .../CategoricalQualityAnalysisState.java | 4 +- .../integer/IntegerQualityAnalysisState.java | 4 +- .../longq/LongQualityAnalysisState.java | 4 +- .../real/RealQualityAnalysisState.java | 4 +- .../string/StringQualityAnalysisState.java | 4 +- .../time/TimeQualityAnalysisState.java | 4 +- .../sequence/SequenceLengthAnalysis.java | 11 +- .../condition/column/BaseColumnCondition.java | 4 +- .../column/DoubleColumnCondition.java | 4 +- .../column/FloatColumnCondition.java | 4 +- .../column/IntegerColumnCondition.java | 4 +- .../condition/column/LongColumnCondition.java | 4 +- .../condition/column/TimeColumnCondition.java | 4 +- .../transform/filter/FilterInvalidValues.java | 18 +- .../org/datavec/api/transform/join/Join.java | 2 +- .../transform/metadata/BinaryMetaData.java | 7 +- .../transform/metadata/BooleanMetaData.java | 7 +- .../transform/metadata/DoubleMetaData.java | 10 +- .../api/transform/metadata/FloatMetaData.java | 10 +- .../transform/metadata/IntegerMetaData.java | 8 +- .../api/transform/metadata/LongMetaData.java | 10 +- .../api/transform/ops/AggregatorImpls.java | 20 +- .../ops/DispatchWithConditionOp.java | 2 +- .../reduce/AggregableColumnReduction.java | 3 +- .../api/transform/reduce/ColumnReduction.java | 2 +- .../datavec/api/transform/reduce/Reducer.java | 12 +- .../impl/GeographicMidpointReduction.java | 6 +- .../schema/conversion/TypeConversion.java | 2 +- .../split/SplitMaxLengthSequence.java | 2 +- .../api/transform/serde/BaseSerializer.java | 2 +- .../api/transform/serde/JsonMappers.java | 4 +- .../api/transform/serde/JsonSerializer.java | 2 +- .../api/transform/serde/ListWrappers.java | 12 +- .../api/transform/serde/YamlSerializer.java | 2 +- .../transform/stringreduce/StringReducer.java | 8 +- .../transform/BaseColumnTransform.java | 2 +- .../CategoricalToIntegerTransform.java | 2 +- .../CategoricalToOneHotTransform.java | 2 +- .../IntegerToCategoricalTransform.java | 2 +- .../transform/categorical/PivotTransform.java | 2 +- .../column/DuplicateColumnsTransform.java | 2 +- .../RemoveAllColumnsExceptForTransform.java | 2 +- .../column/RemoveColumnsTransform.java | 2 +- .../integer/IntegerToOneHotTransform.java | 2 +- .../nlp/TextToCharacterIndexTransform.java | 4 +- .../nlp/TextToTermIndexSequenceTransform.java | 4 +- .../sequence/SequenceDifferenceTransform.java | 4 +- .../StringListToCategoricalSetTransform.java | 2 +- .../StringListToCountsNDArrayTransform.java | 2 +- .../time/DeriveColumnsFromTimeTransform.java | 2 +- .../RenderableComponentHistogram.java | 6 +- .../RenderableComponentLineChart.java | 6 +- .../org/datavec/api/util/ReflectionUtils.java | 2 +- .../DateTimeFieldTypeDeserializer.java | 2 +- .../jackson/DateTimeFieldTypeSerializer.java | 2 +- .../api/util/ndarray/RecordConverter.java | 2 +- .../org/datavec/api/vector/Vectorizer.java | 2 +- .../datavec/api/writable/BooleanWritable.java | 2 +- .../datavec/api/writable/ByteWritable.java | 12 +- .../datavec/api/writable/DoubleWritable.java | 10 +- .../datavec/api/writable/FloatWritable.java | 10 +- .../org/datavec/api/writable/IntWritable.java | 10 +- .../datavec/api/writable/LongWritable.java | 10 +- .../datavec/api/writable/NDArrayWritable.java | 2 +- .../java/org/datavec/api/writable/Text.java | 12 +- .../datavec/api/writable/WritableFactory.java | 6 +- .../batch/AbstractWritableRecordBatch.java | 2 +- .../impl/CSVLineSequenceRecordReaderTest.java | 14 +- .../CSVMultiSequenceRecordReaderTest.java | 16 +- .../reader/impl/CSVRecordReaderTest.java | 12 +- .../impl/JacksonLineRecordReaderTest.java | 8 +- .../reader/impl/JacksonRecordReaderTest.java | 28 +- .../reader/impl/RegexRecordReaderTest.java | 18 +- .../impl/TestCollectionRecordReaders.java | 8 +- .../TransformProcessRecordReaderTests.java | 9 +- .../writer/impl/LibSvmRecordWriterTest.java | 12 +- .../writer/impl/SVMLightRecordWriterTest.java | 12 +- .../datavec/api/split/InputSplitTests.java | 3 +- .../api/split/TestStreamInputSplit.java | 19 +- .../api/transform/TestTransformProcess.java | 8 +- .../transform/condition/TestConditions.java | 178 +++---- .../api/transform/filter/TestFilters.java | 40 +- .../datavec/api/transform/join/TestJoin.java | 20 +- .../transform/ops/AggregableMultiOpTest.java | 21 +- .../transform/ops/AggregatorImplsTest.java | 12 +- .../api/transform/ops/DispatchOpTest.java | 37 +- .../transform/reduce/TestMultiOpReduce.java | 72 +-- .../api/transform/reduce/TestReductions.java | 6 +- .../TestReduceSequenceByWindowFunction.java | 20 +- .../transform/sequence/TestSequenceSplit.java | 20 +- .../sequence/TestWindowFunctions.java | 120 ++--- .../transform/stringreduce/TestReduce.java | 6 +- .../transform/transform/TestTransforms.java | 452 +++++++++--------- .../TestNDArrayWritableTransforms.java | 16 +- .../parse/ParseDoubleTransformTest.java | 3 +- .../org/datavec/api/transform/ui/TestUI.java | 2 +- .../api/writable/RecordConverterTest.java | 8 +- .../datavec/api/writable/WritableTest.java | 12 +- .../org/datavec/arrow/ArrowConverter.java | 4 +- .../arrow/recordreader/ArrowRecordWriter.java | 5 +- .../org/datavec/arrow/ArrowConverterTest.java | 42 +- ...rowWritableRecordTimeSeriesBatchTests.java | 18 +- .../src/main/java/org/datavec/audio/Wave.java | 13 +- .../java/org/datavec/audio/WaveHeader.java | 69 ++- .../org/datavec/audio/dsp/WindowFunction.java | 10 +- .../extension/NormalizedSampleAmplitudes.java | 7 +- .../datavec/audio/extension/Spectrogram.java | 6 +- .../audio/fingerprint/FingerprintManager.java | 10 +- .../fingerprint/FingerprintSimilarity.java | 2 +- .../FingerprintSimilarityComputer.java | 2 +- .../datavec/audio/fingerprint/MapRank.java | 2 +- .../audio/fingerprint/MapRankDouble.java | 4 +- .../audio/fingerprint/MapRankInteger.java | 4 +- .../audio/fingerprint/PairManager.java | 18 +- .../audio/fingerprint/QuickSortDouble.java | 4 +- .../fingerprint/QuickSortIndexPreserved.java | 2 +- .../audio/fingerprint/QuickSortInteger.java | 4 +- .../audio/fingerprint/QuickSortShort.java | 4 +- .../audio/processor/IntensityProcessor.java | 4 +- .../processor/RobustIntensityProcessor.java | 2 +- .../properties/FingerprintProperties.java | 30 +- .../codec/reader/CodecRecordReader.java | 2 +- .../datavec/poi/excel/ExcelRecordReader.java | 2 +- .../reduce/geo/CoordinatesReduction.java | 10 +- .../transform/reduce/TestGeoReduction.java | 10 +- .../transform/TestGeoTransforms.java | 10 +- .../hadoop/conf/ConfigurationUtil.java | 2 +- .../records/reader/mapfile/MapFileReader.java | 8 +- .../reader/mapfile/MapFileRecordReader.java | 2 +- .../mapfile/MapFileSequenceRecordReader.java | 2 +- .../reader/TestMapFileRecordReader.java | 24 +- .../TestMapFileRecordReaderMultipleParts.java | 8 +- ...ileRecordReaderMultiplePartsSomeEmpty.java | 8 +- .../org/datavec/image/loader/CifarLoader.java | 17 +- .../org/datavec/image/loader/LFWLoader.java | 7 +- .../image/loader/NativeImageLoader.java | 21 +- .../org/datavec/image/mnist/MnistDbFile.java | 2 +- .../org/datavec/image/mnist/MnistFetcher.java | 2 +- .../datavec/image/mnist/MnistImageFile.java | 4 +- .../image/mnist/draw/DrawReconstruction.java | 2 +- .../recordreader/BaseImageRecordReader.java | 4 +- .../ObjectDetectionRecordReader.java | 4 +- .../objdetect/impl/SvhnLabelProvider.java | 22 +- .../objdetect/impl/VocLabelProvider.java | 2 +- .../transform/ImageTransformProcess.java | 2 +- .../org/datavec/image/loader/LoaderTests.java | 4 +- .../datavec/image/loader/TestImageLoader.java | 4 +- .../recordreader/TestImageRecordReader.java | 20 +- .../datavec/image/transform/JsonYamlTest.java | 8 +- .../image/transform/TestImageTransform.java | 54 +-- .../org/datavec/nlp/annotator/PoStagger.java | 2 +- .../nlp/metadata/DefaultVocabCache.java | 6 +- .../movingwindow/ContextLabelRetriever.java | 10 +- .../org/datavec/nlp/movingwindow/Window.java | 4 +- .../tokenizer/ConcurrentTokenizer.java | 2 +- .../tokenizer/DefaultStreamTokenizer.java | 2 +- .../tokenizer/DefaultTokenizer.java | 2 +- .../tokenizer/PosUimaTokenizer.java | 8 +- .../tokenization/tokenizer/UimaTokenizer.java | 10 +- .../PosUimaTokenizerFactory.java | 4 +- .../UimaTokenizerFactory.java | 4 +- .../nlp/transforms/GazeteerTransform.java | 2 +- .../nlp/transforms/MultiNlpTransform.java | 8 +- .../nlp/transforms/TestGazeteerTransform.java | 2 +- .../nlp/transforms/TestMultiNLPTransform.java | 2 +- ...OfWordsTermSequenceIndexTransformTest.java | 4 +- .../transforms/LocalTransformExecutor.java | 4 +- ...lTransformProcessSequenceRecordReader.java | 3 +- .../misc/SequenceMergeFunction.java | 2 +- ...ocalTransformProcessRecordReaderTests.java | 9 +- .../TestNDArrayToWritablesFunction.java | 3 +- .../TestWritablesToStringFunctions.java | 6 +- .../transforms/transform/ExecutionTest.java | 56 +-- .../transform/TestGeoTransforms.java | 10 +- .../transform/TestPythonTransformProcess.java | 16 +- .../transforms/transform/join/TestJoin.java | 76 +-- .../rank/TestCalculateSortedRank.java | 8 +- .../sequence/TestConvertToSequence.java | 26 +- .../java/org/datavec/python/NumpyArray.java | 10 +- .../org/datavec/python/PythonCondition.java | 2 +- .../datavec/python/PythonContextManager.java | 4 +- .../org/datavec/python/PythonExecutioner.java | 16 +- .../java/org/datavec/python/PythonJob.java | 4 +- .../java/org/datavec/python/PythonObject.java | 12 +- .../org/datavec/python/PythonProcess.java | 10 +- .../java/org/datavec/python/PythonType.java | 2 +- .../java/org/datavec/python/PythonUtils.java | 4 +- .../java/org/datavec/python/keras/Model.java | 2 +- .../org/datavec/python/PythonNumpyTest.java | 2 +- .../datavec/python/ScalarAndArrayTest.java | 2 +- .../org/datavec/python/TestPythonList.java | 2 +- .../datavec/python/TestPythonVariables.java | 4 +- .../functions/pairdata/PathToKeyFunction.java | 4 +- .../datavec/spark/transform/DataFrames.java | 4 +- .../spark/transform/Normalization.java | 12 +- .../transform/misc/SequenceMergeFunction.java | 2 +- .../sparkfunction/SequenceToRows.java | 4 +- .../spark/transform/sparkfunction/ToRow.java | 4 +- .../spark/transform/utils/SparkExport.java | 4 +- .../spark/transform/utils/SparkUtils.java | 2 +- .../spark/util/SerializableHadoopConfig.java | 2 +- .../datavec/spark/TestKryoSerialization.java | 2 +- .../TestNDArrayToWritablesFunction.java | 3 +- ...PairSequenceRecordReaderBytesFunction.java | 2 +- .../TestRecordReaderBytesFunction.java | 2 +- ...TestSequenceRecordReaderBytesFunction.java | 2 +- .../TestWritablesToStringFunctions.java | 6 +- .../spark/storage/TestSparkStorageUtils.java | 24 +- .../spark/transform/DataFramesTests.java | 48 +- .../spark/transform/ExecutionTest.java | 110 ++--- .../transform/analysis/TestAnalysis.java | 38 +- .../spark/transform/join/TestJoin.java | 76 +-- .../rank/TestCalculateSortedRank.java | 8 +- .../sequence/TestConvertToSequence.java | 26 +- .../org/datavec/spark/util/TestSparkUtil.java | 4 +- .../autodiff/execution/input/Operands.java | 2 +- .../functions/DifferentialFunction.java | 6 +- .../autodiff/listeners/ListenerResponse.java | 2 +- .../org/nd4j/autodiff/listeners/Loss.java | 2 +- .../checkpoint/CheckpointListener.java | 37 +- .../debugging/ExecDebuggingListener.java | 2 +- .../listeners/impl/HistoryListener.java | 6 +- .../autodiff/listeners/impl/UIListener.java | 28 +- .../listeners/profiler/ProfilingListener.java | 6 +- .../profiler/comparison/ProfileAnalyzer.java | 2 +- .../listeners/records/EvaluationRecord.java | 2 +- .../autodiff/listeners/records/History.java | 10 +- .../autodiff/listeners/records/LossCurve.java | 2 +- .../org/nd4j/autodiff/samediff/SDIndex.java | 2 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 20 +- .../autodiff/samediff/TrainingConfig.java | 10 +- .../samediff/internal/AbstractSession.java | 8 +- .../autodiff/samediff/internal/FrameIter.java | 2 +- .../samediff/internal/InferenceSession.java | 2 +- .../internal/memory/ArrayCacheMemoryMgr.java | 6 +- .../samediff/serde/LegacyOpMapper.java | 2 +- .../autodiff/validation/OpValidation.java | 10 +- .../nd4j/autodiff/validation/TestCase.java | 2 +- .../NonInplaceValidationListener.java | 7 +- .../java/org/nd4j/context/Nd4jContext.java | 4 +- .../org/nd4j/enums/ImageResizeMethod.java | 2 +- .../org/nd4j/evaluation/BaseEvaluation.java | 2 +- .../java/org/nd4j/evaluation/IMetric.java | 4 +- .../classification/ConfusionMatrix.java | 2 +- .../classification/EvaluationBinary.java | 16 +- .../evaluation/custom/EvaluationLambda.java | 2 +- .../nd4j/evaluation/custom/MergeLambda.java | 2 +- .../nd4j/evaluation/custom/ResultLambda.java | 2 +- .../regression/RegressionEvaluation.java | 5 +- .../serde/ConfusionMatrixDeserializer.java | 2 +- .../serde/ConfusionMatrixSerializer.java | 2 +- .../evaluation/serde/ROCArraySerializer.java | 2 +- .../main/java/org/nd4j/graph/FlatArray.java | 2 +- .../org/nd4j/graph/FlatConfiguration.java | 2 +- .../main/java/org/nd4j/graph/FlatNode.java | 6 +- .../java/org/nd4j/graph/FlatProperties.java | 2 +- .../DifferentialFunctionClassHolder.java | 26 +- .../converters/ImportClassMapping.java | 2 +- .../properties/adapters/BooleanAdapter.java | 2 +- .../properties/adapters/DataTypeAdapter.java | 4 +- .../imports/graphmapper/tf/TFGraphMapper.java | 2 +- .../tf/tensors/TFTensorMapper.java | 2 +- .../tensorflow/TensorFlowImportValidator.java | 18 +- .../activations/impl/ActivationELU.java | 2 +- .../activations/impl/ActivationGELU.java | 2 +- .../activations/impl/ActivationHardTanH.java | 2 + .../activations/impl/ActivationLReLU.java | 2 +- .../activations/impl/ActivationPReLU.java | 2 +- .../activations/impl/ActivationRReLU.java | 3 +- .../activations/impl/ActivationReLU.java | 6 +- .../impl/ActivationThresholdedReLU.java | 4 +- .../java/org/nd4j/linalg/api/blas/Blas.java | 4 +- .../nd4j/linalg/api/blas/BlasBufferUtil.java | 8 +- .../nd4j/linalg/api/blas/impl/BaseLapack.java | 27 +- .../nd4j/linalg/api/blas/impl/BaseLevel2.java | 20 +- .../nd4j/linalg/api/blas/impl/BaseLevel3.java | 20 +- .../linalg/api/blas/params/GemmParams.java | 42 +- .../api/blas/params/GemvParameters.java | 16 +- .../linalg/api/blas/params/MMulTranspose.java | 2 +- .../linalg/api/buffer/BaseDataBuffer.java | 37 +- .../linalg/api/buffer/util/DataTypeUtil.java | 2 +- .../linalg/api/iter/FirstAxisIterator.java | 2 +- .../nd4j/linalg/api/iter/FlatIterator.java | 6 +- .../linalg/api/iter/INDArrayIterator.java | 2 +- .../linalg/api/iter/LinearIndexLookup.java | 10 +- .../nd4j/linalg/api/iter/NdIndexIterator.java | 4 +- .../linalg/api/memory/AllocationsTracker.java | 2 +- .../linalg/api/memory/BasicMemoryManager.java | 4 +- .../api/memory/DeviceAllocationsTracker.java | 2 +- .../deallocation/DeallocatorService.java | 8 +- .../memory/pointers/ImmortalFloatPointer.java | 2 +- .../provider/BasicWorkspaceManager.java | 4 +- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 42 +- .../java/org/nd4j/linalg/api/ops/BaseOp.java | 17 +- .../nd4j/linalg/api/ops/BaseOpContext.java | 3 +- .../linalg/api/ops/BaseTransformSameOp.java | 2 +- .../nd4j/linalg/api/ops/DynamicCustomOp.java | 27 +- .../api/ops/aggregates/BaseAggregate.java | 10 +- .../nd4j/linalg/api/ops/aggregates/Batch.java | 6 +- .../ops/aggregates/impl/AggregateAxpy.java | 2 +- .../nd4j/linalg/api/ops/custom/Flatten.java | 8 +- .../linalg/api/ops/custom/FusedBatchNorm.java | 2 +- .../linalg/api/ops/custom/LinearSolve.java | 2 +- .../impl/broadcast/bool/BroadcastEqualTo.java | 3 +- .../broadcast/bool/BroadcastLessThan.java | 3 +- .../bool/BroadcastLessThanOrEqual.java | 3 +- .../linalg/api/ops/impl/grid/BaseGridOp.java | 5 +- .../api/ops/impl/image/CropAndResize.java | 3 +- .../linalg/api/ops/impl/image/ResizeArea.java | 4 +- .../api/ops/impl/image/ResizeBilinear.java | 4 +- .../impl/layers/convolution/AvgPooling2D.java | 2 +- .../impl/layers/convolution/DeConv2DTF.java | 2 +- .../impl/layers/convolution/DepthToSpace.java | 2 +- .../LocalResponseNormalization.java | 2 +- .../ops/impl/layers/convolution/SConv2D.java | 4 +- .../impl/layers/convolution/SpaceToDepth.java | 4 +- .../recurrent/outputs/GRUCellOutputs.java | 8 +- .../recurrent/outputs/LSTMCellOutputs.java | 14 +- .../recurrent/outputs/LSTMLayerOutputs.java | 8 +- .../recurrent/outputs/SRUCellOutputs.java | 4 +- .../recurrent/outputs/SRULayerOutputs.java | 6 +- ...arseSoftmaxCrossEntropyLossWithLogits.java | 2 +- .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 4 +- .../api/ops/impl/reduce/TensorMmul.java | 6 +- .../api/ops/impl/reduce3/EqualsWithEps.java | 3 +- .../api/ops/impl/scalar/ScalarDivision.java | 3 +- .../ops/impl/scalar/ScalarSubtraction.java | 3 +- .../ops/impl/scalar/comparison/ScalarAnd.java | 3 +- .../ops/impl/scalar/comparison/ScalarEps.java | 3 +- .../impl/scalar/comparison/ScalarEquals.java | 3 +- .../scalar/comparison/ScalarGreaterThan.java | 3 +- .../comparison/ScalarGreaterThanOrEqual.java | 3 +- .../scalar/comparison/ScalarLessThan.java | 3 +- .../comparison/ScalarLessThanOrEqual.java | 3 +- .../ops/impl/scalar/comparison/ScalarNot.java | 3 +- .../scalar/comparison/ScalarNotEquals.java | 3 +- .../ops/impl/scalar/comparison/ScalarOr.java | 3 +- .../scalar/comparison/ScalarSetValue.java | 3 +- .../ops/impl/scalar/comparison/ScalarXor.java | 3 +- .../api/ops/impl/scatter/ScatterAdd.java | 2 +- .../api/ops/impl/scatter/ScatterDiv.java | 2 +- .../api/ops/impl/scatter/ScatterMax.java | 2 +- .../api/ops/impl/scatter/ScatterMin.java | 2 +- .../api/ops/impl/scatter/ScatterMul.java | 2 +- .../api/ops/impl/scatter/ScatterNd.java | 2 +- .../api/ops/impl/scatter/ScatterNdAdd.java | 2 +- .../api/ops/impl/scatter/ScatterNdSub.java | 2 +- .../api/ops/impl/scatter/ScatterNdUpdate.java | 2 +- .../api/ops/impl/scatter/ScatterSub.java | 2 +- .../api/ops/impl/scatter/ScatterUpdate.java | 4 +- .../ops/impl/shape/ApplyGradientDescent.java | 3 +- .../linalg/api/ops/impl/shape/Create.java | 4 +- .../linalg/api/ops/impl/shape/ExpandDims.java | 2 +- .../nd4j/linalg/api/ops/impl/shape/Eye.java | 2 +- .../linalg/api/ops/impl/shape/OnesLike.java | 2 +- .../linalg/api/ops/impl/shape/Repeat.java | 2 +- .../linalg/api/ops/impl/shape/Reshape.java | 1 - .../linalg/api/ops/impl/shape/Squeeze.java | 5 +- .../api/ops/impl/transforms/MaxOut.java | 5 +- .../api/ops/impl/transforms/any/IsMax.java | 2 +- .../ops/impl/transforms/clip/ClipByValue.java | 2 +- .../impl/transforms/custom/BatchToSpace.java | 2 +- .../transforms/custom/BatchToSpaceND.java | 2 +- .../ops/impl/transforms/custom/Choose.java | 2 +- .../impl/transforms/custom/SpaceToBatch.java | 2 +- .../transforms/custom/SpaceToBatchND.java | 2 +- .../impl/transforms/custom/StandardizeBp.java | 3 +- .../arithmetic/SquaredDifferenceOp.java | 2 +- .../api/ops/impl/transforms/same/Abs.java | 2 +- .../api/ops/impl/transforms/same/Ceil.java | 3 +- .../api/ops/impl/transforms/same/Floor.java | 3 +- .../ops/impl/transforms/same/Identity.java | 2 +- .../ops/impl/transforms/same/Negative.java | 3 +- .../api/ops/impl/transforms/same/Round.java | 3 +- .../api/ops/impl/transforms/same/Sign.java | 3 +- .../api/ops/impl/transforms/strict/ACosh.java | 3 +- .../api/ops/impl/transforms/strict/ASinh.java | 3 +- .../api/ops/impl/transforms/strict/ATanh.java | 3 +- .../api/ops/impl/transforms/strict/Cos.java | 3 +- .../api/ops/impl/transforms/strict/Cosh.java | 3 +- .../api/ops/impl/transforms/strict/Exp.java | 3 +- .../api/ops/impl/transforms/strict/Expm1.java | 3 +- .../api/ops/impl/transforms/strict/Swish.java | 3 +- .../ops/performance/PerformanceTracker.java | 6 +- .../random/compat/RandomStandardNormal.java | 2 +- .../api/ops/random/custom/RandomPoisson.java | 2 +- .../nd4j/linalg/api/rng/DefaultRandom.java | 2 +- .../impl/ConstantDistribution.java | 2 +- .../impl/OrthogonalDistribution.java | 2 +- .../impl/UniformDistribution.java | 3 +- .../linalg/api/shape/LongShapeDescriptor.java | 26 +- .../java/org/nd4j/linalg/api/shape/Shape.java | 76 ++- .../linalg/api/shape/ShapeDescriptor.java | 28 +- .../nd4j/linalg/cache/ArrayDescriptor.java | 2 +- .../linalg/checkutil/NDArrayCreationUtil.java | 8 +- .../compression/BasicNDArrayCompressor.java | 2 +- .../compression/CompressedDataBuffer.java | 2 +- .../linalg/compression/CompressionUtils.java | 10 +- .../nd4j/linalg/convolution/Convolution.java | 6 +- .../linalg/convolution/OldConvolution.java | 2 +- .../linalg/dataset/AsyncDataSetIterator.java | 10 +- .../dataset/AsyncMultiDataSetIterator.java | 10 +- .../ExistingMiniBatchDataSetIterator.java | 2 +- .../dataset/MiniBatchFileDataSetIterator.java | 8 +- .../org/nd4j/linalg/dataset/ViewIterator.java | 2 +- .../adapter/MultiDataSetIteratorAdapter.java | 2 +- .../nd4j/linalg/dataset/api/DataSetUtil.java | 2 +- .../api/iterator/CachingDataSetIterator.java | 8 +- .../api/iterator/MultipleEpochsIterator.java | 4 +- .../api/iterator/SamplingDataSetIterator.java | 6 +- .../dataset/api/iterator/StandardScaler.java | 2 +- .../api/iterator/TestDataSetIterator.java | 6 +- .../iterator/TestMultiDataSetIterator.java | 2 +- .../cache/InFileAndMemoryDataSetCache.java | 4 +- .../iterator/cache/InFileDataSetCache.java | 2 +- .../iterator/cache/InMemoryDataSetCache.java | 4 +- .../CompositeDataSetPreProcessor.java | 2 +- .../CompositeMultiDataSetPreProcessor.java | 2 +- .../ImageMultiPreProcessingScaler.java | 7 +- .../api/preprocessor/MinMaxStrategy.java | 4 +- .../BaseUnderSamplingPreProcessor.java | 2 +- ...lingByMaskingMultiDataSetPreProcessor.java | 4 +- .../UnderSamplingByMaskingPreProcessor.java | 2 +- .../serializer/NormalizerSerializer.java | 2 +- .../linalg/dimensionalityreduction/PCA.java | 2 +- .../RandomProjection.java | 4 +- .../env/impl/WorkspacesDebugAction.java | 4 +- .../nd4j/linalg/factory/BaseBlasWrapper.java | 2 +- .../linalg/factory/BaseNDArrayFactory.java | 2 +- .../nd4j/linalg/factory/NDArrayFactory.java | 2 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 12 +- .../org/nd4j/linalg/factory/Nd4jBackend.java | 6 +- .../nd4j/linalg/factory/RandomFactory.java | 4 +- .../org/nd4j/linalg/heartbeat/Heartbeat.java | 2 +- .../linalg/heartbeat/reports/Environment.java | 9 +- .../nd4j/linalg/heartbeat/reports/Task.java | 9 +- .../heartbeat/utils/EnvironmentUtils.java | 4 +- .../nd4j/linalg/indexing/BooleanIndexing.java | 20 +- .../org/nd4j/linalg/indexing/IndexInfo.java | 2 +- .../nd4j/linalg/indexing/NDArrayIndex.java | 8 +- .../org/nd4j/linalg/indexing/PointIndex.java | 2 +- .../nd4j/linalg/indexing/SpecifiedIndex.java | 4 +- .../nd4j/linalg/indexing/conditions/And.java | 2 +- .../indexing/conditions/ConditionEquals.java | 2 +- .../nd4j/linalg/indexing/conditions/Not.java | 2 +- .../nd4j/linalg/indexing/conditions/Or.java | 2 +- .../indexing/functions/StableNumber.java | 2 +- .../nd4j/linalg/indexing/functions/Value.java | 2 +- .../lossfunctions/impl/LossBinaryXENT.java | 2 +- .../impl/LossMixtureDensity.java | 10 +- .../linalg/ops/transforms/Transforms.java | 2 +- .../org/nd4j/linalg/profiler/OpProfiler.java | 8 +- .../linalg/profiler/data/StackAggregator.java | 2 +- .../profiler/data/StringAggregator.java | 4 +- .../linalg/profiler/data/StringCounter.java | 4 +- .../data/primitives/StackDescriptor.java | 5 +- .../profiler/data/primitives/StackNode.java | 4 +- .../profiler/data/primitives/StackTree.java | 4 +- .../profiler/data/primitives/TimeSet.java | 2 +- .../org/nd4j/linalg/schedule/MapSchedule.java | 4 +- .../nd4j/linalg/string/NDArrayStrings.java | 9 +- .../org/nd4j/linalg/util/DataSetUtils.java | 12 +- .../org/nd4j/linalg/util/ND4JTestUtils.java | 2 +- .../org/nd4j/linalg/util/NDArrayMath.java | 2 +- .../linalg/workspace/BaseWorkspaceMgr.java | 6 +- .../java/org/nd4j/serde/json/JsonMappers.java | 4 +- .../java/org/nd4j/systeminfo/SystemInfo.java | 2 +- .../org/nd4j/versioncheck/VersionCheck.java | 2 +- .../nd4j/weightinit/BaseWeightInitScheme.java | 2 +- .../weightinit/impl/ConstantInitScheme.java | 2 +- .../impl/DistributionInitScheme.java | 2 +- .../impl/LecunUniformInitScheme.java | 2 +- .../nd4j/weightinit/impl/ReluInitScheme.java | 2 +- .../impl/ReluUniformInitScheme.java | 2 +- .../impl/SigmoidUniformInitScheme.java | 4 +- .../weightinit/impl/UniformInitScheme.java | 2 +- .../VarScalingNormalFanAvgInitScheme.java | 4 +- .../impl/VarScalingNormalFanInInitScheme.java | 2 +- .../VarScalingNormalFanOutInitScheme.java | 2 +- ...arScalingNormalUniformFanInInitScheme.java | 2 +- ...rScalingNormalUniformFanOutInitScheme.java | 2 +- .../VarScalingUniformFanAvgInitScheme.java | 4 +- .../impl/XavierFanInInitScheme.java | 2 +- .../weightinit/impl/XavierInitScheme.java | 4 +- .../impl/XavierUniformInitScheme.java | 4 +- .../java/org/deeplearning4j/BaseDL4JTest.java | 2 +- .../org/nd4j/common/base/Preconditions.java | 4 +- .../collection/CompactHeapStringList.java | 2 +- .../common/collection/IntArrayKeyMap.java | 4 +- .../common/collection/IntArrayKeySet.java | 2 +- .../collection/MultiDimensionalMap.java | 2 +- .../collection/MultiDimensionalSet.java | 2 +- .../common/holder/ObjectMapperHolder.java | 2 +- .../io/AbstractFileResolvingResource.java | 2 +- .../org/nd4j/common/io/AbstractResource.java | 4 +- .../org/nd4j/common/io/ClassPathResource.java | 2 +- .../org/nd4j/common/io/CollectionUtils.java | 17 +- .../java/org/nd4j/common/io/ObjectUtils.java | 6 +- .../org/nd4j/common/io/ReflectionUtils.java | 4 +- .../java/org/nd4j/common/io/StringUtils.java | 14 +- .../java/org/nd4j/common/io/VfsUtils.java | 40 +- .../nd4j/common/primitives/CounterMap.java | 2 +- .../serde/JsonDeserializerAtomicBoolean.java | 2 +- .../serde/JsonDeserializerAtomicDouble.java | 2 +- .../serde/JsonSerializerAtomicBoolean.java | 2 +- .../serde/JsonSerializerAtomicDouble.java | 2 +- .../org/nd4j/common/resources/Resources.java | 6 +- .../resources/strumpf/ResourceFile.java | 5 +- .../resources/strumpf/StrumpfResolver.java | 12 +- .../java/org/nd4j/common/tools/BTools.java | 4 +- .../main/java/org/nd4j/common/tools/SIS.java | 2 +- .../org/nd4j/common/util/ArchiveUtils.java | 4 +- .../java/org/nd4j/common/util/ArrayUtil.java | 10 +- .../main/java/org/nd4j/common/util/Index.java | 9 +- .../java/org/nd4j/common/util/MathUtils.java | 4 +- .../java/org/nd4j/common/util/Rational.java | 4 +- .../nd4j/common/util/SynchronizedTable.java | 2 +- .../common/function/FunctionalUtilsTest.java | 6 +- .../nd4j/common/io/ClassPathResourceTest.java | 2 +- .../org/nd4j/common/loader/TestFileBatch.java | 2 +- .../org/nd4j/common/tools/InfoValuesTest.java | 4 +- .../java/org/nd4j/common/tools/SISTest.java | 2 +- .../datasets/test/TestDataSetIterator.java | 2 +- .../core/evaluation/EvaluationTools.java | 2 +- .../core/listener/SystemPolling.java | 8 +- .../core/parallelism/AsyncIterator.java | 8 +- .../impl/RemoteUIStatsStorageRouter.java | 6 +- .../core/ui/UiConnectionInfo.java | 7 +- .../core/util/ModelGuesser.java | 2 +- .../core/util/MovingWindowMatrix.java | 2 +- .../RecordReaderDataSetiteratorTest.java | 108 ++--- .../RecordReaderMultiDataSetIteratorTest.java | 22 +- .../tools/SpecialImageRecordReader.java | 14 +- .../iterator/AbstractDataSetIteratorTest.java | 6 +- .../iterator/AsyncDataSetIteratorTest.java | 4 +- .../AsyncMultiDataSetIteratorTest.java | 4 +- .../iterator/DataSetIteratorTest.java | 4 +- .../iterator/DataSetSplitterTests.java | 6 +- .../EarlyTerminationDataSetIteratorTest.java | 6 +- ...lyTerminationMultiDataSetIteratorTest.java | 7 +- .../JointParallelDataSetIteratorTest.java | 12 +- .../iterator/MultiDataSetSplitterTests.java | 12 +- .../iterator/MultipleEpochsIteratorTest.java | 2 +- .../datasets/iterator/TestAsyncIterator.java | 4 +- .../tools/SimpleVariableGenerator.java | 12 +- .../earlystopping/TestEarlyStopping.java | 4 +- .../TestEarlyStoppingCompGraph.java | 2 +- .../org/deeplearning4j/eval/EvalTest.java | 12 +- .../java/org/deeplearning4j/eval/ROCTest.java | 4 +- .../eval/RegressionEvalTest.java | 6 +- .../exceptions/TestRecordReaders.java | 29 +- .../gradientcheck/AttentionLayerTest.java | 20 +- .../gradientcheck/BNGradientCheckTest.java | 12 +- .../gradientcheck/CNN1DGradientCheckTest.java | 10 +- .../gradientcheck/CNN3DGradientCheckTest.java | 24 +- .../gradientcheck/CNNGradientCheckTest.java | 2 +- .../gradientcheck/DropoutGradientCheck.java | 2 +- .../GlobalPoolingGradientCheckTests.java | 2 +- .../GradientCheckTestsComputationGraph.java | 26 +- .../GradientCheckTestsMasking.java | 4 +- .../gradientcheck/LRNGradientCheckTests.java | 2 +- .../gradientcheck/LSTMGradientCheckTests.java | 2 +- .../LossFunctionGradientCheck.java | 18 +- .../NoBiasGradientCheckTests.java | 6 +- .../OutputLayerGradientChecks.java | 28 +- .../gradientcheck/RnnGradientChecks.java | 8 +- .../gradientcheck/YoloGradientCheckTests.java | 2 +- .../MultiLayerNeuralNetConfigurationTest.java | 4 +- .../MultiNeuralNetConfLayerBuilderTest.java | 5 +- .../nn/conf/constraints/TestConstraints.java | 6 +- .../nn/conf/graph/ElementWiseVertexTest.java | 4 +- .../nn/conf/graph/ShiftVertexTest.java | 2 +- .../nn/conf/layers/LayerBuilderTest.java | 2 +- .../nn/conf/layers/LayerConfigTest.java | 16 +- .../conf/preprocessor/CNNProcessorTest.java | 24 +- .../conf/preprocessor/TestPreProcessors.java | 8 +- .../nn/conf/weightnoise/TestWeightNoise.java | 2 +- .../nn/graph/ComputationGraphTestRNN.java | 30 +- .../nn/graph/TestCompGraphCNN.java | 2 +- .../nn/graph/TestComputationGraphNetwork.java | 24 +- .../nn/graph/TestSetGetParameters.java | 4 +- .../nn/graph/TestVariableLengthTSCG.java | 14 +- .../nn/graph/graphnodes/TestGraphNodes.java | 18 +- .../nn/layers/ActivationLayerTest.java | 6 +- .../nn/layers/CacheModeTest.java | 2 +- .../nn/layers/DropoutLayerTest.java | 4 +- .../nn/layers/OutputLayerTest.java | 22 +- .../nn/layers/RepeatVectorTest.java | 11 +- .../deeplearning4j/nn/layers/SeedTest.java | 4 +- .../layers/convolution/Convolution3DTest.java | 32 +- .../ConvolutionLayerSetupTest.java | 10 +- .../LocallyConnectedLayerTest.java | 2 +- .../layers/convolution/SpaceToDepthTest.java | 25 +- .../convolution/SubsamplingLayerTest.java | 28 +- .../convolution/TestConvolutionModes.java | 8 +- .../layers/convolution/Upsampling1DTest.java | 20 +- .../layers/convolution/Upsampling2DTest.java | 21 +- .../layers/feedforward/dense/DenseTest.java | 6 +- .../embedding/EmbeddingLayerTest.java | 1 - .../normalization/BatchNormalizationTest.java | 20 +- .../normalization/LocalResponseTest.java | 15 +- .../objdetect/TestYolo2OutputLayer.java | 6 +- .../nn/layers/ocnn/OCNNOutputLayerTest.java | 2 +- .../pooling/GlobalPoolingMaskingTests.java | 16 +- .../layers/recurrent/BidirectionalTest.java | 18 +- .../GravesBidirectionalLSTMTest.java | 36 +- .../nn/layers/recurrent/GravesLSTMTest.java | 20 +- .../layers/recurrent/MaskZeroLayerTest.java | 8 +- .../recurrent/TestLastTimeStepLayer.java | 2 +- .../nn/layers/recurrent/TestRnnLayers.java | 16 +- .../nn/layers/recurrent/TestSimpleRnn.java | 2 +- .../layers/recurrent/TestTimeDistributed.java | 2 +- .../nn/layers/samediff/TestSameDiffConv.java | 4 +- .../samediff/testlayers/SameDiffDense.java | 4 +- .../testlayers/SameDiffMSEOutputLayer.java | 8 +- .../TestReconstructionDistributions.java | 2 +- .../nn/layers/variational/TestVAE.java | 2 +- .../nn/misc/TestNetConversion.java | 10 +- .../nn/misc/WorkspaceTests.java | 20 +- .../nn/mkldnn/ValidateMKLDNN.java | 6 +- .../nn/multilayer/MultiLayerTest.java | 26 +- .../nn/multilayer/MultiLayerTestRNN.java | 50 +- .../nn/multilayer/TestMasking.java | 2 +- .../nn/multilayer/TestSetGetParameters.java | 8 +- .../nn/multilayer/TestVariableLengthTS.java | 24 +- .../nn/transferlearning/TestFrozenLayers.java | 8 +- .../nn/updater/TestUpdaters.java | 8 +- .../nn/updater/custom/TestCustomUpdater.java | 2 +- .../nn/util/TestDataSetConsumer.java | 4 +- .../optimize/solver/TestOptimizers.java | 3 +- .../SmartFancyBlockingQueueTest.java | 2 +- .../optimizer/listener/ScoreStatTest.java | 18 +- .../optimizer/listener/TestListeners.java | 2 +- .../deeplearning4j/util/ArrayUtilTest.java | 13 +- .../deeplearning4j/util/ModelGuesserTest.java | 6 +- .../util/ModelSerializerTest.java | 2 +- .../deeplearning4j/util/TestUIDProvider.java | 9 +- .../datasets/base/IrisUtils.java | 2 +- .../datasets/fetchers/EmnistDataFetcher.java | 10 +- .../datasets/fetchers/MnistDataFetcher.java | 4 +- .../datasets/mnist/MnistDbFile.java | 2 +- .../datasets/mnist/MnistImageFile.java | 4 +- .../RecordReaderMultiDataSetIterator.java | 18 +- .../SequenceRecordReaderDataSetIterator.java | 6 +- .../iterator/AbstractDataSetIterator.java | 2 +- .../iterator/AsyncShieldDataSetIterator.java | 2 +- .../AsyncShieldMultiDataSetIterator.java | 2 +- .../CombinedMultiDataSetPreProcessor.java | 2 +- .../iterator/CombinedPreProcessor.java | 2 +- .../iterator/DataSetIteratorSplitter.java | 10 +- .../EarlyTerminationDataSetIterator.java | 4 +- .../EarlyTerminationMultiDataSetIterator.java | 4 +- .../iterator/FileSplitDataSetIterator.java | 8 +- .../MultiDataSetIteratorSplitter.java | 10 +- .../iterator/MultipleEpochsIterator.java | 2 +- .../iterator/RandomDataSetIterator.java | 2 +- .../iterator/RandomMultiDataSetIterator.java | 6 +- .../ReconstructionDataSetIterator.java | 2 +- .../iterator/ScrollableDataSetIterator.java | 7 +- .../ScrollableMultiDataSetIterator.java | 7 +- .../callbacks/InterleavedDataSetCallback.java | 6 +- .../impl/BenchmarkDataSetIterator.java | 8 +- .../impl/BenchmarkMultiDataSetIterator.java | 8 +- .../iterator/impl/ListDataSetIterator.java | 2 +- .../parallel/BaseParallelDataSetIterator.java | 5 +- .../FileSplitParallelDataSetIterator.java | 6 +- .../JointParallelDataSetIterator.java | 6 +- .../iterator/parallel/MultiBoolean.java | 2 +- .../nn/modelimport/keras/Hdf5Archive.java | 4 +- .../nn/modelimport/keras/KerasModel.java | 2 +- .../modelimport/keras/layers/TFOpLayer.java | 4 +- .../keras/layers/TFOpLayerImpl.java | 4 +- .../advanced/activations/KerasPReLU.java | 2 +- .../KerasDepthwiseConvolution2D.java | 2 +- .../keras/layers/core/KerasPermute.java | 2 +- .../keras/layers/recurrent/KerasLSTM.java | 2 +- .../layers/recurrent/KerasSimpleRnn.java | 4 +- .../layers/wrappers/KerasBidirectional.java | 2 +- .../preprocessors/ReshapePreprocessor.java | 4 +- .../keras/utils/KerasModelBuilder.java | 8 +- .../keras/utils/KerasModelUtils.java | 4 +- .../nn/modelimport/keras/MiscTests.java | 24 +- .../configurations/DeepCTRLambdaTest.java | 10 +- .../configurations/FullModelComparisons.java | 20 +- .../keras/configurations/JsonTest.java | 2 +- .../Keras1ModelConfigurationTest.java | 2 +- .../Keras2ModelConfigurationTest.java | 2 +- .../KerasInitilizationTest.java | 16 +- .../keras/e2e/KerasCustomLayerTest.java | 2 +- .../keras/e2e/KerasCustomLossTest.java | 2 +- .../keras/e2e/KerasLambdaTest.java | 4 +- .../keras/e2e/KerasYolo9000PredictTest.java | 2 +- .../activation/KerasLeakyReLUTest.java | 4 +- .../advanced/activation/KerasPReLUTest.java | 4 +- .../activation/KerasThresholdedReLUTest.java | 4 +- .../KerasAtrousConvolution1DTest.java | 4 +- .../KerasAtrousConvolution2DTest.java | 2 +- .../convolution/KerasConvolution1DTest.java | 8 +- .../convolution/KerasConvolution2DTest.java | 8 +- .../convolution/KerasConvolution3DTest.java | 8 +- .../convolution/KerasCropping1DTest.java | 4 +- .../convolution/KerasCropping2DTest.java | 4 +- .../convolution/KerasCropping3DTest.java | 4 +- .../convolution/KerasDeconvolution2DTest.java | 8 +- .../KerasDepthwiseConvolution2DTest.java | 4 +- .../KerasSeparableConvolution2DTest.java | 8 +- .../convolution/KerasUpsampling1DTest.java | 10 +- .../convolution/KerasUpsampling2DTest.java | 10 +- .../convolution/KerasUpsampling3DTest.java | 10 +- .../convolution/KerasZeroPadding1DTest.java | 4 +- .../convolution/KerasZeroPadding2DTest.java | 4 +- .../convolution/KerasZeroPadding3DTest.java | 4 +- .../layers/core/KerasActivationLayer.java | 8 +- .../keras/layers/core/KerasDenseTest.java | 8 +- .../keras/layers/core/KerasDropoutTest.java | 8 +- .../keras/layers/core/KerasMaskingTest.java | 4 +- .../keras/layers/core/KerasPermuteTest.java | 8 +- .../layers/core/KerasRepeatVectorTest.java | 10 +- .../keras/layers/core/KerasReshapeTest.java | 8 +- .../core/KerasSpatialDropout2DTest.java | 8 +- .../layers/embeddings/KerasEmbeddingTest.java | 8 +- .../local/KerasLocallyConnected1DTest.java | 8 +- .../local/KerasLocallyConnected2DTest.java | 8 +- .../layers/noise/KerasAlphaDropoutTest.java | 8 +- .../noise/KerasGaussianDropoutTest.java | 8 +- .../layers/noise/KerasGaussianNoiseTest.java | 8 +- .../KerasBatchNormalizationTest.java | 8 +- .../layers/pooling/KerasPooling1DTest.java | 8 +- .../layers/pooling/KerasPooling2DTest.java | 8 +- .../layers/pooling/KerasPooling3DTest.java | 8 +- .../keras/layers/recurrent/KerasLSTMTest.java | 16 +- .../layers/recurrent/KerasSimpleRnnTest.java | 10 +- .../wrappers/KerasBidirectionalTest.java | 8 +- .../weights/KerasWeightSettingTests.java | 3 +- .../vectorizer/BagOfWordsVectorizer.java | 3 +- .../vectorizer/DefaultInputStreamCreator.java | 2 +- .../vectorizer/TfidfVectorizer.java | 3 +- .../deeplearning4j/iterator/BertIterator.java | 8 +- .../iterator/CnnSentenceDataSetIterator.java | 2 +- .../provider/LabelAwareConverter.java | 4 +- .../learning/impl/elements/BatchItem.java | 2 +- .../impl/elements/BatchSequences.java | 4 +- .../learning/impl/elements/CBOW.java | 2 +- .../learning/impl/elements/SkipGram.java | 2 +- .../embeddings/learning/impl/sequence/DM.java | 2 +- .../loader/VectorsConfiguration.java | 3 +- .../loader/WordVectorSerializer.java | 60 +-- .../reader/impl/BasicModelUtils.java | 7 +- .../reader/impl/FlatModelUtils.java | 3 +- .../reader/impl/TreeModelUtils.java | 5 +- .../wordvectors/WordVectorsImpl.java | 2 +- .../models/fasttext/FastText.java | 2 +- .../models/node2vec/Node2Vec.java | 2 +- .../paragraphvectors/ParagraphVectors.java | 4 +- .../sequencevectors/SequenceVectors.java | 8 +- .../graph/huffman/GraphHuffman.java | 2 +- .../graph/primitives/IGraph.java | 22 +- .../walkers/impl/NearestVertexWalker.java | 4 +- .../graph/walkers/impl/PopularityWalker.java | 2 +- .../iterators/AbstractSequenceIterator.java | 4 +- .../listeners/ScoreListener.java | 5 +- .../listeners/SerializingListener.java | 13 +- .../sequencevectors/sequence/Sequence.java | 2 +- .../transformers/impl/GraphTransformer.java | 2 +- .../ParallelTransformerIterator.java | 4 +- .../models/word2vec/Huffman.java | 4 +- .../models/word2vec/StaticWord2Vec.java | 6 +- .../models/word2vec/StreamWork.java | 2 +- .../models/word2vec/VocabWork.java | 10 +- .../iterator/Word2VecDataFetcher.java | 10 +- .../iterator/Word2VecDataSetIterator.java | 8 +- .../word2vec/wordstore/VocabConstructor.java | 6 +- .../word2vec/wordstore/VocabularyHolder.java | 15 +- .../word2vec/wordstore/VocabularyWord.java | 3 +- .../wordstore/inmemory/AbstractCache.java | 6 +- .../inmemory/InMemoryLookupCache.java | 11 +- .../FileDocumentIterator.java | 4 +- .../FilenamesLabelAwareIterator.java | 4 +- .../text/documentiterator/LabelsSource.java | 4 +- .../inputsanitation/InputHomogenization.java | 2 +- .../movingwindow/ContextLabelRetriever.java | 10 +- .../text/movingwindow/Window.java | 6 +- .../text/movingwindow/WordConverter.java | 2 +- .../AggregatingSentenceIterator.java | 6 +- .../sentenceiterator/BasicLineIterator.java | 2 +- .../BasicResultSetIterator.java | 4 +- .../CollectionSentenceIterator.java | 2 +- .../FileSentenceIterator.java | 5 +- .../LineSentenceIterator.java | 2 +- .../MutipleEpochsSentenceIterator.java | 6 +- .../PrefetchingSentenceIterator.java | 18 +- .../sentenceiterator/StreamLineIterator.java | 6 +- .../SynchronizedSentenceIterator.java | 2 +- .../SentenceIteratorConverter.java | 4 +- .../LabelAwareFileSentenceIterator.java | 3 +- .../tokenizer/DefaultStreamTokenizer.java | 6 +- .../tokenizer/DefaultTokenizer.java | 2 +- .../preprocessor/CompositePreProcessor.java | 2 +- .../preprocessor/StringCleaning.java | 2 +- .../NGramTokenizerFactory.java | 2 +- .../iterator/TestBertIterator.java | 38 +- .../inmemory/InMemoryLookupTableTest.java | 16 +- .../models/fasttext/FastTextTest.java | 8 +- .../ParagraphVectorsTest.java | 14 +- .../sequencevectors/SequenceVectorsTest.java | 2 +- .../graph/walkers/impl/RandomWalkerTest.java | 16 +- .../ParallelTransformerIteratorTest.java | 6 +- .../iterator/Word2VecDataSetIteratorTest.java | 2 +- .../wordstore/VocabConstructorTest.java | 7 +- .../wordstore/inmemory/AbstractCacheTest.java | 4 +- .../BertWordPieceTokenizerTests.java | 4 +- .../client/NearestNeighborsClient.java | 2 +- .../clustering/algorithm/Distance.java | 4 +- .../clustering/cluster/CentersHolder.java | 9 +- .../clustering/info/ClusterSetInfo.java | 4 +- .../clustering/kdtree/HyperRect.java | 8 +- .../clustering/kdtree/KDTree.java | 2 +- .../clustering/lsh/RandomProjectionLSH.java | 12 +- .../clustering/quadtree/QuadTree.java | 4 +- .../clustering/randomprojection/RPTree.java | 7 +- .../clustering/randomprojection/RPUtils.java | 2 +- .../clustering/sptree/Cell.java | 2 +- .../clustering/sptree/SpTree.java | 2 +- .../clustering/util/MathUtils.java | 4 +- .../clustering/util/MultiThreadUtils.java | 2 +- .../clustering/vptree/VPTreeFillSearch.java | 6 +- .../clustering/kdtree/KDTreeTest.java | 6 +- .../clustering/kmeans/KMeansTest.java | 14 +- .../clustering/sptree/SPTreeTest.java | 11 +- .../vptree/VPTreeSerializationTests.java | 4 +- .../clustering/vptree/VpTreeNodeTest.java | 2 +- .../server/NearestNeighborsServer.java | 15 +- .../server/NearestNeighborTest.java | 4 +- .../EarlyStoppingConfiguration.java | 2 +- .../saver/LocalFileGraphSaver.java | 4 +- .../saver/LocalFileModelSaver.java | 4 +- .../scorecalc/ROCScoreCalculator.java | 4 +- .../trainer/BaseEarlyStoppingTrainer.java | 6 +- .../trainer/EarlyStoppingGraphTrainer.java | 2 +- .../trainer/EarlyStoppingTrainer.java | 4 +- .../deeplearning4j/eval/BaseEvaluation.java | 4 +- .../conf/ComputationGraphConfiguration.java | 12 +- .../nn/conf/NeuralNetConfiguration.java | 4 +- .../nn/conf/constraint/MaxNormConstraint.java | 2 +- .../conf/constraint/MinMaxNormConstraint.java | 2 +- .../conf/constraint/UnitNormConstraint.java | 2 +- .../distribution/BinomialDistribution.java | 4 +- .../conf/distribution/NormalDistribution.java | 4 +- .../serde/LegacyDistributionDeserializer.java | 2 +- .../nn/conf/graph/SubsetVertex.java | 2 +- .../nn/conf/layers/AbstractLSTM.java | 4 +- .../nn/conf/layers/CapsuleLayer.java | 4 +- .../nn/conf/layers/CenterLossOutputLayer.java | 8 +- .../nn/conf/layers/ConvolutionLayer.java | 2 +- .../conf/layers/GravesBidirectionalLSTM.java | 2 +- .../nn/conf/layers/InputTypeUtil.java | 69 ++- .../nn/conf/layers/LocallyConnected2D.java | 2 +- .../nn/conf/layers/PrimaryCapsules.java | 8 +- .../nn/conf/layers/Subsampling3DLayer.java | 2 +- .../nn/conf/layers/Upsampling1D.java | 4 +- .../nn/conf/layers/Upsampling3D.java | 2 +- .../conf/layers/convolutional/Cropping1D.java | 2 +- .../conf/layers/convolutional/Cropping2D.java | 2 +- .../conf/layers/convolutional/Cropping3D.java | 2 +- .../objdetect/BoundingBoxesDeserializer.java | 2 +- .../layers/samediff/SameDiffLambdaVertex.java | 4 +- .../CompositeReconstructionDistribution.java | 6 +- .../variational/VariationalAutoencoder.java | 5 +- .../nn/conf/memory/LayerMemoryReport.java | 8 +- .../FeedForwardToCnn3DPreProcessor.java | 2 +- .../conf/serde/BaseNetConfigDeserializer.java | 2 +- .../nn/conf/serde/JsonMappers.java | 4 +- .../serde/format/DataFormatDeserializer.java | 2 +- .../legacy/LegacyIntArrayDeserializer.java | 2 +- .../stepfunctions/DefaultStepFunction.java | 4 +- .../stepfunctions/GradientStepFunction.java | 4 +- .../NegativeDefaultStepFunction.java | 4 +- .../NegativeGradientStepFunction.java | 4 +- .../nn/gradient/DefaultGradient.java | 2 +- .../nn/graph/ComputationGraph.java | 14 +- .../graph/vertex/impl/ElementWiseVertex.java | 2 +- .../graph/vertex/impl/L2NormalizeVertex.java | 4 +- .../nn/graph/vertex/impl/L2Vertex.java | 2 +- .../nn/graph/vertex/impl/LayerVertex.java | 13 +- .../nn/graph/vertex/impl/MergeVertex.java | 2 +- .../graph/vertex/impl/PreprocessorVertex.java | 2 +- .../nn/graph/vertex/impl/ReshapeVertex.java | 6 +- .../nn/graph/vertex/impl/ScaleVertex.java | 2 +- .../nn/graph/vertex/impl/ShiftVertex.java | 2 +- .../nn/graph/vertex/impl/StackVertex.java | 4 +- .../nn/graph/vertex/impl/SubsetVertex.java | 4 +- .../nn/graph/vertex/impl/UnstackVertex.java | 6 +- .../impl/rnn/DuplicateToTimeSeriesVertex.java | 4 +- .../vertex/impl/rnn/LastTimeStepVertex.java | 4 +- .../deeplearning4j/nn/layers/FrozenLayer.java | 4 +- .../nn/layers/FrozenLayerWithBackprop.java | 4 +- .../deeplearning4j/nn/layers/HelperUtils.java | 6 +- .../nn/layers/RepeatVector.java | 4 +- .../layers/convolution/ConvolutionLayer.java | 2 +- .../layers/convolution/Cropping1DLayer.java | 4 +- .../layers/convolution/Cropping2DLayer.java | 4 +- .../layers/convolution/Cropping3DLayer.java | 4 +- .../convolution/ZeroPadding1DLayer.java | 4 +- .../convolution/ZeroPadding3DLayer.java | 4 +- .../layers/convolution/ZeroPaddingLayer.java | 2 +- .../autoencoder/recursive/Tree.java | 19 +- .../embedding/EmbeddingSequenceLayer.java | 2 +- .../nn/layers/mkldnn/MKLDNNConvHelper.java | 4 +- .../normalization/BatchNormalization.java | 7 +- .../nn/layers/objdetect/Yolo2OutputLayer.java | 12 +- .../nn/layers/objdetect/YoloUtils.java | 2 +- .../nn/layers/ocnn/OCNNOutputLayer.java | 4 +- .../layers/recurrent/BidirectionalLayer.java | 6 +- .../nn/layers/recurrent/LSTMHelpers.java | 2 +- .../nn/layers/recurrent/MaskZeroLayer.java | 2 +- .../nn/layers/recurrent/SimpleRnn.java | 2 +- .../recurrent/TimeDistributedLayer.java | 2 +- .../nn/layers/util/MaskLayer.java | 2 +- .../variational/VariationalAutoencoder.java | 5 +- .../nn/multilayer/MultiLayerNetwork.java | 6 +- .../DepthwiseConvolutionParamInitializer.java | 2 +- .../nn/params/PReLUParamInitializer.java | 4 +- .../nn/transferlearning/TransferLearning.java | 36 +- .../nn/updater/BaseMultiLayerUpdater.java | 3 +- .../nn/updater/LayerUpdater.java | 2 +- .../embeddings/WeightInitEmbedding.java | 2 +- .../nn/workspace/LayerWorkspaceMgr.java | 6 +- .../org/deeplearning4j/optimize/Solver.java | 2 +- .../listeners/CheckpointListener.java | 34 +- .../CollectScoresIterationListener.java | 9 +- .../listeners/PerformanceListener.java | 2 +- .../listeners/TimeIterationListener.java | 8 +- .../optimize/solvers/BackTrackLineSearch.java | 6 +- .../optimize/solvers/LBFGS.java | 2 +- .../util/Convolution1DUtils.java | 38 +- .../util/Convolution3DUtils.java | 39 +- .../deeplearning4j/util/ConvolutionUtils.java | 76 +-- .../util/CrashReportingUtil.java | 10 +- .../deeplearning4j/util/OutputLayerUtil.java | 9 +- .../deeplearning4j/util/ValidationUtils.java | 3 +- .../ParameterServerTrainerContext.java | 8 +- .../EarlyStoppingParallelTrainer.java | 12 +- .../parallelism/ParallelInference.java | 22 +- .../parallelism/ParallelWrapper.java | 4 +- .../observers/BasicInferenceObserver.java | 2 +- .../observers/BatchedInferenceObservable.java | 18 +- .../InplaceParallelInferenceTest.java | 8 +- .../parallelism/ParallelInferenceTest.java | 14 +- .../main/ParallelWrapperMainTest.java | 4 +- .../nd4j/python4j/PythonContextManager.java | 4 +- .../org/nd4j/python4j/PythonExecutioner.java | 2 +- .../java/org/nd4j/python4j/PythonObject.java | 12 +- .../java/org/nd4j/python4j/PythonProcess.java | 10 +- .../java/org/nd4j/python4j/PythonTypes.java | 10 +- .../src/test/java/PythonNumpyBasicTest.java | 4 +- .../test/java/PythonNumpyCollectionsTest.java | 2 +- .../test/java/PythonNumpyMultiThreadTest.java | 2 +- .../api/stats/StatsCalculationHelper.java | 6 +- .../worker/ExecuteWorkerPathMDSFlatMap.java | 2 +- .../data/BatchAndExportDataSetsFunction.java | 2 +- .../BatchAndExportMultiDataSetsFunction.java | 2 +- ...litDataSetExamplesPairFlatMapFunction.java | 2 +- .../datavec/DataVecByteDataSetFunction.java | 8 +- .../spark/datavec/RDDMiniBatches.java | 4 +- .../spark/datavec/RecordReaderFunction.java | 4 +- .../BaseSparkEarlyStoppingTrainer.java | 8 +- .../SparkDataSetLossCalculator.java | 6 +- .../SparkEarlyStoppingGraphTrainer.java | 2 +- .../SparkEarlyStoppingTrainer.java | 2 +- .../SparkLossCalculatorComputationGraph.java | 6 +- .../spark/impl/SparkListenable.java | 2 +- .../HashingBalancedPartitioner.java | 6 +- .../impl/evaluation/EvaluationRunner.java | 4 +- .../impl/graph/SparkComputationGraph.java | 6 +- ...VaeReconstructionErrorWithKeyFunction.java | 4 +- ...GVaeReconstructionProbWithKeyFunction.java | 4 +- .../ScoreFlatMapFunctionCGDataSet.java | 6 +- .../ScoreFlatMapFunctionCGMultiDataSet.java | 6 +- .../impl/multilayer/SparkDl4jMultiLayer.java | 6 +- ...VaeReconstructionErrorWithKeyFunction.java | 4 +- .../VaeReconstructionProbWithKeyFunction.java | 4 +- .../ParameterAveragingTrainingWorker.java | 8 +- ...ParameterAveragingTrainingMasterStats.java | 18 +- ...ParameterAveragingTrainingWorkerStats.java | 2 +- .../iterator/PathSparkDataSetIterator.java | 4 +- .../spark/iterator/SparkADSI.java | 2 +- .../spark/iterator/SparkAMDSI.java | 2 +- .../spark/stats/StatsUtils.java | 3 +- .../spark/time/NTPTimeSource.java | 2 +- .../deeplearning4j/spark/util/MLLibUtil.java | 2 +- .../deeplearning4j/spark/util/SparkUtils.java | 7 +- .../util/serde/StorageLevelDeserializer.java | 2 +- .../util/serde/StorageLevelSerializer.java | 2 +- .../spark/TestEarlyStoppingSpark.java | 2 +- .../TestEarlyStoppingSparkCompGraph.java | 2 +- .../org/deeplearning4j/spark/TestKryo.java | 34 +- .../spark/datavec/MiniBatchTests.java | 4 +- .../datavec/TestDataVecDataSetFunctions.java | 8 +- .../impl/graph/TestSparkComputationGraph.java | 2 +- .../impl/multilayer/TestMiscFunctions.java | 6 +- .../multilayer/TestSparkDl4jMultiLayer.java | 5 +- ...arameterAveragingSparkVsSingleMachine.java | 2 +- .../stats/TestTrainingStatsCollection.java | 10 +- .../spark/ui/TestListeners.java | 2 +- .../spark/util/TestRepartitioning.java | 10 +- .../spark/util/TestValidation.java | 13 +- .../word2vec/FirstIterationFunction.java | 40 +- .../embeddings/word2vec/NegativeHolder.java | 4 +- .../word2vec/SecondIterationFunction.java | 34 +- .../embeddings/word2vec/SentenceBatch.java | 2 +- .../embeddings/word2vec/VocabHolder.java | 15 +- .../models/embeddings/word2vec/Word2Vec.java | 10 +- .../embeddings/word2vec/Word2VecChange.java | 2 +- .../embeddings/word2vec/Word2VecParam.java | 2 +- .../word2vec/Word2VecPerformer.java | 16 +- .../word2vec/Word2VecPerformerVoid.java | 6 +- .../embeddings/word2vec/Word2VecSetup.java | 2 +- .../word2vec/Word2VecVariables.java | 10 +- .../spark/text/functions/CountCumSum.java | 4 +- .../FoldBetweenPartitionFunction.java | 2 +- .../FoldWithinPartitionFunction.java | 2 +- .../spark/text/functions/TextPipeline.java | 2 +- .../text/functions/TokenizerFunction.java | 4 +- .../UpdateWordFreqAccumulatorFunction.java | 4 +- .../embeddings/word2vec/Word2VecTest.java | 3 +- .../spark/text/TestFunction.java | 2 +- .../spark/text/TextPipelineTest.java | 11 +- .../networking/v1/SilentTrainingDriver.java | 2 +- .../pw/SharedTrainingWrapper.java | 6 +- .../python/ArrayDescriptor.java | 8 +- .../python/DataSetDescriptor.java | 7 +- .../training/SharedTrainingMaster.java | 4 +- .../iterators/VirtualDataSetIteratorTest.java | 4 +- .../train/GradientSharingTrainingTest.java | 2 +- .../deeplearning4j/plot/BarnesHutTsne.java | 6 +- .../java/org/deeplearning4j/plot/Tsne.java | 2 +- .../org/deeplearning4j/plot/Test6058.java | 2 +- .../nativeblas/BaseNativeNDArrayFactory.java | 15 +- .../java/org/nd4j/compression/impl/NoOp.java | 10 +- .../deallocator/GarbageStateReference.java | 2 +- .../deallocator/NativeRandomDeallocator.java | 2 +- .../nd4j/storage/CompressedRamStorage.java | 4 +- .../cpu/nativecpu/CpuNDArrayFactory.java | 13 +- .../linalg/cpu/nativecpu/CpuTADManager.java | 4 +- .../nativecpu/DirectShapeInfoProvider.java | 4 +- .../linalg/cpu/nativecpu/blas/CpuLapack.java | 14 +- .../linalg/cpu/nativecpu/blas/CpuLevel1.java | 2 +- .../linalg/cpu/nativecpu/blas/CpuLevel2.java | 2 +- .../linalg/cpu/nativecpu/blas/CpuLevel3.java | 2 +- .../nativecpu/buffer/BaseCpuDataBuffer.java | 12 +- .../cpu/nativecpu/buffer/Utf8Buffer.java | 9 +- .../nativecpu/cache/ConstantBuffersCache.java | 24 +- .../compression/CpuFlexibleThreshold.java | 2 +- .../nativecpu/compression/CpuThreshold.java | 2 +- .../cpu/nativecpu/ops/CpuOpContext.java | 10 +- .../nativecpu/ops/NativeOpExecutioner.java | 38 +- .../workspace/CpuWorkspaceDeallocator.java | 8 +- .../concurrency/DeviceAllocationsTracker.java | 2 +- .../jita/allocator/concurrency/RRWLock.java | 6 +- .../jita/allocator/impl/AllocationPoint.java | 14 +- .../jita/allocator/impl/AtomicAllocator.java | 20 +- .../jita/allocator/impl/CudaDeallocator.java | 2 +- .../jita/allocator/impl/MemoryTracker.java | 14 +- .../nd4j/jita/allocator/impl/NestedPoint.java | 9 +- .../jita/allocator/tad/DeviceTADManager.java | 2 +- .../jita/allocator/time/impl/BinaryTimer.java | 10 +- .../jita/allocator/time/impl/SimpleTimer.java | 1 - .../time/providers/OperativeProvider.java | 2 +- .../jita/allocator/time/rings/LockedRing.java | 2 +- .../jita/allocator/utils/AllocationUtils.java | 2 +- .../jita/concurrency/CudaAffinityManager.java | 8 +- .../nd4j/jita/concurrency/EventsProvider.java | 6 +- .../org/nd4j/jita/conf/Configuration.java | 14 +- .../org/nd4j/jita/conf/CudaEnvironment.java | 4 +- .../nd4j/jita/constant/ConstantProtector.java | 4 +- .../ProtectedCudaConstantHandler.java | 2 +- .../ProtectedCudaShapeInfoProvider.java | 8 +- .../jita/handler/impl/CudaZeroHandler.java | 21 +- .../nd4j/jita/workspace/CudaWorkspace.java | 4 +- .../workspace/CudaWorkspaceDeallocator.java | 6 +- .../nd4j/linalg/jcublas/CublasPointer.java | 8 +- .../linalg/jcublas/JCublasNDArrayFactory.java | 25 +- .../linalg/jcublas/blas/JcublasLapack.java | 44 +- .../linalg/jcublas/blas/JcublasLevel1.java | 8 +- .../linalg/jcublas/blas/JcublasLevel2.java | 6 +- .../linalg/jcublas/blas/JcublasLevel3.java | 12 +- .../jcublas/buffer/BaseCudaDataBuffer.java | 7 +- .../linalg/jcublas/buffer/CudaUtf8Buffer.java | 9 +- .../linalg/jcublas/context/CudaContext.java | 2 +- .../ops/executioner/CudaExecutioner.java | 16 +- .../ops/executioner/CudaGridExecutioner.java | 50 +- .../ops/executioner/CudaOpContext.java | 8 +- cavis-native/cavis-native-lib/build.gradle | 2 +- .../nd4j/aeron/ipc/AeronNDArrayPublisher.java | 2 +- .../java/org/nd4j/aeron/ipc/AeronUtil.java | 16 +- .../aeron/ipc/NDArrayFragmentHandler.java | 4 +- .../org/nd4j/aeron/ipc/NDArrayMessage.java | 2 +- .../ipc/chunk/InMemoryChunkAccumulator.java | 2 +- .../ndarrayholder/InMemoryNDArrayHolder.java | 4 +- .../nd4j/aeron/ipc/LargeNdArrayIpcTest.java | 6 +- .../org/nd4j/aeron/ipc/NdArrayIpcTest.java | 8 +- .../org/nd4j/common/base/Preconditions.java | 4 +- .../collection/CompactHeapStringList.java | 2 +- .../common/collection/IntArrayKeyMap.java | 4 +- .../common/collection/IntArrayKeySet.java | 2 +- .../collection/MultiDimensionalMap.java | 2 +- .../collection/MultiDimensionalSet.java | 2 +- .../common/holder/ObjectMapperHolder.java | 2 +- .../io/AbstractFileResolvingResource.java | 2 +- .../org/nd4j/common/io/AbstractResource.java | 4 +- .../org/nd4j/common/io/ClassPathResource.java | 2 +- .../org/nd4j/common/io/CollectionUtils.java | 17 +- .../java/org/nd4j/common/io/ObjectUtils.java | 6 +- .../org/nd4j/common/io/ReflectionUtils.java | 4 +- .../java/org/nd4j/common/io/StringUtils.java | 14 +- .../java/org/nd4j/common/io/VfsUtils.java | 40 +- .../nd4j/common/primitives/CounterMap.java | 2 +- .../serde/JsonDeserializerAtomicBoolean.java | 2 +- .../serde/JsonDeserializerAtomicDouble.java | 2 +- .../serde/JsonSerializerAtomicBoolean.java | 2 +- .../serde/JsonSerializerAtomicDouble.java | 2 +- .../org/nd4j/common/resources/Resources.java | 6 +- .../resources/strumpf/ResourceFile.java | 5 +- .../resources/strumpf/StrumpfResolver.java | 12 +- .../java/org/nd4j/common/tools/BTools.java | 4 +- .../main/java/org/nd4j/common/tools/SIS.java | 2 +- .../org/nd4j/common/util/ArchiveUtils.java | 4 +- .../java/org/nd4j/common/util/ArrayUtil.java | 10 +- .../main/java/org/nd4j/common/util/Index.java | 9 +- .../java/org/nd4j/common/util/MathUtils.java | 4 +- .../java/org/nd4j/common/util/Rational.java | 4 +- .../nd4j/common/util/SynchronizedTable.java | 2 +- .../common/function/FunctionalUtilsTest.java | 6 +- .../org/nd4j/common/loader/TestFileBatch.java | 2 +- .../org/nd4j/common/tools/InfoValuesTest.java | 4 +- .../java/org/nd4j/common/tools/SISTest.java | 2 +- .../background/BackgroundDaemonStarter.java | 4 +- .../RemoteParameterServerClientTests.java | 6 +- .../ParameterServerClientPartialTest.java | 6 +- .../client/ParameterServerClientTest.java | 6 +- .../updater/SoftSyncParameterUpdater.java | 2 +- .../updater/SynchronousParameterUpdater.java | 2 +- .../storage/InMemoryUpdateStorage.java | 2 +- .../updater/storage/NoUpdateStorage.java | 2 +- .../logic/RetransmissionHandler.java | 2 +- .../completion/FrameCompletionHandler.java | 10 +- .../logic/storage/BaseStorage.java | 2 +- .../distributed/messages/Frame.java | 2 +- .../intercom/DistributedCbowDotMessage.java | 2 +- .../intercom/DistributedSgDotMessage.java | 2 +- .../training/impl/CbowTrainer.java | 8 +- .../training/impl/SkipGramTrainer.java | 10 +- .../distributed/transport/BaseTransport.java | 4 +- .../transport/RoutedTransport.java | 14 +- .../distributed/util/NetworkInformation.java | 3 +- .../distributed/v2/ModelParameterServer.java | 28 +- .../v2/chunks/impl/FileChunksTracker.java | 6 +- .../v2/chunks/impl/InmemoryChunksTracker.java | 2 +- .../params/UpdaterParametersMessage.java | 2 +- .../v2/transport/impl/BaseTransport.java | 4 +- .../v2/transport/impl/DummyTransport.java | 4 +- .../distributed/v2/util/MeshOrganizer.java | 8 +- .../VoidParameterServerStressTest.java | 8 +- .../distributed/VoidParameterServerTest.java | 15 +- .../distributed/logic/ClipboardTest.java | 4 +- .../logic/FrameCompletionHandlerTest.java | 2 +- .../logic/routing/InterleavedRouterTest.java | 2 +- .../distributed/messages/VoidMessageTest.java | 2 +- .../aggregations/VoidAggregationTest.java | 2 +- .../transport/RoutedTransportTest.java | 2 +- .../util/NetworkOrganizerTest.java | 20 +- .../v2/util/MeshOrganizerTest.java | 2 +- .../v2/util/MessageSplitterTest.java | 2 +- .../node/ParameterServerNodeTest.java | 8 +- .../status/play/StorageTests.java | 4 +- .../conversion/DummyDeAllocator.java | 2 +- .../tensorflow/conversion/TensorDataType.java | 34 +- .../conversion/TensorflowConversion.java | 6 +- .../conversion/graphrunner/GraphRunner.java | 8 +- .../ConvolutionalIterationListener.java | 10 +- .../ui/components/chart/Chart.java | 4 +- .../ui/components/chart/ChartHistogram.java | 6 +- .../components/chart/ChartHorizontalBar.java | 4 +- .../ui/components/chart/ChartLine.java | 8 +- .../ui/components/chart/ChartScatter.java | 6 +- .../ui/components/chart/ChartStackedArea.java | 4 +- .../ui/components/chart/ChartTimeline.java | 4 +- .../decorator/DecoratorAccordion.java | 4 +- .../ui/components/table/ComponentTable.java | 2 +- .../ui/components/text/ComponentText.java | 4 +- .../ui/standalone/StaticPageUtil.java | 3 +- .../word2vec/NearestNeighborsQuery.java | 3 +- .../ui/model/stats/BaseStatsListener.java | 4 +- .../ui/model/stats/impl/SbeUtil.java | 3 +- .../model/storage/InMemoryStatsStorage.java | 4 +- .../storage/mapdb/MapDBStatsStorage.java | 14 +- .../storage/sqlite/J7FileStatsStorage.java | 8 +- .../ui/model/weights/HistogramBin.java | 5 +- .../beans/CompactModelAndGradient.java | 9 +- .../org/deeplearning4j/ui/VertxUIServer.java | 14 +- .../org/deeplearning4j/ui/api/UIServer.java | 2 +- .../deeplearning4j/ui/i18n/DefaultI18N.java | 4 +- .../deeplearning4j/ui/i18n/I18NProvider.java | 2 +- .../module/remote/RemoteReceiverModule.java | 2 +- .../ui/module/train/TrainModule.java | 14 +- .../ui/module/train/TrainModuleUtils.java | 2 +- .../ui/module/tsne/TsneModule.java | 2 +- .../org/deeplearning4j/zoo/ModelMetaData.java | 2 +- .../java/org/deeplearning4j/zoo/ZooModel.java | 4 +- .../zoo/util/darknet/DarknetLabels.java | 4 +- .../org/deeplearning4j/zoo/TestImageNet.java | 4 +- .../deeplearning4j/zoo/TestInstantiation.java | 2 +- vsconfig.gradle | 2 +- 1274 files changed, 4725 insertions(+), 5023 deletions(-) diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java index d3a7179f6..7813efe6a 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java @@ -139,7 +139,6 @@ public class BrianTest /*extends BaseDL4JTest*/ { //.setExecutorEnv("spark.executor.cores", "2") //.setExecutorEnv("spark.executor.memory", "2g") //.set("spark.submit.deployMode", "client") - ; /* SparkSession spark = SparkSession @@ -240,7 +239,7 @@ public class BrianTest /*extends BaseDL4JTest*/ { */ TransformProcess tp = new TransformProcess.Builder(inputSchema) .removeAllColumnsExceptFor("country_code", "lat", "lon") - .stringToCategorical("country_code", Arrays.asList(new String[] {"GR", "FR", "DE", "CH"})) + .stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH")) .filter(new FilterInvalidValues()) .categoricalToOneHot("country_code") .build(); diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java index 436016352..be62228c1 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java @@ -225,7 +225,7 @@ public class BrianTest2 /*extends BaseDL4JTest*/ { */ TransformProcess tp = new TransformProcess.Builder(inputSchema) .removeAllColumnsExceptFor("country_code", "lat", "lon") - .stringToCategorical("country_code", Arrays.asList(new String[] {"GR", "FR", "DE", "CH"})) + .stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH")) .filter(new FilterInvalidValues()) .categoricalToOneHot("country_code") .build(); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index 29e80ce99..fbc0d60a3 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -91,10 +91,10 @@ public class IntegrationTestRunner { public static final double MAX_REL_ERROR_SCORES = 1e-4; - private static List> layerClasses = new ArrayList<>(); - private static List> preprocClasses = new ArrayList<>(); - private static List> graphVertexClasses = new ArrayList<>(); - private static List> evaluationClasses = new ArrayList<>(); + private static final List> layerClasses = new ArrayList<>(); + private static final List> preprocClasses = new ArrayList<>(); + private static final List> graphVertexClasses = new ArrayList<>(); + private static final List> evaluationClasses = new ArrayList<>(); private static Map, Integer> layerConfClassesSeen = new HashMap<>(); private static Map, Integer> preprocessorConfClassesSeen = new HashMap<>(); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java index 4ecc4dd2a..d65a0a9cc 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java @@ -67,8 +67,8 @@ public class CNN1DTestCases { testOverfitting = false; } - int miniBatchSize = 16; - int exampleLength = 128; + final int miniBatchSize = 16; + final int exampleLength = 128; @Override public ModelType modelType() { diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java index 8b5cf6358..3b351e277 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java @@ -271,11 +271,11 @@ public class CNN2DTestCases { public static TestCase getYoloHouseNumbers() { return new TestCase() { - private int width = 416; - private int height = 416; - private int nChannels = 3; - private int gridWidth = 13; - private int gridHeight = 13; + private final int width = 416; + private final int height = 416; + private final int nChannels = 3; + private final int gridWidth = 13; + private final int gridHeight = 13; { testName = "YOLOHouseNumbers"; diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java index 4c8448c63..f856d5159 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java @@ -108,7 +108,7 @@ public class CNN3DTestCases { public MultiDataSet getGradientsTestData() throws Exception { Nd4j.getRandom().setSeed(12345); //NCDHW format - INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + INDArray arr = Nd4j.rand(2, 3, 8, 8, 8); INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10); return new org.nd4j.linalg.dataset.MultiDataSet(arr, labels); } @@ -135,6 +135,6 @@ public class CNN3DTestCases { } }; - }; + } } diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java index 025f1ab54..a2cf437fe 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java @@ -93,8 +93,8 @@ public class RNNTestCases { minAbsErrorParamsPostTraining = 2e-3; } - private int miniBatchSize = 32; - private int exampleLength = 200; + private final int miniBatchSize = 32; + private final int exampleLength = 200; @Override diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java index a7be40676..4d038abf8 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/misc/CharacterIterator.java @@ -31,23 +31,24 @@ import java.io.File; import java.io.IOException; import java.net.URL; import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.util.*; public class CharacterIterator implements DataSetIterator { //Valid characters - private char[] validCharacters; + private final char[] validCharacters; //Maps each character to an index ind the input/output - private Map charToIdxMap; + private final Map charToIdxMap; //All characters of the input file (after filtering to only those that are valid - private char[] fileCharacters; + private final char[] fileCharacters; //Length of each example/minibatch (number of characters) - private int exampleLength; + private final int exampleLength; //Size of each minibatch (number of examples) - private int miniBatchSize; - private Random rng; + private final int miniBatchSize; + private final Random rng; //Offsets for the start of each example - private LinkedList exampleStartOffsets = new LinkedList<>(); + private final LinkedList exampleStartOffsets = new LinkedList<>(); /** * @param textFilePath Path to text file to use for generating samples @@ -299,7 +300,7 @@ public class CharacterIterator implements DataSetIterator { if (!f.exists()) throw new IOException("File does not exist: " + fileLocation); //Download problem? char[] validCharacters = CharacterIterator.getMinimalCharacterSet(); //Which characters are allowed? Others will be removed - return new CharacterIterator(fileLocation, Charset.forName("UTF-8"), + return new CharacterIterator(fileLocation, StandardCharsets.UTF_8, miniBatchSize, sequenceLength, validCharacters, new Random(12345)); } diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java index 98ec32dd0..2c28a2d3e 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/samediff/SameDiffCNNCases.java @@ -305,7 +305,7 @@ public class SameDiffCNNCases { // [minibatch,8,1,1,1] - int channels_height_width_depth = 8 * 1 * 1 * 1; + int channels_height_width_depth = 8; SDVariable layer1_reshaped = layer1.reshape(-1, channels_height_width_depth); @@ -331,7 +331,7 @@ public class SameDiffCNNCases { public Map getGradientsTestDataSameDiff() throws Exception { Nd4j.getRandom().setSeed(12345); //NCDHW format - INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + INDArray arr = Nd4j.rand(2, 3, 8, 8, 8); INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10); Map map = new HashMap<>(); @@ -357,7 +357,7 @@ public class SameDiffCNNCases { Nd4j.getRandom().setSeed(12345); List> list = new ArrayList<>(); - INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + INDArray arr = Nd4j.rand(2, 3, 8, 8, 8); list.add(Collections.singletonMap("in", arr)); @@ -368,7 +368,7 @@ public class SameDiffCNNCases { public MultiDataSet getGradientsTestData() throws Exception { Nd4j.getRandom().setSeed(12345); //NCDHW format - INDArray arr = Nd4j.rand(new int[]{2, 3, 8, 8, 8}); + INDArray arr = Nd4j.rand(2, 3, 8, 8, 8); INDArray labels = org.deeplearning4j.integration.TestUtils.randomOneHot(2, 10); return new org.nd4j.linalg.dataset.MultiDataSet(arr, labels); } diff --git a/build_requirements.md b/build_requirements.md index 602190b95..77d54050b 100644 --- a/build_requirements.md +++ b/build_requirements.md @@ -141,4 +141,8 @@ groupId:artifactId:packaging:classifier:version In your case it should work with -edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 \ No newline at end of file +edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 + + +Native cpu code under linux needs libc6-dev +/lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found \ No newline at end of file diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/conf/Configuration.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/conf/Configuration.java index 71b7f7c2a..922b31aed 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/conf/Configuration.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/conf/Configuration.java @@ -266,7 +266,7 @@ public class Configuration implements Iterable>, Writa reloadConfiguration(); } - private static Pattern varPat = Pattern.compile("\\$\\{[^\\}\\$\u0020]+\\}"); + private static final Pattern varPat = Pattern.compile("\\$\\{[^\\}\\$\u0020]+\\}"); private String substituteVars(String expr) { if (expr == null) { @@ -555,7 +555,7 @@ public class Configuration implements Iterable>, Writa } /** - * Get the value of the name property as a Pattern. + * Get the value of the name property as a {@code Pattern}. * If no such property is specified, or if the specified value is not a valid * Pattern, then DefaultValue is returned. * diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/formats/output/OutputFormat.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/formats/output/OutputFormat.java index e66e37a6d..14322d888 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/formats/output/OutputFormat.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/formats/output/OutputFormat.java @@ -27,7 +27,7 @@ import org.datavec.api.records.writer.RecordWriter; public interface OutputFormat { - public static final String OUTPUT_PATH = "org.nd4j.outputpath"; + String OUTPUT_PATH = "org.nd4j.outputpath"; /** * Create a record writer diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/BinaryComparable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/BinaryComparable.java index 4f19f0b78..a75fe0b30 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/BinaryComparable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/BinaryComparable.java @@ -34,7 +34,7 @@ public abstract class BinaryComparable implements Comparable { /** * Compare bytes from {#getBytes()}. - * @see org.apache.hadoop.io.WritableComparator#compareBytes(byte[],int,int,byte[],int,int) + * {@code org.apache.hadoop.io.WritableComparator#compareBytes(byte[], int, int, byte[], int, int)} */ public int compareTo(BinaryComparable other) { if (this == other) @@ -63,7 +63,7 @@ public abstract class BinaryComparable implements Comparable { /** * Return a hash of the bytes returned from {#getBytes()}. - * @see org.apache.hadoop.io.WritableComparator#hashBytes(byte[],int) + * {@code org.apache.hadoop.io.WritableComparator#hashBytes(byte[],int)} */ public int hashCode() { return WritableComparator.hashBytes(getBytes(), getLength()); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/DataInputBuffer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/DataInputBuffer.java index 7491f95bd..be57e50a3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/DataInputBuffer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/DataInputBuffer.java @@ -50,7 +50,7 @@ public class DataInputBuffer extends DataInputStream { } } - private Buffer buffer; + private final Buffer buffer; /** Constructs a new empty buffer. */ public DataInputBuffer() { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/DataOutputBuffer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/DataOutputBuffer.java index 105ee2717..a43022885 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/DataOutputBuffer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/DataOutputBuffer.java @@ -44,7 +44,7 @@ public class DataOutputBuffer extends DataOutputStream { public void write(DataInput in, int len) throws IOException { int newcount = count + len; if (newcount > buf.length) { - byte newbuf[] = new byte[Math.max(buf.length << 1, newcount)]; + byte[] newbuf = new byte[Math.max(buf.length << 1, newcount)]; System.arraycopy(buf, 0, newbuf, 0, count); buf = newbuf; } @@ -53,7 +53,7 @@ public class DataOutputBuffer extends DataOutputStream { } } - private Buffer buffer; + private final Buffer buffer; /** Constructs a new empty buffer. */ public DataOutputBuffer() { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/RawComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/RawComparator.java index 4e3d056eb..0b4f83662 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/RawComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/RawComparator.java @@ -25,6 +25,6 @@ import java.util.Comparator; public interface RawComparator extends Comparator { - public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2); + int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparator.java index 16cc4f35b..bcd2b8074 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparator.java @@ -31,7 +31,7 @@ import java.util.HashMap; public class WritableComparator implements RawComparator { - private static HashMap comparators = new HashMap<>(); // registry + private static final HashMap comparators = new HashMap<>(); // registry /** Get a comparator for a {@link WritableComparable} implementation. */ public static synchronized WritableComparator get(Class c) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableUtils.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableUtils.java index ebac8c856..7070ce47b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableUtils.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableUtils.java @@ -229,7 +229,7 @@ public final class WritableUtils { /** * Serializes an integer to a binary stream with zero-compressed encoding. - * For -120 <= i <= 127, only one byte is used with the actual value. + * For -120 <= i <= 127, only one byte is used with the actual value. * For other values of i, the first byte value indicates whether the * integer is positive or negative, and the number of bytes that follow. * If the first byte value v is between -121 and -124, the following integer @@ -248,7 +248,7 @@ public final class WritableUtils { /** * Serializes a long to a binary stream with zero-compressed encoding. - * For -112 <= i <= 127, only one byte is used with the actual value. + * For -112 <= i lt;= 127, only one byte is used with the actual value. * For other values of i, the first byte value indicates whether the * long is positive or negative, and the number of bytes that follow. * If the first byte value v is between -113 and -120, the following long diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/LabelWriterConverter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/LabelWriterConverter.java index 6dc7c807e..470f88417 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/LabelWriterConverter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/LabelWriterConverter.java @@ -27,7 +27,7 @@ import org.datavec.api.writable.Writable; import java.util.List; public class LabelWriterConverter implements WritableConverter { - private List labels; + private final List labels; public LabelWriterConverter(List labels) { this.labels = labels; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathLabelGenerator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathLabelGenerator.java index 5995c4967..d5bb50d2a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathLabelGenerator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathLabelGenerator.java @@ -35,7 +35,7 @@ public interface PathLabelGenerator extends Serializable { * If true: infer the set of possible label classes, and convert these to integer indexes. If when true, the * returned Writables should be text writables.
*
- * For regression use cases (or PathLabelGenerator classification instances that do their own label -> integer + * For regression use cases (or PathLabelGenerator classification instances that do their own label -> integer * assignment), this should return false. * * @return whether label classes should be inferred diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/serializers/SerializationFactory.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/serializers/SerializationFactory.java index fc60b8fdc..b57ee475e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/serializers/SerializationFactory.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/serializers/SerializationFactory.java @@ -35,7 +35,7 @@ public class SerializationFactory extends Configured { private static final Logger LOG = LoggerFactory.getLogger(SerializationFactory.class.getName()); - private List> serializations = new ArrayList<>(); + private final List> serializations = new ArrayList<>(); /** *

@@ -47,7 +47,7 @@ public class SerializationFactory extends Configured { public SerializationFactory(Configuration conf) { super(conf); for (String serializerName : conf.getStrings("io.serializations", - new String[] {"org.apache.hadoop.io.serializer.WritableSerialization"})) { + "org.apache.hadoop.io.serializer.WritableSerialization")) { add(conf, serializerName); } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/Buffer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/Buffer.java index 8c6dbfa1f..496af6f72 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/Buffer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/Buffer.java @@ -113,7 +113,7 @@ public class Buffer implements Comparable, Cloneable { /** * Change the capacity of the backing storage. - * The data is preserved if newCapacity >= getCount(). + * The data is preserved if newCapacity >= getCount(). * @param newCapacity The new capacity in bytes. */ public void setCapacity(int newCapacity) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/IOUtils.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/IOUtils.java index 2a4793ada..dc1c6cae1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/IOUtils.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/IOUtils.java @@ -209,9 +209,7 @@ public class IOUtils { * @return */ static String toCSVBuffer(Buffer buf) { - StringBuilder sb = new StringBuilder("#"); - sb.append(buf.toString()); - return sb.toString(); + return "#" + buf.toString(); } /** @@ -441,7 +439,7 @@ public class IOUtils { /** * Serializes a long to a binary stream with zero-compressed encoding. - * For -112 <= i <= 127, only one byte is used with the actual value. + * For -112 <= i <= 127, only one byte is used with the actual value. * For other values of i, the first byte value indicates whether the * long is positive or negative, and the number of bytes that follow. * If the first byte value v is between -113 and -120, the following long diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java index 14d1da31d..a72793529 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java @@ -99,8 +99,6 @@ public interface RecordReader extends AutoCloseable, Serializable, Configurable /** * Reset record reader iterator - * - * @return */ void reset(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java index 96693a1a3..905854f03 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java @@ -39,7 +39,7 @@ import java.util.List; */ public class ComposableRecordReader extends BaseRecordReader { - private RecordReader[] readers; + private final RecordReader[] readers; public ComposableRecordReader(RecordReader... readers) { this.readers = readers; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java index e01d93ed1..ab436407a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java @@ -35,7 +35,7 @@ import java.util.List; public class ConcatenatingRecordReader extends BaseRecordReader { - private RecordReader[] readers; + private final RecordReader[] readers; public ConcatenatingRecordReader(RecordReader... readers) { this.readers = readers; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java index 376205d47..a9448b981 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java @@ -196,7 +196,7 @@ public class FileRecordReader extends BaseRecordReader { while ((line = br.readLine()) != null) { sb.append(line).append("\n"); } - return Collections.singletonList((Writable) new Text(sb.toString())); + return Collections.singletonList(new Text(sb.toString())); } @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java index 94314393e..b05b739df 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java @@ -200,7 +200,7 @@ public class LineRecordReader extends BaseRecordReader { //Here: we are reading a single line from the DataInputStream BufferedReader br = new BufferedReader(new InputStreamReader(dataInputStream)); String line = br.readLine(); - return Collections.singletonList((Writable) new Text(line)); + return Collections.singletonList(new Text(line)); } protected Iterator getIterator(int location) { @@ -265,7 +265,7 @@ public class LineRecordReader extends BaseRecordReader { throw new IllegalArgumentException( "Invalid metadata; expected RecordMetaDataLine instance; got: " + rmd); } - list.add(new Triple<>(count++, (RecordMetaDataLine) rmd, (List) null)); + list.add(new Triple<>(count++, (RecordMetaDataLine) rmd, null)); if (rmd.getURI() != null) uris.add(rmd.getURI()); } @@ -332,7 +332,7 @@ public class LineRecordReader extends BaseRecordReader { throw new IllegalStateException("Could not get line " + nextLineIdx + " from URI " + currentURI + ": has only " + currentLineIdx + " lines"); } - t.setThird(Collections.singletonList(new Text(line))); + t.setThird(Collections.singletonList(new Text(line))); } } else { //Not URI based: String split, etc @@ -347,7 +347,7 @@ public class LineRecordReader extends BaseRecordReader { line = iterator.next(); currentLineIdx++; } - t.setThird(Collections.singletonList(new Text(line))); + t.setThird(Collections.singletonList(new Text(line))); } closeIfRequired(iterator); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java index c87b541f8..e33f0a9ec 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java @@ -43,7 +43,7 @@ public class CollectionSequenceRecordReader extends BaseRecordReader implements /** * - * @param records Collection of sequences. For example, List>> where the inner two lists + * @param records Collection of sequences. For example, {@code List>>} where the inner two lists * are a sequence, and the outer list/collection is a list of sequences */ public CollectionSequenceRecordReader(Collection>> records) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVMultiSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVMultiSequenceRecordReader.java index abcc113ae..5e4571f81 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVMultiSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVMultiSequenceRecordReader.java @@ -45,9 +45,9 @@ public class CSVMultiSequenceRecordReader extends CSVRecordReader implements Seq PAD } - private String sequenceSeparatorRegex; - private Mode mode; - private Writable padValue; + private final String sequenceSeparatorRegex; + private final Mode mode; + private final Writable padValue; /** * Create a sequence reader using the default value for skip lines (0), the default delimiter (',') and the default diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVNLinesSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVNLinesSequenceRecordReader.java index 71faf9d81..86e9c3c64 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVNLinesSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVNLinesSequenceRecordReader.java @@ -41,7 +41,7 @@ public class CSVNLinesSequenceRecordReader extends CSVRecordReader implements Se public static final String LINES_PER_SEQUENCE = NAME_SPACE + ".nlinespersequence"; private int nLinesPerSequence; - private String delimiter; + private final String delimiter; /** * No-arg constructor with the default number of lines per sequence (10) @@ -124,7 +124,7 @@ public class CSVNLinesSequenceRecordReader extends CSVRecordReader implements Se "Invalid metadata; expected RecordMetaDataLineInterval instance; got: " + rmd); } list.add(new Triple<>(count++, (RecordMetaDataLineInterval) rmd, - (List>) new ArrayList>())); + new ArrayList>())); } //Sort by starting line number: diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java index 1a25a2ab4..02a94f8d8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java @@ -39,8 +39,8 @@ public class CSVVariableSlidingWindowRecordReader extends CSVRecordReader implem public static final String LINES_PER_SEQUENCE = NAME_SPACE + ".nlinespersequence"; private int maxLinesPerSequence; - private String delimiter; - private int stride; + private final String delimiter; + private final int stride; private LinkedList> queue; private boolean exhausted; @@ -60,7 +60,7 @@ public class CSVVariableSlidingWindowRecordReader extends CSVRecordReader implem /** * @param maxLinesPerSequence Number of lines in each sequence, use default delemiter(,) between entries in the same line - * @param stride Number of lines between records (increment window > 1 line) + * @param stride Number of lines between records (increment window > 1 line) */ public CSVVariableSlidingWindowRecordReader(int maxLinesPerSequence, int stride) { this(maxLinesPerSequence, 0, stride, String.valueOf(CSVRecordReader.DEFAULT_DELIMITER)); @@ -68,7 +68,7 @@ public class CSVVariableSlidingWindowRecordReader extends CSVRecordReader implem /** * @param maxLinesPerSequence Number of lines in each sequence, use default delemiter(,) between entries in the same line - * @param stride Number of lines between records (increment window > 1 line) + * @param stride Number of lines between records (increment window > 1 line) */ public CSVVariableSlidingWindowRecordReader(int maxLinesPerSequence, int stride, String delimiter) { this(maxLinesPerSequence, 0, stride, String.valueOf(CSVRecordReader.DEFAULT_DELIMITER)); @@ -78,7 +78,7 @@ public class CSVVariableSlidingWindowRecordReader extends CSVRecordReader implem * * @param maxLinesPerSequence Number of lines in each sequences * @param skipNumLines Number of lines to skip at the start of the file (only skipped once, not per sequence) - * @param stride Number of lines between records (increment window > 1 line) + * @param stride Number of lines between records (increment window > 1 line) * @param delimiter Delimiter between entries in the same line, for example "," */ public CSVVariableSlidingWindowRecordReader(int maxLinesPerSequence, int skipNumLines, int stride, String delimiter) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/SerializableCSVParser.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/SerializableCSVParser.java index f8b033633..f9222fb42 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/SerializableCSVParser.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/SerializableCSVParser.java @@ -302,7 +302,7 @@ public class SerializableCSVParser implements Serializable { } /** - * precondition: sb.length() > 0 + * precondition: sb.length() > 0 * * @param sb A sequence of characters to examine * @return true if every character in the sequence is whitespace diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemoryRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemoryRecordReader.java index 105b2068a..d9023b46a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemoryRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemoryRecordReader.java @@ -114,8 +114,6 @@ public class InMemoryRecordReader implements RecordReader { /** * Reset record reader iterator - * - * @return */ @Override public void reset() { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemorySequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemorySequenceRecordReader.java index f97e5f28e..76be03200 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemorySequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemorySequenceRecordReader.java @@ -195,8 +195,6 @@ public class InMemorySequenceRecordReader implements SequenceRecordReader { /** * Reset record reader iterator - * - * @return */ @Override public void reset() { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/FieldSelection.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/FieldSelection.java index 08644df9a..e3c36bb53 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/FieldSelection.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/FieldSelection.java @@ -31,8 +31,8 @@ public class FieldSelection implements Serializable { public static final Writable DEFAULT_MISSING_VALUE = new Text(""); - private List fieldPaths; - private List valueIfMissing; + private final List fieldPaths; + private final List valueIfMissing; private FieldSelection(Builder builder) { this.fieldPaths = builder.fieldPaths; @@ -53,8 +53,8 @@ public class FieldSelection implements Serializable { public static class Builder { - private List fieldPaths = new ArrayList<>(); - private List valueIfMissing = new ArrayList<>(); + private final List fieldPaths = new ArrayList<>(); + private final List valueIfMissing = new ArrayList<>(); /** diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineRecordReader.java index 17f348e54..e759b6aa6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineRecordReader.java @@ -29,8 +29,8 @@ import com.fasterxml.jackson.databind.ObjectMapper; public class JacksonLineRecordReader extends LineRecordReader { - private FieldSelection selection; - private ObjectMapper mapper; + private final FieldSelection selection; + private final ObjectMapper mapper; public JacksonLineRecordReader(FieldSelection selection, ObjectMapper mapper) { this.selection = selection; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java index 7b27cae0f..3c2d81f69 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java @@ -39,8 +39,8 @@ import java.util.NoSuchElementException; public class JacksonLineSequenceRecordReader extends FileRecordReader implements SequenceRecordReader { - private FieldSelection selection; - private ObjectMapper mapper; + private final FieldSelection selection; + private final ObjectMapper mapper; /** * diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java index 8e5e571e7..de0d41573 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java @@ -45,12 +45,12 @@ public class JacksonRecordReader extends BaseRecordReader { private static final TypeReference> typeRef = new TypeReference>() {}; - private FieldSelection selection; - private ObjectMapper mapper; - private boolean shuffle; - private long rngSeed; - private PathLabelGenerator labelGenerator; - private int labelPosition; + private final FieldSelection selection; + private final ObjectMapper mapper; + private final boolean shuffle; + private final long rngSeed; + private final PathLabelGenerator labelGenerator; + private final int labelPosition; private InputSplit is; private Random r; @Getter @Setter diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java index b9e52f33a..419c82c4d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java @@ -35,7 +35,7 @@ import java.util.List; public class MatlabRecordReader extends FileRecordReader { - private List> records = new ArrayList<>(); + private final List> records = new ArrayList<>(); private Iterator> currIter; @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java index 4534162bd..6d5bc5ea1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java @@ -96,8 +96,6 @@ public class SVMLightRecordReader extends LineRecordReader { * Set configuration. * * @param conf DataVec configuration - * @throws IOException - * @throws InterruptedException */ @Override public void setConf(Configuration conf) { @@ -181,7 +179,7 @@ public class SVMLightRecordReader extends LineRecordReader { if (index < 0) throw new NumberFormatException(""); } catch (NumberFormatException e) { - String msg = String.format("Feature index must be positive integer (found %s)", featureTokens[i].toString()); + String msg = String.format("Feature index must be positive integer (found %s)", featureTokens[i]); throw new NumberFormatException(msg); } @@ -218,7 +216,7 @@ public class SVMLightRecordReader extends LineRecordReader { if (index < 0) throw new NumberFormatException(""); } catch (NumberFormatException e) { - String msg = String.format("Multilabel index must be positive integer (found %s)", labelTokens[i].toString()); + String msg = String.format("Multilabel index must be positive integer (found %s)", labelTokens[i]); throw new NumberFormatException(msg); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java index 298a5d931..3a216d784 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java @@ -41,11 +41,11 @@ import java.util.regex.Pattern; public class RegexLineRecordReader extends LineRecordReader { public final static String SKIP_NUM_LINES = NAME_SPACE + ".skipnumlines"; - private String regex; + private final String regex; private int skipNumLines; - private Pattern pattern; + private final Pattern pattern; private int numLinesSkipped; - private int currLine = 0; + private final int currLine = 0; public RegexLineRecordReader(String regex, int skipNumLines) { this.regex = regex; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java index 41b9f2e1b..ebf685d50 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java @@ -61,11 +61,11 @@ public class RegexSequenceRecordReader extends FileRecordReader implements Seque public static final Logger LOG = LoggerFactory.getLogger(RegexSequenceRecordReader.class); - private String regex; + private final String regex; private int skipNumLines; - private Pattern pattern; + private final Pattern pattern; private transient Charset charset; - private LineErrorHandling errorHandling; + private final LineErrorHandling errorHandling; public RegexSequenceRecordReader(String regex, int skipNumLines) { this(regex, skipNumLines, DEFAULT_CHARSET, DEFAULT_ERROR_HANDLING); @@ -92,7 +92,7 @@ public class RegexSequenceRecordReader extends FileRecordReader implements Seque @Override public List> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException { - String fileContents = IOUtils.toString(new BufferedInputStream(dataInputStream), charset.name()); + String fileContents = IOUtils.toString(new BufferedInputStream(dataInputStream), charset); return loadSequence(fileContents, uri); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java index 160b2c134..2b8a38d58 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java @@ -145,8 +145,6 @@ public class TransformProcessRecordReader implements RecordReader { /** * Reset record reader iterator - * - * @return */ @Override public void reset() { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java index 7023e70b4..cb9213dea 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java @@ -195,8 +195,6 @@ public class TransformProcessSequenceRecordReader implements SequenceRecordReade /** * Reset record reader iterator - * - * @return */ @Override public void reset() { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/SVMLightRecordWriter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/SVMLightRecordWriter.java index 56db1df58..c15f9ede6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/SVMLightRecordWriter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/SVMLightRecordWriter.java @@ -94,7 +94,7 @@ public class SVMLightRecordWriter extends FileRecordWriter { @Override public PartitionMetaData write(List record) throws IOException { if (!record.isEmpty()) { - List recordList = record instanceof List ? (List) record : new ArrayList<>(record); + List recordList = record instanceof List ? record : new ArrayList<>(record); /* Infer label columns, if necessary. The default is * to assume that last column is a label and that the @@ -198,7 +198,7 @@ public class SVMLightRecordWriter extends FileRecordWriter { } // Remove extra label delimiter at beginning - String line = result.substring(1).toString(); + String line = result.substring(1); out.write(line.getBytes()); out.write(NEW_LINE.getBytes()); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/BaseInputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/BaseInputSplit.java index 428a1df2e..7b26e65cf 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/BaseInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/BaseInputSplit.java @@ -124,9 +124,7 @@ public abstract class BaseInputSplit implements InputSplit { for (int i = 0; i < weights.length; i++) { List uris = new ArrayList<>(); - for (int j = partitions[i]; j < partitions[i + 1]; j++) { - uris.add(paths[j]); - } + uris.addAll(Arrays.asList(paths).subList(partitions[i], partitions[i + 1])); splits[i] = new CollectionInputSplit(uris); } return splits; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java index 97183f346..c6fb48d3d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/FileSplit.java @@ -138,7 +138,7 @@ public class FileSplit extends BaseInputSplit { return addNewLocation(new File(rootDir, UUID.randomUUID().toString()).toURI().toString()); else { //add a file in the same directory as the file with the same extension as the original file - return addNewLocation(new File(rootDir.getParent(), UUID.randomUUID().toString() + "." + FilenameUtils.getExtension(rootDir.getAbsolutePath())).toURI().toString()); + return addNewLocation(new File(rootDir.getParent(), UUID.randomUUID() + "." + FilenameUtils.getExtension(rootDir.getAbsolutePath())).toURI().toString()); } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputStreamInputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputStreamInputSplit.java index 7bd514745..fadb215cb 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputStreamInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/InputStreamInputSplit.java @@ -31,7 +31,7 @@ import java.util.Iterator; public class InputStreamInputSplit implements InputSplit { private InputStream is; - private URI[] location; + private final URI[] location; /** * Instantiate with the given @@ -130,7 +130,7 @@ public class InputStreamInputSplit implements InputSplit { public Iterator locationsPathIterator() { if(location.length >= 1) return Collections.singletonList(location[0].getPath()).iterator(); - return Arrays.asList("").iterator(); + return Collections.singletonList("").iterator(); } @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/ListStringSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/ListStringSplit.java index d979bdad7..0d714e603 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/ListStringSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/ListStringSplit.java @@ -33,7 +33,7 @@ import java.util.List; * has delimited data of some kind. */ public class ListStringSplit implements InputSplit { - private List> data; + private final List> data; public ListStringSplit(List> data) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java index c61b1d591..b534e8d12 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/NumberedFileInputSplit.java @@ -43,12 +43,12 @@ public class NumberedFileInputSplit implements InputSplit { * the index of the file, possibly zero-padded to x digits if the pattern is in the form %0xd. * @param minIdxInclusive Minimum index/number (starting number in sequence of files, inclusive) * @param maxIdxInclusive Maximum index/number (last number in sequence of files, inclusive) - * @see {NumberedFileInputSplitTest} + * */ public NumberedFileInputSplit(String baseString, int minIdxInclusive, int maxIdxInclusive) { Matcher m = p.matcher(baseString); if (baseString == null || !m.find()) { - throw new IllegalArgumentException("Base String must match this regular expression: " + p.toString()); + throw new IllegalArgumentException("Base String must match this regular expression: " + p); } this.baseString = baseString; this.minIdx = minIdxInclusive; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StringSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StringSplit.java index 8db924475..21f97fef7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StringSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/split/StringSplit.java @@ -31,7 +31,7 @@ import java.util.Iterator; * @author Adam Gibson */ public class StringSplit implements InputSplit { - private String data; + private final String data; public StringSplit(String data) { this.data = data; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java index dfd848ec3..9673c9a4f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java @@ -449,7 +449,7 @@ public class TransformProcess implements Serializable { /** * Infer the categories for the given record reader for a particular column * Note that each "column index" is a column in the context of: - * List record = ...; + * {@code List record = ...;} * record.get(columnIndex); * * Note that anything passed in as a column will be automatically converted to a @@ -483,7 +483,7 @@ public class TransformProcess implements Serializable { * if you have more than one column you plan on inferring categories for) * * Note that each "column index" is a column in the context of: - * List record = ...; + * {@code List record = ...;} * record.get(columnIndex); * * @@ -607,8 +607,8 @@ public class TransformProcess implements Serializable { */ public static class Builder { - private List actionList = new ArrayList<>(); - private Schema initialSchema; + private final List actionList = new ArrayList<>(); + private final Schema initialSchema; public Builder(Schema initialSchema) { this.initialSchema = initialSchema; @@ -1274,7 +1274,7 @@ public class TransformProcess implements Serializable { * not be modified. * * @param columnName Name of the column in which to do replacement - * @param mapping Map of oldValues -> newValues + * @param mapping Map of oldValues -> newValues */ public Builder stringMapTransform(String columnName, Map mapping) { return transform(new StringMapTransform(columnName, mapping)); @@ -1358,7 +1358,8 @@ public class TransformProcess implements Serializable { * Keys in the map are the regular expressions; the Values in the map are their String replacements. * For example: *

- * + *
+ * * * * @@ -1378,7 +1379,7 @@ public class TransformProcess implements Serializable { * * * - * + * * * * diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/NDArrayAnalysis.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/NDArrayAnalysis.java index c97d7c744..f4fff17c7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/NDArrayAnalysis.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/columns/NDArrayAnalysis.java @@ -55,7 +55,7 @@ public class NDArrayAnalysis implements ColumnAnalysis { public String toString() { Map sortedCountsByRank = new LinkedHashMap<>(); List keys = - new ArrayList<>(countsByRank == null ? Collections.emptySet() : countsByRank.keySet()); + new ArrayList<>(countsByRank == null ? Collections.emptySet() : countsByRank.keySet()); Collections.sort(keys); for (Integer i : keys) { sortedCountsByRank.put(i, countsByRank.get(i)); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/IntegerAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/IntegerAnalysisCounter.java index 0028e5e1d..0a37ac1d4 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/IntegerAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/IntegerAnalysisCounter.java @@ -101,8 +101,8 @@ public class IntegerAnalysisCounter implements AnalysisCounter countsByRank = new HashMap<>(); + private final Map countsByRank = new HashMap<>(); private double minValue = Double.MAX_VALUE; private double maxValue = -Double.MAX_VALUE; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisCounter.java index c31f35ad8..a18237513 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisCounter.java @@ -83,7 +83,7 @@ public class StringAnalysisCounter implements AnalysisCounter counts = new HashMap<>(); + private final HashMap counts = new HashMap<>(); - private List stateNames; + private final List stateNames; public CategoricalHistogramCounter(List stateNames) { this.stateNames = stateNames; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/json/TDigestDeserializer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/json/TDigestDeserializer.java index dd4289906..9b2650e94 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/json/TDigestDeserializer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/json/TDigestDeserializer.java @@ -34,8 +34,8 @@ import java.io.ObjectInputStream; public class TDigestDeserializer extends JsonDeserializer { @Override - public TDigest deserialize(JsonParser jp, DeserializationContext d) throws IOException, JsonProcessingException { - JsonNode node = (JsonNode)jp.getCodec().readTree(jp); + public TDigest deserialize(JsonParser jp, DeserializationContext d) throws IOException { + JsonNode node = jp.getCodec().readTree(jp); String field = node.get("digest").asText(); Base64 b = new Base64(); byte[] bytes = b.decode(field); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/json/TDigestSerializer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/json/TDigestSerializer.java index c3bd4517a..e2ad09f0a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/json/TDigestSerializer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/json/TDigestSerializer.java @@ -33,7 +33,7 @@ import java.io.ObjectOutputStream; public class TDigestSerializer extends JsonSerializer { @Override - public void serialize(TDigest td, JsonGenerator j, SerializerProvider sp) throws IOException, JsonProcessingException { + public void serialize(TDigest td, JsonGenerator j, SerializerProvider sp) throws IOException { try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ oos.writeObject(td); oos.close(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/bytes/BytesQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/bytes/BytesQualityAnalysisState.java index 409387600..f6c6e8c3c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/bytes/BytesQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/bytes/BytesQualityAnalysisState.java @@ -29,7 +29,7 @@ import org.datavec.api.writable.Writable; public class BytesQualityAnalysisState implements QualityAnalysisState { @Getter - private BytesQuality bytesQuality; + private final BytesQuality bytesQuality; public BytesQualityAnalysisState() { this.bytesQuality = new BytesQuality(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAnalysisState.java index 5dc13406a..44aaac563 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAnalysisState.java @@ -31,8 +31,8 @@ public class CategoricalQualityAnalysisState implements QualityAnalysisState=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Value to use in the condition */ public DoubleColumnCondition(String columnName, ConditionOp op, double value) { @@ -54,7 +54,7 @@ public class DoubleColumnCondition extends BaseColumnCondition { * * @param column Column to check for the condition * @param sequenceConditionMode Mode for handling sequence data - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Value to use in the condition */ public DoubleColumnCondition(String column, SequenceConditionMode sequenceConditionMode, ConditionOp op, diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/FloatColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/FloatColumnCondition.java index be8ab40e6..f4d40b45e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/FloatColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/FloatColumnCondition.java @@ -42,7 +42,7 @@ public class FloatColumnCondition extends BaseColumnCondition { * Uses default sequence condition mode, {@link BaseColumnCondition#DEFAULT_SEQUENCE_CONDITION_MODE} * * @param columnName Column to check for the condition - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Value to use in the condition */ public FloatColumnCondition(String columnName, ConditionOp op, float value) { @@ -54,7 +54,7 @@ public class FloatColumnCondition extends BaseColumnCondition { * * @param column Column to check for the condition * @param sequenceConditionMode Mode for handling sequence data - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Value to use in the condition */ public FloatColumnCondition(String column, SequenceConditionMode sequenceConditionMode, ConditionOp op, diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java index bd55caed5..0029eb044 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java @@ -42,7 +42,7 @@ public class IntegerColumnCondition extends BaseColumnCondition { * Uses default sequence condition mode, {@link BaseColumnCondition#DEFAULT_SEQUENCE_CONDITION_MODE} * * @param columnName Column to check for the condition - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Value to use in the condition */ public IntegerColumnCondition(String columnName, ConditionOp op, int value) { @@ -54,7 +54,7 @@ public class IntegerColumnCondition extends BaseColumnCondition { * * @param column Column to check for the condition * @param sequenceConditionMode Mode for handling sequence data - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Value to use in the condition */ public IntegerColumnCondition(String column, SequenceConditionMode sequenceConditionMode, ConditionOp op, diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java index 5855628fa..a83be4fcf 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java @@ -42,7 +42,7 @@ public class LongColumnCondition extends BaseColumnCondition { * Uses default sequence condition mode, {@link BaseColumnCondition#DEFAULT_SEQUENCE_CONDITION_MODE} * * @param columnName Column to check for the condition - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Value to use in the condition */ public LongColumnCondition(String columnName, ConditionOp op, long value) { @@ -54,7 +54,7 @@ public class LongColumnCondition extends BaseColumnCondition { * * @param column Column to check for the condition * @param sequenceConditionMode Mode for handling sequence data - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Value to use in the condition */ public LongColumnCondition(String column, SequenceConditionMode sequenceConditionMode, ConditionOp op, long value) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java index 590ef4522..00c2714ce 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java @@ -42,7 +42,7 @@ public class TimeColumnCondition extends BaseColumnCondition { * Uses default sequence condition mode, {@link BaseColumnCondition#DEFAULT_SEQUENCE_CONDITION_MODE} * * @param columnName Column to check for the condition - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Time value (in epoch millisecond format) to use in the condition */ public TimeColumnCondition(String columnName, ConditionOp op, long value) { @@ -54,7 +54,7 @@ public class TimeColumnCondition extends BaseColumnCondition { * * @param column Column to check for the condition * @param sequenceConditionMode Mode for handling sequence data - * @param op Operation (<, >=, !=, etc) + * @param op Operation {@code (<, >=, !=, etc)} * @param value Time value (in epoch millisecond format) to use in the condition */ public TimeColumnCondition(String column, SequenceConditionMode sequenceConditionMode, ConditionOp op, long value) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java index 54b6cfe07..3a5a35b68 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java @@ -111,24 +111,18 @@ public class FilterInvalidValues implements Filter { private boolean filterColumn(List row, int i) { ColumnMetaData meta = schema.getMetaData(i); if (row.get(i) instanceof Float) { - if (!meta.isValid(new FloatWritable((Float) row.get(i)))) - return true; + return !meta.isValid(new FloatWritable((Float) row.get(i))); } else if (row.get(i) instanceof Double) { - if (!meta.isValid(new DoubleWritable((Double) row.get(i)))) - return true; + return !meta.isValid(new DoubleWritable((Double) row.get(i))); } else if (row.get(i) instanceof String) { - if (!meta.isValid(new Text(((String) row.get(i)).toString()))) - return true; + return !meta.isValid(new Text(((String) row.get(i)))); } else if (row.get(i) instanceof Integer) { - if (!meta.isValid(new IntWritable((Integer) row.get(i)))) - return true; + return !meta.isValid(new IntWritable((Integer) row.get(i))); } else if (row.get(i) instanceof Long) { - if (!meta.isValid(new LongWritable((Long) row.get(i)))) - return true; + return !meta.isValid(new LongWritable((Long) row.get(i))); } else if (row.get(i) instanceof Boolean) { - if (!meta.isValid(new BooleanWritable((Boolean) row.get(i)))) - return true; + return !meta.isValid(new BooleanWritable((Boolean) row.get(i))); } return false; } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/join/Join.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/join/Join.java index d723c1448..d71b3c0c5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/join/Join.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/join/Join.java @@ -96,7 +96,7 @@ public class Join implements Serializable { public static class Builder { - private JoinType joinType; + private final JoinType joinType; private Schema leftSchema; private Schema rightSchema; private String[] joinColumnsLeft; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BinaryMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BinaryMetaData.java index 3acb56ded..91a9238f1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BinaryMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BinaryMetaData.java @@ -84,9 +84,8 @@ public class BinaryMetaData extends BaseColumnMetaData { @Override public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("BinaryMetaData(name=\"").append(name).append("\","); - sb.append(")"); - return sb.toString(); + String sb = "BinaryMetaData(name=\"" + name + "\"," + + ")"; + return sb; } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BooleanMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BooleanMetaData.java index 5fae67985..66d8872b1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BooleanMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BooleanMetaData.java @@ -84,9 +84,8 @@ public class BooleanMetaData extends BaseColumnMetaData { @Override public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("BooleanMetaData(name=\"").append(name).append("\","); - sb.append(")"); - return sb.toString(); + String sb = "BooleanMetaData(name=\"" + name + "\"," + + ")"; + return sb; } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/DoubleMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/DoubleMetaData.java index 6a3aee77c..aaa85a489 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/DoubleMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/DoubleMetaData.java @@ -84,10 +84,7 @@ public class DoubleMetaData extends BaseColumnMetaData { if (minAllowedValue != null && d < minAllowedValue) return false; - if (maxAllowedValue != null && d > maxAllowedValue) - return false; - - return true; + return maxAllowedValue == null || !(d > maxAllowedValue); } /** @@ -115,10 +112,7 @@ public class DoubleMetaData extends BaseColumnMetaData { if (minAllowedValue != null && d < minAllowedValue) return false; - if (maxAllowedValue != null && d > maxAllowedValue) - return false; - - return true; + return maxAllowedValue == null || !(d > maxAllowedValue); } @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/FloatMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/FloatMetaData.java index 69f087433..7bcb7abe2 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/FloatMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/FloatMetaData.java @@ -84,10 +84,7 @@ public class FloatMetaData extends BaseColumnMetaData { if (minAllowedValue != null && d < minAllowedValue) return false; - if (maxAllowedValue != null && d > maxAllowedValue) - return false; - - return true; + return maxAllowedValue == null || d <= maxAllowedValue; } /** @@ -115,10 +112,7 @@ public class FloatMetaData extends BaseColumnMetaData { if (minAllowedValue != null && d < minAllowedValue) return false; - if (maxAllowedValue != null && d > maxAllowedValue) - return false; - - return true; + return maxAllowedValue == null || d <= maxAllowedValue; } @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/IntegerMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/IntegerMetaData.java index 2bf3a2bdc..c856da307 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/IntegerMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/IntegerMetaData.java @@ -65,9 +65,7 @@ public class IntegerMetaData extends BaseColumnMetaData { if (minAllowedValue != null && value < minAllowedValue) return false; - if (maxAllowedValue != null && value > maxAllowedValue) - return false; - return true; + return maxAllowedValue == null || value <= maxAllowedValue; } /** @@ -90,9 +88,7 @@ public class IntegerMetaData extends BaseColumnMetaData { if (minAllowedValue != null && value < minAllowedValue) return false; - if (maxAllowedValue != null && value > maxAllowedValue) - return false; - return true; + return maxAllowedValue == null || value <= maxAllowedValue; } @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/LongMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/LongMetaData.java index 66a49874d..01119430e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/LongMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/LongMetaData.java @@ -66,10 +66,7 @@ public class LongMetaData extends BaseColumnMetaData { } if (minAllowedValue != null && value < minAllowedValue) return false; - if (maxAllowedValue != null && value > maxAllowedValue) - return false; - - return true; + return maxAllowedValue == null || value <= maxAllowedValue; } /** @@ -92,10 +89,7 @@ public class LongMetaData extends BaseColumnMetaData { if (minAllowedValue != null && value < minAllowedValue) return false; - if (maxAllowedValue != null && value > maxAllowedValue) - return false; - - return true; + return maxAllowedValue == null || value <= maxAllowedValue; } @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregatorImpls.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregatorImpls.java index 18fbf9af6..ce1b2b94d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregatorImpls.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregatorImpls.java @@ -97,9 +97,9 @@ public class AggregatorImpls { } else if (a instanceof Float || b instanceof Float) { return new Float(a.floatValue() + b.floatValue()); } else if (a instanceof Long || b instanceof Long) { - return new Long(a.longValue() + b.longValue()); + return Long.valueOf(a.longValue() + b.longValue()); } else { - return new Integer(a.intValue() + b.intValue()); + return Integer.valueOf(a.intValue() + b.intValue()); } } @@ -146,9 +146,9 @@ public class AggregatorImpls { } else if (a instanceof Float || b instanceof Float) { return new Float(a.floatValue() * b.floatValue()); } else if (a instanceof Long || b instanceof Long) { - return new Long(a.longValue() * b.longValue()); + return Long.valueOf(a.longValue() * b.longValue()); } else { - return new Integer(a.intValue() * b.intValue()); + return Integer.valueOf(a.intValue() * b.intValue()); } } @@ -347,7 +347,7 @@ public class AggregatorImpls { * of the square root of the arithmetic mean of squared differences to the mean, corrected with Bessel's correction. * * See https://en.wikipedia.org/wiki/Unbiased_estimation_of_standard_deviation - * This is computed with Welford's method for increased numerical stability & aggregability. + * This is computed with Welford's method for increased numerical stability & aggregability. */ public static class AggregableStdDev implements IAggregableReduceOp { @@ -402,7 +402,7 @@ public class AggregatorImpls { * of the square root of the arithmetic mean of squared differences to the mean. * * See https://en.wikipedia.org/wiki/Unbiased_estimation_of_standard_deviation - * This is computed with Welford's method for increased numerical stability & aggregability. + * This is computed with Welford's method for increased numerical stability & aggregability. */ public static class AggregableUncorrectedStdDev extends AggregableStdDev { @@ -418,7 +418,7 @@ public class AggregatorImpls { * of the arithmetic mean of squared differences to the mean, corrected with Bessel's correction. * * See https://en.wikipedia.org/wiki/Unbiased_estimation_of_standard_deviation - * This is computed with Welford's method for increased numerical stability & aggregability. + * This is computed with Welford's method for increased numerical stability & aggregability. */ public static class AggregableVariance implements IAggregableReduceOp { @@ -474,7 +474,7 @@ public class AggregatorImpls { * of the arithmetic mean of squared differences to the mean. * * See https://en.wikipedia.org/wiki/Variance#Population_variance_and_sample_variance - * This is computed with Welford's method for increased numerical stability & aggregability. + * This is computed with Welford's method for increased numerical stability & aggregability. */ public static class AggregablePopulationVariance extends AggregableVariance { @@ -491,7 +491,7 @@ public class AggregatorImpls { * here. * * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting - * a nonzero `sp > p` in HyperLogLogPlus(p, sp) would trigger sparse + * a nonzero `sp > p` in HyperLogLogPlus(p, sp) would trigger sparse * representation of registers, which may reduce the memory consumption * and increase accuracy when the cardinality is small. * @param @@ -501,7 +501,7 @@ public class AggregatorImpls { private float p = 0.05f; @Getter - private HyperLogLogPlus hll = new HyperLogLogPlus((int) Math.ceil(2.0 * Math.log(1.054 / p) / Math.log(2)), 0); + private final HyperLogLogPlus hll = new HyperLogLogPlus((int) Math.ceil(2.0 * Math.log(1.054 / p) / Math.log(2)), 0); public AggregableCountUnique(float precision) { this.p = precision; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java index 6f44cac42..eb25b7b56 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java @@ -36,7 +36,7 @@ public class DispatchWithConditionOp extends DispatchOp @Getter @NonNull - private List conditions; + private final List conditions; public DispatchWithConditionOp(List>> ops, List conds) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableColumnReduction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableColumnReduction.java index 6ed205b8b..b6db27c0f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableColumnReduction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableColumnReduction.java @@ -37,14 +37,13 @@ public interface AggregableColumnReduction extends Serializable, ColumnOp { * and NOT the single row * (as is usually the case for {@code List} instances * - * @param columnData The Writable objects for a column * @return Writable containing the reduced data */ IAggregableReduceOp> reduceOp(); /** * Post-reduce: what is the name of the column? - * For example, "myColumn" -> "mean(myColumn)" + * For example, "myColumn" -> "mean(myColumn)" * * @param columnInputName Name of the column before reduction * @return Name of the column after the reduction diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/ColumnReduction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/ColumnReduction.java index 96a066c39..57a9fecf3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/ColumnReduction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/ColumnReduction.java @@ -43,7 +43,7 @@ public interface ColumnReduction extends Serializable, ColumnOp { /** * Post-reduce: what is the name of the column? - * For example, "myColumn" -> "mean(myColumn)" + * For example, "myColumn" -> "mean(myColumn)" * * @param columnInputName Name of the column before reduction * @return Name of the column after the reduction diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/Reducer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/Reducer.java index 8536198f9..0979773a3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/Reducer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/Reducer.java @@ -291,11 +291,11 @@ public class Reducer implements IAssociativeReducer { public static class Builder { - private ReduceOp defaultOp; - private Map> opMap = new HashMap<>(); - private Map customReductions = new HashMap<>(); - private Map conditionalReductions = new HashMap<>(); - private Set ignoreInvalidInColumns = new HashSet<>(); + private final ReduceOp defaultOp; + private final Map> opMap = new HashMap<>(); + private final Map customReductions = new HashMap<>(); + private final Map conditionalReductions = new HashMap<>(); + private final Set ignoreInvalidInColumns = new HashSet<>(); private String[] keyColumns; @@ -480,7 +480,6 @@ public class Reducer implements IAssociativeReducer { * ignored/excluded. * * @param column Name of the column to execute the conditional reduction on - * @param outputName Name of the column, after the reduction has been executed * @param reductions Reductions to execute * @param condition Condition to use in the reductions */ @@ -500,7 +499,6 @@ public class Reducer implements IAssociativeReducer { * * @param column Name of the column to execute the conditional reduction on * @param outputName Name of the column, after the reduction has been executed - * @param reductions Reductions to execute * @param condition Condition to use in the reductions */ public Builder conditionalReduction(String column, String outputName, ReduceOp reduction, Condition condition) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java index 27933596f..b3538c8a7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java @@ -69,7 +69,7 @@ public class GeographicMidpointReduction implements AggregableColumnReduction { @Override public List getColumnOutputMetaData(List newColumnName, ColumnMetaData columnInputMeta) { - return Collections.singletonList(new StringMetaData(newColumnName.get(0))); + return Collections.singletonList(new StringMetaData(newColumnName.get(0))); } @Override @@ -111,7 +111,7 @@ public class GeographicMidpointReduction implements AggregableColumnReduction { public static class AverageCoordinateReduceOp implements IAggregableReduceOp> { private static final double PI_180 = Math.PI / 180.0; - private String delim; + private final String delim; private double sumx; private double sumy; @@ -186,7 +186,7 @@ public class GeographicMidpointReduction implements AggregableColumnReduction { Preconditions.checkState(!Double.isNaN(longDeg), "Final longitude is NaN"); String str = latDeg + delim + longDeg; - return Collections.singletonList(new Text(str)); + return Collections.singletonList(new Text(str)); } } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java index 1e6b4c87c..bc7fa2a98 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java @@ -24,7 +24,7 @@ import org.datavec.api.writable.Writable; public class TypeConversion { - private static TypeConversion SINGLETON = new TypeConversion(); + private static final TypeConversion SINGLETON = new TypeConversion(); private TypeConversion() {} diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SplitMaxLengthSequence.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SplitMaxLengthSequence.java index 2dca4077e..1d80c1f5c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SplitMaxLengthSequence.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SplitMaxLengthSequence.java @@ -44,7 +44,7 @@ public class SplitMaxLengthSequence implements SequenceSplit { /** * @param maxSequenceLength max length of sequences * @param equalSplits if true: split larger sequences into equal sized subsequences. If false: split into - * n maxSequenceLength sequences, and (if necessary) 1 with 1 <= length < maxSequenceLength + * n maxSequenceLength sequences, and (if necessary) 1 with 1 <= length < maxSequenceLength */ public SplitMaxLengthSequence(@JsonProperty("maxSequenceLength") int maxSequenceLength, @JsonProperty("equalSplits") boolean equalSplits) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/BaseSerializer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/BaseSerializer.java index 169b2b174..b5c1e7ceb 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/BaseSerializer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/BaseSerializer.java @@ -295,7 +295,7 @@ public abstract class BaseSerializer { /** * Deserialize an IStringReducer List serialized using {@link #serializeReducerList(List)}, or - * an array serialized using {@link #serialize(IReducer[])} + * an array serialized using {@code #serialize(IReducer[])} * * @param str String representation (YAML/JSON) of the IStringReducer list * @return {@code List} diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java index 7b28c2991..e70c6cb0c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/JsonMappers.java @@ -34,8 +34,8 @@ import com.fasterxml.jackson.datatype.joda.JodaModule; @Slf4j public class JsonMappers { - private static ObjectMapper jsonMapper; - private static ObjectMapper yamlMapper; + private static final ObjectMapper jsonMapper; + private static final ObjectMapper yamlMapper; private static ObjectMapper legacyMapper; //For 1.0.0-alpha and earlier TransformProcess etc static { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/JsonSerializer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/JsonSerializer.java index 90d36ec1c..9733f9d8d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/JsonSerializer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/JsonSerializer.java @@ -24,7 +24,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; public class JsonSerializer extends BaseSerializer { - private ObjectMapper om; + private final ObjectMapper om; public JsonSerializer() { this.om = JsonMappers.getMapper(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/ListWrappers.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/ListWrappers.java index 8e3b2ac56..efec02086 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/ListWrappers.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/ListWrappers.java @@ -37,7 +37,7 @@ public class ListWrappers { @Getter public static class TransformList { - private List list; + private final List list; public TransformList(@JsonProperty("list") List list) { this.list = list; @@ -46,7 +46,7 @@ public class ListWrappers { @Getter public static class FilterList { - private List list; + private final List list; public FilterList(@JsonProperty("list") List list) { this.list = list; @@ -55,7 +55,7 @@ public class ListWrappers { @Getter public static class ConditionList { - private List list; + private final List list; public ConditionList(@JsonProperty("list") List list) { this.list = list; @@ -64,7 +64,7 @@ public class ListWrappers { @Getter public static class ReducerList { - private List list; + private final List list; public ReducerList(@JsonProperty("list") List list) { this.list = list; @@ -73,7 +73,7 @@ public class ListWrappers { @Getter public static class SequenceComparatorList { - private List list; + private final List list; public SequenceComparatorList(@JsonProperty("list") List list) { this.list = list; @@ -82,7 +82,7 @@ public class ListWrappers { @Getter public static class DataActionList { - private List list; + private final List list; public DataActionList(@JsonProperty("list") List list) { this.list = list; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/YamlSerializer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/YamlSerializer.java index 2afe02937..1e7a20846 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/YamlSerializer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/YamlSerializer.java @@ -24,7 +24,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; public class YamlSerializer extends BaseSerializer { - private ObjectMapper om; + private final ObjectMapper om; public YamlSerializer() { this.om = JsonMappers.getMapperYaml(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java index 907bd7d0c..17d3ef39b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java @@ -177,10 +177,10 @@ public class StringReducer implements IStringReducer { public static class Builder { - private StringReduceOp defaultOp; - private Map opMap = new HashMap<>(); - private Map customReductions = new HashMap<>(); - private Set ignoreInvalidInColumns = new HashSet<>(); + private final StringReduceOp defaultOp; + private final Map opMap = new HashMap<>(); + private final Map customReductions = new HashMap<>(); + private final Set ignoreInvalidInColumns = new HashSet<>(); private String outputColumnName; private List inputColumns; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnTransform.java index 6bea20d6c..67ef0ea43 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnTransform.java @@ -80,7 +80,7 @@ public abstract class BaseColumnTransform extends BaseTransform implements Colum if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } int n = writables.size(); List out = new ArrayList<>(n); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java index 5afd6564e..236e0cc8e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java @@ -96,7 +96,7 @@ public class CategoricalToIntegerTransform extends BaseTransform { if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } int idx = getColumnIdx(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java index 56687431c..9a43b80fc 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java @@ -123,7 +123,7 @@ public class CategoricalToOneHotTransform extends BaseTransform { if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } int idx = getColumnIdx(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java index e4f9debf9..881b88013 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java @@ -89,7 +89,7 @@ public class IntegerToCategoricalTransform extends BaseColumnTransform { IntegerToCategoricalTransform o2 = (IntegerToCategoricalTransform) o; - return map != null ? map.equals(o2.map) : o2.map == null; + return Objects.equals(map, o2.map); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java index 39bc5c315..04b23f1e9 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java @@ -169,7 +169,7 @@ public class PivotTransform extends BaseTransform { if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } int idxKey = inputSchema.getIndexOfColumn(keyColumn); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java index 41f857c1a..62f419855 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java @@ -112,7 +112,7 @@ public class DuplicateColumnsTransform implements Transform, ColumnOp { if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } List out = new ArrayList<>(writables.size() + columnsToDuplicate.size()); int i = 0; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java index f71ab0d99..52e13cc8b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java @@ -89,7 +89,7 @@ public class RemoveAllColumnsExceptForTransform extends BaseTransform implements if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } List outList = new ArrayList<>(columnsToKeep.length); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java index d5177a055..62de1b280 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java @@ -123,7 +123,7 @@ public class RemoveColumnsTransform extends BaseTransform implements ColumnOp { String toString = StringUtils.join(list, ","); throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString() + " and record " + toString); + + "). Transform = " + this + " and record " + toString); } List outList = new ArrayList<>(writables.size() - columnsToRemove.length); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java index ca27348ae..1bd907723 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java @@ -103,7 +103,7 @@ public class IntegerToOneHotTransform extends BaseTransform { if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } int idx = getColumnIdx(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java index c882a76a2..20d1b1c2e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java @@ -57,7 +57,7 @@ public class TextToCharacterIndexTransform extends BaseSequenceExpansionTransfor @Override protected List expandedColumnMetaDatas(List origColumnMeta, List expandedColumnNames) { - return Collections.singletonList(new IntegerMetaData(expandedColumnNames.get(0), 0, characterIndexMap.size()-1)); + return Collections.singletonList(new IntegerMetaData(expandedColumnNames.get(0), 0, characterIndexMap.size()-1)); } @Override @@ -65,7 +65,7 @@ public class TextToCharacterIndexTransform extends BaseSequenceExpansionTransfor if(writableMap == null){ Map> m = new HashMap<>(); for(Map.Entry entry : characterIndexMap.entrySet()){ - m.put(entry.getKey(), Collections.singletonList(new IntWritable(entry.getValue()))); + m.put(entry.getKey(), Collections.singletonList(new IntWritable(entry.getValue()))); } writableMap = m; } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java index 9adbf1771..fa2990e78 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java @@ -84,7 +84,7 @@ public class TextToTermIndexSequenceTransform extends BaseSequenceExpansionTrans @Override protected List expandedColumnMetaDatas(List origColumnMeta, List expandedColumnNames) { - return Collections.singletonList(new IntegerMetaData(expandedColumnNames.get(0), 0, wordIndexMap.size()-1)); + return Collections.singletonList(new IntegerMetaData(expandedColumnNames.get(0), 0, wordIndexMap.size()-1)); } @Override @@ -92,7 +92,7 @@ public class TextToTermIndexSequenceTransform extends BaseSequenceExpansionTrans if(writableMap == null){ Map> m = new HashMap<>(); for(Map.Entry entry : wordIndexMap.entrySet()) { - m.put(entry.getKey(), Collections.singletonList(new IntWritable(entry.getValue()))); + m.put(entry.getKey(), Collections.singletonList(new IntWritable(entry.getValue()))); } writableMap = m; } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java index 61bc30796..4ba0e8968 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java @@ -68,7 +68,7 @@ public class SequenceDifferenceTransform implements Transform { * * @param columnName Name of the column to perform the operation on. * @param newColumnName New name for the column. May be same as the origina lcolumn name - * @param lookback Lookback period, in number of time steps. Must be > 0 + * @param lookback Lookback period, in number of time steps. Must be > 0 */ public SequenceDifferenceTransform(String columnName, String newColumnName, int lookback) { this(columnName, newColumnName, lookback, FirstStepMode.Default, null); @@ -80,7 +80,7 @@ public class SequenceDifferenceTransform implements Transform { * * @param columnName Name of the column to perform the operation on. * @param newColumnName New name for the column. May be same as the origina lcolumn name - * @param lookback Lookback period, in number of time steps. Must be > 0 + * @param lookback Lookback period, in number of time steps. Must be > 0 * @param firstStepMode see {@link FirstStepMode} * @param specifiedValueWritable Must be null if using FirstStepMode.Default, or non-null if using FirstStepMode.SpecifiedValue */ diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java index 83d56fd7e..108da34e7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java @@ -123,7 +123,7 @@ public class StringListToCategoricalSetTransform extends BaseTransform { if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } int n = writables.size(); List out = new ArrayList<>(n); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java index 9f3ff0dcf..e682dc099 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java @@ -168,7 +168,7 @@ public class StringListToCountsNDArrayTransform extends BaseTransform { if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } int n = writables.size(); List out = new ArrayList<>(n); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java index d1e290f7a..425b4cc68 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java @@ -147,7 +147,7 @@ public class DeriveColumnsFromTimeTransform implements Transform { if (writables.size() != inputSchema.numColumns()) { throw new IllegalStateException("Cannot execute transform: input writables list length (" + writables.size() + ") does not " + "match expected number of elements (schema: " + inputSchema.numColumns() - + "). Transform = " + toString()); + + "). Transform = " + this); } int i = 0; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/components/RenderableComponentHistogram.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/components/RenderableComponentHistogram.java index 7efa3d894..d18e9489c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/components/RenderableComponentHistogram.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/components/RenderableComponentHistogram.java @@ -56,9 +56,9 @@ public class RenderableComponentHistogram extends RenderableComponent { public static class Builder { private String title; - private List lowerBounds = new ArrayList<>(); - private List upperBounds = new ArrayList<>(); - private List yValues = new ArrayList<>(); + private final List lowerBounds = new ArrayList<>(); + private final List upperBounds = new ArrayList<>(); + private final List yValues = new ArrayList<>(); private int marginTop = 60; private int marginBottom = 60; private int marginLeft = 60; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/components/RenderableComponentLineChart.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/components/RenderableComponentLineChart.java index f2cb8793b..735de97f6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/components/RenderableComponentLineChart.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/components/RenderableComponentLineChart.java @@ -65,9 +65,9 @@ public class RenderableComponentLineChart extends RenderableComponent { public static class Builder { private String title; - private List x = new ArrayList<>(); - private List y = new ArrayList<>(); - private List seriesNames = new ArrayList<>(); + private final List x = new ArrayList<>(); + private final List y = new ArrayList<>(); + private final List seriesNames = new ArrayList<>(); private boolean removeAxisHorizontal = false; private boolean legend = true; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ReflectionUtils.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ReflectionUtils.java index 7bc9618ca..826705f3a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ReflectionUtils.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ReflectionUtils.java @@ -113,7 +113,7 @@ public class ReflectionUtils { /** * Allocate a buffer for each thread that tries to clone objects. */ - private static ThreadLocal cloneBuffers = new ThreadLocal() { + private static final ThreadLocal cloneBuffers = new ThreadLocal() { protected synchronized CopyInCopyOutBuffer initialValue() { return new CopyInCopyOutBuffer(); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/jackson/DateTimeFieldTypeDeserializer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/jackson/DateTimeFieldTypeDeserializer.java index 198b49755..b2a6e5788 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/jackson/DateTimeFieldTypeDeserializer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/jackson/DateTimeFieldTypeDeserializer.java @@ -66,7 +66,7 @@ public class DateTimeFieldTypeDeserializer extends JsonDeserializer { @Override public void serialize(DateTimeFieldType dateTimeFieldType, JsonGenerator jsonGenerator, - SerializerProvider serializerProvider) throws IOException, JsonProcessingException { + SerializerProvider serializerProvider) throws IOException { jsonGenerator.writeStartObject(); jsonGenerator.writeStringField("fieldType", dateTimeFieldType.getName()); jsonGenerator.writeEndObject(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java index 98cc7d339..360f2aa74 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java @@ -234,7 +234,7 @@ public class RecordConverter { } /** - * Convert a collection into a `List`, i.e. a record that can be used with other datavec methods. + * Convert a collection into a {@code List}, i.e. a record that can be used with other datavec methods. * Uses a schema to decide what kind of writable to use. * * @return a record diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/vector/Vectorizer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/vector/Vectorizer.java index 3c1529fa8..55c0dba4c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/vector/Vectorizer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/vector/Vectorizer.java @@ -81,7 +81,7 @@ public interface Vectorizer { * This allows for neat inheritance and polymorphism * for fit and fit/transform among other things */ - public static interface RecordCallBack { + interface RecordCallBack { /** * The record callback * @param record diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/BooleanWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/BooleanWritable.java index bec23e5e2..7fe2aadbc 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/BooleanWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/BooleanWritable.java @@ -95,7 +95,7 @@ public class BooleanWritable implements WritableComparable { public int compareTo(Object o) { boolean a = this.value; boolean b = ((BooleanWritable) o).value; - return ((a == b) ? 0 : (a == false) ? -1 : 1); + return ((a == b) ? 0 : (!a) ? -1 : 1); } public String toString() { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java index f2f098cd8..68bf9ebd6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java @@ -65,15 +65,15 @@ public class ByteWritable implements WritableComparable { public boolean fuzzyEquals(Writable o, double tolerance) { double other; if (o instanceof IntWritable){ - other = ((IntWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof LongWritable) { - other = ((LongWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof ByteWritable) { - other = ((ByteWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof DoubleWritable) { - other = ((DoubleWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof FloatWritable) { - other = ((FloatWritable) o).toDouble(); + other = o.toDouble(); } else { return false; } return DoubleMath.fuzzyEquals(this.value, other, tolerance); } @@ -90,7 +90,7 @@ public class ByteWritable implements WritableComparable { } public int hashCode() { - return (int)value; + return value; } /** Compares two ByteWritables. */ diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java index ed795e958..8a6ef79ed 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java @@ -69,15 +69,15 @@ public class DoubleWritable implements WritableComparable { public boolean fuzzyEquals(Writable o, double tolerance) { double other; if (o instanceof IntWritable){ - other = ((IntWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof LongWritable) { - other = ((LongWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof ByteWritable) { - other = ((ByteWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof DoubleWritable) { - other = ((DoubleWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof FloatWritable) { - other = ((FloatWritable) o).toDouble(); + other = o.toDouble(); } else { return false; } return DoubleMath.fuzzyEquals(this.value, other, tolerance); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java index c98bc78f3..783e77b9a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java @@ -66,15 +66,15 @@ public class FloatWritable implements WritableComparable { public boolean fuzzyEquals(Writable o, double tolerance) { double other; if (o instanceof IntWritable){ - other = ((IntWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof LongWritable) { - other = ((LongWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof ByteWritable) { - other = ((ByteWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof DoubleWritable) { - other = ((DoubleWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof FloatWritable) { - other = ((FloatWritable) o).toDouble(); + other = o.toDouble(); } else { return false; } return DoubleMath.fuzzyEquals(this.value, other, tolerance); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java index 37d74df2f..56739a8f6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java @@ -66,15 +66,15 @@ public class IntWritable implements WritableComparable { public boolean fuzzyEquals(Writable o, double tolerance) { double other; if (o instanceof IntWritable){ - other = ((IntWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof LongWritable) { - other = ((LongWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof ByteWritable) { - other = ((ByteWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof DoubleWritable) { - other = ((DoubleWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof FloatWritable) { - other = ((FloatWritable) o).toDouble(); + other = o.toDouble(); } else { return false; } return DoubleMath.fuzzyEquals(this.value, other, tolerance); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java index 4b7dc3d35..599bde104 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java @@ -65,15 +65,15 @@ public class LongWritable implements WritableComparable { public boolean fuzzyEquals(Writable o, double tolerance) { double other; if (o instanceof IntWritable){ - other = ((IntWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof LongWritable) { - other = ((LongWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof ByteWritable) { - other = ((ByteWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof DoubleWritable) { - other = ((DoubleWritable) o).toDouble(); + other = o.toDouble(); } else if (o instanceof FloatWritable) { - other = ((FloatWritable) o).toDouble(); + other = o.toDouble(); } else { return false; } return DoubleMath.fuzzyEquals(this.value, other, tolerance); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/NDArrayWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/NDArrayWritable.java index 383ac3aac..cf3d154e3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/NDArrayWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/NDArrayWritable.java @@ -184,7 +184,7 @@ public class NDArrayWritable extends ArrayWritable implements WritableComparable } for (int i = 0; i < array.rank(); i++) { - if (Long.compare(array.size(i), other.array.size(i)) != 0) { + if (array.size(i) != other.array.size(i)) { return Long.compare(array.size(i), other.array.size(i)); } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java index 43dc58036..b36452a0d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java @@ -39,14 +39,14 @@ import java.text.StringCharacterIterator; public class Text extends BinaryComparable implements WritableComparable { - private static ThreadLocal ENCODER_FACTORY = new ThreadLocal() { + private static final ThreadLocal ENCODER_FACTORY = new ThreadLocal() { protected CharsetEncoder initialValue() { return StandardCharsets.UTF_8.newEncoder().onMalformedInput(CodingErrorAction.REPORT) .onUnmappableCharacter(CodingErrorAction.REPORT); } }; - private static ThreadLocal DECODER_FACTORY = new ThreadLocal() { + private static final ThreadLocal DECODER_FACTORY = new ThreadLocal() { protected CharsetDecoder initialValue() { return StandardCharsets.UTF_8.newDecoder().onMalformedInput(CodingErrorAction.REPORT) .onUnmappableCharacter(CodingErrorAction.REPORT); @@ -106,7 +106,7 @@ public class Text extends BinaryComparable implements WritableComparable> map = new ConcurrentHashMap<>(); - private Map> constructorMap = new ConcurrentHashMap<>(); + private final Map> map = new ConcurrentHashMap<>(); + private final Map> constructorMap = new ConcurrentHashMap<>(); private WritableFactory() { for (WritableType wt : WritableType.values()) { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractWritableRecordBatch.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractWritableRecordBatch.java index b2d4b7621..715d2a674 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractWritableRecordBatch.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractWritableRecordBatch.java @@ -139,7 +139,7 @@ public abstract class AbstractWritableRecordBatch implements List public static class RecordBatchListIterator implements ListIterator> { private int index; - private AbstractWritableRecordBatch underlying; + private final AbstractWritableRecordBatch underlying; public RecordBatchListIterator(AbstractWritableRecordBatch underlying){ this.underlying = underlying; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java index 0817973a7..bfc531e7c 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java @@ -53,15 +53,15 @@ public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { rr.initialize(new FileSplit(source)); List> exp0 = Arrays.asList( - Collections.singletonList(new Text("a")), - Collections.singletonList(new Text("b")), - Collections.singletonList(new Text("c"))); + Collections.singletonList(new Text("a")), + Collections.singletonList(new Text("b")), + Collections.singletonList(new Text("c"))); List> exp1 = Arrays.asList( - Collections.singletonList(new Text("1")), - Collections.singletonList(new Text("2")), - Collections.singletonList(new Text("3")), - Collections.singletonList(new Text("4"))); + Collections.singletonList(new Text("1")), + Collections.singletonList(new Text("2")), + Collections.singletonList(new Text("3")), + Collections.singletonList(new Text("4"))); for( int i=0; i<3; i++ ) { int count = 0; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java index 882fc628f..e52f71cc2 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java @@ -78,12 +78,12 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { List> exp0 = new ArrayList<>(); for (String s : "a,b,c,1,2,3,4,x,y".split(",")) { - exp0.add(Collections.singletonList(new Text(s))); + exp0.add(Collections.singletonList(new Text(s))); } List> exp1 = new ArrayList<>(); for (String s : "A,B,C".split(",")) { - exp1.add(Collections.singletonList(new Text(s))); + exp1.add(Collections.singletonList(new Text(s))); } assertEquals(exp0, seqRR.sequenceRecord()); @@ -131,10 +131,10 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { List> exp0 = Arrays.asList( - Arrays.asList(new Text("a"), new Text("1"), new Text("x")), - Arrays.asList(new Text("b"), new Text("2"), new Text("y"))); + Arrays.asList(new Text("a"), new Text("1"), new Text("x")), + Arrays.asList(new Text("b"), new Text("2"), new Text("y"))); - List> exp1 = Collections.singletonList(Arrays.asList(new Text("A"), new Text("B"), new Text("C"))); + List> exp1 = Collections.singletonList(Arrays.asList(new Text("A"), new Text("B"), new Text("C"))); assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord()); @@ -181,10 +181,10 @@ public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { List> exp0 = Arrays.asList( - Arrays.asList(new Text("a"), new Text("1"), new Text("x")), - Arrays.asList(new Text("b"), new Text("PAD"), new Text("PAD"))); + Arrays.asList(new Text("a"), new Text("1"), new Text("x")), + Arrays.asList(new Text("b"), new Text("PAD"), new Text("PAD"))); - List> exp1 = Collections.singletonList(Arrays.asList(new Text("A"), new Text("B"), new Text("C"))); + List> exp1 = Collections.singletonList(Arrays.asList(new Text("A"), new Text("B"), new Text("C"))); assertEquals(exp0, seqRR.sequenceRecord()); assertEquals(exp1, seqRR.sequenceRecord()); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java index 0b54b7147..16ed450df 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java @@ -265,19 +265,19 @@ public class CSVRecordReaderTest extends BaseND4JTest { Assertions.assertThrows(NoSuchElementException.class, () -> { final int numLines = 4; - final List lineList = Arrays.asList((Writable) new IntWritable(numLines - 1), - (Writable) new Text("one"), (Writable) new Text("two"), (Writable) new Text("three")); + final List lineList = Arrays.asList(new IntWritable(numLines - 1), + new Text("one"), new Text("two"), new Text("three")); String header = ",one,two,three"; List lines = new ArrayList<>(); for (int i = 0; i < numLines; i++) - lines.add(Integer.toString(i) + header); + lines.add(i + header); File tempFile = File.createTempFile("csvSkipLines", ".csv"); FileUtils.writeLines(tempFile, lines); CSVRecordReader rr = new CSVRecordReader(numLines, ','); rr.initialize(new FileSplit(tempFile)); rr.reset(); - assertTrue(!rr.hasNext()); + assertFalse(rr.hasNext()); rr.next(); }); } @@ -285,12 +285,12 @@ public class CSVRecordReaderTest extends BaseND4JTest { @Test public void testCsvSkipAllButOneLine() throws IOException, InterruptedException { final int numLines = 4; - final List lineList = Arrays.asList(new Text(Integer.toString(numLines - 1)), + final List lineList = Arrays.asList(new Text(Integer.toString(numLines - 1)), new Text("one"), new Text("two"), new Text("three")); String header = ",one,two,three"; List lines = new ArrayList<>(); for (int i = 0; i < numLines; i++) - lines.add(Integer.toString(i) + header); + lines.add(i + header); File tempFile = File.createTempFile("csvSkipLines", ".csv"); FileUtils.writeLines(tempFile, lines); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java index 883f0e0d4..3fcb9e9f5 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java @@ -100,12 +100,12 @@ public class JacksonLineRecordReaderTest extends BaseND4JTest { rr.initialize(new CollectionInputSplit(u)); List> expSeq0 = new ArrayList<>(); - expSeq0.add(Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"))); - expSeq0.add(Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"))); - expSeq0.add(Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"))); + expSeq0.add(Arrays.asList(new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"))); + expSeq0.add(Arrays.asList(new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"))); + expSeq0.add(Arrays.asList(new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"))); List> expSeq1 = new ArrayList<>(); - expSeq1.add(Arrays.asList((Writable) new Text("aValue3"), new Text("bValue3"), new Text("cxValue3"))); + expSeq1.add(Arrays.asList(new Text("aValue3"), new Text("bValue3"), new Text("cxValue3"))); int count = 0; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java index b6f13adbd..08b94fdec 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java @@ -65,7 +65,7 @@ public class JacksonRecordReaderTest extends BaseND4JTest { //For third JSON file: c:x:value is missing ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); - File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID().toString()); + File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID()); FileUtils.forceMkdir(f); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); @@ -83,7 +83,7 @@ public class JacksonRecordReaderTest extends BaseND4JTest { //Exact same information as JSON format, but in YAML format ClassPathResource cpr = new ClassPathResource("datavec-api/yaml/"); - File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID().toString()); + File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID()); FileUtils.forceMkdir(f); cpr.copyDirectory(f); String path = new File(f, "yaml_test_%d.txt").getAbsolutePath(); @@ -102,7 +102,7 @@ public class JacksonRecordReaderTest extends BaseND4JTest { //Exact same information as JSON format, but in XML format ClassPathResource cpr = new ClassPathResource("datavec-api/xml/"); - File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID().toString()); + File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID()); FileUtils.forceMkdir(f); cpr.copyDirectory(f); String path = new File(f, "xml_test_%d.txt").getAbsolutePath(); @@ -126,17 +126,17 @@ public class JacksonRecordReaderTest extends BaseND4JTest { private static void testJacksonRecordReader(RecordReader rr) { List json0 = rr.next(); - List exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); + List exp0 = Arrays.asList(new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); assertEquals(exp0, json0); List json1 = rr.next(); List exp1 = - Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")); + Arrays.asList(new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")); assertEquals(exp1, json1); List json2 = rr.next(); List exp2 = - Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")); + Arrays.asList(new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")); assertEquals(exp2, json2); assertFalse(rr.hasNext()); @@ -153,7 +153,7 @@ public class JacksonRecordReaderTest extends BaseND4JTest { public void testAppendingLabels() throws Exception { ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); - File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID().toString()); + File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID()); FileUtils.forceMkdir(f); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); @@ -165,15 +165,15 @@ public class JacksonRecordReaderTest extends BaseND4JTest { new LabelGen()); rr.initialize(is); - List exp0 = Arrays.asList((Writable) new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), + List exp0 = Arrays.asList(new Text("aValue0"), new Text("bValue0"), new Text("cxValue0"), new IntWritable(0)); assertEquals(exp0, rr.next()); - List exp1 = Arrays.asList((Writable) new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), + List exp1 = Arrays.asList(new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1"), new IntWritable(1)); assertEquals(exp1, rr.next()); - List exp2 = Arrays.asList((Writable) new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), + List exp2 = Arrays.asList(new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX"), new IntWritable(2)); assertEquals(exp2, rr.next()); @@ -182,15 +182,15 @@ public class JacksonRecordReaderTest extends BaseND4JTest { new LabelGen(), 0); rr.initialize(is); - exp0 = Arrays.asList((Writable) new IntWritable(0), new Text("aValue0"), new Text("bValue0"), + exp0 = Arrays.asList(new IntWritable(0), new Text("aValue0"), new Text("bValue0"), new Text("cxValue0")); assertEquals(exp0, rr.next()); - exp1 = Arrays.asList((Writable) new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), + exp1 = Arrays.asList(new IntWritable(1), new Text("aValue1"), new Text("MISSING_B"), new Text("cxValue1")); assertEquals(exp1, rr.next()); - exp2 = Arrays.asList((Writable) new IntWritable(2), new Text("aValue2"), new Text("bValue2"), + exp2 = Arrays.asList(new IntWritable(2), new Text("aValue2"), new Text("bValue2"), new Text("MISSING_CX")); assertEquals(exp2, rr.next()); } @@ -198,7 +198,7 @@ public class JacksonRecordReaderTest extends BaseND4JTest { @Test public void testAppendingLabelsMetaData() throws Exception { ClassPathResource cpr = new ClassPathResource("datavec-api/json/"); - File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID().toString()); + File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID()); FileUtils.forceMkdir(f); cpr.copyDirectory(f); String path = new File(f, "json_test_%d.txt").getAbsolutePath(); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java index a2d6622b3..481e3a8ba 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java @@ -58,11 +58,11 @@ public class RegexRecordReaderTest extends BaseND4JTest { RecordReader rr = new RegexLineRecordReader(regex, 1); rr.initialize(new FileSplit(new ClassPathResource("datavec-api/logtestdata/logtestfile0.txt").getFile())); - List exp0 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), + List exp0 = Arrays.asList(new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!")); - List exp1 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), + List exp1 = Arrays.asList(new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!")); - List exp2 = Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), + List exp2 = Arrays.asList(new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!")); assertEquals(exp0, rr.next()); assertEquals(exp1, rr.next()); @@ -125,20 +125,20 @@ public class RegexRecordReaderTest extends BaseND4JTest { rr.initialize(is); List> exp0 = new ArrayList<>(); - exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), + exp0.add(Arrays.asList(new Text("2016-01-01 23:59:59.001"), new Text("1"), new Text("DEBUG"), new Text("First entry message!"))); - exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), + exp0.add(Arrays.asList(new Text("2016-01-01 23:59:59.002"), new Text("2"), new Text("INFO"), new Text("Second entry message!"))); - exp0.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), + exp0.add(Arrays.asList(new Text("2016-01-01 23:59:59.003"), new Text("3"), new Text("WARN"), new Text("Third entry message!"))); List> exp1 = new ArrayList<>(); - exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), + exp1.add(Arrays.asList(new Text("2016-01-01 23:59:59.011"), new Text("11"), new Text("DEBUG"), new Text("First entry message!"))); - exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), + exp1.add(Arrays.asList(new Text("2016-01-01 23:59:59.012"), new Text("12"), new Text("INFO"), new Text("Second entry message!"))); - exp1.add(Arrays.asList((Writable) new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), + exp1.add(Arrays.asList(new Text("2016-01-01 23:59:59.013"), new Text("13"), new Text("WARN"), new Text("Third entry message!"))); assertEquals(exp0, rr.sequenceRecord()); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java index decbf0275..c3287ffa6 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java @@ -43,13 +43,13 @@ public class TestCollectionRecordReaders extends BaseND4JTest { List>> listOfSequences = new ArrayList<>(); List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new IntWritable(0), new IntWritable(1))); - sequence1.add(Arrays.asList((Writable) new IntWritable(2), new IntWritable(3))); + sequence1.add(Arrays.asList(new IntWritable(0), new IntWritable(1))); + sequence1.add(Arrays.asList(new IntWritable(2), new IntWritable(3))); listOfSequences.add(sequence1); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new IntWritable(4), new IntWritable(5))); - sequence2.add(Arrays.asList((Writable) new IntWritable(6), new IntWritable(7))); + sequence2.add(Arrays.asList(new IntWritable(4), new IntWritable(5))); + sequence2.add(Arrays.asList(new IntWritable(6), new IntWritable(7))); listOfSequences.add(sequence2); SequenceRecordReader seqRR = new CollectionSequenceRecordReader(listOfSequences); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java index ee2c9b091..3645c034a 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java @@ -36,6 +36,7 @@ import org.nd4j.common.io.ClassPathResource; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -74,18 +75,18 @@ public class TransformProcessRecordReaderTests extends BaseND4JTest { public void simpleTransformTestSequence() { List> sequence = new ArrayList<>(); //First window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), + sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2), new IntWritable(0))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build(); TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build(); InMemorySequenceRecordReader inMemorySequenceRecordReader = - new InMemorySequenceRecordReader(Arrays.asList(sequence)); + new InMemorySequenceRecordReader(Collections.singletonList(sequence)); TransformProcessSequenceRecordReader transformProcessSequenceRecordReader = new TransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess); List> next = transformProcessSequenceRecordReader.sequenceRecord(); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java index 885a75ec0..f0013e516 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java @@ -168,7 +168,7 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 13); arr3.putScalar(1, 14); arr3.putScalar(2, 15); - List record = Arrays.asList((Writable) new DoubleWritable(1), + List record = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), @@ -204,7 +204,7 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 0); arr3.putScalar(1, 1); arr3.putScalar(2, 0); - List record = Arrays.asList((Writable) new DoubleWritable(1), + List record = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), @@ -241,7 +241,7 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 0); arr3.putScalar(1, 1); arr3.putScalar(2, 0); - List record = Arrays.asList((Writable) new DoubleWritable(1), + List record = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), @@ -273,7 +273,7 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { @Test public void testNonIntegerButValidMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(3), + List record = Arrays.asList(new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); @@ -296,7 +296,7 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { @Test public void nonIntegerMultilabel() throws Exception { Assertions.assertThrows(NumberFormatException.class, () -> { - List record = Arrays.asList((Writable) new IntWritable(3), + List record = Arrays.asList(new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); @@ -319,7 +319,7 @@ public class LibSvmRecordWriterTest extends BaseND4JTest { @Test public void nonBinaryMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(0), + List record = Arrays.asList(new IntWritable(0), new IntWritable(1), new IntWritable(2)); File tempFile = File.createTempFile("LibSvmRecordWriter", ".txt"); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java index d38611cc4..48ee43c47 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java @@ -165,7 +165,7 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 13); arr3.putScalar(1, 14); arr3.putScalar(2, 15); - List record = Arrays.asList((Writable) new DoubleWritable(1), + List record = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), @@ -201,7 +201,7 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 0); arr3.putScalar(1, 1); arr3.putScalar(2, 0); - List record = Arrays.asList((Writable) new DoubleWritable(1), + List record = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), @@ -238,7 +238,7 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { arr3.putScalar(0, 0); arr3.putScalar(1, 1); arr3.putScalar(2, 0); - List record = Arrays.asList((Writable) new DoubleWritable(1), + List record = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(arr2), new IntWritable(2), new DoubleWritable(3), @@ -270,7 +270,7 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { @Test public void testNonIntegerButValidMultilabel() throws Exception { - List record = Arrays.asList((Writable) new IntWritable(3), + List record = Arrays.asList(new IntWritable(3), new IntWritable(2), new DoubleWritable(1.0)); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); @@ -293,7 +293,7 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { @Test public void nonIntegerMultilabel() throws Exception { Assertions.assertThrows(NumberFormatException.class, () -> { - List record = Arrays.asList((Writable) new IntWritable(3), + List record = Arrays.asList(new IntWritable(3), new IntWritable(2), new DoubleWritable(1.2)); File tempFile = File.createTempFile("SVMLightRecordWriter", ".txt"); @@ -317,7 +317,7 @@ public class SVMLightRecordWriterTest extends BaseND4JTest { @Test public void nonBinaryMultilabel() throws Exception { Assertions.assertThrows(NumberFormatException.class, () -> { - List record = Arrays.asList((Writable) new IntWritable(0), + List record = Arrays.asList(new IntWritable(0), new IntWritable(1), new IntWritable(2)); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java index 43d274151..7048e610b 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java @@ -137,7 +137,7 @@ public class InputSplitTests extends BaseND4JTest { FileSplit boostrap = new FileSplit(tmpDir); Assertions.assertTrue(boostrap.needsBootstrapForWrite()); boostrap.bootStrapForWrite(); - Assertions.assertTrue(tmpDir.listFiles() != null); + Assertions.assertNotNull(tmpDir.listFiles()); } @Test @@ -156,6 +156,7 @@ public class InputSplitTests extends BaseND4JTest { for (int i = 0; i < paths2.length; i++) { if (!paths2[i].toString().startsWith("file:///label0/")) { notOnlyFirstLabel = true; + break; } } Assertions.assertTrue(notOnlyFirstLabel); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java index 09b01cf8d..c53099d0f 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java @@ -36,10 +36,7 @@ import java.io.IOException; import java.io.InputStream; import java.net.URI; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; +import java.util.*; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -67,9 +64,9 @@ public class TestStreamInputSplit extends BaseND4JTest { rr.initialize(is); List> exp = new ArrayList<>(); - exp.add(Arrays.asList(new Text("a"), new Text("b"), new Text("c"))); - exp.add(Arrays.asList(new Text("d"), new Text("e"), new Text("f"))); - exp.add(Arrays.asList(new Text("1"), new Text("2"), new Text("3"))); + exp.add(Arrays.asList(new Text("a"), new Text("b"), new Text("c"))); + exp.add(Arrays.asList(new Text("d"), new Text("e"), new Text("f"))); + exp.add(Arrays.asList(new Text("1"), new Text("2"), new Text("3"))); List> act = new ArrayList<>(); while(rr.hasNext()){ @@ -111,10 +108,10 @@ public class TestStreamInputSplit extends BaseND4JTest { List>> exp = new ArrayList<>(); exp.add(Arrays.asList( - Arrays.asList(new Text("a"), new Text("b"), new Text("c")), - Arrays.asList(new Text("d"), new Text("e"), new Text("f")))); - exp.add(Arrays.asList( - Arrays.asList(new Text("1"), new Text("2"), new Text("3")))); + Arrays.asList(new Text("a"), new Text("b"), new Text("c")), + Arrays.asList(new Text("d"), new Text("e"), new Text("f")))); + exp.add(Collections.singletonList( + Arrays.asList(new Text("1"), new Text("2"), new Text("3")))); List>> act = new ArrayList<>(); while (rr.hasNext()) { diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java index 7a968ddfe..d478c934b 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java @@ -51,8 +51,8 @@ public class TestTransformProcess extends BaseND4JTest { .doubleMathOp("col2", MathOp.Add, 1.0) .build(); - List in = Arrays.asList(new Text("Text"), new DoubleWritable(2.0)); - List exp = Arrays.asList(new Text("Text"), new DoubleWritable(3.0)); + List in = Arrays.asList(new Text("Text"), new DoubleWritable(2.0)); + List exp = Arrays.asList(new Text("Text"), new DoubleWritable(3.0)); List out = transformProcess.execute(in); assertEquals(exp, out); @@ -73,11 +73,11 @@ public class TestTransformProcess extends BaseND4JTest { .build(); String s = "in text"; - List input = Collections.singletonList(new Text(s)); + List input = Collections.singletonList(new Text(s)); List> expSeq = new ArrayList<>(s.length()); for( int i = 0; isingletonList(new IntWritable(m.get(s.charAt(i))))); + expSeq.add(Collections.singletonList(new IntWritable(m.get(s.charAt(i))))); } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java index f49e0c4d4..1f4af7292 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java @@ -44,20 +44,20 @@ public class TestConditions extends BaseND4JTest { Condition condition = new IntegerColumnCondition("column", SequenceConditionMode.Or, ConditionOp.LessThan, 0); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new IntWritable(-1)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new IntWritable(-2)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(0)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(1)))); + assertTrue(condition.condition(Collections.singletonList(new IntWritable(-1)))); + assertTrue(condition.condition(Collections.singletonList(new IntWritable(-2)))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(0)))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(1)))); Set set = new HashSet<>(); set.add(0); set.add(3); condition = new IntegerColumnCondition("column", SequenceConditionMode.Or, ConditionOp.InSet, set); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new IntWritable(0)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new IntWritable(3)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(1)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(2)))); + assertTrue(condition.condition(Collections.singletonList(new IntWritable(0)))); + assertTrue(condition.condition(Collections.singletonList(new IntWritable(3)))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(1)))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(2)))); } @Test @@ -67,19 +67,19 @@ public class TestConditions extends BaseND4JTest { Condition condition = new LongColumnCondition("column", SequenceConditionMode.Or, ConditionOp.NotEqual, 5L); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new LongWritable(0)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new LongWritable(1)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new LongWritable(5)))); + assertTrue(condition.condition(Collections.singletonList(new LongWritable(0)))); + assertTrue(condition.condition(Collections.singletonList(new LongWritable(1)))); + assertFalse(condition.condition(Collections.singletonList(new LongWritable(5)))); Set set = new HashSet<>(); set.add(0L); set.add(3L); condition = new LongColumnCondition("column", SequenceConditionMode.Or, ConditionOp.NotInSet, set); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new LongWritable(5)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new LongWritable(10)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new LongWritable(0)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new LongWritable(3)))); + assertTrue(condition.condition(Collections.singletonList(new LongWritable(5)))); + assertTrue(condition.condition(Collections.singletonList(new LongWritable(10)))); + assertFalse(condition.condition(Collections.singletonList(new LongWritable(0)))); + assertFalse(condition.condition(Collections.singletonList(new LongWritable(3)))); } @Test @@ -90,20 +90,20 @@ public class TestConditions extends BaseND4JTest { new DoubleColumnCondition("column", SequenceConditionMode.Or, ConditionOp.GreaterOrEqual, 0); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(0.0)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(0.5)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new DoubleWritable(-0.5)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new DoubleWritable(-1)))); + assertTrue(condition.condition(Collections.singletonList(new DoubleWritable(0.0)))); + assertTrue(condition.condition(Collections.singletonList(new DoubleWritable(0.5)))); + assertFalse(condition.condition(Collections.singletonList(new DoubleWritable(-0.5)))); + assertFalse(condition.condition(Collections.singletonList(new DoubleWritable(-1)))); Set set = new HashSet<>(); set.add(0.0); set.add(3.0); condition = new DoubleColumnCondition("column", SequenceConditionMode.Or, ConditionOp.InSet, set); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(0.0)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(3.0)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new DoubleWritable(1.0)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new DoubleWritable(2.0)))); + assertTrue(condition.condition(Collections.singletonList(new DoubleWritable(0.0)))); + assertTrue(condition.condition(Collections.singletonList(new DoubleWritable(3.0)))); + assertFalse(condition.condition(Collections.singletonList(new DoubleWritable(1.0)))); + assertFalse(condition.condition(Collections.singletonList(new DoubleWritable(2.0)))); } @@ -115,20 +115,20 @@ public class TestConditions extends BaseND4JTest { new FloatColumnCondition("column", SequenceConditionMode.Or, ConditionOp.GreaterOrEqual, 0); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new FloatWritable(0.0f)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new FloatWritable(0.5f)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new FloatWritable(-0.5f)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new FloatWritable(-1f)))); + assertTrue(condition.condition(Collections.singletonList(new FloatWritable(0.0f)))); + assertTrue(condition.condition(Collections.singletonList(new FloatWritable(0.5f)))); + assertFalse(condition.condition(Collections.singletonList(new FloatWritable(-0.5f)))); + assertFalse(condition.condition(Collections.singletonList(new FloatWritable(-1f)))); Set set = new HashSet(); set.add(0.0f); set.add(3.0f); condition = new FloatColumnCondition("column", SequenceConditionMode.Or, ConditionOp.InSet, set); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new FloatWritable(0.0f)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new FloatWritable(3.0f)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new FloatWritable(1.0f)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new FloatWritable(2.0f)))); + assertTrue(condition.condition(Collections.singletonList(new FloatWritable(0.0f)))); + assertTrue(condition.condition(Collections.singletonList(new FloatWritable(3.0f)))); + assertFalse(condition.condition(Collections.singletonList(new FloatWritable(1.0f)))); + assertFalse(condition.condition(Collections.singletonList(new FloatWritable(2.0f)))); } @Test @@ -138,18 +138,18 @@ public class TestConditions extends BaseND4JTest { Condition condition = new StringColumnCondition("column", SequenceConditionMode.Or, ConditionOp.Equal, "value"); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("value")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("not_value")))); + assertTrue(condition.condition(Collections.singletonList(new Text("value")))); + assertFalse(condition.condition(Collections.singletonList(new Text("not_value")))); Set set = new HashSet<>(); set.add("in set"); set.add("also in set"); condition = new StringColumnCondition("column", SequenceConditionMode.Or, ConditionOp.InSet, set); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("in set")))); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("also in set")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("not in the set")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text(":)")))); + assertTrue(condition.condition(Collections.singletonList(new Text("in set")))); + assertTrue(condition.condition(Collections.singletonList(new Text("also in set")))); + assertFalse(condition.condition(Collections.singletonList(new Text("not in the set")))); + assertFalse(condition.condition(Collections.singletonList(new Text(":)")))); } @Test @@ -160,18 +160,18 @@ public class TestConditions extends BaseND4JTest { new CategoricalColumnCondition("column", SequenceConditionMode.Or, ConditionOp.Equal, "alpha"); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("alpha")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("beta")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("gamma")))); + assertTrue(condition.condition(Collections.singletonList(new Text("alpha")))); + assertFalse(condition.condition(Collections.singletonList(new Text("beta")))); + assertFalse(condition.condition(Collections.singletonList(new Text("gamma")))); Set set = new HashSet<>(); set.add("alpha"); set.add("beta"); condition = new StringColumnCondition("column", SequenceConditionMode.Or, ConditionOp.InSet, set); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("alpha")))); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("beta")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("gamma")))); + assertTrue(condition.condition(Collections.singletonList(new Text("alpha")))); + assertTrue(condition.condition(Collections.singletonList(new Text("beta")))); + assertFalse(condition.condition(Collections.singletonList(new Text("gamma")))); } @Test @@ -183,18 +183,18 @@ public class TestConditions extends BaseND4JTest { 1451606400000L); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new LongWritable(1451606400000L)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new LongWritable(1451606400000L - 1L)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new LongWritable(1451606400000L + 1L)))); + assertTrue(condition.condition(Collections.singletonList(new LongWritable(1451606400000L)))); + assertTrue(condition.condition(Collections.singletonList(new LongWritable(1451606400000L - 1L)))); + assertFalse(condition.condition(Collections.singletonList(new LongWritable(1451606400000L + 1L)))); assertFalse(condition - .condition(Collections.singletonList((Writable) new LongWritable(1451606400000L + 1000L)))); + .condition(Collections.singletonList(new LongWritable(1451606400000L + 1000L)))); Set set = new HashSet<>(); set.add(1451606400000L); condition = new TimeColumnCondition("column", SequenceConditionMode.Or, ConditionOp.InSet, set); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new LongWritable(1451606400000L)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new LongWritable(1451606400000L + 1L)))); + assertTrue(condition.condition(Collections.singletonList(new LongWritable(1451606400000L)))); + assertFalse(condition.condition(Collections.singletonList(new LongWritable(1451606400000L + 1L)))); } @Test @@ -206,22 +206,22 @@ public class TestConditions extends BaseND4JTest { Condition condition = new StringRegexColumnCondition("column", "abc.*"); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("abc")))); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("abcdefghijk")))); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("abc more text \tetc")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("ab")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("also doesn't match")))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text(" abc")))); + assertTrue(condition.condition(Collections.singletonList(new Text("abc")))); + assertTrue(condition.condition(Collections.singletonList(new Text("abcdefghijk")))); + assertTrue(condition.condition(Collections.singletonList(new Text("abc more text \tetc")))); + assertFalse(condition.condition(Collections.singletonList(new Text("ab")))); + assertFalse(condition.condition(Collections.singletonList(new Text("also doesn't match")))); + assertFalse(condition.condition(Collections.singletonList(new Text(" abc")))); //Check application on non-String columns schema = TestTransforms.getSchema(ColumnType.Integer); condition = new StringRegexColumnCondition("column", "123\\d*"); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new IntWritable(123)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new IntWritable(123456)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(-123)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(456789)))); + assertTrue(condition.condition(Collections.singletonList(new IntWritable(123)))); + assertTrue(condition.condition(Collections.singletonList(new IntWritable(123456)))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(-123)))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(456789)))); } @Test @@ -231,10 +231,10 @@ public class TestConditions extends BaseND4JTest { Condition condition = new NullWritableColumnCondition("column"); condition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) NullWritable.INSTANCE))); - assertTrue(condition.condition(Collections.singletonList((Writable) new NullWritable()))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(0)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("1")))); + assertTrue(condition.condition(Collections.singletonList(NullWritable.INSTANCE))); + assertTrue(condition.condition(Collections.singletonList(new NullWritable()))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(0)))); + assertFalse(condition.condition(Collections.singletonList(new Text("1")))); } @Test @@ -248,16 +248,16 @@ public class TestConditions extends BaseND4JTest { Condition notCondition = BooleanCondition.NOT(condition); notCondition.setInputSchema(schema); - assertTrue(condition.condition(Collections.singletonList((Writable) new IntWritable(-1)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new IntWritable(-2)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(0)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(1)))); + assertTrue(condition.condition(Collections.singletonList(new IntWritable(-1)))); + assertTrue(condition.condition(Collections.singletonList(new IntWritable(-2)))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(0)))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(1)))); //Expect opposite for not condition: - assertFalse(notCondition.condition(Collections.singletonList((Writable) new IntWritable(-1)))); - assertFalse(notCondition.condition(Collections.singletonList((Writable) new IntWritable(-2)))); - assertTrue(notCondition.condition(Collections.singletonList((Writable) new IntWritable(0)))); - assertTrue(notCondition.condition(Collections.singletonList((Writable) new IntWritable(1)))); + assertFalse(notCondition.condition(Collections.singletonList(new IntWritable(-1)))); + assertFalse(notCondition.condition(Collections.singletonList(new IntWritable(-2)))); + assertTrue(notCondition.condition(Collections.singletonList(new IntWritable(0)))); + assertTrue(notCondition.condition(Collections.singletonList(new IntWritable(1)))); } @Test @@ -274,10 +274,10 @@ public class TestConditions extends BaseND4JTest { Condition andCondition = BooleanCondition.AND(condition1, condition2); andCondition.setInputSchema(schema); - assertFalse(andCondition.condition(Collections.singletonList((Writable) new IntWritable(-1)))); - assertTrue(andCondition.condition(Collections.singletonList((Writable) new IntWritable(-2)))); - assertFalse(andCondition.condition(Collections.singletonList((Writable) new IntWritable(0)))); - assertFalse(andCondition.condition(Collections.singletonList((Writable) new IntWritable(1)))); + assertFalse(andCondition.condition(Collections.singletonList(new IntWritable(-1)))); + assertTrue(andCondition.condition(Collections.singletonList(new IntWritable(-2)))); + assertFalse(andCondition.condition(Collections.singletonList(new IntWritable(0)))); + assertFalse(andCondition.condition(Collections.singletonList(new IntWritable(1)))); } @@ -288,15 +288,15 @@ public class TestConditions extends BaseND4JTest { Condition condition = new InvalidValueColumnCondition("column"); condition.setInputSchema(schema); - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(-1)))); //Not invalid -> condition does not apply - assertFalse(condition.condition(Collections.singletonList((Writable) new IntWritable(-2)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new LongWritable(1000)))); - assertFalse(condition.condition(Collections.singletonList((Writable) new Text("1000")))); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("text")))); - assertTrue(condition.condition(Collections.singletonList((Writable) new Text("NaN")))); + assertFalse(condition.condition(Collections.singletonList(new IntWritable(-1)))); //Not invalid -> condition does not apply + assertFalse(condition.condition(Collections.singletonList(new IntWritable(-2)))); + assertFalse(condition.condition(Collections.singletonList(new LongWritable(1000)))); + assertFalse(condition.condition(Collections.singletonList(new Text("1000")))); + assertTrue(condition.condition(Collections.singletonList(new Text("text")))); + assertTrue(condition.condition(Collections.singletonList(new Text("NaN")))); assertTrue(condition.condition( - Collections.singletonList((Writable) new LongWritable(1L + (long) Integer.MAX_VALUE)))); - assertTrue(condition.condition(Collections.singletonList((Writable) new DoubleWritable(3.14159)))); + Collections.singletonList(new LongWritable(1L + (long) Integer.MAX_VALUE)))); + assertTrue(condition.condition(Collections.singletonList(new DoubleWritable(3.14159)))); } @Test @@ -304,14 +304,14 @@ public class TestConditions extends BaseND4JTest { Condition c = new SequenceLengthCondition(ConditionOp.LessThan, 2); - List> l1 = Arrays.asList(Collections.singletonList(NullWritable.INSTANCE)); + List> l1 = Collections.singletonList(Collections.singletonList(NullWritable.INSTANCE)); - List> l2 = Arrays.asList(Collections.singletonList(NullWritable.INSTANCE), - Collections.singletonList(NullWritable.INSTANCE)); + List> l2 = Arrays.asList(Collections.singletonList(NullWritable.INSTANCE), + Collections.singletonList(NullWritable.INSTANCE)); - List> l3 = Arrays.asList(Collections.singletonList(NullWritable.INSTANCE), - Collections.singletonList(NullWritable.INSTANCE), - Collections.singletonList(NullWritable.INSTANCE)); + List> l3 = Arrays.asList(Collections.singletonList(NullWritable.INSTANCE), + Collections.singletonList(NullWritable.INSTANCE), + Collections.singletonList(NullWritable.INSTANCE)); assertTrue(c.conditionSequence(l1)); assertFalse(c.conditionSequence(l2)); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java index 0b339bffa..1f937609e 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java @@ -45,9 +45,9 @@ public class TestFilters extends BaseND4JTest { @Test public void testFilterNumColumns() { List> list = new ArrayList<>(); - list.add(Collections.singletonList((Writable) new IntWritable(-1))); - list.add(Collections.singletonList((Writable) new IntWritable(0))); - list.add(Collections.singletonList((Writable) new IntWritable(2))); + list.add(Collections.singletonList(new IntWritable(-1))); + list.add(Collections.singletonList(new IntWritable(0))); + list.add(Collections.singletonList(new IntWritable(2))); Schema schema = new Schema.Builder().addColumnInteger("intCol", 0, 10) //Only values in the range 0 to 10 are ok .addColumnDouble("doubleCol", -100.0, 100.0) //-100 to 100 only; no NaN or infinite @@ -56,7 +56,7 @@ public class TestFilters extends BaseND4JTest { for (int i = 0; i < list.size(); i++) assertTrue(numColumns.removeExample(list.get(i))); - List correct = Arrays.asList(new IntWritable(0), new DoubleWritable(2)); + List correct = Arrays.asList(new IntWritable(0), new DoubleWritable(2)); assertFalse(numColumns.removeExample(correct)); } @@ -65,9 +65,9 @@ public class TestFilters extends BaseND4JTest { public void testFilterInvalidValues() { List> list = new ArrayList<>(); - list.add(Collections.singletonList((Writable) new IntWritable(-1))); - list.add(Collections.singletonList((Writable) new IntWritable(0))); - list.add(Collections.singletonList((Writable) new IntWritable(2))); + list.add(Collections.singletonList(new IntWritable(-1))); + list.add(Collections.singletonList(new IntWritable(0))); + list.add(Collections.singletonList(new IntWritable(2))); Schema schema = new Schema.Builder().addColumnInteger("intCol", 0, 10) //Only values in the range 0 to 10 are ok .addColumnDouble("doubleCol", -100.0, 100.0) //-100 to 100 only; no NaN or infinite @@ -77,16 +77,16 @@ public class TestFilters extends BaseND4JTest { filter.setInputSchema(schema); //Test valid examples: - assertFalse(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(0)))); - assertFalse(filter.removeExample(asList((Writable) new IntWritable(10), new DoubleWritable(0)))); - assertFalse(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(-100)))); - assertFalse(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(100)))); + assertFalse(filter.removeExample(asList(new IntWritable(0), new DoubleWritable(0)))); + assertFalse(filter.removeExample(asList(new IntWritable(10), new DoubleWritable(0)))); + assertFalse(filter.removeExample(asList(new IntWritable(0), new DoubleWritable(-100)))); + assertFalse(filter.removeExample(asList(new IntWritable(0), new DoubleWritable(100)))); //Test invalid: - assertTrue(filter.removeExample(asList((Writable) new IntWritable(-1), new DoubleWritable(0)))); - assertTrue(filter.removeExample(asList((Writable) new IntWritable(11), new DoubleWritable(0)))); - assertTrue(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(-101)))); - assertTrue(filter.removeExample(asList((Writable) new IntWritable(0), new DoubleWritable(101)))); + assertTrue(filter.removeExample(asList(new IntWritable(-1), new DoubleWritable(0)))); + assertTrue(filter.removeExample(asList(new IntWritable(11), new DoubleWritable(0)))); + assertTrue(filter.removeExample(asList(new IntWritable(0), new DoubleWritable(-101)))); + assertTrue(filter.removeExample(asList(new IntWritable(0), new DoubleWritable(101)))); } @Test @@ -98,11 +98,11 @@ public class TestFilters extends BaseND4JTest { Filter filter = new ConditionFilter(condition); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(10)))); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(1)))); - assertFalse(filter.removeExample(Collections.singletonList((Writable) new IntWritable(0)))); - assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-1)))); - assertTrue(filter.removeExample(Collections.singletonList((Writable) new IntWritable(-10)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(10)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(1)))); + assertFalse(filter.removeExample(Collections.singletonList(new IntWritable(0)))); + assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-1)))); + assertTrue(filter.removeExample(Collections.singletonList(new IntWritable(-10)))); } } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java index c41ebb165..ad056ccef 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java @@ -47,20 +47,20 @@ public class TestJoin extends BaseND4JTest { Schema secondSchema = new Schema.Builder().addColumnString("keyColumn").addColumnsInteger("second0").build(); List> first = new ArrayList<>(); - first.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1))); - first.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11))); + first.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1))); + first.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11))); List> second = new ArrayList<>(); - second.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(100))); - second.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(110))); + second.add(Arrays.asList(new Text("key0"), new IntWritable(100))); + second.add(Arrays.asList(new Text("key1"), new IntWritable(110))); Join join = new Join.Builder(Join.JoinType.Inner).setJoinColumns("keyColumn") .setSchemas(firstSchema, secondSchema).build(); List> expected = new ArrayList<>(); - expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1), + expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1), new IntWritable(100))); - expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11), + expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11), new IntWritable(110))); @@ -74,9 +74,9 @@ public class TestJoin extends BaseND4JTest { //Check joining with null values: expected = new ArrayList<>(); - expected.add(Arrays.asList((Writable) new Text("key0"), new IntWritable(0), new IntWritable(1), + expected.add(Arrays.asList(new Text("key0"), new IntWritable(0), new IntWritable(1), NullWritable.INSTANCE)); - expected.add(Arrays.asList((Writable) new Text("key1"), new IntWritable(10), new IntWritable(11), + expected.add(Arrays.asList(new Text("key1"), new IntWritable(10), new IntWritable(11), NullWritable.INSTANCE)); for (int i = 0; i < first.size(); i++) { List out = join.joinExamples(first.get(i), null); @@ -84,9 +84,9 @@ public class TestJoin extends BaseND4JTest { } expected = new ArrayList<>(); - expected.add(Arrays.asList((Writable) new Text("key0"), NullWritable.INSTANCE, NullWritable.INSTANCE, + expected.add(Arrays.asList(new Text("key0"), NullWritable.INSTANCE, NullWritable.INSTANCE, new IntWritable(100))); - expected.add(Arrays.asList((Writable) new Text("key1"), NullWritable.INSTANCE, NullWritable.INSTANCE, + expected.add(Arrays.asList(new Text("key1"), NullWritable.INSTANCE, NullWritable.INSTANCE, new IntWritable(110))); for (int i = 0; i < first.size(); i++) { List out = join.joinExamples(null, second.get(i)); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java index caadceb15..6106e37a3 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java @@ -26,11 +26,12 @@ import org.nd4j.common.tests.BaseND4JTest; import java.util.*; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; public class AggregableMultiOpTest extends BaseND4JTest { - private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + private final List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @Test public void testMulti() throws Exception { @@ -38,18 +39,18 @@ public class AggregableMultiOpTest extends BaseND4JTest { AggregatorImpls.AggregableSum as = new AggregatorImpls.AggregableSum<>(); AggregableMultiOp multi = new AggregableMultiOp<>(Arrays.asList(af, as)); - assertTrue(multi.getOperations().size() == 2); + assertEquals(2, multi.getOperations().size()); for (int i = 0; i < intList.size(); i++) { multi.accept(intList.get(i)); } // mutablility - assertTrue(as.get().toDouble() == 45D); - assertTrue(af.get().toInt() == 1); + assertEquals(45D, as.get().toDouble()); + assertEquals(1, af.get().toInt()); List res = multi.get(); - assertTrue(res.get(1).toDouble() == 45D); - assertTrue(res.get(0).toInt() == 1); + assertEquals(45D, res.get(1).toDouble()); + assertEquals(1, res.get(0).toInt()); AggregatorImpls.AggregableFirst rf = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableSum rs = new AggregatorImpls.AggregableSum<>(); @@ -60,12 +61,12 @@ public class AggregableMultiOpTest extends BaseND4JTest { } List revRes = reverse.get(); - assertTrue(revRes.get(1).toDouble() == 45D); - assertTrue(revRes.get(0).toInt() == 9); + assertEquals(45D, revRes.get(1).toDouble()); + assertEquals(9, revRes.get(0).toInt()); multi.combine(reverse); List combinedRes = multi.get(); - assertTrue(combinedRes.get(1).toDouble() == 90D); - assertTrue(combinedRes.get(0).toInt() == 1); + assertEquals(90D, combinedRes.get(1).toDouble()); + assertEquals(1, combinedRes.get(0).toInt()); } } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java index 8cfd5e979..0cad8f9f0 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java @@ -33,8 +33,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class AggregatorImplsTest extends BaseND4JTest { - private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); - private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); + private final List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + private final List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); @Test public void aggregableFirstTest() { @@ -48,7 +48,7 @@ public class AggregatorImplsTest extends BaseND4JTest { for (int i = 0; i < stringList.size(); i++) { firstS.accept(stringList.get(i)); } - assertTrue(firstS.get().toString().equals("arakoa")); + assertEquals("arakoa", firstS.get().toString()); AggregatorImpls.AggregableFirst reverse = new AggregatorImpls.AggregableFirst<>(); @@ -72,7 +72,7 @@ public class AggregatorImplsTest extends BaseND4JTest { for (int i = 0; i < stringList.size(); i++) { lastS.accept(stringList.get(i)); } - assertTrue(lastS.get().toString().equals("acceptance")); + assertEquals("acceptance", lastS.get().toString()); AggregatorImpls.AggregableLast reverse = new AggregatorImpls.AggregableLast<>(); @@ -182,7 +182,7 @@ public class AggregatorImplsTest extends BaseND4JTest { for (int i = 0; i < intList.size(); i++) { mn.accept(intList.get(i)); } - assertEquals(9l, (long) mn.getCount()); + assertEquals(9L, (long) mn.getCount()); assertEquals(5D, mn.get().toDouble(), 0.001); @@ -191,7 +191,7 @@ public class AggregatorImplsTest extends BaseND4JTest { reverse.accept(intList.get(intList.size() - i - 1)); } mn.combine(reverse); - assertEquals(18l, (long) mn.getCount()); + assertEquals(18L, (long) mn.getCount()); assertEquals(5D, mn.get().toDouble(), 0.001); } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java index a04d6f57a..a8c9aace5 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java @@ -29,35 +29,36 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; public class DispatchOpTest extends BaseND4JTest { - private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); - private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); + private final List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + private final List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); @Test public void testDispatchSimple() { AggregatorImpls.AggregableFirst af = new AggregatorImpls.AggregableFirst<>(); AggregatorImpls.AggregableSum as = new AggregatorImpls.AggregableSum<>(); AggregableMultiOp multiaf = - new AggregableMultiOp<>(Collections.>singletonList(af)); + new AggregableMultiOp<>(Collections.singletonList(af)); AggregableMultiOp multias = - new AggregableMultiOp<>(Collections.>singletonList(as)); + new AggregableMultiOp<>(Collections.singletonList(as)); DispatchOp parallel = - new DispatchOp<>(Arrays.>>asList(multiaf, multias)); + new DispatchOp<>(Arrays.asList(multiaf, multias)); - assertTrue(multiaf.getOperations().size() == 1); - assertTrue(multias.getOperations().size() == 1); - assertTrue(parallel.getOperations().size() == 2); + assertEquals(1, multiaf.getOperations().size()); + assertEquals(1, multias.getOperations().size()); + assertEquals(2, parallel.getOperations().size()); for (int i = 0; i < intList.size(); i++) { parallel.accept(Arrays.asList(intList.get(i), intList.get(i))); } List res = parallel.get(); - assertTrue(res.get(1).toDouble() == 45D); - assertTrue(res.get(0).toInt() == 1); + assertEquals(45D, res.get(1).toDouble()); + assertEquals(1, res.get(0).toInt()); } @@ -73,20 +74,20 @@ public class DispatchOpTest extends BaseND4JTest { DispatchOp parallel = new DispatchOp<>( - Arrays.>>asList(multi, otherMulti)); + Arrays.asList(multi, otherMulti)); - assertTrue(multi.getOperations().size() == 2); - assertTrue(otherMulti.getOperations().size() == 2); - assertTrue(parallel.getOperations().size() == 2); + assertEquals(2, multi.getOperations().size()); + assertEquals(2, otherMulti.getOperations().size()); + assertEquals(2, parallel.getOperations().size()); for (int i = 0; i < intList.size(); i++) { parallel.accept(Arrays.asList(intList.get(i), intList.get(i))); } List res = parallel.get(); - assertTrue(res.get(1).toDouble() == 45D); - assertTrue(res.get(0).toInt() == 1); - assertTrue(res.get(3).toDouble() == 9); - assertTrue(res.get(2).toInt() == 9); + assertEquals(45D, res.get(1).toDouble()); + assertEquals(1, res.get(0).toInt()); + assertEquals(9, res.get(3).toDouble()); + assertEquals(9, res.get(2).toInt()); } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java index 80d7d7eee..b42f75f6f 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java @@ -46,10 +46,10 @@ public class TestMultiOpReduce extends BaseND4JTest { public void testMultiOpReducerDouble() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(0))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(1))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new DoubleWritable(2))); + inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(0))); + inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(1))); + inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2))); + inputs.add(Arrays.asList(new Text("someKey"), new DoubleWritable(2))); Map exp = new LinkedHashMap<>(); exp.put(ReduceOp.Min, 0.0); @@ -90,10 +90,10 @@ public class TestMultiOpReduce extends BaseND4JTest { public void testReducerInteger() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(0))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(1))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(2))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(2))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(0))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(1))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(2))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(2))); Map exp = new LinkedHashMap<>(); exp.put(ReduceOp.Min, 0.0); @@ -135,10 +135,10 @@ public class TestMultiOpReduce extends BaseND4JTest { public void testReduceString() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("1"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("2"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("3"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("4"))); Map exp = new LinkedHashMap<>(); exp.put(ReduceOp.Append, "1234"); @@ -171,12 +171,12 @@ public class TestMultiOpReduce extends BaseND4JTest { public void testReduceIntegerIgnoreInvalidValues() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("0"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(2))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("ignore me"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("also ignore me"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("0"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("1"))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(2))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("ignore me"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("also ignore me"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("2"))); Map exp = new LinkedHashMap<>(); @@ -238,16 +238,16 @@ public class TestMultiOpReduce extends BaseND4JTest { public void testCustomReductions() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(1), new Text("zero"), + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(1), new Text("zero"), new DoubleWritable(0))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(2), new Text("one"), + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(2), new Text("one"), new DoubleWritable(1))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(3), new Text("two"), + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(3), new Text("two"), new DoubleWritable(2))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(4), new Text("three"), + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(4), new Text("three"), new DoubleWritable(3))); - List expected = Arrays.asList((Writable) new Text("someKey"), new IntWritable(10), new Text("one"), + List expected = Arrays.asList(new Text("someKey"), new IntWritable(10), new Text("one"), new DoubleWritable(1)); @@ -288,16 +288,16 @@ public class TestMultiOpReduce extends BaseND4JTest { public void testCustomReductionsWithCondition() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(1), new Text("zero"), + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(1), new Text("zero"), new DoubleWritable(0))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(2), new Text("one"), + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(2), new Text("one"), new DoubleWritable(1))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(3), new Text("two"), + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(3), new Text("two"), new DoubleWritable(2))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new IntWritable(4), new Text("three"), + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(4), new Text("three"), new DoubleWritable(3))); - List expected = Arrays.asList((Writable) new Text("someKey"), new IntWritable(10), new IntWritable(3), + List expected = Arrays.asList(new Text("someKey"), new IntWritable(10), new IntWritable(3), new DoubleWritable(1)); @@ -341,7 +341,7 @@ public class TestMultiOpReduce extends BaseND4JTest { public IAggregableReduceOp> reduceOp() { //For testing: let's take the second value return new AggregableMultiOp<>(Collections - .>singletonList(new AggregableSecond())); + .singletonList(new AggregableSecond())); } @Override @@ -483,12 +483,12 @@ public class TestMultiOpReduce extends BaseND4JTest { .addColumnString("filterCol").addColumnString("textCol").build(); List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(1), new Text("a"), new Text("zero"))); - inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(2), new Text("b"), new Text("one"))); - inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(3), new Text("a"), new Text("two"))); - inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(4), new Text("b"), new Text("three"))); - inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(5), new Text("a"), new Text("three"))); - inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(6), new Text("b"), new Text("three"))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(1), new Text("a"), new Text("zero"))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(2), new Text("b"), new Text("one"))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(3), new Text("a"), new Text("two"))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(4), new Text("b"), new Text("three"))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(5), new Text("a"), new Text("three"))); + inputs.add(Arrays.asList(new Text("someKey"), new IntWritable(6), new Text("b"), new Text("three"))); Condition condition = new StringColumnCondition("filterCol", ConditionOp.Equal, "a"); @@ -504,7 +504,7 @@ public class TestMultiOpReduce extends BaseND4JTest { accumulator.accept(inputs.get(i)); } List out = accumulator.get(); - List expected = Arrays.asList(new Text("someKey"), new IntWritable(1 + 3 + 5), + List expected = Arrays.asList(new Text("someKey"), new IntWritable(1 + 3 + 5), new LongWritable(2), new LongWritable(4)); assertEquals(4, out.size()); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java index f7aa89170..c5867e5dd 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java @@ -43,7 +43,7 @@ public class TestReductions extends BaseND4JTest { Text t2 = new Text("41.8781136,-87.6297982"); Text t3 = new Text("33.7489954,-84.3879824"); - List list = Arrays.asList(t1, t1, t1, t2, t2, t3); + List list = Arrays.asList(t1, t1, t1, t2, t2, t3); GeographicMidpointReduction reduction = new GeographicMidpointReduction(","); @@ -68,8 +68,8 @@ public class TestReductions extends BaseND4JTest { //Test multiple reductions - list = Arrays.asList(t1, t1, t2); - List list2 = Arrays.asList(t1, t2, t3); + list = Arrays.asList(t1, t1, t2); + List list2 = Arrays.asList(t1, t2, t3); reduceOp = reduction.reduceOp(); for(Writable w : list){ reduceOp.accept(w); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java index 1bb9ae62a..6add32050 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java @@ -52,15 +52,15 @@ public class TestReduceSequenceByWindowFunction extends BaseND4JTest { //Create some data. List> sequence = new ArrayList<>(); //First window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2))); //Second window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3))); //Third window: empty //Fourth window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").build(); @@ -79,17 +79,17 @@ public class TestReduceSequenceByWindowFunction extends BaseND4JTest { assertEquals(4, postApply.size()); - List exp0 = Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0 + 1 + 2)); + List exp0 = Arrays.asList(new LongWritable(1451606400000L), new IntWritable(1 + 2)); assertEquals(exp0, postApply.get(0)); - List exp1 = Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3)); + List exp1 = Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3)); assertEquals(exp1, postApply.get(1)); // here, takefirst of an empty window -> nullwritable makes more sense - List exp2 = Arrays.asList((Writable) NullWritable.INSTANCE, NullWritable.INSTANCE); + List exp2 = Arrays.asList(NullWritable.INSTANCE, NullWritable.INSTANCE); assertEquals(exp2, postApply.get(2)); - List exp3 = Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(9)); + List exp3 = Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(9)); assertEquals(exp3, postApply.get(3)); } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java index c26eaec61..cae0795c4 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java @@ -46,13 +46,13 @@ public class TestSequenceSplit extends BaseND4JTest { .build(); List> inputSequence = new ArrayList<>(); - inputSequence.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0"))); - inputSequence.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1"))); + inputSequence.add(Arrays.asList(new LongWritable(0), new Text("t0"))); + inputSequence.add(Arrays.asList(new LongWritable(1000), new Text("t1"))); //Second split: 74 seconds later - inputSequence.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2"))); - inputSequence.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3"))); + inputSequence.add(Arrays.asList(new LongWritable(75000), new Text("t2"))); + inputSequence.add(Arrays.asList(new LongWritable(100000), new Text("t3"))); //Third split: 1 minute and 1 milliseconds later - inputSequence.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4"))); + inputSequence.add(Arrays.asList(new LongWritable(160001), new Text("t4"))); SequenceSplit seqSplit = new SequenceSplitTimeSeparation("time", 1, TimeUnit.MINUTES); seqSplit.setInputSchema(schema); @@ -61,13 +61,13 @@ public class TestSequenceSplit extends BaseND4JTest { assertEquals(3, splits.size()); List> exp0 = new ArrayList<>(); - exp0.add(Arrays.asList((Writable) new LongWritable(0), new Text("t0"))); - exp0.add(Arrays.asList((Writable) new LongWritable(1000), new Text("t1"))); + exp0.add(Arrays.asList(new LongWritable(0), new Text("t0"))); + exp0.add(Arrays.asList(new LongWritable(1000), new Text("t1"))); List> exp1 = new ArrayList<>(); - exp1.add(Arrays.asList((Writable) new LongWritable(75000), new Text("t2"))); - exp1.add(Arrays.asList((Writable) new LongWritable(100000), new Text("t3"))); + exp1.add(Arrays.asList(new LongWritable(75000), new Text("t2"))); + exp1.add(Arrays.asList(new LongWritable(100000), new Text("t3"))); List> exp2 = new ArrayList<>(); - exp2.add(Arrays.asList((Writable) new LongWritable(160001), new Text("t4"))); + exp2.add(Arrays.asList(new LongWritable(160001), new Text("t4"))); assertEquals(exp0, splits.get(0)); assertEquals(exp1, splits.get(1)); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java index ff45a3f3e..f1becef3c 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java @@ -49,15 +49,15 @@ public class TestWindowFunctions extends BaseND4JTest { //Create some data. List> sequence = new ArrayList<>(); //First window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2))); //Second window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3))); //Third window: empty //Fourth window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").build(); @@ -100,15 +100,15 @@ public class TestWindowFunctions extends BaseND4JTest { //Create some data. List> sequence = new ArrayList<>(); //First window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2))); //Second window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3))); //Third window: empty //Fourth window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").build(); @@ -150,15 +150,15 @@ public class TestWindowFunctions extends BaseND4JTest { //Create some data. List> sequence = new ArrayList<>(); //First window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2))); //Second window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 1000L), new IntWritable(3))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 1000L), new IntWritable(3))); //Third window: empty //Fourth window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3000L), new IntWritable(4))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 3100L), new IntWritable(5))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3000L), new IntWritable(4))); + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 3100L), new IntWritable(5))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").build(); @@ -188,13 +188,13 @@ public class TestWindowFunctions extends BaseND4JTest { //Create some data. List> sequence = new ArrayList<>(); //First window: - sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); - sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); - sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); - sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); - sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); - sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); + sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); + sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); + sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); + sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); + sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); + sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); + sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) @@ -207,32 +207,32 @@ public class TestWindowFunctions extends BaseND4JTest { //First window: -1000 to 1000 List> exp0 = new ArrayList<>(); - exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); - exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); - exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); + exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); + exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); + exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); //Second window: 0 to 2000 List> exp1 = new ArrayList<>(); - exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); - exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); - exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); - exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); - exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); + exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); + exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); + exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); + exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); + exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); //Third window: 1000 to 3000 List> exp2 = new ArrayList<>(); - exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); - exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); - exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); + exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); + exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); + exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); //Fourth window: 2000 to 4000 List> exp3 = new ArrayList<>(); - exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); + exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); //Fifth window: 3000 to 5000 List> exp4 = new ArrayList<>(); //Sixth window: 4000 to 6000 List> exp5 = new ArrayList<>(); - exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); + exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); //Seventh window: 5000 to 7000 List> exp6 = new ArrayList<>(); - exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); + exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); List>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp4, exp5, exp6); @@ -250,13 +250,13 @@ public class TestWindowFunctions extends BaseND4JTest { //Create some data. List> sequence = new ArrayList<>(); //First window: - sequence.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); - sequence.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); - sequence.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); - sequence.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); - sequence.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); - sequence.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); + sequence.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); + sequence.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); + sequence.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); + sequence.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); + sequence.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); + sequence.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); + sequence.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) @@ -272,31 +272,31 @@ public class TestWindowFunctions extends BaseND4JTest { //First window: -1000 to 1000 List> exp0 = new ArrayList<>(); - exp0.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); - exp0.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); - exp0.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); + exp0.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); + exp0.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); + exp0.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); //Second window: 0 to 2000 List> exp1 = new ArrayList<>(); - exp1.add(Arrays.asList((Writable) new LongWritable(0), new IntWritable(0))); - exp1.add(Arrays.asList((Writable) new LongWritable(100), new IntWritable(1))); - exp1.add(Arrays.asList((Writable) new LongWritable(200), new IntWritable(2))); - exp1.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); - exp1.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); + exp1.add(Arrays.asList(new LongWritable(0), new IntWritable(0))); + exp1.add(Arrays.asList(new LongWritable(100), new IntWritable(1))); + exp1.add(Arrays.asList(new LongWritable(200), new IntWritable(2))); + exp1.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); + exp1.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); //Third window: 1000 to 3000 List> exp2 = new ArrayList<>(); - exp2.add(Arrays.asList((Writable) new LongWritable(1000), new IntWritable(3))); - exp2.add(Arrays.asList((Writable) new LongWritable(1500), new IntWritable(4))); - exp2.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); + exp2.add(Arrays.asList(new LongWritable(1000), new IntWritable(3))); + exp2.add(Arrays.asList(new LongWritable(1500), new IntWritable(4))); + exp2.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); //Fourth window: 2000 to 4000 List> exp3 = new ArrayList<>(); - exp3.add(Arrays.asList((Writable) new LongWritable(2000), new IntWritable(5))); + exp3.add(Arrays.asList(new LongWritable(2000), new IntWritable(5))); //Fifth window: 3000 to 5000 -> Empty: excluded //Sixth window: 4000 to 6000 List> exp5 = new ArrayList<>(); - exp5.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); + exp5.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); //Seventh window: 5000 to 7000 List> exp6 = new ArrayList<>(); - exp6.add(Arrays.asList((Writable) new LongWritable(5000), new IntWritable(7))); + exp6.add(Arrays.asList(new LongWritable(5000), new IntWritable(7))); List>> windowsExp = Arrays.asList(exp0, exp1, exp2, exp3, exp5, exp6); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java index f7eaa85ad..13730f588 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java @@ -37,9 +37,9 @@ public class TestReduce extends BaseND4JTest { public void testReducerDouble() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); - inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); - inputs.add(Arrays.asList((Writable) new Text("1"), new Text("2"))); + inputs.add(Arrays.asList(new Text("1"), new Text("2"))); + inputs.add(Arrays.asList(new Text("1"), new Text("2"))); + inputs.add(Arrays.asList(new Text("1"), new Text("2"))); Map exp = new LinkedHashMap<>(); exp.put(StringReduceOp.MERGE, "12"); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index c0468b916..9a6f8b1d3 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -122,9 +122,9 @@ public class TestTransforms extends BaseND4JTest { assertNotNull(meta.getMaxAllowedValue()); assertEquals(2, (int) meta.getMaxAllowedValue()); - assertEquals(0, transform.map(Collections.singletonList((Writable) new Text("zero"))).get(0).toInt()); - assertEquals(1, transform.map(Collections.singletonList((Writable) new Text("one"))).get(0).toInt()); - assertEquals(2, transform.map(Collections.singletonList((Writable) new Text("two"))).get(0).toInt()); + assertEquals(0, transform.map(Collections.singletonList(new Text("zero"))).get(0).toInt()); + assertEquals(1, transform.map(Collections.singletonList(new Text("one"))).get(0).toInt()); + assertEquals(2, transform.map(Collections.singletonList(new Text("two"))).get(0).toInt()); } @Test @@ -147,11 +147,11 @@ public class TestTransforms extends BaseND4JTest { } assertEquals(Arrays.asList(new IntWritable(1), new IntWritable(0), new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new Text("zero")))); + transform.map(Collections.singletonList(new Text("zero")))); assertEquals(Arrays.asList(new IntWritable(0), new IntWritable(1), new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new Text("one")))); + transform.map(Collections.singletonList(new Text("one")))); assertEquals(Arrays.asList(new IntWritable(0), new IntWritable(0), new IntWritable(1)), - transform.map(Collections.singletonList((Writable) new Text("two")))); + transform.map(Collections.singletonList(new Text("two")))); } @Test @@ -177,16 +177,16 @@ public class TestTransforms extends BaseND4JTest { assertEquals(columnTypesExp, out.getColumnTypes()); //Expand (second,100) into (0,100,0). Leave the remaining columns as is - List e1 = Arrays.asList(new DoubleWritable(1), new DoubleWritable(0), new DoubleWritable(100), + List e1 = Arrays.asList(new DoubleWritable(1), new DoubleWritable(0), new DoubleWritable(100), new DoubleWritable(0), new DoubleWritable(-1)); - List a1 = t.map(Arrays.asList(new DoubleWritable(1), new Text("second"), new DoubleWritable(100), + List a1 = t.map(Arrays.asList(new DoubleWritable(1), new Text("second"), new DoubleWritable(100), new DoubleWritable(-1))); assertEquals(e1,a1); //Expand (third,200) into (0,0,200). Leave the remaining columns as is - List e2 = Arrays.asList(new DoubleWritable(1), new DoubleWritable(0), new DoubleWritable(0), + List e2 = Arrays.asList(new DoubleWritable(1), new DoubleWritable(0), new DoubleWritable(0), new DoubleWritable(200), new DoubleWritable(-1)); - List a2 = t.map(Arrays.asList(new DoubleWritable(1), new Text("third"), new DoubleWritable(200), + List a2 = t.map(Arrays.asList(new DoubleWritable(1), new Text("third"), new DoubleWritable(200), new DoubleWritable(-1))); assertEquals(e2,a2); } @@ -205,11 +205,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList("zero", "one", "two"), meta.getStateNames()); assertEquals(Collections.singletonList((Writable) new Text("zero")), - transform.map(Collections.singletonList((Writable) new IntWritable(0)))); + transform.map(Collections.singletonList(new IntWritable(0)))); assertEquals(Collections.singletonList((Writable) new Text("one")), - transform.map(Collections.singletonList((Writable) new IntWritable(1)))); + transform.map(Collections.singletonList(new IntWritable(1)))); assertEquals(Collections.singletonList((Writable) new Text("two")), - transform.map(Collections.singletonList((Writable) new IntWritable(2)))); + transform.map(Collections.singletonList(new IntWritable(2)))); } @Test @@ -228,11 +228,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList("column[3]", "column[4]", "column[5]"), out.getColumnNames()); assertEquals(Arrays.asList(new IntWritable(1), new IntWritable(0), new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(3)))); + transform.map(Collections.singletonList(new IntWritable(3)))); assertEquals(Arrays.asList(new IntWritable(0), new IntWritable(1), new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(4)))); + transform.map(Collections.singletonList(new IntWritable(4)))); assertEquals(Arrays.asList(new IntWritable(0), new IntWritable(0), new IntWritable(1)), - transform.map(Collections.singletonList((Writable) new IntWritable(5)))); + transform.map(Collections.singletonList(new IntWritable(5)))); } @Test @@ -249,11 +249,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList("zero", "one", "two"), meta.getStateNames()); assertEquals(Collections.singletonList((Writable) new Text("zero")), - transform.map(Collections.singletonList((Writable) new Text("zero")))); + transform.map(Collections.singletonList(new Text("zero")))); assertEquals(Collections.singletonList((Writable) new Text("one")), - transform.map(Collections.singletonList((Writable) new Text("one")))); + transform.map(Collections.singletonList(new Text("one")))); assertEquals(Collections.singletonList((Writable) new Text("two")), - transform.map(Collections.singletonList((Writable) new Text("two")))); + transform.map(Collections.singletonList(new Text("two")))); } @Test @@ -350,7 +350,7 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.Integer, out.getMetaData(1).getColumnType()); assertEquals(Arrays.asList(new Text("one"), new IntWritable(1)), - transform.map(Arrays.asList((Writable) new DoubleWritable(1.0), new Text("one"), + transform.map(Arrays.asList(new DoubleWritable(1.0), new Text("one"), new IntWritable(1), new LongWritable(1L)))); } @@ -369,7 +369,7 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.Integer, out.getMetaData(1).getColumnType()); assertEquals(Arrays.asList(new Text("one"), new IntWritable(1)), - transform.map(Arrays.asList((Writable) new DoubleWritable(1.0), new Text("one"), + transform.map(Arrays.asList(new DoubleWritable(1.0), new Text("one"), new IntWritable(1), new LongWritable(1L)))); } @@ -386,11 +386,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.Integer, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(0)))); + transform.map(Collections.singletonList(new IntWritable(0)))); assertEquals(Collections.singletonList((Writable) new IntWritable(1)), - transform.map(Collections.singletonList((Writable) new IntWritable(1)))); + transform.map(Collections.singletonList(new IntWritable(1)))); assertEquals(Collections.singletonList((Writable) new IntWritable(1000)), - transform.map(Collections.singletonList((Writable) new Text("")))); + transform.map(Collections.singletonList(new Text("")))); } @Test @@ -405,11 +405,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.Integer, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(0)))); + transform.map(Collections.singletonList(new IntWritable(0)))); assertEquals(Collections.singletonList((Writable) new IntWritable(1)), - transform.map(Collections.singletonList((Writable) new IntWritable(1)))); + transform.map(Collections.singletonList(new IntWritable(1)))); assertEquals(Collections.singletonList((Writable) new IntWritable(1000)), - transform.map(Collections.singletonList((Writable) new Text("")))); + transform.map(Collections.singletonList(new Text("")))); } @Test @@ -434,13 +434,13 @@ public class TestTransforms extends BaseND4JTest { double loge2 = Math.log(2); assertEquals(0.0, - transform.map(Collections.singletonList((Writable) new DoubleWritable(min))).get(0).toDouble(), + transform.map(Collections.singletonList(new DoubleWritable(min))).get(0).toDouble(), 1e-6); double d = scale * Math.log((10 - min) / (mu - min) + 1) / loge2; - assertEquals(d, transform.map(Collections.singletonList((Writable) new DoubleWritable(10))).get(0).toDouble(), + assertEquals(d, transform.map(Collections.singletonList(new DoubleWritable(10))).get(0).toDouble(), 1e-6); d = scale * Math.log((3 - min) / (mu - min) + 1) / loge2; - assertEquals(d, transform.map(Collections.singletonList((Writable) new DoubleWritable(3))).get(0).toDouble(), + assertEquals(d, transform.map(Collections.singletonList(new DoubleWritable(3))).get(0).toDouble(), 1e-6); } @@ -466,22 +466,22 @@ public class TestTransforms extends BaseND4JTest { assertEquals(1, meta2.getMaxAllowedValue(), 1e-6); - assertEquals(0.0, transform.map(Collections.singletonList((Writable) new DoubleWritable(0))).get(0).toDouble(), + assertEquals(0.0, transform.map(Collections.singletonList(new DoubleWritable(0))).get(0).toDouble(), 1e-6); assertEquals(1.0, - transform.map(Collections.singletonList((Writable) new DoubleWritable(100))).get(0).toDouble(), + transform.map(Collections.singletonList(new DoubleWritable(100))).get(0).toDouble(), 1e-6); - assertEquals(0.5, transform.map(Collections.singletonList((Writable) new DoubleWritable(50))).get(0).toDouble(), + assertEquals(0.5, transform.map(Collections.singletonList(new DoubleWritable(50))).get(0).toDouble(), 1e-6); assertEquals(-1.0, - transform2.map(Collections.singletonList((Writable) new DoubleWritable(0))).get(0).toDouble(), + transform2.map(Collections.singletonList(new DoubleWritable(0))).get(0).toDouble(), 1e-6); assertEquals(1.0, - transform2.map(Collections.singletonList((Writable) new DoubleWritable(100))).get(0).toDouble(), + transform2.map(Collections.singletonList(new DoubleWritable(100))).get(0).toDouble(), 1e-6); assertEquals(0.0, - transform2.map(Collections.singletonList((Writable) new DoubleWritable(50))).get(0).toDouble(), + transform2.map(Collections.singletonList(new DoubleWritable(50))).get(0).toDouble(), 1e-6); } @@ -504,13 +504,13 @@ public class TestTransforms extends BaseND4JTest { assertNull(meta.getMaxAllowedValue()); - assertEquals(0.0, transform.map(Collections.singletonList((Writable) new DoubleWritable(mu))).get(0).toDouble(), + assertEquals(0.0, transform.map(Collections.singletonList(new DoubleWritable(mu))).get(0).toDouble(), 1e-6); double d = (10 - mu) / sigma; - assertEquals(d, transform.map(Collections.singletonList((Writable) new DoubleWritable(10))).get(0).toDouble(), + assertEquals(d, transform.map(Collections.singletonList(new DoubleWritable(10))).get(0).toDouble(), 1e-6); d = (-2 - mu) / sigma; - assertEquals(d, transform.map(Collections.singletonList((Writable) new DoubleWritable(-2))).get(0).toDouble(), + assertEquals(d, transform.map(Collections.singletonList(new DoubleWritable(-2))).get(0).toDouble(), 1e-6); } @@ -532,10 +532,10 @@ public class TestTransforms extends BaseND4JTest { assertNull(meta.getMaxAllowedValue()); - assertEquals(0.0, transform.map(Collections.singletonList((Writable) new DoubleWritable(mu))).get(0).toDouble(), + assertEquals(0.0, transform.map(Collections.singletonList(new DoubleWritable(mu))).get(0).toDouble(), 1e-6); assertEquals(10 - mu, - transform.map(Collections.singletonList((Writable) new DoubleWritable(10))).get(0).toDouble(), + transform.map(Collections.singletonList(new DoubleWritable(10))).get(0).toDouble(), 1e-6); } @@ -552,11 +552,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new Text("one")), - transform.map(Collections.singletonList((Writable) new Text("one")))); + transform.map(Collections.singletonList(new Text("one")))); assertEquals(Collections.singletonList((Writable) new Text("two")), - transform.map(Collections.singletonList((Writable) new Text("two")))); + transform.map(Collections.singletonList(new Text("two")))); assertEquals(Collections.singletonList((Writable) new Text("replacement")), - transform.map(Collections.singletonList((Writable) new Text("this should be replaced")))); + transform.map(Collections.singletonList(new Text("this should be replaced")))); } @Test @@ -571,13 +571,13 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new Text("one")), - transform.map(Collections.singletonList((Writable) new Text("one ")))); + transform.map(Collections.singletonList(new Text("one ")))); assertEquals(Collections.singletonList((Writable) new Text("two")), - transform.map(Collections.singletonList((Writable) new Text("two\t")))); + transform.map(Collections.singletonList(new Text("two\t")))); assertEquals(Collections.singletonList((Writable) new Text("three")), - transform.map(Collections.singletonList((Writable) new Text("three\n")))); + transform.map(Collections.singletonList(new Text("three\n")))); assertEquals(Collections.singletonList((Writable) new Text("one")), - transform.map(Collections.singletonList((Writable) new Text(" o n e\t")))); + transform.map(Collections.singletonList(new Text(" o n e\t")))); } @Test @@ -592,11 +592,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new Text("one")), - transform.map(Collections.singletonList((Writable) new Text("one")))); + transform.map(Collections.singletonList(new Text("one")))); assertEquals(Collections.singletonList((Writable) new Text("newvalue")), - transform.map(Collections.singletonList((Writable) new Text("")))); + transform.map(Collections.singletonList(new Text("")))); assertEquals(Collections.singletonList((Writable) new Text("three")), - transform.map(Collections.singletonList((Writable) new Text("three")))); + transform.map(Collections.singletonList(new Text("three")))); } @Test @@ -611,11 +611,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new Text("one_AppendThis")), - transform.map(Collections.singletonList((Writable) new Text("one")))); + transform.map(Collections.singletonList(new Text("one")))); assertEquals(Collections.singletonList((Writable) new Text("two_AppendThis")), - transform.map(Collections.singletonList((Writable) new Text("two")))); + transform.map(Collections.singletonList(new Text("two")))); assertEquals(Collections.singletonList((Writable) new Text("three_AppendThis")), - transform.map(Collections.singletonList((Writable) new Text("three")))); + transform.map(Collections.singletonList(new Text("three")))); } @Test @@ -637,17 +637,17 @@ public class TestTransforms extends BaseND4JTest { } assertEquals(Arrays.asList(new Text("false"), new Text("false"), new Text("false")), - transform.map(Collections.singletonList((Writable) new Text("")))); + transform.map(Collections.singletonList(new Text("")))); assertEquals(Arrays.asList(new Text("true"), new Text("false"), new Text("false")), - transform.map(Collections.singletonList((Writable) new Text("a")))); + transform.map(Collections.singletonList(new Text("a")))); assertEquals(Arrays.asList(new Text("false"), new Text("true"), new Text("false")), - transform.map(Collections.singletonList((Writable) new Text("b")))); + transform.map(Collections.singletonList(new Text("b")))); assertEquals(Arrays.asList(new Text("false"), new Text("false"), new Text("true")), - transform.map(Collections.singletonList((Writable) new Text("c")))); + transform.map(Collections.singletonList(new Text("c")))); assertEquals(Arrays.asList(new Text("true"), new Text("false"), new Text("true")), - transform.map(Collections.singletonList((Writable) new Text("a,c")))); + transform.map(Collections.singletonList(new Text("a,c")))); assertEquals(Arrays.asList(new Text("true"), new Text("true"), new Text("true")), - transform.map(Collections.singletonList((Writable) new Text("a,b,c")))); + transform.map(Collections.singletonList(new Text("a,b,c")))); } @Test @@ -665,11 +665,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new Text("ONE")), - transform.map(Collections.singletonList((Writable) new Text("one")))); + transform.map(Collections.singletonList(new Text("one")))); assertEquals(Collections.singletonList((Writable) new Text("TWO")), - transform.map(Collections.singletonList((Writable) new Text("two")))); + transform.map(Collections.singletonList(new Text("two")))); assertEquals(Collections.singletonList((Writable) new Text("three")), - transform.map(Collections.singletonList((Writable) new Text("three")))); + transform.map(Collections.singletonList(new Text("three")))); } @@ -721,9 +721,9 @@ public class TestTransforms extends BaseND4JTest { long out2 = 1435708799000L; assertEquals(Collections.singletonList((Writable) new LongWritable(out1)), - transform.map(Collections.singletonList((Writable) new Text(in1)))); + transform.map(Collections.singletonList(new Text(in1)))); assertEquals(Collections.singletonList((Writable) new LongWritable(out2)), - transform.map(Collections.singletonList((Writable) new Text(in2)))); + transform.map(Collections.singletonList(new Text(in2)))); //Check serialization: things like DateTimeFormatter etc aren't serializable, hence we need custom serialization :/ ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -737,9 +737,9 @@ public class TestTransforms extends BaseND4JTest { Transform deserialized = (Transform) ois.readObject(); assertEquals(Collections.singletonList((Writable) new LongWritable(out1)), - deserialized.map(Collections.singletonList((Writable) new Text(in1)))); + deserialized.map(Collections.singletonList(new Text(in1)))); assertEquals(Collections.singletonList((Writable) new LongWritable(out2)), - deserialized.map(Collections.singletonList((Writable) new Text(in2)))); + deserialized.map(Collections.singletonList(new Text(in2)))); } @@ -792,9 +792,9 @@ public class TestTransforms extends BaseND4JTest { out2.add(new Text("2015-06-30 23:59:59")); assertEquals(out1, - transform.map(Arrays.asList((Writable) new LongWritable(in1), new Text("otherColumnValue")))); + transform.map(Arrays.asList(new LongWritable(in1), new Text("otherColumnValue")))); assertEquals(out2, - transform.map(Arrays.asList((Writable) new LongWritable(in2), new Text("otherColumnValue")))); + transform.map(Arrays.asList(new LongWritable(in2), new Text("otherColumnValue")))); @@ -810,9 +810,9 @@ public class TestTransforms extends BaseND4JTest { Transform deserialized = (Transform) ois.readObject(); assertEquals(out1, deserialized - .map(Arrays.asList((Writable) new LongWritable(in1), new Text("otherColumnValue")))); + .map(Arrays.asList(new LongWritable(in1), new Text("otherColumnValue")))); assertEquals(out2, deserialized - .map(Arrays.asList((Writable) new LongWritable(in2), new Text("otherColumnValue")))); + .map(Arrays.asList(new LongWritable(in2), new Text("otherColumnValue")))); } @@ -839,8 +839,8 @@ public class TestTransforms extends BaseND4JTest { assertEquals(expOutTypes.get(i), out.getType(i)); } - List inList = Arrays.asList((Writable) new Text("one"), new IntWritable(2), new LongWritable(3L)); - List outList = Arrays.asList((Writable) new Text("one"), new IntWritable(2), new IntWritable(2), + List inList = Arrays.asList(new Text("one"), new IntWritable(2), new LongWritable(3L)); + List outList = Arrays.asList(new Text("one"), new IntWritable(2), new IntWritable(2), new LongWritable(3L), new LongWritable(3L)); assertEquals(outList, transform.map(inList)); @@ -861,11 +861,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(5, (int) meta.getMaxAllowedValue()); assertEquals(Collections.singletonList((Writable) new IntWritable(-5)), - transform.map(Collections.singletonList((Writable) new IntWritable(-1)))); + transform.map(Collections.singletonList(new IntWritable(-1)))); assertEquals(Collections.singletonList((Writable) new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(0)))); + transform.map(Collections.singletonList(new IntWritable(0)))); assertEquals(Collections.singletonList((Writable) new IntWritable(5)), - transform.map(Collections.singletonList((Writable) new IntWritable(1)))); + transform.map(Collections.singletonList(new IntWritable(1)))); } @Test @@ -885,11 +885,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList((Writable) new IntWritable(1), new Text("something"), new IntWritable(2), new IntWritable(3)), - transform.map(Arrays.asList((Writable) new IntWritable(1), new Text("something"), + transform.map(Arrays.asList(new IntWritable(1), new Text("something"), new IntWritable(2)))); assertEquals(Arrays.asList((Writable) new IntWritable(100), new Text("something2"), new IntWritable(21), new IntWritable(121)), - transform.map(Arrays.asList((Writable) new IntWritable(100), new Text("something2"), + transform.map(Arrays.asList(new IntWritable(100), new Text("something2"), new IntWritable(21)))); } @@ -908,11 +908,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(5, (long) meta.getMaxAllowedValue()); assertEquals(Collections.singletonList((Writable) new LongWritable(-5)), - transform.map(Collections.singletonList((Writable) new LongWritable(-1)))); + transform.map(Collections.singletonList(new LongWritable(-1)))); assertEquals(Collections.singletonList((Writable) new LongWritable(0)), - transform.map(Collections.singletonList((Writable) new LongWritable(0)))); + transform.map(Collections.singletonList(new LongWritable(0)))); assertEquals(Collections.singletonList((Writable) new LongWritable(5)), - transform.map(Collections.singletonList((Writable) new LongWritable(1)))); + transform.map(Collections.singletonList(new LongWritable(1)))); } @Test @@ -932,11 +932,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList((Writable) new LongWritable(1), new Text("something"), new LongWritable(2), new LongWritable(3)), - transform.map(Arrays.asList((Writable) new LongWritable(1), new Text("something"), + transform.map(Arrays.asList(new LongWritable(1), new Text("something"), new LongWritable(2)))); assertEquals(Arrays.asList((Writable) new LongWritable(100), new Text("something2"), new LongWritable(21), new LongWritable(121)), - transform.map(Arrays.asList((Writable) new LongWritable(100), new Text("something2"), + transform.map(Arrays.asList(new LongWritable(100), new Text("something2"), new LongWritable(21)))); } @@ -952,9 +952,9 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.Time, out.getType(0)); assertEquals(Collections.singletonList((Writable) new LongWritable(1000 + 43200000)), - transform.map(Collections.singletonList((Writable) new LongWritable(1000)))); + transform.map(Collections.singletonList(new LongWritable(1000)))); assertEquals(Collections.singletonList((Writable) new LongWritable(1452441600000L + 43200000)), - transform.map(Collections.singletonList((Writable) new LongWritable(1452441600000L)))); + transform.map(Collections.singletonList(new LongWritable(1452441600000L)))); } @Test @@ -972,11 +972,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(5.0, meta.getMaxAllowedValue(), 1e-6); assertEquals(Collections.singletonList((Writable) new DoubleWritable(-5)), - transform.map(Collections.singletonList((Writable) new DoubleWritable(-1)))); + transform.map(Collections.singletonList(new DoubleWritable(-1)))); assertEquals(Collections.singletonList((Writable) new DoubleWritable(0)), - transform.map(Collections.singletonList((Writable) new DoubleWritable(0)))); + transform.map(Collections.singletonList(new DoubleWritable(0)))); assertEquals(Collections.singletonList((Writable) new DoubleWritable(5)), - transform.map(Collections.singletonList((Writable) new DoubleWritable(1)))); + transform.map(Collections.singletonList(new DoubleWritable(1)))); } @Test @@ -992,11 +992,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.String, out.getType(1)); assertEquals(Arrays.asList(new DoubleWritable(Math.sin(1)), new Text("0")), - transform.map(Arrays.asList(new DoubleWritable(1), new Text("0")))); + transform.map(Arrays.asList(new DoubleWritable(1), new Text("0")))); assertEquals(Arrays.asList(new DoubleWritable(Math.sin(2)), new Text("1")), - transform.map(Arrays.asList(new DoubleWritable(2), new Text("1")))); + transform.map(Arrays.asList(new DoubleWritable(2), new Text("1")))); assertEquals(Arrays.asList(new DoubleWritable(Math.sin(3)), new Text("2")), - transform.map(Arrays.asList(new DoubleWritable(3), new Text("2")))); + transform.map(Arrays.asList(new DoubleWritable(3), new Text("2")))); } @Test @@ -1016,11 +1016,11 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList((Writable) new Text("something"), new DoubleWritable(1.0), new DoubleWritable(2.1), new DoubleWritable(3.1)), - transform.map(Arrays.asList((Writable) new Text("something"), new DoubleWritable(1.0), + transform.map(Arrays.asList(new Text("something"), new DoubleWritable(1.0), new DoubleWritable(2.1)))); assertEquals(Arrays.asList((Writable) new Text("something2"), new DoubleWritable(100.0), new DoubleWritable(21.1), new DoubleWritable(121.1)), - transform.map(Arrays.asList((Writable) new Text("something2"), new DoubleWritable(100.0), + transform.map(Arrays.asList(new Text("something2"), new DoubleWritable(100.0), new DoubleWritable(21.1)))); } @@ -1061,10 +1061,10 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList(ColumnType.Integer, ColumnType.String, ColumnType.Double), out.getColumnTypes()); assertEquals(Arrays.asList((Writable) new IntWritable(1), new Text("one"), new DoubleWritable(1.1)), transform - .map(Arrays.asList((Writable) new DoubleWritable(1.1), new Text("one"), new IntWritable(1)))); + .map(Arrays.asList(new DoubleWritable(1.1), new Text("one"), new IntWritable(1)))); assertEquals(Arrays.asList((Writable) new IntWritable(2), new Text("two"), new DoubleWritable(200.2)), transform - .map(Arrays.asList((Writable) new DoubleWritable(200.2), new Text("two"), new IntWritable(2)))); + .map(Arrays.asList(new DoubleWritable(200.2), new Text("two"), new IntWritable(2)))); } @Test @@ -1078,15 +1078,15 @@ public class TestTransforms extends BaseND4JTest { transform.setInputSchema(schema); assertEquals(Collections.singletonList((Writable) new IntWritable(10)), - transform.map(Collections.singletonList((Writable) new IntWritable(10)))); + transform.map(Collections.singletonList(new IntWritable(10)))); assertEquals(Collections.singletonList((Writable) new IntWritable(1)), - transform.map(Collections.singletonList((Writable) new IntWritable(1)))); + transform.map(Collections.singletonList(new IntWritable(1)))); assertEquals(Collections.singletonList((Writable) new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(0)))); + transform.map(Collections.singletonList(new IntWritable(0)))); assertEquals(Collections.singletonList((Writable) new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(-1)))); + transform.map(Collections.singletonList(new IntWritable(-1)))); assertEquals(Collections.singletonList((Writable) new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(-10)))); + transform.map(Collections.singletonList(new IntWritable(-10)))); } @Test @@ -1100,15 +1100,15 @@ public class TestTransforms extends BaseND4JTest { transform.setInputSchema(schema); assertEquals(Collections.singletonList((Writable) new IntWritable(1)), - transform.map(Collections.singletonList((Writable) new IntWritable(10)))); + transform.map(Collections.singletonList(new IntWritable(10)))); assertEquals(Collections.singletonList((Writable) new IntWritable(1)), - transform.map(Collections.singletonList((Writable) new IntWritable(1)))); + transform.map(Collections.singletonList(new IntWritable(1)))); assertEquals(Collections.singletonList((Writable) new IntWritable(1)), - transform.map(Collections.singletonList((Writable) new IntWritable(0)))); + transform.map(Collections.singletonList(new IntWritable(0)))); assertEquals(Collections.singletonList((Writable) new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(-1)))); + transform.map(Collections.singletonList(new IntWritable(-1)))); assertEquals(Collections.singletonList((Writable) new IntWritable(0)), - transform.map(Collections.singletonList((Writable) new IntWritable(-10)))); + transform.map(Collections.singletonList(new IntWritable(-10)))); } @Test @@ -1119,11 +1119,11 @@ public class TestTransforms extends BaseND4JTest { Transform transform = new ConditionalCopyValueTransform("third", "second", condition); transform.setInputSchema(schema); - List list = Arrays.asList((Writable) new Text("first"), new Text("second"), new Text("third")); + List list = Arrays.asList(new Text("first"), new Text("second"), new Text("third")); assertEquals(list, transform.map(list)); - list = Arrays.asList((Writable) new Text("first"), new Text("second"), new Text("")); - List exp = Arrays.asList((Writable) new Text("first"), new Text("second"), new Text("second")); + list = Arrays.asList(new Text("first"), new Text("second"), new Text("")); + List exp = Arrays.asList(new Text("first"), new Text("second"), new Text("second")); assertEquals(exp, transform.map(list)); } @@ -1133,10 +1133,10 @@ public class TestTransforms extends BaseND4JTest { .addColumnDouble("thirdCol").build(); List> sequence = new ArrayList<>(); - sequence.add(Arrays.asList(new Text("val0"), new IntWritable(10), new DoubleWritable(10))); - sequence.add(Arrays.asList(new Text("val1"), new IntWritable(15), new DoubleWritable(15))); - sequence.add(Arrays.asList(new Text("val2"), new IntWritable(25), new DoubleWritable(25))); - sequence.add(Arrays.asList(new Text("val3"), new IntWritable(40), new DoubleWritable(40))); + sequence.add(Arrays.asList(new Text("val0"), new IntWritable(10), new DoubleWritable(10))); + sequence.add(Arrays.asList(new Text("val1"), new IntWritable(15), new DoubleWritable(15))); + sequence.add(Arrays.asList(new Text("val2"), new IntWritable(25), new DoubleWritable(25))); + sequence.add(Arrays.asList(new Text("val3"), new IntWritable(40), new DoubleWritable(40))); Transform t = new SequenceDifferenceTransform("secondCol"); t.setInputSchema(schema); @@ -1144,10 +1144,10 @@ public class TestTransforms extends BaseND4JTest { List> out = t.mapSequence(sequence); List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new Text("val0"), new IntWritable(0), new DoubleWritable(10))); - expected.add(Arrays.asList(new Text("val1"), new IntWritable(15 - 10), new DoubleWritable(15))); - expected.add(Arrays.asList(new Text("val2"), new IntWritable(25 - 15), new DoubleWritable(25))); - expected.add(Arrays.asList(new Text("val3"), new IntWritable(40 - 25), new DoubleWritable(40))); + expected.add(Arrays.asList(new Text("val0"), new IntWritable(0), new DoubleWritable(10))); + expected.add(Arrays.asList(new Text("val1"), new IntWritable(15 - 10), new DoubleWritable(15))); + expected.add(Arrays.asList(new Text("val2"), new IntWritable(25 - 15), new DoubleWritable(25))); + expected.add(Arrays.asList(new Text("val3"), new IntWritable(40 - 25), new DoubleWritable(40))); assertEquals(expected, out); @@ -1160,10 +1160,10 @@ public class TestTransforms extends BaseND4JTest { assertEquals(outputSchema.getColumnNames(), Arrays.asList("firstCol", "secondCol", "newThirdColName")); expected = new ArrayList<>(); - expected.add(Arrays.asList(new Text("val0"), new IntWritable(10), NullWritable.INSTANCE)); - expected.add(Arrays.asList(new Text("val1"), new IntWritable(15), NullWritable.INSTANCE)); - expected.add(Arrays.asList(new Text("val2"), new IntWritable(25), new DoubleWritable(25 - 10))); - expected.add(Arrays.asList(new Text("val3"), new IntWritable(40), new DoubleWritable(40 - 15))); + expected.add(Arrays.asList(new Text("val0"), new IntWritable(10), NullWritable.INSTANCE)); + expected.add(Arrays.asList(new Text("val1"), new IntWritable(15), NullWritable.INSTANCE)); + expected.add(Arrays.asList(new Text("val2"), new IntWritable(25), new DoubleWritable(25 - 10))); + expected.add(Arrays.asList(new Text("val3"), new IntWritable(40), new DoubleWritable(40 - 15))); } @@ -1181,9 +1181,9 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList((Writable) new Text("something"), new DoubleWritable(1.0), new IntWritable(10)), - transform.map(Arrays.asList((Writable) new Text("something"), new DoubleWritable(1.0)))); + transform.map(Arrays.asList(new Text("something"), new DoubleWritable(1.0)))); assertEquals(Arrays.asList((Writable) new Text("something2"), new DoubleWritable(100.0), new IntWritable(10)), - transform.map(Arrays.asList((Writable) new Text("something2"), new DoubleWritable(100.0)))); + transform.map(Arrays.asList(new Text("something2"), new DoubleWritable(100.0)))); } @Test @@ -1202,7 +1202,7 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new Text("BoneConeTone")), - transform.map(Collections.singletonList((Writable) new Text("B1midT3")))); + transform.map(Collections.singletonList(new Text("B1midT3")))); // No link map = new HashMap<>(); @@ -1215,7 +1215,7 @@ public class TestTransforms extends BaseND4JTest { assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); assertEquals(Collections.singletonList((Writable) new Text("4.25")), - transform.map(Collections.singletonList((Writable) new Text(" 4.25 ")))); + transform.map(Collections.singletonList(new Text(" 4.25 ")))); } @Test @@ -1234,12 +1234,12 @@ public class TestTransforms extends BaseND4JTest { t.setInputSchema(schema); List> seq = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8))); List> exp = Collections.singletonList( - Arrays.asList(new DoubleWritable(3), new LongWritable(3L), new DoubleWritable(8))); + Arrays.asList(new DoubleWritable(3), new LongWritable(3L), new DoubleWritable(8))); List> act = t.mapSequence(seq); assertEquals(exp, act); @@ -1255,22 +1255,22 @@ public class TestTransforms extends BaseND4JTest { @Test public void testSequenceMovingWindowReduceTransform(){ List> seq = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); List> exp1 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5), new DoubleWritable((2+5)/2.0)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8), new DoubleWritable((2+5+8)/3.0)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11), new DoubleWritable((5+8+11)/3.0))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5), new DoubleWritable((2+5)/2.0)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8), new DoubleWritable((2+5+8)/3.0)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11), new DoubleWritable((5+8+11)/3.0))); List> exp2 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2), NullWritable.INSTANCE), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5), NullWritable.INSTANCE), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8), new DoubleWritable((2+5+8)/3.0)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11), new DoubleWritable((5+8+11)/3.0))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2), NullWritable.INSTANCE), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5), NullWritable.INSTANCE), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8), new DoubleWritable((2+5+8)/3.0)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11), new DoubleWritable((5+8+11)/3.0))); Schema schema = new SequenceSchema.Builder().addColumnsDouble("col%d",0,2).build(); Schema expOutSchema1 = new SequenceSchema.Builder().addColumnsDouble("col%d",0,2).addColumnDouble("mean(3,col2)").build(); @@ -1296,18 +1296,18 @@ public class TestTransforms extends BaseND4JTest { @Test public void testTrimSequenceTransform(){ List> seq = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); List> expTrimFirst = Arrays.asList( - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); List> expTrimLast = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5))); SequenceTrimTransform tFirst = new SequenceTrimTransform(2, true); SequenceTrimTransform tLast = new SequenceTrimTransform(2, false); @@ -1323,15 +1323,15 @@ public class TestTransforms extends BaseND4JTest { @Test public void testSequenceTrimToLengthTransform(){ List> seq = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); List> expTrimLength3 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8))); Schema s = new Schema.Builder() .addColumnsDouble("first", "second", "third") @@ -1346,8 +1346,8 @@ public class TestTransforms extends BaseND4JTest { List> seq2 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5))); out = p.executeSequence(seq2); assertEquals(seq2, out); @@ -1361,28 +1361,28 @@ public class TestTransforms extends BaseND4JTest { @Test public void testSequenceTrimToLengthTransformTrimOrPad(){ List> seq = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)), - Arrays.asList(new DoubleWritable(12), new DoubleWritable(13), new DoubleWritable(14))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11)), + Arrays.asList(new DoubleWritable(12), new DoubleWritable(13), new DoubleWritable(14))); List> seq2 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5))); List> expTrimLength4 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); Schema s = new Schema.Builder() .addColumnsDouble("first", "second", "third") .build(); TransformProcess p = new TransformProcess.Builder(s) - .trimOrPadSequenceToLength(4, Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902))) + .trimOrPadSequenceToLength(4, Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902))) .build(); List> out = p.executeSequence(seq); @@ -1390,10 +1390,10 @@ public class TestTransforms extends BaseND4JTest { List> exp2 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)), - Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902)), + Arrays.asList(new DoubleWritable(900), new DoubleWritable(901), new DoubleWritable(902))); out = p.executeSequence(seq2); assertEquals(exp2, out); @@ -1410,21 +1410,21 @@ public class TestTransforms extends BaseND4JTest { public void testSequenceOffsetTransform(){ List> seq = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(11))); Schema schema = new SequenceSchema.Builder().addColumnsDouble("col%d",0,2).build(); //First: test InPlace List> exp1 = Arrays.asList( - Arrays.asList(new DoubleWritable(6), new DoubleWritable(1), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(4), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(6), new DoubleWritable(1), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(4), new DoubleWritable(11))); List> exp2 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(7), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(10), new DoubleWritable(5))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(7), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(10), new DoubleWritable(5))); //In-place + trim SequenceOffsetTransform t_inplace_trim_p2 = new SequenceOffsetTransform(Collections.singletonList("col1"), @@ -1447,15 +1447,15 @@ public class TestTransforms extends BaseND4JTest { t_inplace_specified_m2.setInputSchema(schema); List> exp3 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), NullWritable.INSTANCE, new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), NullWritable.INSTANCE, new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(1), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(4), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), NullWritable.INSTANCE, new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), NullWritable.INSTANCE, new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(1), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(4), new DoubleWritable(11))); List> exp4 = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(7), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(10), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), NullWritable.INSTANCE, new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), NullWritable.INSTANCE, new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(7), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(10), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), NullWritable.INSTANCE, new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), NullWritable.INSTANCE, new DoubleWritable(11))); assertEquals(exp3, t_inplace_specified_p2.mapSequence(seq)); assertEquals(exp4, t_inplace_specified_m2.mapSequence(seq)); @@ -1465,12 +1465,12 @@ public class TestTransforms extends BaseND4JTest { //Second: test NewColumn List> exp1a = Arrays.asList( - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(1), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(4), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(1), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(4), new DoubleWritable(11))); List> exp2a = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(7), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(10), new DoubleWritable(5))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(7), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(10), new DoubleWritable(5))); SequenceOffsetTransform t_newcol_trim_p2 = new SequenceOffsetTransform(Collections.singletonList("col1"), 2, SequenceOffsetTransform.OperationType.NewColumn, SequenceOffsetTransform.EdgeHandling.TrimSequence, null); SequenceOffsetTransform t_newcol_trim_m2 = new SequenceOffsetTransform(Collections.singletonList("col1"), @@ -1482,15 +1482,15 @@ public class TestTransforms extends BaseND4JTest { assertEquals(exp2a, t_newcol_trim_m2.mapSequence(seq)); List> exp3a = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), NullWritable.INSTANCE, new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), NullWritable.INSTANCE, new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(1), new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(4), new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), NullWritable.INSTANCE, new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), NullWritable.INSTANCE, new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), new DoubleWritable(1), new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), new DoubleWritable(4), new DoubleWritable(11))); List> exp4a = Arrays.asList( - Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(7), new DoubleWritable(2)), - Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(10), new DoubleWritable(5)), - Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), NullWritable.INSTANCE, new DoubleWritable(8)), - Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), NullWritable.INSTANCE, new DoubleWritable(11))); + Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), new DoubleWritable(7), new DoubleWritable(2)), + Arrays.asList(new DoubleWritable(3), new DoubleWritable(4), new DoubleWritable(10), new DoubleWritable(5)), + Arrays.asList(new DoubleWritable(6), new DoubleWritable(7), NullWritable.INSTANCE, new DoubleWritable(8)), + Arrays.asList(new DoubleWritable(9), new DoubleWritable(10), NullWritable.INSTANCE, new DoubleWritable(11))); SequenceOffsetTransform t_newcol_specified_p2 = new SequenceOffsetTransform(Collections.singletonList("col1"), 2, SequenceOffsetTransform.OperationType.NewColumn, SequenceOffsetTransform.EdgeHandling.SpecifiedValue, NullWritable.INSTANCE); @@ -1519,7 +1519,7 @@ public class TestTransforms extends BaseND4JTest { Schema s = new Schema.Builder().addColumnString("inCol").build(); t.setInputSchema(s); - List l = Collections.singletonList(new Text("cat,cat,dog,dog,dog,unknown")); + List l = Collections.singletonList(new Text("cat,cat,dog,dog,dog,unknown")); List out = t.map(l); @@ -1541,7 +1541,7 @@ public class TestTransforms extends BaseND4JTest { Schema s = new Schema.Builder().addColumnString("inCol").build(); t.setInputSchema(s); - List l = Collections.singletonList(new Text("cat,dog,dog,dog,unknown")); + List l = Collections.singletonList(new Text("cat,dog,dog,dog,unknown")); List out = t.map(l); @@ -1559,8 +1559,8 @@ public class TestTransforms extends BaseND4JTest { Schema s = new Schema.Builder().addColumnString("col").addColumnDouble("d").build(); List> inSeq = Arrays.asList( - Arrays.asList(new Text("text"), new DoubleWritable(1.0)), - Arrays.asList(new Text("ab"), new DoubleWritable(2.0))); + Arrays.asList(new Text("text"), new DoubleWritable(1.0)), + Arrays.asList(new Text("ab"), new DoubleWritable(2.0))); Map map = new HashMap<>(); map.put('a', 0); @@ -1570,12 +1570,12 @@ public class TestTransforms extends BaseND4JTest { map.put('x', 4); List> exp = Arrays.asList( - Arrays.asList(new IntWritable(3), new DoubleWritable(1.0)), - Arrays.asList(new IntWritable(2), new DoubleWritable(1.0)), - Arrays.asList(new IntWritable(4), new DoubleWritable(1.0)), - Arrays.asList(new IntWritable(3), new DoubleWritable(1.0)), - Arrays.asList(new IntWritable(0), new DoubleWritable(2.0)), - Arrays.asList(new IntWritable(1), new DoubleWritable(2.0))); + Arrays.asList(new IntWritable(3), new DoubleWritable(1.0)), + Arrays.asList(new IntWritable(2), new DoubleWritable(1.0)), + Arrays.asList(new IntWritable(4), new DoubleWritable(1.0)), + Arrays.asList(new IntWritable(3), new DoubleWritable(1.0)), + Arrays.asList(new IntWritable(0), new DoubleWritable(2.0)), + Arrays.asList(new IntWritable(1), new DoubleWritable(2.0))); Transform t = new TextToCharacterIndexTransform("col", "newName", map, false); t.setInputSchema(s); @@ -1603,8 +1603,8 @@ public class TestTransforms extends BaseND4JTest { .build(); List vocab = Arrays.asList("zero", "one", "two", "three"); List> inSeq = Arrays.asList( - Arrays.asList(new Text("a"), new Text("zero four two"), new DoubleWritable(4.2)), - Arrays.asList(new Text("b"), new Text("six one two four three five"), new DoubleWritable(87.9))); + Arrays.asList(new Text("a"), new Text("zero four two"), new DoubleWritable(4.2)), + Arrays.asList(new Text("b"), new Text("six one two four three five"), new DoubleWritable(87.9))); Schema expSchema = new Schema.Builder() .addColumnString("ID") @@ -1612,11 +1612,11 @@ public class TestTransforms extends BaseND4JTest { .addColumnDouble("FEATURE") .build(); List> exp = Arrays.asList( - Arrays.asList(new Text("a"), new IntWritable(0), new DoubleWritable(4.2)), - Arrays.asList(new Text("a"), new IntWritable(2), new DoubleWritable(4.2)), - Arrays.asList(new Text("b"), new IntWritable(1), new DoubleWritable(87.9)), - Arrays.asList(new Text("b"), new IntWritable(2), new DoubleWritable(87.9)), - Arrays.asList(new Text("b"), new IntWritable(3), new DoubleWritable(87.9))); + Arrays.asList(new Text("a"), new IntWritable(0), new DoubleWritable(4.2)), + Arrays.asList(new Text("a"), new IntWritable(2), new DoubleWritable(4.2)), + Arrays.asList(new Text("b"), new IntWritable(1), new DoubleWritable(87.9)), + Arrays.asList(new Text("b"), new IntWritable(2), new DoubleWritable(87.9)), + Arrays.asList(new Text("b"), new IntWritable(3), new DoubleWritable(87.9))); Transform t = new TextToTermIndexSequenceTransform("TEXT", "INDEXSEQ", vocab, " ", false); t.setInputSchema(schema); @@ -1664,16 +1664,16 @@ public class TestTransforms extends BaseND4JTest { assertEquals(Arrays.asList(ColumnType.String, ColumnType.Double, ColumnType.Categorical, ColumnType.Categorical), s2.getColumnTypes()); List> in = Arrays.asList( - Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), - Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), - Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); + Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), + Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), + Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); List> expected = Arrays.asList( - Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("3"), new Text("8")), - Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("2"), new Text("7")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("1"), new Text("6")), - Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("2"), new Text("Other"))); + Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("3"), new Text("8")), + Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("2"), new Text("7")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("1"), new Text("6")), + Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("2"), new Text("Other"))); List> out = new ArrayList<>(); for(List i : in){ diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java index 8c4a44687..06c75574f 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java @@ -54,11 +54,11 @@ public class TestNDArrayWritableTransforms extends BaseND4JTest { TransformProcess tp = new TransformProcess.Builder(s).ndArrayScalarOpTransform("col1", MathOp.Add, 100).build(); - List in = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)), + List in = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)), new Text("str0")); List out = tp.execute(in); - List exp = Arrays.asList(new DoubleWritable(0), + List exp = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10).addi(100)), new Text("str0")); assertEquals(exp, out); @@ -81,12 +81,12 @@ public class TestNDArrayWritableTransforms extends BaseND4JTest { assertEquals(expColNames, tp.getFinalSchema().getColumnNames()); - List in = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)), + List in = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE))); List out = tp.execute(in); List exp = - Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)), + Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE,0, 10, 1).reshape(1,10)), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0).castTo(DataType.DOUBLE)), new NDArrayWritable(Nd4j.linspace(DataType.DOUBLE, 0, 10, 1).addi(2.0).reshape(1,10))); @@ -111,11 +111,11 @@ public class TestNDArrayWritableTransforms extends BaseND4JTest { assertEquals(expColNames, tp.getFinalSchema().getColumnNames()); - List in = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)), + List in = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Nd4j.linspace(0, 9, 10)), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 2.0))); List out = tp.execute(in); - List exp = Arrays.asList(new DoubleWritable(0), + List exp = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(Transforms.sin(Nd4j.linspace(0, 9, 10))), new NDArrayWritable(Transforms.sqrt(Nd4j.valueArrayOf(1, 10, 2.0)))); @@ -145,11 +145,11 @@ public class TestNDArrayWritableTransforms extends BaseND4JTest { INDArray arr2 = Nd4j.rand(1, 10); double cosine = Transforms.cosineSim(arr1, arr2); - List in = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(arr1.dup()), + List in = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(arr1.dup()), new NDArrayWritable(arr2.dup())); List out = tp.execute(in); - List exp = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(arr1), + List exp = Arrays.asList(new DoubleWritable(0), new NDArrayWritable(arr1), new NDArrayWritable(arr2), new DoubleWritable(cosine)); assertEquals(exp, out); diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java index e531d040f..f6f5ab4b0 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java @@ -28,6 +28,7 @@ import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -37,7 +38,7 @@ public class ParseDoubleTransformTest extends BaseND4JTest { public void testDoubleTransform() { List record = new ArrayList<>(); record.add(new Text("0.0")); - List transformed = Arrays.asList(new DoubleWritable(0.0)); + List transformed = Collections.singletonList(new DoubleWritable(0.0)); assertEquals(transformed, new ParseDoubleTransform().map(record)); } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java index 7b13b03d7..c774bc7cb 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java @@ -151,7 +151,7 @@ public class TestUI extends BaseND4JTest { List> sequence = new ArrayList<>(nSteps); for (int i = 0; i < nSteps; i++) { String c = "s" + i % 3; - sequence.add(Arrays.asList(new DoubleWritable(Math.sin(i / 10.0)), new Text(c), + sequence.add(Arrays.asList(new DoubleWritable(Math.sin(i / 10.0)), new Text(c), new Text(String.valueOf(i)))); } diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index ed9c01793..40d74cfcf 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -85,7 +85,7 @@ public class RecordConverterTest extends BaseND4JTest { @Test public void testNDArrayWritableConcat() { - List l = Arrays.asList(new DoubleWritable(1), + List l = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1, 3}, DataType.FLOAT)), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[]{6, 7, 8}, new long[]{1, 3}, DataType.FLOAT)), new IntWritable(9), new IntWritable(1)); @@ -99,8 +99,8 @@ public class RecordConverterTest extends BaseND4JTest { @Test public void testNDArrayWritableConcatToMatrix(){ - List l1 = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5)); - List l2 = Arrays.asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10)); + List l1 = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[]{2, 3, 4}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5)); + List l2 = Arrays.asList(new DoubleWritable(6), new NDArrayWritable(Nd4j.create(new double[]{7, 8, 9}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(10)); INDArray exp = Nd4j.create(new double[][]{ {1,2,3,4,5}, @@ -113,7 +113,7 @@ public class RecordConverterTest extends BaseND4JTest { @Test public void testToRecordWithListOfObject(){ - final List list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L); + final List list = Arrays.asList(3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L); final Schema schema = new Schema.Builder() .addColumnInteger("a") .addColumnFloat("b") diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java index 767742e4a..0d5acbbb4 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java @@ -47,7 +47,7 @@ public class WritableTest extends BaseND4JTest { assertEquals(new FloatWritable(1), new FloatWritable(1)); assertEquals(new Text("Hello"), new Text("Hello")); assertEquals(new BytesWritable("Hello".getBytes()),new BytesWritable("Hello".getBytes())); - INDArray ndArray = Nd4j.rand(new int[]{1, 100}); + INDArray ndArray = Nd4j.rand(1, 100); assertEquals(new NDArrayWritable(ndArray), new NDArrayWritable(ndArray)); assertEquals(new NullWritable(), new NullWritable()); @@ -61,7 +61,7 @@ public class WritableTest extends BaseND4JTest { public void testBytesWritableIndexing() { byte[] doubleWrite = new byte[16]; ByteBuffer wrapped = ByteBuffer.wrap(doubleWrite); - Buffer buffer = (Buffer) wrapped; + Buffer buffer = wrapped; wrapped.putDouble(1.0); wrapped.putDouble(2.0); buffer.rewind(); @@ -88,8 +88,8 @@ public class WritableTest extends BaseND4JTest { @Test public void testIntLongWritable() { - assertEquals(new IntWritable(1), new LongWritable(1l)); - assertEquals(new LongWritable(2l), new IntWritable(2)); + assertEquals(new IntWritable(1), new LongWritable(1L)); + assertEquals(new LongWritable(2L), new IntWritable(2)); long l = 1L << 34; // those would cast to the same Int @@ -134,8 +134,8 @@ public class WritableTest extends BaseND4JTest { for( int i=0; i<5; i++ ){ orig.get(0).add(Nd4j.rand(1,10)); - orig.get(1).add(Nd4j.rand(new int[]{1,5,6})); - orig.get(2).add(Nd4j.rand(new int[]{1,3,4,5})); + orig.get(1).add(Nd4j.rand(1,5,6)); + orig.get(2).add(Nd4j.rand(1,3,4,5)); } List> origByExample = new ArrayList<>(); //Outer list over examples, inner list over writables diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java index 0bf5637a8..9d88cdb1c 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java @@ -187,7 +187,7 @@ public class ArrowConverter { break; } - return Nd4j.create(buffer,new int[] {cols,1}); + return Nd4j.create(buffer, cols,1); } @@ -658,7 +658,7 @@ public class ArrowConverter { * @return the created vectors */ public static List toArrowColumnsStringSingle(final BufferAllocator bufferAllocator, final Schema schema, List dataVecRecord) { - return toArrowColumnsString(bufferAllocator,schema, Arrays.asList(dataVecRecord)); + return toArrowColumnsString(bufferAllocator,schema, Collections.singletonList(dataVecRecord)); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java index 32472e582..322de25a5 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java @@ -31,12 +31,13 @@ import org.datavec.arrow.ArrowConverter; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class ArrowRecordWriter implements RecordWriter { private Configuration configuration; - private Schema schema; + private final Schema schema; private Partitioner partitioner; public ArrowRecordWriter(Schema schema) { @@ -63,7 +64,7 @@ public class ArrowRecordWriter implements RecordWriter { @Override public PartitionMetaData write(List record) throws IOException { - return writeBatch(Arrays.asList(record)); + return writeBatch(Collections.singletonList(record)); } @Override diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java index a2c1902d3..3c52adc80 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java @@ -65,7 +65,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; @Slf4j public class ArrowConverterTest extends BaseND4JTest { - private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); + private static final BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); @TempDir public File testDir; @@ -80,7 +80,7 @@ public class ArrowConverterTest extends BaseND4JTest { int numRows = 4; List> ret = new ArrayList<>(numRows); for(int i = 0; i < numRows; i++) { - ret.add(Arrays.asList(new NDArrayWritable(Nd4j.linspace(1,4,4).reshape(1, 4)))); + ret.add(Collections.singletonList(new NDArrayWritable(Nd4j.linspace(1, 4, 4).reshape(1, 4)))); } List fieldVectors = ArrowConverter.toArrowColumns(bufferAllocator, schema, ret); @@ -144,7 +144,7 @@ public class ArrowConverterTest extends BaseND4JTest { List fieldVectors = ArrowConverter.toArrowColumnsStringSingle(bufferAllocator, schema.build(), single); List> records = ArrowConverter.toArrowWritables(fieldVectors, schema.build()); List> assertion = new ArrayList<>(); - assertion.add(Arrays.asList(new IntWritable(0),new IntWritable(1))); + assertion.add(Arrays.asList(new IntWritable(0),new IntWritable(1))); assertEquals(assertion,records); List> batch = new ArrayList<>(); @@ -156,8 +156,8 @@ public class ArrowConverterTest extends BaseND4JTest { List> batchRecords = ArrowConverter.toArrowWritables(fieldVectorsBatch, schema.build()); List> assertionBatch = new ArrayList<>(); - assertionBatch.add(Arrays.asList(new IntWritable(0),new IntWritable(0))); - assertionBatch.add(Arrays.asList(new IntWritable(1),new IntWritable(1))); + assertionBatch.add(Arrays.asList(new IntWritable(0),new IntWritable(0))); + assertionBatch.add(Arrays.asList(new IntWritable(1),new IntWritable(1))); assertEquals(assertionBatch,batchRecords); @@ -175,14 +175,14 @@ public class ArrowConverterTest extends BaseND4JTest { } List> input = Arrays.asList( - Arrays.asList(new LongWritable(0),new LongWritable(1)), - Arrays.asList(new LongWritable(2),new LongWritable(3)) + Arrays.asList(new LongWritable(0),new LongWritable(1)), + Arrays.asList(new LongWritable(2),new LongWritable(3)) ); List fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input); ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build()); - List assertion = Arrays.asList(new LongWritable(4), new LongWritable(5)); - writableRecordBatch.set(1, Arrays.asList(new LongWritable(4),new LongWritable(5))); + List assertion = Arrays.asList(new LongWritable(4), new LongWritable(5)); + writableRecordBatch.set(1, Arrays.asList(new LongWritable(4),new LongWritable(5))); List recordTest = writableRecordBatch.get(1); assertEquals(assertion,recordTest); } @@ -197,14 +197,14 @@ public class ArrowConverterTest extends BaseND4JTest { } List> input = Arrays.asList( - Arrays.asList(new IntWritable(0),new IntWritable(1)), - Arrays.asList(new IntWritable(2),new IntWritable(3)) + Arrays.asList(new IntWritable(0),new IntWritable(1)), + Arrays.asList(new IntWritable(2),new IntWritable(3)) ); List fieldVector = ArrowConverter.toArrowColumns(bufferAllocator,schema.build(),input); ArrowWritableRecordBatch writableRecordBatch = new ArrowWritableRecordBatch(fieldVector,schema.build()); - List assertion = Arrays.asList(new IntWritable(4), new IntWritable(5)); - writableRecordBatch.set(1, Arrays.asList(new IntWritable(4),new IntWritable(5))); + List assertion = Arrays.asList(new IntWritable(4), new IntWritable(5)); + writableRecordBatch.set(1, Arrays.asList(new IntWritable(4),new IntWritable(5))); List recordTest = writableRecordBatch.get(1); assertEquals(assertion,recordTest); } @@ -218,7 +218,7 @@ public class ArrowConverterTest extends BaseND4JTest { } for(int i = 0; i < 5; i++) { - List> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); + List> arr = Collections.singletonList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); entries.add(arr); } @@ -249,7 +249,7 @@ public class ArrowConverterTest extends BaseND4JTest { } for(int i = 0; i < 5; i++) { - List> arr = Arrays.asList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); + List> arr = Collections.singletonList(Arrays.asList(String.valueOf(i), String.valueOf(i), String.valueOf(i))); entries.add(arr); } @@ -266,7 +266,7 @@ public class ArrowConverterTest extends BaseND4JTest { File f = testDir; - File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID().toString() + ".arrorw"); + File tmpFile = new File(f, "tmp-arrow-file-" + UUID.randomUUID() + ".arrorw"); FileOutputStream outputStream = new FileOutputStream(tmpFile); tmpFile.deleteOnExit(); ArrowConverter.writeRecordBatchTo(recordsToWrite.getRight(),recordsToWrite.getFirst(),outputStream); @@ -302,7 +302,7 @@ public class ArrowConverterTest extends BaseND4JTest { assertEquals(matrix.rows(),vectors.size()); INDArray vector = Nd4j.linspace(1,4,4); - val vectors2 = ArrowConverter.convertToArrowVector(vector,Arrays.asList("test"), ColumnType.Double,bufferAllocator); + val vectors2 = ArrowConverter.convertToArrowVector(vector, Collections.singletonList("test"), ColumnType.Double,bufferAllocator); assertEquals(1,vectors2.size()); assertEquals(matrix.length(),vectors2.get(0).getValueCount()); @@ -440,7 +440,7 @@ public class ArrowConverterTest extends BaseND4JTest { File tmp = tmpDataFile(recordsToWrite); RecordReader recordReader = new ArrowRecordReader(); RecordMetaDataIndex recordMetaDataIndex = new RecordMetaDataIndex(0,tmp.toURI(),ArrowRecordReader.class); - recordReader.loadFromMetaData(Arrays.asList(recordMetaDataIndex)); + recordReader.loadFromMetaData(Collections.singletonList(recordMetaDataIndex)); Record record = recordReader.nextRecord(); assertEquals(2,record.getRecord().size()); @@ -474,7 +474,7 @@ public class ArrowConverterTest extends BaseND4JTest { File f = testDir; //send file - File tmp = new File(f,"tmp-file-" + UUID.randomUUID().toString()); + File tmp = new File(f,"tmp-file-" + UUID.randomUUID()); tmp.mkdirs(); File tmpFile = new File(tmp,"data.arrow"); tmpFile.deleteOnExit(); @@ -487,8 +487,8 @@ public class ArrowConverterTest extends BaseND4JTest { private Pair>> recordToWrite() { List> records = new ArrayList<>(); - records.add(Arrays.asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); - records.add(Arrays.asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); + records.add(Arrays.asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); + records.add(Arrays.asList(new DoubleWritable(0.0),new DoubleWritable(0.0))); Schema.Builder schemaBuilder = new Schema.Builder(); for(int i = 0; i < 2; i++) { schemaBuilder.addColumnFloat("col-" + i); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java index b39b88013..e3c1471fe 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java @@ -42,7 +42,7 @@ import static org.junit.jupiter.api.Assertions.assertFalse; public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { - private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); + private static final BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); @Test @@ -54,9 +54,9 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { List> timeStep = Arrays.asList( - Arrays.asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)), - Arrays.asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)), - Arrays.asList(new IntWritable(4),new IntWritable(5),new IntWritable(6)) + Arrays.asList(new IntWritable(0),new IntWritable(1),new IntWritable(2)), + Arrays.asList(new IntWritable(1),new IntWritable(2),new IntWritable(3)), + Arrays.asList(new IntWritable(4),new IntWritable(5),new IntWritable(6)) ); int numTimeSteps = 5; @@ -87,13 +87,13 @@ public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { .addColumnDouble("dbl"); List> firstSeq = Arrays.asList( - Arrays.asList(new Text("00"),new IntWritable(0),new DoubleWritable(2.0)), - Arrays.asList(new Text("01"),new IntWritable(1),new DoubleWritable(2.1)), - Arrays.asList(new Text("02"),new IntWritable(2),new DoubleWritable(2.2))); + Arrays.asList(new Text("00"),new IntWritable(0),new DoubleWritable(2.0)), + Arrays.asList(new Text("01"),new IntWritable(1),new DoubleWritable(2.1)), + Arrays.asList(new Text("02"),new IntWritable(2),new DoubleWritable(2.2))); List> secondSeq = Arrays.asList( - Arrays.asList(new Text("10"),new IntWritable(10),new DoubleWritable(12.0)), - Arrays.asList(new Text("11"),new IntWritable(11),new DoubleWritable(12.1))); + Arrays.asList(new Text("10"),new IntWritable(10),new DoubleWritable(12.0)), + Arrays.asList(new Text("11"),new IntWritable(11),new DoubleWritable(12.1))); List>> sequences = Arrays.asList(firstSeq, secondSeq); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/Wave.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/Wave.java index db75a546f..909b57601 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/Wave.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/Wave.java @@ -61,7 +61,7 @@ public class Wave implements Serializable { initWaveWithInputStream(inputStream); inputStream.close(); } catch (IOException e) { - System.out.println(e.toString()); + System.out.println(e); } } @@ -96,7 +96,7 @@ public class Wave implements Serializable { data = new byte[inputStream.available()]; inputStream.read(data); } catch (IOException e) { - System.err.println(e.toString()); + System.err.println(e); } // end load data } else { @@ -132,7 +132,7 @@ public class Wave implements Serializable { waveHeader.setSubChunk2Size(subChunk2Size); byte[] trimmedData = new byte[(int) subChunk2Size]; - System.arraycopy(data, (int) leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size); + System.arraycopy(data, leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size); data = trimmedData; } else { System.err.println("Trim error: Negative length"); @@ -303,10 +303,9 @@ public class Wave implements Serializable { } public String toString() { - StringBuilder sb = new StringBuilder(waveHeader.toString()); - sb.append("\n"); - sb.append("length: " + timestamp()); - return sb.toString(); + String sb = waveHeader.toString() + "\n" + + "length: " + timestamp(); + return sb; } public double[] getNormalizedAmplitudes() { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/WaveHeader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/WaveHeader.java index 3f7af014a..0c5e462a7 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/WaveHeader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/WaveHeader.java @@ -36,7 +36,7 @@ public class WaveHeader { public static final String DATA_HEADER = "data"; public static final int HEADER_BYTE_LENGTH = 44; // 44 bytes for header - private boolean valid; + private final boolean valid; private String chunkId; // 4 bytes private long chunkSize; // unsigned 4 bytes, little endian private String format; // 4 bytes @@ -82,7 +82,7 @@ public class WaveHeader { // little endian chunkSize = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | (long) (headerBuffer[pointer++] & 0xff) << 16 - | (long) (headerBuffer[pointer++] & 0xff << 24); + | (long) (headerBuffer[pointer++] & 0xffL << 24); format = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++]}); subChunk1Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], @@ -90,16 +90,16 @@ public class WaveHeader { subChunk1Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | (long) (headerBuffer[pointer++] & 0xff) << 16 | (long) (headerBuffer[pointer++] & 0xff) << 24; - audioFormat = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); - channels = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); + audioFormat = (headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8; + channels = (headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8; sampleRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | (long) (headerBuffer[pointer++] & 0xff) << 16 | (long) (headerBuffer[pointer++] & 0xff) << 24; byteRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 | (long) (headerBuffer[pointer++] & 0xff) << 16 | (long) (headerBuffer[pointer++] & 0xff) << 24; - blockAlign = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); - bitsPerSample = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); + blockAlign = (headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8; + bitsPerSample = (headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8; subChunk2Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++]}); subChunk2Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 @@ -122,7 +122,7 @@ public class WaveHeader { } // check the format is support - if (chunkId.toUpperCase().equals(RIFF_HEADER) && format.toUpperCase().equals(WAVE_HEADER) && audioFormat == 1) { + if (chunkId.equalsIgnoreCase(RIFF_HEADER) && format.equalsIgnoreCase(WAVE_HEADER) && audioFormat == 1) { return true; } else { System.err.println("WaveHeader: Unsupported header format"); @@ -197,7 +197,7 @@ public class WaveHeader { } this.sampleRate = sampleRate; - this.byteRate = sampleRate * bitsPerSample / 8; + this.byteRate = (long) sampleRate * bitsPerSample / 8; this.chunkSize = newSubChunk2Size + 36; this.subChunk2Size = newSubChunk2Size; } @@ -252,32 +252,31 @@ public class WaveHeader { public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("chunkId: " + chunkId); - sb.append("\n"); - sb.append("chunkSize: " + chunkSize); - sb.append("\n"); - sb.append("format: " + format); - sb.append("\n"); - sb.append("subChunk1Id: " + subChunk1Id); - sb.append("\n"); - sb.append("subChunk1Size: " + subChunk1Size); - sb.append("\n"); - sb.append("audioFormat: " + audioFormat); - sb.append("\n"); - sb.append("channels: " + channels); - sb.append("\n"); - sb.append("sampleRate: " + sampleRate); - sb.append("\n"); - sb.append("byteRate: " + byteRate); - sb.append("\n"); - sb.append("blockAlign: " + blockAlign); - sb.append("\n"); - sb.append("bitsPerSample: " + bitsPerSample); - sb.append("\n"); - sb.append("subChunk2Id: " + subChunk2Id); - sb.append("\n"); - sb.append("subChunk2Size: " + subChunk2Size); - return sb.toString(); + String sb = "chunkId: " + chunkId + + "\n" + + "chunkSize: " + chunkSize + + "\n" + + "format: " + format + + "\n" + + "subChunk1Id: " + subChunk1Id + + "\n" + + "subChunk1Size: " + subChunk1Size + + "\n" + + "audioFormat: " + audioFormat + + "\n" + + "channels: " + channels + + "\n" + + "sampleRate: " + sampleRate + + "\n" + + "byteRate: " + byteRate + + "\n" + + "blockAlign: " + blockAlign + + "\n" + + "bitsPerSample: " + bitsPerSample + + "\n" + + "subChunk2Id: " + subChunk2Id + + "\n" + + "subChunk2Size: " + subChunk2Size; + return sb; } } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/dsp/WindowFunction.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/dsp/WindowFunction.java index c0d7b6253..6e859aab1 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/dsp/WindowFunction.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/dsp/WindowFunction.java @@ -39,15 +39,15 @@ public class WindowFunction { } public void setWindowType(String w) { - if (w.toUpperCase().equals("RECTANGULAR")) + if (w.equalsIgnoreCase("RECTANGULAR")) windowType = RECTANGULAR; - if (w.toUpperCase().equals("BARTLETT")) + if (w.equalsIgnoreCase("BARTLETT")) windowType = BARTLETT; - if (w.toUpperCase().equals("HANNING")) + if (w.equalsIgnoreCase("HANNING")) windowType = HANNING; - if (w.toUpperCase().equals("HAMMING")) + if (w.equalsIgnoreCase("HAMMING")) windowType = HAMMING; - if (w.toUpperCase().equals("BLACKMAN")) + if (w.equalsIgnoreCase("BLACKMAN")) windowType = BLACKMAN; } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/extension/NormalizedSampleAmplitudes.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/extension/NormalizedSampleAmplitudes.java index 9a8eaba58..76381efa9 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/extension/NormalizedSampleAmplitudes.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/extension/NormalizedSampleAmplitudes.java @@ -26,7 +26,7 @@ import org.datavec.audio.Wave; */ public class NormalizedSampleAmplitudes { - private Wave wave; + private final Wave wave; private double[] normalizedAmplitudes; // normalizedAmplitudes[sampleNumber]=normalizedAmplitudeInTheFrame public NormalizedSampleAmplitudes(Wave wave) { @@ -43,12 +43,9 @@ public class NormalizedSampleAmplitudes { if (normalizedAmplitudes == null) { - boolean signed = true; + boolean signed = wave.getWaveHeader().getBitsPerSample() != 8; // usually 8bit is unsigned - if (wave.getWaveHeader().getBitsPerSample() == 8) { - signed = false; - } short[] amplitudes = wave.getSampleAmplitudes(); int numSamples = amplitudes.length; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/extension/Spectrogram.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/extension/Spectrogram.java index fdc680e1d..9d91f9b66 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/extension/Spectrogram.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/extension/Spectrogram.java @@ -31,11 +31,11 @@ public class Spectrogram { public static final int SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE = 1024; public static final int SPECTROGRAM_DEFAULT_OVERLAP_FACTOR = 0; // 0 for no overlapping - private Wave wave; + private final Wave wave; private double[][] spectrogram; // relative spectrogram private double[][] absoluteSpectrogram; // absolute spectrogram - private int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2 - private int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping + private final int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2 + private final int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping private int numFrames; // number of frames of the spectrogram private int framesPerSecond; // frame per second of the spectrogram private int numFrequencyUnit; // number of y-axis unit diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintManager.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintManager.java index efa481a91..38435166d 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintManager.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintManager.java @@ -42,11 +42,11 @@ import java.util.List; @Slf4j public class FingerprintManager { - private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); - private int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame(); - private int overlapFactor = fingerprintProperties.getOverlapFactor(); - private int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame(); - private int numFilterBanks = fingerprintProperties.getNumFilterBanks(); + private final FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); + private final int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame(); + private final int overlapFactor = fingerprintProperties.getOverlapFactor(); + private final int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame(); + private final int numFilterBanks = fingerprintProperties.getNumFilterBanks(); /** * Constructor diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarity.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarity.java index c76756310..6ca4335ec 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarity.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarity.java @@ -27,7 +27,7 @@ import org.datavec.audio.properties.FingerprintProperties; */ public class FingerprintSimilarity { - private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); + private final FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); private int mostSimilarFramePosition; private float score; private float similarity; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarityComputer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarityComputer.java index 222bb5e67..3f832a884 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarityComputer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarityComputer.java @@ -27,7 +27,7 @@ import java.util.List; */ public class FingerprintSimilarityComputer { - private FingerprintSimilarity fingerprintSimilarity; + private final FingerprintSimilarity fingerprintSimilarity; byte[] fingerprint1, fingerprint2; /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRank.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRank.java index 3378f5c09..9d31ccde9 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRank.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRank.java @@ -19,5 +19,5 @@ package org.datavec.audio.fingerprint; import java.util.List; public interface MapRank { - public List getOrderedKeyList(int numKeys, boolean sharpLimit); + List getOrderedKeyList(int numKeys, boolean sharpLimit); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankDouble.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankDouble.java index a24ba0959..376e6f361 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankDouble.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankDouble.java @@ -21,7 +21,7 @@ import java.util.Map.Entry; public class MapRankDouble implements MapRank { - private Map map; + private final Map map; private boolean acsending = true; public MapRankDouble(Map map, boolean acsending) { @@ -95,7 +95,7 @@ public class MapRankDouble implements MapRank { } while (true) { - double targetValue = (Double) listArr[index]; + double targetValue = listArr[index]; Iterator passedMapIterator = passedMap.entrySet().iterator(); while (passedMapIterator.hasNext()) { Entry entry = passedMapIterator.next(); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankInteger.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankInteger.java index befbcbe00..aa218d6e9 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankInteger.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankInteger.java @@ -21,7 +21,7 @@ import java.util.Map.Entry; public class MapRankInteger implements MapRank { - private Map map; + private final Map map; private boolean acsending = true; public MapRankInteger(Map map, boolean acsending) { @@ -95,7 +95,7 @@ public class MapRankInteger implements MapRank { } while (true) { - int targetValue = (Integer) listArr[index]; + int targetValue = listArr[index]; Iterator passedMapIterator = passedMap.entrySet().iterator(); while (passedMapIterator.hasNext()) { Entry entry = passedMapIterator.next(); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/PairManager.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/PairManager.java index ff18c34c9..e072c7266 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/PairManager.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/PairManager.java @@ -33,16 +33,16 @@ import java.util.List; public class PairManager { FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); - private int numFilterBanks = fingerprintProperties.getNumFilterBanks(); - private int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks; - private int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength(); - private int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval(); - private int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance(); - private int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits(); + private final int numFilterBanks = fingerprintProperties.getNumFilterBanks(); + private final int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks; + private final int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength(); + private final int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval(); + private final int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance(); + private final int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits(); - private int maxPairs; - private boolean isReferencePairing; - private HashMap stopPairTable = new HashMap<>(); + private final int maxPairs; + private final boolean isReferencePairing; + private final HashMap stopPairTable = new HashMap<>(); /** * Constructor diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortDouble.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortDouble.java index 258cbc888..0c127a484 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortDouble.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortDouble.java @@ -18,8 +18,8 @@ package org.datavec.audio.fingerprint; public class QuickSortDouble extends QuickSort { - private int[] indexes; - private double[] array; + private final int[] indexes; + private final double[] array; public QuickSortDouble(double[] array) { this.array = array; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortIndexPreserved.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortIndexPreserved.java index 61e391d71..77aa155ca 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortIndexPreserved.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortIndexPreserved.java @@ -18,7 +18,7 @@ package org.datavec.audio.fingerprint; public class QuickSortIndexPreserved { - private QuickSort quickSort; + private final QuickSort quickSort; public QuickSortIndexPreserved(int[] array) { quickSort = new QuickSortInteger(array); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortInteger.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortInteger.java index 178553865..63db39ac3 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortInteger.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortInteger.java @@ -18,8 +18,8 @@ package org.datavec.audio.fingerprint; public class QuickSortInteger extends QuickSort { - private int[] indexes; - private int[] array; + private final int[] indexes; + private final int[] array; public QuickSortInteger(int[] array) { this.array = array; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortShort.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortShort.java index 8b4324b7e..91c275215 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortShort.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortShort.java @@ -18,8 +18,8 @@ package org.datavec.audio.fingerprint; public class QuickSortShort extends QuickSort { - private int[] indexes; - private short[] array; + private final int[] indexes; + private final short[] array; public QuickSortShort(short[] array) { this.array = array; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/processor/IntensityProcessor.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/processor/IntensityProcessor.java index 083ac4765..18521fd6f 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/processor/IntensityProcessor.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/processor/IntensityProcessor.java @@ -18,7 +18,7 @@ package org.datavec.audio.processor; public interface IntensityProcessor { - public void execute(); + void execute(); - public double[][] getIntensities(); + double[][] getIntensities(); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/processor/RobustIntensityProcessor.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/processor/RobustIntensityProcessor.java index 1d884855c..667f8194a 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/processor/RobustIntensityProcessor.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/processor/RobustIntensityProcessor.java @@ -20,7 +20,7 @@ package org.datavec.audio.processor; public class RobustIntensityProcessor implements IntensityProcessor { private double[][] intensities; - private int numPointsPerFrame; + private final int numPointsPerFrame; public RobustIntensityProcessor(double[][] intensities, int numPointsPerFrame) { this.intensities = intensities; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/properties/FingerprintProperties.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/properties/FingerprintProperties.java index db69a0b36..2109d60eb 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/properties/FingerprintProperties.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/properties/FingerprintProperties.java @@ -20,24 +20,24 @@ public class FingerprintProperties { protected static FingerprintProperties instance = null; - private int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint - private int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size - private int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ... 32 - private int numFilterBanks = 4; + private final int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint + private final int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size + private final int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ... 32 + private final int numFilterBanks = 4; - private int upperBoundedFrequency = 1500; // low pass - private int lowerBoundedFrequency = 400; // high pass - private int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps) - private int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps - private int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually + private final int upperBoundedFrequency = 1500; // low pass + private final int lowerBoundedFrequency = 400; // high pass + private final int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps) + private final int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps + private final int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually - private int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs - private int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip - private int numAnchorPointsPerInterval = 10; - private int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second) - private int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second) + private final int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs + private final int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip + private final int numAnchorPointsPerInterval = 10; + private final int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second) + private final int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second) - private int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units + private final int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units public static FingerprintProperties getInstance() { if (instance == null) { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java index 9e32b9bc0..475fe932b 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java @@ -126,7 +126,7 @@ public class CodecRecordReader extends BaseCodecRecordReader { /** Ugly workaround to a bug in JCodec: https://github.com/jcodec/jcodec/issues/24 */ private static class FixedByteBufferSeekableByteChannel extends ByteBufferSeekableByteChannel { - private ByteBuffer backing; + private final ByteBuffer backing; public FixedByteBufferSeekableByteChannel(ByteBuffer backing) { super(backing); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java index 6925bc47d..2ec31f426 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java @@ -47,7 +47,7 @@ public class ExcelRecordReader extends FileRecordReader { private Iterator sheetIterator; private Iterator rows; // Create a DataFormatter to format and get each cell's value as String - private DataFormatter dataFormatter = new DataFormatter(); + private final DataFormatter dataFormatter = new DataFormatter(); private Workbook currWorkBook; //we should ensure that the number of columns is consistent across all worksheets private int numColumns = -1; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java index d5e9e3439..44d6409c2 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java @@ -121,7 +121,7 @@ public class CoordinatesReduction implements AggregableColumnReduction { throw new UnsupportedOperationException(); } - private IAggregableReduceOp> reducer; + private final IAggregableReduceOp> reducer; @Override public IAggregableReduceOp> reduceOp() { @@ -132,11 +132,11 @@ public class CoordinatesReduction implements AggregableColumnReduction { public static class CoordinateAggregableReduceOp implements IAggregableReduceOp> { - private int nOps; - private Supplier>> initialOpValue; + private final int nOps; + private final Supplier>> initialOpValue; @Getter - private ArrayList>> perCoordinateOps; // of size coords() - private String delimiter; + private final ArrayList>> perCoordinateOps; // of size coords() + private final String delimiter; public CoordinateAggregableReduceOp(int n, Supplier>> initialOp, String delim) { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java index 55fd5855a..14d89576e 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java @@ -40,12 +40,12 @@ public class TestGeoReduction { public void testCustomReductions() { List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1#5"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2#6"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3#7"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4#8"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("1#5"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("2#6"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("3#7"))); + inputs.add(Arrays.asList(new Text("someKey"), new Text("4#8"))); - List expected = Arrays.asList((Writable) new Text("someKey"), new Text("10.0#26.0")); + List expected = Arrays.asList(new Text("someKey"), new Text("10.0#26.0")); Schema schema = new Schema.Builder().addColumnString("key").addColumnString("coord").build(); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java index d91d34b95..d6249b756 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java @@ -71,10 +71,10 @@ public class TestGeoTransforms { out.getColumnTypes()); assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)), - transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10")))); + transform.map(Arrays.asList(new Text("-30"), new Text("20"), new Text("10")))); assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"), new DoubleWritable(Math.sqrt(160))), - transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), + transform.map(Arrays.asList(new Text("50|40"), new Text("10|-20"), new Text("10|5")))); } @@ -94,7 +94,7 @@ public class TestGeoTransforms { double latitude = 51.5142; double longitude = -0.0931; - List writables = transform.map(Collections.singletonList((Writable) new Text(in))); + List writables = transform.map(Collections.singletonList(new Text(in))); assertEquals(1, writables.size()); String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); assertEquals(2, coordinates.length); @@ -112,7 +112,7 @@ public class TestGeoTransforms { ObjectInputStream ois = new ObjectInputStream(bais); Transform deserialized = (Transform) ois.readObject(); - writables = deserialized.map(Collections.singletonList((Writable) new Text(in))); + writables = deserialized.map(Collections.singletonList(new Text(in))); assertEquals(1, writables.size()); coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); //System.out.println(Arrays.toString(coordinates)); @@ -141,7 +141,7 @@ public class TestGeoTransforms { assertEquals(1, out.getColumnMetaData().size()); assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); - List writables = transform.map(Collections.singletonList((Writable) new Text(in))); + List writables = transform.map(Collections.singletonList(new Text(in))); assertEquals(1, writables.size()); assertEquals(location, writables.get(0).toString()); //System.out.println(location); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java index 01d5b2f84..66bb5742f 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java @@ -40,7 +40,7 @@ public class ConfigurationUtil { String baseConfPathTrimmed = baseConfPath.trim(); - if (false == "/".equals(baseConfPathTrimmed.endsWith("/"))) { + if (!"/".equals(baseConfPathTrimmed.endsWith("/"))) { baseConfPathTrimmed += "/"; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java index f5b28847e..3bc0e7111 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java @@ -42,10 +42,10 @@ import java.util.List; */ public class MapFileReader implements Closeable { - private MapFile.Reader[] readers; - private IndexToKey indexToKey; - private Class recordClass; - private List> recordIndexesEachReader; + private final MapFile.Reader[] readers; + private final IndexToKey indexToKey; + private final Class recordClass; + private final List> recordIndexesEachReader; private Long numRecords; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java index df649f8e4..23909e2bc 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java @@ -66,7 +66,7 @@ public class MapFileRecordReader implements RecordReader { private long numRecords; private long position; - private Random rng; + private final Random rng; private int[] order; /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java index 3a0513132..03f071eae 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java @@ -68,7 +68,7 @@ public class MapFileSequenceRecordReader implements SequenceRecordReader { private long numSequences; private long position; - private Random rng; + private final Random rng; private int[] order; /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java index d5595e53d..fa159c36c 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java @@ -77,27 +77,27 @@ public class TestMapFileRecordReader { seqMap = new HashMap<>(); seqMap.put(new LongWritable(0), new SequenceRecordWritable(Arrays.asList( - Arrays.asList(new Text("zero"), new IntWritable(0), + Arrays.asList(new Text("zero"), new IntWritable(0), new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))), - Arrays.asList(new Text("one"), new IntWritable(1), + Arrays.asList(new Text("one"), new IntWritable(1), new DoubleWritable(1.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 1.0))), - Arrays.asList(new Text("two"), new IntWritable(2), + Arrays.asList(new Text("two"), new IntWritable(2), new DoubleWritable(2.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 2.0)))))); seqMap.put(new LongWritable(1), new SequenceRecordWritable(Arrays.asList( - Arrays.asList(new Text("Bzero"), new IntWritable(10), + Arrays.asList(new Text("Bzero"), new IntWritable(10), new DoubleWritable(10), new NDArrayWritable(Nd4j.valueArrayOf(10, 10.0))), - Arrays.asList(new Text("Bone"), new IntWritable(11), + Arrays.asList(new Text("Bone"), new IntWritable(11), new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))), - Arrays.asList(new Text("Btwo"), new IntWritable(12), + Arrays.asList(new Text("Btwo"), new IntWritable(12), new DoubleWritable(12.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 12.0)))))); seqMap.put(new LongWritable(2), new SequenceRecordWritable(Arrays.asList( - Arrays.asList(new Text("Czero"), new IntWritable(20), + Arrays.asList(new Text("Czero"), new IntWritable(20), new DoubleWritable(20), new NDArrayWritable(Nd4j.valueArrayOf(10, 20.0))), - Arrays.asList(new Text("Cone"), new IntWritable(21), + Arrays.asList(new Text("Cone"), new IntWritable(21), new DoubleWritable(21.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 21.0))), - Arrays.asList(new Text("Ctwo"), new IntWritable(22), + Arrays.asList(new Text("Ctwo"), new IntWritable(22), new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0)))))); @@ -125,17 +125,17 @@ public class TestMapFileRecordReader { recordMap = new HashMap<>(); recordMap.put(new LongWritable(0), - new RecordWritable(Arrays.asList(new Text("zero"), + new RecordWritable(Arrays.asList(new Text("zero"), new IntWritable(0), new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))))); recordMap.put(new LongWritable(1), - new RecordWritable(Arrays.asList(new Text("one"), + new RecordWritable(Arrays.asList(new Text("one"), new IntWritable(11), new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))))); recordMap.put(new LongWritable(2), - new RecordWritable(Arrays.asList(new Text("two"), + new RecordWritable(Arrays.asList(new Text("two"), new IntWritable(22), new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0))))); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java index 7b50373c8..81be8ce7c 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java @@ -96,11 +96,11 @@ public class TestMapFileRecordReaderMultipleParts { for (int i = 0; i < 9; i++) { seqMap.put(new LongWritable(i), new SequenceRecordWritable(Arrays.asList( - Arrays.asList(new Text(i + "-0"), new IntWritable(3 * i), + Arrays.asList(new Text(i + "-0"), new IntWritable(3 * i), new DoubleWritable(3 * i)), - Arrays.asList(new Text(i + "-1"), + Arrays.asList(new Text(i + "-1"), new IntWritable(3 * i + 1), new DoubleWritable(3 * i + 1.0)), - Arrays.asList(new Text(i + "-2"), + Arrays.asList(new Text(i + "-2"), new IntWritable(3 * i + 2), new DoubleWritable(3 * i + 2.0))))); } @@ -141,7 +141,7 @@ public class TestMapFileRecordReaderMultipleParts { recordMap = new HashMap<>(); for (int i = 0; i < 9; i++) { - recordMap.put(new LongWritable(i), new RecordWritable(Arrays.asList( + recordMap.put(new LongWritable(i), new RecordWritable(Arrays.asList( new Text(String.valueOf(i)), new IntWritable(i), new DoubleWritable(i)))); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java index ff420241b..1a3999e05 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java @@ -96,11 +96,11 @@ public class TestMapFileRecordReaderMultiplePartsSomeEmpty { for (int i = 0; i < 6; i++) { seqMap.put(new LongWritable(i), new SequenceRecordWritable(Arrays.asList( - Arrays.asList(new Text(i + "-0"), new IntWritable(3 * i), + Arrays.asList(new Text(i + "-0"), new IntWritable(3 * i), new DoubleWritable(3 * i)), - Arrays.asList(new Text(i + "-1"), + Arrays.asList(new Text(i + "-1"), new IntWritable(3 * i + 1), new DoubleWritable(3 * i + 1.0)), - Arrays.asList(new Text(i + "-2"), + Arrays.asList(new Text(i + "-2"), new IntWritable(3 * i + 2), new DoubleWritable(3 * i + 2.0))))); } @@ -146,7 +146,7 @@ public class TestMapFileRecordReaderMultiplePartsSomeEmpty { recordMap = new HashMap<>(); for (int i = 0; i < 6; i++) { - recordMap.put(new LongWritable(i), new RecordWritable(Arrays.asList( + recordMap.put(new LongWritable(i), new RecordWritable(Arrays.asList( new Text(String.valueOf(i)), new IntWritable(i), new DoubleWritable(i)))); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java index 0a677f063..662b54148 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/CifarLoader.java @@ -248,14 +248,11 @@ public class CifarLoader extends NativeImageLoader implements Serializable { File f; if (train) { f = new File(trainFilesSerialized + 1 + ".ser"); - if (!f.exists()) - return false; + return f.exists(); } else { f = new File(testFilesSerialized); - if (!f.exists()) - return false; + return f.exists(); } - return true; } /** @@ -315,9 +312,9 @@ public class CifarLoader extends NativeImageLoader implements Serializable { } for (int i = 0; i < result.numExamples(); i++) { INDArray newFeatures = result.get(i).getFeatures(); - newFeatures.tensorAlongDimension(0, new int[] {0, 2, 3}).divi(255); - newFeatures.tensorAlongDimension(1, new int[] {0, 2, 3}).subi(uMean).divi(uStd); - newFeatures.tensorAlongDimension(2, new int[] {0, 2, 3}).subi(vMean).divi(vStd); + newFeatures.tensorAlongDimension(0, 0, 2, 3).divi(255); + newFeatures.tensorAlongDimension(1, 0, 2, 3).subi(uMean).divi(uStd); + newFeatures.tensorAlongDimension(2, 0, 2, 3).subi(vMean).divi(vStd); result.get(i).setFeatures(newFeatures); } result.save(fileName); @@ -372,8 +369,8 @@ public class CifarLoader extends NativeImageLoader implements Serializable { for (DataSet data : result) { try { if (useSpecialPreProcessCifar) { - INDArray uChannel = data.getFeatures().tensorAlongDimension(1, new int[] {0, 2, 3}); - INDArray vChannel = data.getFeatures().tensorAlongDimension(2, new int[] {0, 2, 3}); + INDArray uChannel = data.getFeatures().tensorAlongDimension(1, 0, 2, 3); + INDArray vChannel = data.getFeatures().tensorAlongDimension(2, 0, 2, 3); uTempMean = uChannel.meanNumber().doubleValue(); // TODO INDArray.var result is incorrect based on dimensions passed in thus using manual uStd += varManual(uChannel, uTempMean); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java index dc75e7e1c..700978ff6 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/LFWLoader.java @@ -158,17 +158,14 @@ public class LFWLoader extends BaseImageLoader implements Serializable { public boolean imageFilesExist() { if (useSubset) { File f = new File(BASE_DIR, lfwSubsetData.get("filesFilenameUnzipped")); - if (!f.exists()) - return false; + return f.exists(); } else { File f = new File(BASE_DIR, lfwData.get("filesFilenameUnzipped")); if (!f.exists()) return false; f = new File(BASE_DIR, lfwLabel.get("filesFilenameUnzipped")); - if (!f.exists()) - return false; + return f.exists(); } - return true; } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java index bda972a86..3cf702a94 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/loader/NativeImageLoader.java @@ -59,8 +59,8 @@ import static org.bytedeco.opencv.global.opencv_imgproc.*; */ public class NativeImageLoader extends BaseImageLoader { private static final int MIN_BUFFER_STEP_SIZE = 64 * 1024; - private byte[] buffer = null; - private Mat bufferMat = null; + private final byte[] buffer = null; + private final Mat bufferMat = null; @Getter public static final String[] ALLOWED_FORMATS = {"bmp", "gif", "jpg", "jpeg", "jp2", "pbm", "pgm", "ppm", "pnm", @@ -239,7 +239,7 @@ public class NativeImageLoader extends BaseImageLoader { tempPix = pix = pix2; int channels = pix.d() / 8; dtype = CV_8UC(channels); - Mat mat = new Mat(height, width, dtype, pix.data(), 4 * pix.wpl()); + Mat mat = new Mat(height, width, dtype, pix.data(), 4L * pix.wpl()); mat2 = new Mat(height, width, CV_8UC(channels)); // swap bytes if needed int[] swap = {0, channels - 1, 1, channels - 2, 2, channels - 3, 3, channels - 4}, @@ -408,7 +408,7 @@ public class NativeImageLoader extends BaseImageLoader { ret.data().offset() * Nd4j.sizeOfDataType(ret.data().dataType())); if (pointer instanceof FloatPointer) { - FloatIndexer retidx = FloatIndexer.create((FloatPointer) pagedPointer.asFloatPointer(), + FloatIndexer retidx = FloatIndexer.create(pagedPointer.asFloatPointer(), new long[] {channels, rows, cols}, new long[] {stride[0], stride[1], stride[2]}, direct); if (idx instanceof UByteIndexer) { UByteIndexer ubyteidx = (UByteIndexer) idx; @@ -453,7 +453,7 @@ public class NativeImageLoader extends BaseImageLoader { } retidx.release(); } else if (pointer instanceof DoublePointer) { - DoubleIndexer retidx = DoubleIndexer.create((DoublePointer) pagedPointer.asDoublePointer(), + DoubleIndexer retidx = DoubleIndexer.create(pagedPointer.asDoublePointer(), new long[] {channels, rows, cols}, new long[] {stride[0], stride[1], stride[2]}, direct); if (idx instanceof UByteIndexer) { UByteIndexer ubyteidx = (UByteIndexer) idx; @@ -871,14 +871,13 @@ public class NativeImageLoader extends BaseImageLoader { PIX pix = pixa.pix(i); currentD = asMatrix(convert(pix)); pixDestroy(pix); - switch (this.multiPageMode) { - case MINIBATCH: - index = new INDArrayIndex[]{NDArrayIndex.point(i),NDArrayIndex.all(), NDArrayIndex.all(),NDArrayIndex.all(),NDArrayIndex.all()}; - break; -// case CHANNELS: + if (this.multiPageMode == MultiPageMode.MINIBATCH) { + index = new INDArrayIndex[]{NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all()}; + // case CHANNELS: // index = new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all(),NDArrayIndex.all()}; // break; - default: throw new UnsupportedOperationException("Unsupported MultiPageMode: " + multiPageMode); + } else { + throw new UnsupportedOperationException("Unsupported MultiPageMode: " + multiPageMode); } data.put(index , currentD.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all(),NDArrayIndex.all())); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistDbFile.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistDbFile.java index f39ea4ef6..03a347cb3 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistDbFile.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistDbFile.java @@ -25,7 +25,7 @@ import java.io.IOException; import java.io.RandomAccessFile; public abstract class MnistDbFile extends RandomAccessFile { - private int count; + private final int count; /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistFetcher.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistFetcher.java index df6cede7c..75202afea 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistFetcher.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistFetcher.java @@ -34,7 +34,7 @@ import java.net.URL; public class MnistFetcher { private File fileDir; - private static Logger log = LoggerFactory.getLogger(MnistFetcher.class); + private static final Logger log = LoggerFactory.getLogger(MnistFetcher.class); private static final String trainingFilesURL = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"; private static final String trainingFilesFilename = "images-idx1-ubyte.gz"; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistImageFile.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistImageFile.java index f13fd9ca9..c5d10183f 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistImageFile.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/MnistImageFile.java @@ -25,8 +25,8 @@ import java.io.IOException; public class MnistImageFile extends MnistDbFile { - private int rows; - private int cols; + private final int rows; + private final int cols; /** * Creates new MNIST database image file ready for reading. diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/draw/DrawReconstruction.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/draw/DrawReconstruction.java index d58a279f1..beae5c881 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/draw/DrawReconstruction.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/mnist/draw/DrawReconstruction.java @@ -35,7 +35,7 @@ public class DrawReconstruction { public JFrame frame; BufferedImage img; - private INDArray data; + private final INDArray data; private int width = 28; private int height = 28; public String title = "TEST"; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index 48502f95d..4a2426ac4 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -122,7 +122,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { } protected boolean containsFormat(String format) { - for (String format2 : imageLoader.getALLOWED_FORMATS()) + for (String format2 : BaseImageLoader.getALLOWED_FORMATS()) if (format.endsWith("." + format2)) return true; return false; @@ -235,7 +235,7 @@ public abstract class BaseImageRecordReader extends BaseRecordReader { try { NDArrayWritable ndArrayWritable = new NDArrayWritable(imageLoader.asMatrix(inputStreamInputSplit.getIs())); finishedInputStreamSplit = true; - return Arrays.asList(ndArrayWritable); + return Collections.singletonList(ndArrayWritable); } catch (IOException e) { log.error("",e); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java index a8f25d876..c7e8657cb 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java @@ -139,9 +139,7 @@ public class ObjectDetectionRecordReader extends BaseImageRecordReader { List imageObjects = labelProvider.getImageObjectsForPath(location); for (ImageObject io : imageObjects) { String name = io.getLabel(); - if (!labelSet.contains(name)) { - labelSet.add(name); - } + labelSet.add(name); } } iter = new FileFromPathIterator(inputSplit.locationsPathIterator()); //This handles randomization internally if necessary diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/impl/SvhnLabelProvider.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/impl/SvhnLabelProvider.java index 9f29e0b17..c098105b6 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/impl/SvhnLabelProvider.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/impl/SvhnLabelProvider.java @@ -41,11 +41,11 @@ import java.util.Map; public class SvhnLabelProvider implements ImageObjectLabelProvider { - private static DataType refType = new DataType(PredType.STD_REF_OBJ()); - private static DataType charType = new DataType(PredType.NATIVE_CHAR()); - private static DataType intType = new DataType(PredType.NATIVE_INT()); + private static final DataType refType = new DataType(PredType.STD_REF_OBJ()); + private static final DataType charType = new DataType(PredType.NATIVE_CHAR()); + private static final DataType intType = new DataType(PredType.NATIVE_INT()); - private Map> labelMap; + private final Map> labelMap; public SvhnLabelProvider(File dir) throws IOException { labelMap = new HashMap>(); @@ -74,11 +74,11 @@ public class SvhnLabelProvider implements ImageObjectLabelProvider { PointerPointer labelPtr = new PointerPointer(256); IntPointer intPtr = new IntPointer(256); for (int i = 0; i < n; i++) { - DataSet nameRef = new DataSet(file, namePtr.position(i * ptrSize)); + DataSet nameRef = new DataSet(file, namePtr.position((long) i * ptrSize)); nameRef.read(bytePtr, charType); String filename = bytePtr.getString(); - Group bboxGroup = new Group(file, bboxPtr.position(i * ptrSize)); + Group bboxGroup = new Group(file, bboxPtr.position((long) i * ptrSize)); DataSet topDataset = bboxGroup.openDataSet("top"); DataSet leftDataset = bboxGroup.openDataSet("left"); DataSet heightDataset = bboxGroup.openDataSet("height"); @@ -101,23 +101,23 @@ public class SvhnLabelProvider implements ImageObjectLabelProvider { assert !isFloat || m == 1; for (int j = 0; j < m; j++) { - DataSet topSet = isFloat ? topDataset : new DataSet(file, topPtr.position(j * ptrSize)); + DataSet topSet = isFloat ? topDataset : new DataSet(file, topPtr.position((long) j * ptrSize)); topSet.read(intPtr, intType); int top = intPtr.get(); - DataSet leftSet = isFloat ? leftDataset : new DataSet(file, leftPtr.position(j * ptrSize)); + DataSet leftSet = isFloat ? leftDataset : new DataSet(file, leftPtr.position((long) j * ptrSize)); leftSet.read(intPtr, intType); int left = intPtr.get(); - DataSet heightSet = isFloat ? heightDataset : new DataSet(file, heightPtr.position(j * ptrSize)); + DataSet heightSet = isFloat ? heightDataset : new DataSet(file, heightPtr.position((long) j * ptrSize)); heightSet.read(intPtr, intType); int height = intPtr.get(); - DataSet widthSet = isFloat ? widthDataset : new DataSet(file, widthPtr.position(j * ptrSize)); + DataSet widthSet = isFloat ? widthDataset : new DataSet(file, widthPtr.position((long) j * ptrSize)); widthSet.read(intPtr, intType); int width = intPtr.get(); - DataSet labelSet = isFloat ? labelDataset : new DataSet(file, labelPtr.position(j * ptrSize)); + DataSet labelSet = isFloat ? labelDataset : new DataSet(file, labelPtr.position((long) j * ptrSize)); labelSet.read(intPtr, intType); int label = intPtr.get(); if (label == 10) { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/impl/VocLabelProvider.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/impl/VocLabelProvider.java index 192d06fe2..22a300e40 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/impl/VocLabelProvider.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/impl/VocLabelProvider.java @@ -42,7 +42,7 @@ public class VocLabelProvider implements ImageObjectLabelProvider { private static final String XMAX_TAG = ""; private static final String YMAX_TAG = ""; - private String annotationsDir; + private final String annotationsDir; public VocLabelProvider(@NonNull String baseDirectory){ this.annotationsDir = FilenameUtils.concat(baseDirectory, "Annotations"); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransformProcess.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransformProcess.java index 788a26581..43f90a502 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransformProcess.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransformProcess.java @@ -162,7 +162,7 @@ public class ImageTransformProcess { */ public static class Builder { - private List transformList; + private final List transformList; private int seed = 0; public Builder() { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java index 2ca05dd55..427878e63 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/LoaderTests.java @@ -70,7 +70,7 @@ public class LoaderTests { File dir = new File(FilenameUtils.concat(System.getProperty("user.home"), "cifar/cifar-10-batches-bin")); CifarLoader cifar = new CifarLoader(false, dir); assertTrue(dir.exists()); - assertTrue(cifar.getLabels() != null); + assertNotNull(cifar.getLabels()); } @Test @@ -171,7 +171,7 @@ public class LoaderTests { CifarLoader loader = new CifarLoader(row, col, channels, train, preProcessCifar); DataSet data = loader.next(numExamples); - long shape[] = data.getFeatures().shape(); + long[] shape = data.getFeatures().shape(); assertEquals(shape.length, 4); assertEquals(shape[0], numExamples); assertEquals(shape[1], channels); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java index 1273b8b31..433f2d25b 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/loader/TestImageLoader.java @@ -37,8 +37,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class TestImageLoader { - private static long seed = 10; - private static Random rng = new Random(seed); + private static final long seed = 10; + private static final Random rng = new Random(seed); @Test public void testToIntArrayArray() throws Exception { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java index e075e8c5d..44f4a31ee 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java @@ -207,7 +207,7 @@ public class TestImageRecordReader { List expLabels = new ArrayList<>(); for(URI u : arr){ String path = u.getPath(); - expLabels.add(testLabel(path.substring(path.length()-5, path.length()))); + expLabels.add(testLabel(path.substring(path.length()-5))); } int count = 0; @@ -280,7 +280,7 @@ public class TestImageRecordReader { @Override public Writable getLabelForPath(String path) { - String filename = path.substring(path.length()-5, path.length()); + String filename = path.substring(path.length()-5); return testLabel(filename); } @@ -336,7 +336,7 @@ public class TestImageRecordReader { List> expLabels = new ArrayList<>(); for(URI u : arr){ String path = u.getPath(); - expLabels.add(testMultiLabel(path.substring(path.length()-5, path.length()))); + expLabels.add(testMultiLabel(path.substring(path.length()-5))); } int count = 0; @@ -411,22 +411,22 @@ public class TestImageRecordReader { private static List testMultiLabel(String filename){ switch(filename){ case "0.jpg": - return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)), + return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)), new NDArrayWritable(Nd4j.create(new double[]{1,0,0}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(0.0)); case "1.png": - return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)), + return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)), new NDArrayWritable(Nd4j.create(new double[]{0,1,0}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(1.0)); case "2.jpg": - return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)), + return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{1,0}, new long[]{1,2}, DataType.FLOAT)), new NDArrayWritable(Nd4j.create(new double[]{0,0,1}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(2.0)); case "A.jpg": - return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)), + return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)), new NDArrayWritable(Nd4j.create(new double[]{1,0,0}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(3.0)); case "B.png": - return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)), + return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)), new NDArrayWritable(Nd4j.create(new double[]{0,1,0}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(4.0)); case "C.jpg": - return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)), + return Arrays.asList(new NDArrayWritable(Nd4j.create(new double[]{0,1}, new long[]{1,2}, DataType.FLOAT)), new NDArrayWritable(Nd4j.create(new double[]{0,0,1}, new long[]{1,3}, DataType.FLOAT)), new DoubleWritable(5.0)); default: throw new RuntimeException(filename); @@ -435,7 +435,7 @@ public class TestImageRecordReader { private static class CountingListener implements RecordListener { - private RecordListener listener; + private final RecordListener listener; private int count = 0; public CountingListener(RecordListener listener) { diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java index f1d194769..2d27491f9 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/transform/JsonYamlTest.java @@ -95,11 +95,11 @@ public class JsonYamlTest { imgYaml = itYaml.transform(imgYaml); if (it instanceof RandomCropTransform) { - assertTrue(img.getFrame().imageHeight == imgJson.getFrame().imageHeight); - assertTrue(img.getFrame().imageWidth == imgJson.getFrame().imageWidth); + assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight); + assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth); - assertTrue(img.getFrame().imageHeight == imgYaml.getFrame().imageHeight); - assertTrue(img.getFrame().imageWidth == imgYaml.getFrame().imageWidth); + assertEquals(img.getFrame().imageHeight, imgYaml.getFrame().imageHeight); + assertEquals(img.getFrame().imageWidth, imgYaml.getFrame().imageWidth); } else if (it instanceof FilterImageTransform) { assertEquals(img.getFrame().imageHeight, imgJson.getFrame().imageHeight); assertEquals(img.getFrame().imageWidth, imgJson.getFrame().imageWidth); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/transform/TestImageTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/transform/TestImageTransform.java index ee713e091..593589eab 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/transform/TestImageTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/transform/TestImageTransform.java @@ -76,7 +76,7 @@ public class TestImageTransform { assertEquals( - x, transformed[4], 0); assertEquals( - y, transformed[5], 0); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); } @Test @@ -93,7 +93,7 @@ public class TestImageTransform { assertTrue(f.imageWidth <= frame.imageWidth); assertEquals(f.imageChannels, frame.imageChannels); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); transform = new CropImageTransform(1, 2, 3, 4); writable = transform.transform(writable); @@ -118,29 +118,29 @@ public class TestImageTransform { assertEquals(f.imageWidth, frame.imageWidth); assertEquals(f.imageChannels, frame.imageChannels); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); transform = new FlipImageTransform(-2); writable = transform.transform(writable); - float[] transformed = transform.query(new float[] {10, 20}); + float[] transformed = transform.query(10, 20); assertEquals(10, transformed[0], 0); assertEquals(20, transformed[1], 0); transform = new FlipImageTransform(0); writable = transform.transform(writable); - transformed = transform.query(new float[] {30, 40}); + transformed = transform.query(30, 40); assertEquals(30, transformed[0], 0); assertEquals(frame.imageHeight - 40 - 1, transformed[1], 0); transform = new FlipImageTransform(1); writable = transform.transform(writable); - transformed = transform.query(new float[] {50, 60}); + transformed = transform.query(50, 60); assertEquals(frame.imageWidth - 50 - 1, transformed[0], 0); assertEquals(60, transformed[1], 0); transform = new FlipImageTransform(-1); writable = transform.transform(writable); - transformed = transform.query(new float[] {70, 80}); + transformed = transform.query(70, 80); assertEquals(frame.imageWidth - 70 - 1, transformed[0], 0); assertEquals(frame.imageHeight - 80 - 1, transformed[1], 0); } @@ -160,7 +160,7 @@ public class TestImageTransform { assertTrue(f.imageWidth <= 3 * frame.imageWidth / 2); assertEquals(f.imageChannels, frame.imageChannels); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); transform = new ScaleImageTransform(frame.imageWidth, 2 * frame.imageHeight); writable = transform.transform(writable); @@ -186,7 +186,7 @@ public class TestImageTransform { assertEquals(f.imageWidth, frame.imageWidth); assertEquals(f.imageChannels, frame.imageChannels); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); transform = new RotateImageTransform(0, 0, -90, 0); writable = transform.transform(writable); @@ -212,7 +212,7 @@ public class TestImageTransform { assertEquals(f.imageWidth, frame.imageWidth); assertEquals(f.imageChannels, frame.imageChannels); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); transform = new WarpImageTransform(1, 2, 3, 4, 5, 6, 7, 8); writable = transform.transform(writable); @@ -245,11 +245,11 @@ public class TestImageTransform { assertTrue(f.imageWidth <= frame.imageWidth + 20); assertEquals(f.imageChannels, frame.imageChannels); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); transform = new MultiImageTransform(new ColorConversionTransform(COLOR_BGR2RGB)); writable = transform.transform(writable); - float[] transformed = transform.query(new float[] {11, 22}); + float[] transformed = transform.query(11, 22); assertEquals(11, transformed[0], 0); assertEquals(22, transformed[1], 0); } @@ -269,7 +269,7 @@ public class TestImageTransform { assertEquals(f.imageWidth, frame.imageWidth); assertEquals(f.imageChannels, frame.imageChannels); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); } @Test @@ -284,9 +284,9 @@ public class TestImageTransform { assertEquals(w, writable); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); - float[] transformed = transform.query(new float[] {33, 44}); + float[] transformed = transform.query(33, 44); assertEquals(33, transformed[0], 0); assertEquals(44, transformed[1], 0); } @@ -312,9 +312,9 @@ public class TestImageTransform { Frame newframe = w.getFrame(); assertNotEquals(frame, newframe); - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); - float[] transformed = transform.query(new float[] {55, 66}); + float[] transformed = transform.query(55, 66); assertEquals(55, transformed[0], 0); assertEquals(66, transformed[1], 0); } @@ -336,9 +336,9 @@ public class TestImageTransform { showTrans.transform(writable); Frame newframe = w.getFrame(); assertNotEquals(frame, newframe); - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); - float[] transformed = transform.query(new float[] {66, 77}); + float[] transformed = transform.query(66, 77); assertEquals(66, transformed[0], 0); assertEquals(77, transformed[1], 0); } @@ -352,10 +352,10 @@ public class TestImageTransform { for (int i = 0; i < 100; i++) { ImageWritable w = transform.transform(writable); Frame f = w.getFrame(); - assertTrue(f.imageHeight == frame.imageHeight / 2); - assertTrue(f.imageWidth == frame.imageWidth / 2); + assertEquals(f.imageHeight, frame.imageHeight / 2); + assertEquals(f.imageWidth, frame.imageWidth / 2); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); transform = new RandomCropTransform(frame.imageHeight, frame.imageWidth); writable = transform.transform(writable); @@ -382,15 +382,15 @@ public class TestImageTransform { for (int i = 0; i < 100; i++) { ImageWritable w = transform.transform(writable); Frame f = w.getFrame(); - assertTrue(f.imageHeight == frame.imageHeight / 2); - assertTrue(f.imageWidth == frame.imageWidth / 2); + assertEquals(f.imageHeight, frame.imageHeight / 2); + assertEquals(f.imageWidth, frame.imageWidth / 2); assertEquals(f.imageChannels, frame.imageChannels); } - assertEquals(null, transform.transform(null)); + assertNull(transform.transform(null)); transform = new PipelineImageTransform(new EqualizeHistTransform()); writable = transform.transform(writable); - float[] transformed = transform.query(new float[] {88, 99}); + float[] transformed = transform.query(88, 99); assertEquals(88, transformed[0], 0); assertEquals(99, transformed[1], 0); } @@ -426,7 +426,7 @@ public class TestImageTransform { assertEquals(newFrame.imageHeight, 74); assertEquals(newFrame.imageWidth, 61); - float[] transformed = transform.query(new float[] {88, 32}); + float[] transformed = transform.query(88, 32); assertEquals(0, transformed[0], 0); assertEquals(0, transformed[1], 0); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/PoStagger.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/PoStagger.java index d071a42a4..6ad2d8a1e 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/PoStagger.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/PoStagger.java @@ -161,7 +161,7 @@ public class PoStagger extends CasAnnotator_ImplBase { final List posTags = this.posTagger.tag(sentenceTokenList); - double posProbabilities[] = null; + double[] posProbabilities = null; if (this.probabilityFeature != null) { posProbabilities = this.posTagger.probs(); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/DefaultVocabCache.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/DefaultVocabCache.java index 16dff8e5f..9b07d54ed 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/DefaultVocabCache.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/DefaultVocabCache.java @@ -30,10 +30,10 @@ import org.nd4j.common.util.Index; */ public class DefaultVocabCache implements VocabCache { - private Counter wordFrequencies = new Counter<>(); - private Counter docFrequencies = new Counter<>(); + private final Counter wordFrequencies = new Counter<>(); + private final Counter docFrequencies = new Counter<>(); private int minWordFrequency; - private Index vocabWords = new Index(); + private final Index vocabWords = new Index(); private double numDocs = 0; /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/ContextLabelRetriever.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/ContextLabelRetriever.java index 76b0244bd..b0b3da4a4 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/ContextLabelRetriever.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/ContextLabelRetriever.java @@ -35,8 +35,8 @@ import java.util.List; public class ContextLabelRetriever { - private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>"; - private static String END_LABEL = ""; + private static final String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>"; + private static final String END_LABEL = ""; private ContextLabelRetriever() {} @@ -66,7 +66,7 @@ public class ContextLabelRetriever { //no labels; add these as NONE and begin the new label if (!currTokens.isEmpty()) { - tokensWithSameLabel.add(new Pair<>("NONE", (List) new ArrayList<>(currTokens))); + tokensWithSameLabel.add(new Pair<>("NONE", new ArrayList<>(currTokens))); currTokens.clear(); } @@ -86,7 +86,7 @@ public class ContextLabelRetriever { Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s", currLabel, endLabel); - tokensWithSameLabel.add(new Pair<>(currLabel, (List) new ArrayList<>(currTokens))); + tokensWithSameLabel.add(new Pair<>(currLabel, new ArrayList<>(currTokens))); currTokens.clear(); @@ -100,7 +100,7 @@ public class ContextLabelRetriever { //no labels; add these as NONE and begin the new label if (!currTokens.isEmpty()) { - tokensWithSameLabel.add(new Pair<>("none", (List) new ArrayList<>(currTokens))); + tokensWithSameLabel.add(new Pair<>("none", new ArrayList<>(currTokens))); currTokens.clear(); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Window.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Window.java index a75b37fd0..be7b46563 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Window.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Window.java @@ -40,8 +40,8 @@ public class Window implements Serializable { private boolean beginLabel; private boolean endLabel; private int median; - private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>"; - private static String END_LABEL = ""; + private static final String BEGIN_LABEL = "<([A-Z]+|\\d+)>"; + private static final String END_LABEL = ""; private int begin, end; /** diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/ConcurrentTokenizer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/ConcurrentTokenizer.java index d46e68790..4f7980435 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/ConcurrentTokenizer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/ConcurrentTokenizer.java @@ -121,7 +121,7 @@ public class ConcurrentTokenizer extends AbstractTokenizer { protected void postProcessAnnotations(Span[] tokens, AnnotationFS[] tokenAnnotations) { // if interest if (probabilityFeature != null) { - double tokenProbabilties[] = tokenizer.getTokenProbabilities(); + double[] tokenProbabilties = tokenizer.getTokenProbabilities(); for (int i = 0; i < tokenAnnotations.length; i++) { tokenAnnotations[i].setDoubleValue(probabilityFeature, tokenProbabilties[i]); diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultStreamTokenizer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultStreamTokenizer.java index 9216ed24a..f6872a768 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultStreamTokenizer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultStreamTokenizer.java @@ -28,7 +28,7 @@ import java.util.List; */ public class DefaultStreamTokenizer implements Tokenizer { - private StreamTokenizer streamTokenizer; + private final StreamTokenizer streamTokenizer; private TokenPreProcess tokenPreProcess; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultTokenizer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultTokenizer.java index 4b393c0d5..c2972a606 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultTokenizer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultTokenizer.java @@ -30,7 +30,7 @@ public class DefaultTokenizer implements Tokenizer { tokenizer = new StringTokenizer(tokens); } - private StringTokenizer tokenizer; + private final StringTokenizer tokenizer; private TokenPreProcess tokenPreProcess; @Override diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/PosUimaTokenizer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/PosUimaTokenizer.java index e9e94bb99..a094a1f3e 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/PosUimaTokenizer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/PosUimaTokenizer.java @@ -41,8 +41,8 @@ import java.util.List; public class PosUimaTokenizer implements Tokenizer { private static AnalysisEngine engine; - private List tokens; - private Collection allowedPosTags; + private final List tokens; + private final Collection allowedPosTags; private int index; private static CAS cas; @@ -85,9 +85,7 @@ public class PosUimaTokenizer implements Tokenizer { String check = token.getCoveredText(); if (check.matches("<[A-Z]+>") || check.matches("")) return false; - else if (token.getPos() != null && !this.allowedPosTags.contains(token.getPos())) - return false; - return true; + else return token.getPos() == null || this.allowedPosTags.contains(token.getPos()); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java index eb14fdabd..18d942005 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java @@ -34,10 +34,10 @@ import java.util.List; */ public class UimaTokenizer implements Tokenizer { - private List tokens; + private final List tokens; private int index; - private static Logger log = LoggerFactory.getLogger(UimaTokenizer.class); - private boolean checkForLabel; + private static final Logger log = LoggerFactory.getLogger(UimaTokenizer.class); + private final boolean checkForLabel; private TokenPreProcess tokenPreProcessor; @@ -73,9 +73,7 @@ public class UimaTokenizer implements Tokenizer { } private boolean valid(String check) { - if (check.matches("<[A-Z]+>") || check.matches("")) - return false; - return true; + return !check.matches("<[A-Z]+>") && !check.matches(""); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/PosUimaTokenizerFactory.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/PosUimaTokenizerFactory.java index 5419bf9ef..3a7368b02 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/PosUimaTokenizerFactory.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/PosUimaTokenizerFactory.java @@ -41,8 +41,8 @@ import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDesc */ public class PosUimaTokenizerFactory implements TokenizerFactory { - private AnalysisEngine tokenizer; - private Collection allowedPoSTags; + private final AnalysisEngine tokenizer; + private final Collection allowedPoSTags; private TokenPreProcess tokenPreProcess; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/UimaTokenizerFactory.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/UimaTokenizerFactory.java index 7b24244ac..b91ae9a74 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/UimaTokenizerFactory.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/UimaTokenizerFactory.java @@ -40,8 +40,8 @@ import java.io.InputStream; public class UimaTokenizerFactory implements TokenizerFactory { - private UimaResource uimaResource; - private boolean checkForLabel; + private final UimaResource uimaResource; + private final boolean checkForLabel; private static AnalysisEngine defaultAnalysisEngine; private TokenPreProcess preProcess; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java index 1c774e904..d99f629dc 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java @@ -98,7 +98,7 @@ public class GazeteerTransform extends BaseColumnTransform implements BagOfWords @Override public List> mapSequence(List> sequence) { INDArray arr = (INDArray) mapSequence((Object) sequence); - return Collections.singletonList(Collections.singletonList(new NDArrayWritable(arr))); + return Collections.singletonList(Collections.singletonList(new NDArrayWritable(arr))); } @Override diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java index b702422e8..8f8c0deb0 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java @@ -40,9 +40,9 @@ import java.util.List; */ public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform { - private BagOfWordsTransform[] transforms; - private String newColumnName; - private List vocabWords; + private final BagOfWordsTransform[] transforms; + private final String newColumnName; + private final List vocabWords; /** * @@ -80,7 +80,7 @@ public class MultiNlpTransform extends BaseColumnTransform implements BagOfWords @Override public List> mapSequence(List> sequence) { - return Collections.singletonList(Collections.singletonList(new NDArrayWritable(transformFrom(sequence)))); + return Collections.singletonList(Collections.singletonList(new NDArrayWritable(transformFrom(sequence)))); } @Override diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java index c63ff14c7..7bfbe4eb0 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java @@ -60,7 +60,7 @@ public class TestGazeteerTransform { String[] split = s.split(" "); List> seq = new ArrayList<>(); for(String s2 : split){ - seq.add(Collections.singletonList(new Text(s2))); + seq.add(Collections.singletonList(new Text(s2))); } input.add(seq); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java index b0642f2a9..ebb5c52c7 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java @@ -52,7 +52,7 @@ public class TestMultiNLPTransform { String[] split = s.split(" "); List> seq = new ArrayList<>(); for(String s2 : split){ - seq.add(Collections.singletonList(new Text(s2))); + seq.add(Collections.singletonList(new Text(s2))); } input.add(seq); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java index dded0cc06..dfa2e228a 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java @@ -89,7 +89,7 @@ public class TokenizerBagOfWordsTermSequenceIndexTransformTest { */ List>> input = new ArrayList<>(); - input.add(Arrays.asList(Arrays.asList(new Text(corpus[0])),Arrays.asList(new Text(corpus[1])))); + input.add(Arrays.asList(Collections.singletonList(new Text(corpus[0])), Collections.singletonList(new Text(corpus[1])))); // First: Check TfidfVectorizer vs. scikit: @@ -313,7 +313,7 @@ public class TokenizerBagOfWordsTermSequenceIndexTransformTest { //input.add(Arrays.asList(Arrays.asList(new Text(corpus[0])),Arrays.asList(new Text(corpus[1])))); List> seq = new ArrayList<>(); for(String s : corpus){ - seq.add(Collections.singletonList(new Text(s))); + seq.add(Collections.singletonList(new Text(s))); } input.add(seq); diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java index 1e2bc42ed..2f7328954 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java @@ -68,7 +68,7 @@ public class LocalTransformExecutor { //returning empty records public final static String LOG_ERROR_PROPERTY = "org.datavec.spark.transform.logerrors"; - private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); + private static final BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); /** * Execute the specified TransformProcess with the given input data
@@ -98,7 +98,7 @@ public class LocalTransformExecutor { * Execute the specified TransformProcess with the given input data
* Note: this method can only be used if the TransformProcess * starts with non-sequential data, - * but returns sequence + * but returns sequence * data (after grouping or converting to a sequence as one of the steps) * * @param inputWritables Input data to process diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformProcessSequenceRecordReader.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformProcessSequenceRecordReader.java index e2427b409..8fc5fccae 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformProcessSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformProcessSequenceRecordReader.java @@ -26,6 +26,7 @@ import org.datavec.api.transform.TransformProcess; import org.datavec.api.writable.Writable; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class LocalTransformProcessSequenceRecordReader extends TransformProcessSequenceRecordReader { @@ -36,7 +37,7 @@ public class LocalTransformProcessSequenceRecordReader extends TransformProcessS @Override public List> sequenceRecord() { - return LocalTransformExecutor.executeSequenceToSequence(Arrays.asList(sequenceRecordReader.nextSequence().getSequenceRecord()),transformProcess + return LocalTransformExecutor.executeSequenceToSequence(Collections.singletonList(sequenceRecordReader.nextSequence().getSequenceRecord()),transformProcess ).get(0); } diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceMergeFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceMergeFunction.java index 0e0c6afef..e0fdf697e 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceMergeFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceMergeFunction.java @@ -31,7 +31,7 @@ import java.util.List; public class SequenceMergeFunction implements Function>>>, List>> { - private SequenceMerge sequenceMerge; + private final SequenceMerge sequenceMerge; public SequenceMergeFunction(SequenceMerge sequenceMerge) { this.sequenceMerge = sequenceMerge; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java index 2ec96607e..25dc7b738 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java @@ -41,6 +41,7 @@ import org.nd4j.common.io.ClassPathResource; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -64,18 +65,18 @@ public class LocalTransformProcessRecordReaderTests { public void simpleTransformTestSequence() { List> sequence = new ArrayList<>(); //First window: - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L), new IntWritable(0), + sequence.add(Arrays.asList(new LongWritable(1451606400000L), new IntWritable(0), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 100L), new IntWritable(1), + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 100L), new IntWritable(1), new IntWritable(0))); - sequence.add(Arrays.asList((Writable) new LongWritable(1451606400000L + 200L), new IntWritable(2), + sequence.add(Arrays.asList(new LongWritable(1451606400000L + 200L), new IntWritable(2), new IntWritable(0))); Schema schema = new SequenceSchema.Builder().addColumnTime("timecolumn", DateTimeZone.UTC) .addColumnInteger("intcolumn").addColumnInteger("intcolumn2").build(); TransformProcess transformProcess = new TransformProcess.Builder(schema).removeColumns("intcolumn2").build(); InMemorySequenceRecordReader inMemorySequenceRecordReader = - new InMemorySequenceRecordReader(Arrays.asList(sequence)); + new InMemorySequenceRecordReader(Collections.singletonList(sequence)); LocalTransformProcessSequenceRecordReader transformProcessSequenceRecordReader = new LocalTransformProcessSequenceRecordReader(inMemorySequenceRecordReader, transformProcess); List> next = transformProcessSequenceRecordReader.sequenceRecord(); diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java index 37a86a2f3..95b6ebfab 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java @@ -31,6 +31,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -50,7 +51,7 @@ public class TestNDArrayToWritablesFunction { @Test public void testNDArrayToWritablesArray() throws Exception { INDArray arr = Nd4j.arange(5); - List expected = Arrays.asList((Writable) new NDArrayWritable(arr)); + List expected = Collections.singletonList(new NDArrayWritable(arr)); List actual = new NDArrayToWritablesFunction(true).apply(arr); assertEquals(expected, actual); } diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java index fca45adb1..1086866f2 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java @@ -44,7 +44,7 @@ public class TestWritablesToStringFunctions { @Test public void testWritablesToString() throws Exception { - List l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue")); + List l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue")); String expected = l.get(0).toString() + "," + l.get(1).toString(); assertEquals(expected, new WritablesToStringFunction(",").apply(l)); @@ -53,8 +53,8 @@ public class TestWritablesToStringFunctions { @Test public void testSequenceWritablesToString() throws Exception { - List> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")), - Arrays.asList(new DoubleWritable(2.5), new Text("otherValue"))); + List> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")), + Arrays.asList(new DoubleWritable(2.5), new Text("otherValue"))); String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n" + l.get(1).get(0).toString() + "," + l.get(1).get(1).toString(); diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java index e265136f8..94aa8eeaf 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java @@ -87,9 +87,9 @@ public class ExecutionTest { .doubleMathOp("col2", MathOp.Add, 10.0).floatMathOp("col3", MathOp.Add, 5f).build(); List> inputData = new ArrayList<>(); - inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f))); - inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f))); - inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f))); + inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1), new FloatWritable(0.3f))); + inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1), new FloatWritable(1.7f))); + inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1), new FloatWritable(3.6f))); List> rdd = (inputData); @@ -103,9 +103,9 @@ public class ExecutionTest { }); List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f))); - expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f))); - expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f))); + expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1), new FloatWritable(5.3f))); + expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1), new FloatWritable(6.7f))); + expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1), new FloatWritable(8.6f))); assertEquals(expected, out); } @@ -116,9 +116,9 @@ public class ExecutionTest { .addColumnDouble("col1").addColumnDouble("col2") .addColumnDouble("col3").build(); List> inputData = new ArrayList<>(); - inputData.add(Arrays.asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1))); - inputData.add(Arrays.asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1))); - inputData.add(Arrays.asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1))); + inputData.add(Arrays.asList(new IntWritable(0), new DoubleWritable(1), new DoubleWritable(0.1))); + inputData.add(Arrays.asList(new IntWritable(1), new DoubleWritable(3), new DoubleWritable(1.1))); + inputData.add(Arrays.asList(new IntWritable(2), new DoubleWritable(3), new DoubleWritable(2.1))); TransformProcess transformProcess = new TransformProcess.Builder(filterSchema) .filter(new DoubleColumnCondition("col1",ConditionOp.LessThan,1)).build(); List> execute = LocalTransformExecutor.execute(inputData, transformProcess); @@ -136,12 +136,12 @@ public class ExecutionTest { List>> inputSequences = new ArrayList<>(); List> seq1 = new ArrayList<>(); - seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); - seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); - seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); + seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); + seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); + seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); List> seq2 = new ArrayList<>(); - seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); - seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); + seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); + seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); inputSequences.add(seq1); inputSequences.add(seq2); @@ -159,12 +159,12 @@ public class ExecutionTest { List>> expectedSequence = new ArrayList<>(); List> seq1e = new ArrayList<>(); - seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); - seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); - seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); + seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); + seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); + seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); List> seq2e = new ArrayList<>(); - seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); - seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); + seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); + seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); expectedSequence.add(seq1e); expectedSequence.add(seq2e); @@ -177,8 +177,8 @@ public class ExecutionTest { public void testReductionGlobal() { List> in = Arrays.asList( - Arrays.asList(new Text("first"), new DoubleWritable(3.0)), - Arrays.asList(new Text("second"), new DoubleWritable(5.0)) + Arrays.asList(new Text("first"), new DoubleWritable(3.0)), + Arrays.asList(new Text("second"), new DoubleWritable(5.0)) ); List> inData = in; @@ -198,7 +198,7 @@ public class ExecutionTest { List> out = outRdd; - List> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0))); + List> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0))); assertEquals(expOut, out); } @@ -207,10 +207,10 @@ public class ExecutionTest { public void testReductionByKey(){ List> in = Arrays.asList( - Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), - Arrays.asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), - Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), - Arrays.asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)) + Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), + Arrays.asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), + Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), + Arrays.asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)) ); List> inData = in; @@ -233,8 +233,8 @@ public class ExecutionTest { List> out = outRdd; List> expOut = Arrays.asList( - Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), - Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); + Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), + Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); out = new ArrayList<>(out); Collections.sort( diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java index d0e431678..3cca330af 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java @@ -75,10 +75,10 @@ public class TestGeoTransforms { out.getColumnTypes()); assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)), - transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10")))); + transform.map(Arrays.asList(new Text("-30"), new Text("20"), new Text("10")))); assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"), new DoubleWritable(Math.sqrt(160))), - transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), + transform.map(Arrays.asList(new Text("50|40"), new Text("10|-20"), new Text("10|5")))); } @@ -98,7 +98,7 @@ public class TestGeoTransforms { double latitude = 51.5142; double longitude = -0.0931; - List writables = transform.map(Collections.singletonList((Writable) new Text(in))); + List writables = transform.map(Collections.singletonList(new Text(in))); assertEquals(1, writables.size()); String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); assertEquals(2, coordinates.length); @@ -116,7 +116,7 @@ public class TestGeoTransforms { ObjectInputStream ois = new ObjectInputStream(bais); Transform deserialized = (Transform) ois.readObject(); - writables = deserialized.map(Collections.singletonList((Writable) new Text(in))); + writables = deserialized.map(Collections.singletonList(new Text(in))); assertEquals(1, writables.size()); coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); //System.out.println(Arrays.toString(coordinates)); @@ -145,7 +145,7 @@ public class TestGeoTransforms { assertEquals(1, out.getColumnMetaData().size()); assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); - List writables = transform.map(Collections.singletonList((Writable) new Text(in))); + List writables = transform.map(Collections.singletonList(new Text(in))); assertEquals(1, writables.size()); assertEquals(location, writables.get(0).toString()); //System.out.println(location); diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java index 21ae33b3f..1dd62a88e 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java @@ -69,7 +69,7 @@ public class TestPythonTransformProcess { .build() ).build(); - List inputs = Arrays.asList((Writable)new Text("Hello "), new Text("World!")); + List inputs = Arrays.asList(new Text("Hello "), new Text("World!")); List outputs = tp.execute(inputs); assertEquals((outputs.get(0)).toString(), "Hello "); @@ -100,7 +100,7 @@ public class TestPythonTransformProcess { .inputSchema(initialSchema) .build() ).build(); - List inputs = Arrays.asList((Writable)new IntWritable(10), + List inputs = Arrays.asList(new IntWritable(10), new FloatWritable(3.5f), new Text("5"), new DoubleWritable(2.0) @@ -134,7 +134,6 @@ public class TestPythonTransformProcess { .build() ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), new NDArrayWritable(arr2) ); @@ -170,7 +169,6 @@ public class TestPythonTransformProcess { .build() ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), new NDArrayWritable(arr2) ); @@ -206,7 +204,6 @@ public class TestPythonTransformProcess { ).build(); List inputs = Arrays.asList( - (Writable) new NDArrayWritable(arr1), new NDArrayWritable(arr2) ); @@ -271,7 +268,6 @@ public class TestPythonTransformProcess { List> inputs = new ArrayList<>(); inputs.add( Arrays.asList( - (Writable) new IntWritable(5), new FloatWritable(3.0f), new Text("abcd"), @@ -279,7 +275,6 @@ public class TestPythonTransformProcess { ); inputs.add( Arrays.asList( - (Writable) new IntWritable(-3), new FloatWritable(3.0f), new Text("abcd"), @@ -287,7 +282,6 @@ public class TestPythonTransformProcess { ); inputs.add( Arrays.asList( - (Writable) new IntWritable(5), new FloatWritable(11.2f), new Text("abcd"), @@ -305,7 +299,7 @@ public class TestPythonTransformProcess { .returnAllInputs(true) .build(); List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable)new IntWritable(1))); + inputs.add(Collections.singletonList(new IntWritable(1))); Schema inputSchema = new Builder() .addColumnInteger("a") .build(); @@ -327,7 +321,7 @@ public class TestPythonTransformProcess { .build(); List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)))); + inputs.add(Collections.singletonList(new NDArrayWritable(Nd4j.scalar(1).reshape(1, 1)))); Schema inputSchema = new Builder() .addColumnNDArray("a",new long[]{1,1}) .build(); @@ -360,7 +354,7 @@ public class TestPythonTransformProcess { .build(); List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)), + inputs.add(Arrays.asList(new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)), new NDArrayWritable(Nd4j.scalar(2).reshape(1,1)))); Schema inputSchema = new Builder() .addColumnNDArray("a",new long[]{1,1}) diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java index adb511603..b7fc564c7 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java @@ -46,27 +46,27 @@ public class TestJoin { .addColumnDouble("amount").build(); List> infoList = new ArrayList<>(); - infoList.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"))); - infoList.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"))); - infoList.add(Arrays.asList(new LongWritable(50000), new Text("Customer50000"))); + infoList.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"))); + infoList.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"))); + infoList.add(Arrays.asList(new LongWritable(50000), new Text("Customer50000"))); List> purchaseList = new ArrayList<>(); - purchaseList.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), + purchaseList.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), new DoubleWritable(10.00))); - purchaseList.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), + purchaseList.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), new DoubleWritable(20.00))); - purchaseList.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), + purchaseList.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), new DoubleWritable(30.00))); Join join = new Join.Builder(Join.JoinType.RightOuter).setJoinColumns("customerID") .setSchemas(customerInfoSchema, purchasesSchema).build(); List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), + expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), new LongWritable(1000000), new DoubleWritable(10.00))); - expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), + expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), new LongWritable(1000001), new DoubleWritable(20.00))); - expected.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"), + expected.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"), new LongWritable(1000002), new DoubleWritable(30.00))); @@ -100,11 +100,11 @@ public class TestJoin { .setSchemas(purchasesSchema, customerInfoSchema).build(); List> expectedManyToOne = new ArrayList<>(); - expectedManyToOne.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), + expectedManyToOne.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), new DoubleWritable(10.00), new Text("Customer12345"))); - expectedManyToOne.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), + expectedManyToOne.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), new DoubleWritable(20.00), new Text("Customer12345"))); - expectedManyToOne.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), + expectedManyToOne.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), new DoubleWritable(30.00), new Text("Customer98765"))); List> joined2 = LocalTransformExecutor.executeJoin(join2, purchases, info); @@ -138,45 +138,45 @@ public class TestJoin { .addColumnCategorical("otherCategory", Arrays.asList("cat0", "cat1", "cat2")).build(); List> first = new ArrayList<>(); - first.add(Arrays.asList(new LongWritable(0), new Text("cat0"))); - first.add(Arrays.asList(new LongWritable(1), new Text("cat0"))); - first.add(Arrays.asList(new LongWritable(2), new Text("cat1"))); + first.add(Arrays.asList(new LongWritable(0), new Text("cat0"))); + first.add(Arrays.asList(new LongWritable(1), new Text("cat0"))); + first.add(Arrays.asList(new LongWritable(2), new Text("cat1"))); List> second = new ArrayList<>(); - second.add(Arrays.asList(new LongWritable(100), new Text("cat0"))); - second.add(Arrays.asList(new LongWritable(101), new Text("cat0"))); - second.add(Arrays.asList(new LongWritable(102), new Text("cat2"))); + second.add(Arrays.asList(new LongWritable(100), new Text("cat0"))); + second.add(Arrays.asList(new LongWritable(101), new Text("cat0"))); + second.add(Arrays.asList(new LongWritable(102), new Text("cat2"))); List> expOuterJoin = new ArrayList<>(); - expOuterJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); - expOuterJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); - expOuterJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); - expOuterJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); - expOuterJoin.add(Arrays.asList(new LongWritable(2), new Text("cat1"), new NullWritable())); - expOuterJoin.add(Arrays.asList(new NullWritable(), new Text("cat2"), new LongWritable(102))); + expOuterJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); + expOuterJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); + expOuterJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); + expOuterJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); + expOuterJoin.add(Arrays.asList(new LongWritable(2), new Text("cat1"), new NullWritable())); + expOuterJoin.add(Arrays.asList(new NullWritable(), new Text("cat2"), new LongWritable(102))); List> expLeftJoin = new ArrayList<>(); - expLeftJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); - expLeftJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); - expLeftJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); - expLeftJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); - expLeftJoin.add(Arrays.asList(new LongWritable(2), new Text("cat1"), new NullWritable())); + expLeftJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); + expLeftJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); + expLeftJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); + expLeftJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); + expLeftJoin.add(Arrays.asList(new LongWritable(2), new Text("cat1"), new NullWritable())); List> expRightJoin = new ArrayList<>(); - expRightJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); - expRightJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); - expRightJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); - expRightJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); - expRightJoin.add(Arrays.asList(new NullWritable(), new Text("cat2"), new LongWritable(102))); + expRightJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); + expRightJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); + expRightJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); + expRightJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); + expRightJoin.add(Arrays.asList(new NullWritable(), new Text("cat2"), new LongWritable(102))); List> expInnerJoin = new ArrayList<>(); - expInnerJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); - expInnerJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); - expInnerJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); - expInnerJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); + expInnerJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); + expInnerJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); + expInnerJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); + expInnerJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); List> firstRDD = (first); List> secondRDD = (second); diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java index 39f3405a9..a14c5f468 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java @@ -45,10 +45,10 @@ public class TestCalculateSortedRank { public void testCalculateSortedRank() { List> data = new ArrayList<>(); - data.add(Arrays.asList((Writable) new Text("0"), new DoubleWritable(0.0))); - data.add(Arrays.asList((Writable) new Text("3"), new DoubleWritable(0.3))); - data.add(Arrays.asList((Writable) new Text("2"), new DoubleWritable(0.2))); - data.add(Arrays.asList((Writable) new Text("1"), new DoubleWritable(0.1))); + data.add(Arrays.asList(new Text("0"), new DoubleWritable(0.0))); + data.add(Arrays.asList(new Text("3"), new DoubleWritable(0.3))); + data.add(Arrays.asList(new Text("2"), new DoubleWritable(0.2))); + data.add(Arrays.asList(new Text("1"), new DoubleWritable(0.1))); List> rdd = (data); diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java index 04a4a5c47..bd3ace8c8 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java @@ -48,12 +48,12 @@ public class TestConvertToSequence { Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build(); List> allExamples = - Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), - Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), - Arrays.asList(new Text("k1a"), new Text("k2a"), + Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), + Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), - Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), - Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); + Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); TransformProcess tp = new TransformProcess.Builder(s) .convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time")) @@ -75,13 +75,13 @@ public class TestConvertToSequence { } List> expSeq0 = Arrays.asList( - Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), - Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), - Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); List> expSeq1 = Arrays.asList( - Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), - Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); + Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), + Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); assertEquals(expSeq0, seq0); assertEquals(expSeq1, seq1); @@ -96,9 +96,9 @@ public class TestConvertToSequence { .build(); List> allExamples = Arrays.asList( - Arrays.asList(new Text("a"), new LongWritable(0)), - Arrays.asList(new Text("b"), new LongWritable(1)), - Arrays.asList(new Text("c"), new LongWritable(2))); + Arrays.asList(new Text("a"), new LongWritable(0)), + Arrays.asList(new Text("b"), new LongWritable(1)), + Arrays.asList(new Text("c"), new LongWritable(2))); TransformProcess tp = new TransformProcess.Builder(s) .convertToSequence() diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/NumpyArray.java index 708184de7..ca597c60b 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -46,8 +46,8 @@ import static org.nd4j.linalg.api.buffer.DataType.FLOAT; @NoArgsConstructor public class NumpyArray { - private static NativeOps nativeOps; - private static Map arrayCache; // Avoids re-allocation of device buffer + private static final NativeOps nativeOps; + private static final Map arrayCache; // Avoids re-allocation of device buffer private long address; private long[] shape; private long[] strides; @@ -62,7 +62,7 @@ public class NumpyArray { } @Builder - public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy) { + public NumpyArray(long address, long[] shape, long[] strides, DataType dtype, boolean copy) { this.address = address; this.shape = shape; this.strides = strides; @@ -81,11 +81,11 @@ public class NumpyArray { return new NumpyArray(nd4jArray.dup()); } - public NumpyArray(long address, long[] shape, long strides[]) { + public NumpyArray(long address, long[] shape, long[] strides) { this(address, shape, strides, FLOAT, false); } - public NumpyArray(long address, long[] shape, long strides[], DataType dtype) { + public NumpyArray(long address, long[] shape, long[] strides, DataType dtype) { this(address, shape, strides, dtype, false); } diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonCondition.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonCondition.java index e94e5a171..62370246f 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonCondition.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonCondition.java @@ -37,7 +37,7 @@ public class PythonCondition implements Condition { private Schema inputSchema; private PythonVariables pyInputs; private PythonTransform pythonTransform; - private String code; + private final String code; public PythonCondition(String pythonCode) { diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonContextManager.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonContextManager.java index c3563bfc2..b46610918 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonContextManager.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonContextManager.java @@ -36,8 +36,8 @@ import java.util.concurrent.atomic.AtomicBoolean; public class PythonContextManager { - private static Set contexts = new HashSet<>(); - private static AtomicBoolean init = new AtomicBoolean(false); + private static final Set contexts = new HashSet<>(); + private static final AtomicBoolean init = new AtomicBoolean(false); private static String currentContext; private static final String MAIN_CONTEXT = "main"; static { diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index dd48cb104..f0e6f4eed 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -89,7 +89,7 @@ import static org.datavec.python.Python.*; public class PythonExecutioner { - private static AtomicBoolean init = new AtomicBoolean(false); + private static final AtomicBoolean init = new AtomicBoolean(false); public final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.datavec.python.path"; public final static String JAVACPP_PYTHON_APPEND_TYPE = "org.datavec.python.javacpp.path.append"; public final static String DEFAULT_APPEND_TYPE = "before"; @@ -139,7 +139,7 @@ public class PythonExecutioner { * and b is the variable value. * @param varName Name of the python variable being set. Should be a valid python identifier string * @param pythonObject Value for the python variable - * @throws Exception + * @throws PythonException */ public static void setVariable(String varName, PythonObject pythonObject) throws PythonException{ if (!validateVariableName(varName)){ @@ -345,10 +345,8 @@ public class PythonExecutioner { //// TODO: fix in javacpp File sitePackagesWindows = new File(python.cachePackage(), "site-packages"); File[] packages2 = new File[packages.length + 1]; - for (int i = 0;i < packages.length; i++){ - //System.out.println(packages[i].getAbsolutePath()); - packages2[i] = packages[i]; - } + //System.out.println(packages[i].getAbsolutePath()); + System.arraycopy(packages, 0, packages2, 0, packages.length); packages2[packages.length] = sitePackagesWindows; //System.out.println(sitePackagesWindows.getAbsolutePath()); packages = packages2; @@ -369,7 +367,7 @@ public class PythonExecutioner { sb.append(path); - log.info("Prepending javacpp python path: {}", sb.toString()); + log.info("Prepending javacpp python path: {}", sb); break; case AFTER: sb.append(path); @@ -379,7 +377,7 @@ public class PythonExecutioner { sb.append(java.io.File.pathSeparator); } - log.info("Appending javacpp python path " + sb.toString()); + log.info("Appending javacpp python path " + sb); break; case NONE: log.info("Not appending javacpp path"); @@ -388,7 +386,7 @@ public class PythonExecutioner { } //prepend the javacpp packages - log.info("Final python path: {}", sb.toString()); + log.info("Final python path: {}", sb); Py_SetPath(sb.toString()); } diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonJob.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonJob.java index c50c9bb9e..81894d101 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonJob.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonJob.java @@ -121,7 +121,7 @@ public class PythonJob { PythonObject arg = argsList.get(i); PythonObject val = Python.globals().get(arg); if (val.isNone()) { - throw new PythonException("Input value not received for run() argument: " + arg.toString()); + throw new PythonException("Input value not received for run() argument: " + arg); } runargs.set(arg, val); } @@ -153,7 +153,7 @@ public class PythonJob { PythonObject arg = argsList.get(i); PythonObject val = Python.globals().get(arg); if (val.isNone()) { - throw new PythonException("Input value not received for run() argument: " + arg.toString()); + throw new PythonException("Input value not received for run() argument: " + arg); } runargs.set(arg, val); } diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonObject.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonObject.java index 4a6a617d5..b9d809aab 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonObject.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonObject.java @@ -158,7 +158,7 @@ public class PythonObject { } public PythonObject(int data) { - nativePythonObject = PyLong_FromLong((long) data); + nativePythonObject = PyLong_FromLong(data); } public PythonObject(long data) { @@ -208,7 +208,7 @@ public class PythonObject { } public PythonObject(Object[] data) { - PyObject pyList = PyList_New((long) data.length); + PyObject pyList = PyList_New(data.length); for (int i = 0; i < data.length; i++) { PyList_SetItem(pyList, i, j2pyObject(data[i]).nativePythonObject); } @@ -216,7 +216,7 @@ public class PythonObject { } public PythonObject(List data) { - PyObject pyList = PyList_New((long) data.size()); + PyObject pyList = PyList_New(data.size()); for (int i = 0; i < data.size(); i++) { PyList_SetItem(pyList, i, j2pyObject(data.get(i)).nativePythonObject); } @@ -384,9 +384,7 @@ public class PythonObject { public PythonObject call(Object... args) { if (args.length > 0 && args[args.length - 1] instanceof Map) { List args2 = new ArrayList<>(); - for (int i = 0; i < args.length - 1; i++) { - args2.add(args[i]); - } + args2.addAll(Arrays.asList(args).subList(0, args.length - 1)); return call(args2, (Map) args[args.length - 1]); } if (args.length == 0) { @@ -444,7 +442,7 @@ public class PythonObject { } public PythonObject get(int key) { - return get(PyLong_FromLong((long) key)); + return get(PyLong_FromLong(key)); } public PythonObject get(long key) { diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonProcess.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonProcess.java index a8ee56510..8c86cc69a 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonProcess.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonProcess.java @@ -27,12 +27,10 @@ import java.util.Arrays; @Slf4j public class PythonProcess { - private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + private static final String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ String[] allArgs = new String[arguments.length + 1]; - for (int i = 0; i < arguments.length; i++){ - allArgs[i + 1] = arguments[i]; - } + System.arraycopy(arguments, 0, allArgs, 1, arguments.length); allArgs[0] = pythonExecutable; log.info("Executing command: " + Arrays.toString(allArgs)); ProcessBuilder pb = new ProcessBuilder(allArgs); @@ -45,9 +43,7 @@ public class PythonProcess { public static void run(String... arguments)throws IOException, InterruptedException{ String[] allArgs = new String[arguments.length + 1]; - for (int i = 0; i < arguments.length; i++){ - allArgs[i + 1] = arguments[i]; - } + System.arraycopy(arguments, 0, allArgs, 1, arguments.length); allArgs[0] = pythonExecutable; log.info("Executing command: " + Arrays.toString(allArgs)); ProcessBuilder pb = new ProcessBuilder(allArgs); diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonType.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonType.java index d0a3f488f..4ac4bce2a 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonType.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonType.java @@ -55,7 +55,7 @@ public abstract class PythonType { } public static PythonType valueOf(String typeName) throws PythonException{ try{ - typeName.valueOf(typeName); + String.valueOf(typeName); } catch (IllegalArgumentException iae){ throw new PythonException("Invalid python type: " + typeName, iae); } diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonUtils.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonUtils.java index d3e991b35..7d0ea15cd 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonUtils.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonUtils.java @@ -132,7 +132,7 @@ public class PythonUtils { pyVars.addBool(colName); break; default: - throw new Exception("Unsupported python input type: " + colType.toString()); + throw new Exception("Unsupported python input type: " + colType); } } @@ -220,7 +220,7 @@ public class PythonUtils { public static Map toMap(JSONObject jsonobj) { Map map = new HashMap<>(); - String[] keys = (String[]) jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); + String[] keys = jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); for (String key : keys) { Object value = jsonobj.get(key); if (value instanceof JSONArray) { diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/keras/Model.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/keras/Model.java index d8a9b0651..04d2cafd6 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/keras/Model.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/keras/Model.java @@ -8,7 +8,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; public class Model { - private PythonObject pyModel; + private final PythonObject pyModel; private static PythonObject installAndImportTF() throws PythonException{ diff --git a/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/PythonNumpyTest.java b/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/PythonNumpyTest.java index 1c0721c32..83a8636e6 100644 --- a/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/PythonNumpyTest.java +++ b/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/PythonNumpyTest.java @@ -45,7 +45,7 @@ public class PythonNumpyTest { }; } - private DataType dataType; + private final DataType dataType; public PythonNumpyTest(DataType dataType) { this.dataType = dataType; diff --git a/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java b/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java index f6f39d68c..6d536655b 100644 --- a/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java +++ b/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/ScalarAndArrayTest.java @@ -35,7 +35,7 @@ public class ScalarAndArrayTest { }; } - private INDArray indArray; + private final INDArray indArray; public ScalarAndArrayTest(INDArray indArray) { this.indArray = indArray; diff --git a/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/TestPythonList.java b/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/TestPythonList.java index 259431cba..a52c58fd9 100644 --- a/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/TestPythonList.java +++ b/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/TestPythonList.java @@ -92,7 +92,7 @@ public class TestPythonList { ), map }; PythonObject pyList = new PythonObject(objs); - System.out.println(pyList.toString()); + System.out.println(pyList); String expectedStr = "[1, 2, 'a', 3.0, 4, 5.0, [10" + ", 20, 'b', 30.0, 40, 50.0, {'arr': array([1.," + " 2., 3., 4.], dtype=float32), 1: 'a', 'a': [" + diff --git a/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java b/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java index b709e1608..668f820a9 100644 --- a/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java +++ b/cavis-datavec/cavis-datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java @@ -60,12 +60,12 @@ public class TestPythonVariables { BytePointer bp = new BytePointer(arr.data().pointer()); Object[] values = { 1L,1.0,"1",true, Collections.singletonMap("1",1), - new Object[]{1}, Arrays.asList(1), arr, bp + new Object[]{1}, Collections.singletonList(1), arr, bp }; Object[] expectedValues = { 1L,1.0,"1",true, Collections.singletonMap("1",1), - Arrays.asList(1), Arrays.asList(1), arr, bp + Collections.singletonList(1), Collections.singletonList(1), arr, bp }; for(int i = 0; i < types.length; i++) { diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/pairdata/PathToKeyFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/pairdata/PathToKeyFunction.java index 5c43c969f..39187a785 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/pairdata/PathToKeyFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/pairdata/PathToKeyFunction.java @@ -28,8 +28,8 @@ import scala.Tuple3; public class PathToKeyFunction implements PairFunction, String, Tuple3> { - private PathToKeyConverter converter; - private int index; + private final PathToKeyConverter converter; + private final int index; public PathToKeyFunction(int index, PathToKeyConverter converter) { this.index = index; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/DataFrames.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/DataFrames.java index 1e41cc3c1..2a9fdb9fb 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/DataFrames.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/DataFrames.java @@ -43,6 +43,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import static org.apache.spark.sql.functions.avg; @@ -380,8 +381,7 @@ public class DataFrames { */ public static List toList(String[] input) { List ret = new ArrayList<>(); - for (int i = 0; i < input.length; i++) - ret.add(input[i]); + Collections.addAll(ret, input); return ret; } diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/Normalization.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/Normalization.java index f4d513017..de5511017 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/Normalization.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/Normalization.java @@ -42,7 +42,7 @@ public class Normalization { * rdd */ public static Dataset zeromeanUnitVariance(Dataset frame) { - return zeromeanUnitVariance(frame, Collections.emptyList()); + return zeromeanUnitVariance(frame, Collections.emptyList()); } /** @@ -55,7 +55,7 @@ public class Normalization { * rdd */ public static JavaRDD> zeromeanUnitVariance(Schema schema, JavaRDD> data) { - return zeromeanUnitVariance(schema, data, Collections.emptyList()); + return zeromeanUnitVariance(schema, data, Collections.emptyList()); } /** @@ -67,7 +67,7 @@ public class Normalization { * @return the normalized dataframe per column */ public static Dataset normalize(Dataset dataFrame, double min, double max) { - return normalize(dataFrame, min, max, Collections.emptyList()); + return normalize(dataFrame, min, max, Collections.emptyList()); } /** @@ -82,7 +82,7 @@ public class Normalization { public static JavaRDD> normalize(Schema schema, JavaRDD> data, double min, double max) { Dataset frame = DataFrames.toDataFrame(schema, data); - return DataFrames.toRecords(normalize(frame, min, max, Collections.emptyList())).getSecond(); + return DataFrames.toRecords(normalize(frame, min, max, Collections.emptyList())).getSecond(); } @@ -93,7 +93,7 @@ public class Normalization { * @return the normalized dataframe per column */ public static Dataset normalize(Dataset dataFrame) { - return normalize(dataFrame, 0, 1, Collections.emptyList()); + return normalize(dataFrame, 0, 1, Collections.emptyList()); } /** @@ -104,7 +104,7 @@ public class Normalization { * @return the normalized ata */ public static JavaRDD> normalize(Schema schema, JavaRDD> data) { - return normalize(schema, data, 0, 1, Collections.emptyList()); + return normalize(schema, data, 0, 1, Collections.emptyList()); } diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceMergeFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceMergeFunction.java index 03f3efebd..933fb6d4a 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceMergeFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceMergeFunction.java @@ -31,7 +31,7 @@ import java.util.List; public class SequenceMergeFunction implements Function>>>, List>> { - private SequenceMerge sequenceMerge; + private final SequenceMerge sequenceMerge; public SequenceMergeFunction(SequenceMerge sequenceMerge) { this.sequenceMerge = sequenceMerge; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java index 7868cb3c7..faac3ac13 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java @@ -32,8 +32,8 @@ import java.util.*; public class SequenceToRows implements FlatMapFunction>, Row> { - private Schema schema; - private StructType structType; + private final Schema schema; + private final StructType structType; public SequenceToRows(Schema schema) { this.schema = schema; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRow.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRow.java index bac0740d8..6128901ee 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRow.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRow.java @@ -34,8 +34,8 @@ import java.util.List; public class ToRow implements Function, Row> { - private Schema schema; - private StructType structType; + private final Schema schema; + private final StructType structType; public ToRow(Schema schema) { this.schema = schema; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkExport.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkExport.java index a9ed6f13a..549ff7e13 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkExport.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkExport.java @@ -165,7 +165,7 @@ public class SparkExport { throws Exception { baseDir.mkdirs(); if (!baseDir.isDirectory()) - throw new IllegalArgumentException("File is not a directory: " + baseDir.toString()); + throw new IllegalArgumentException("File is not a directory: " + baseDir); String baseDirStr = baseDir.toString(); List fileContents = sequences.map(new SequenceToStringFunction(",")).collect(); @@ -192,7 +192,7 @@ public class SparkExport { String delimiter, String filePrefix, String fileExtension) throws Exception { baseDir.mkdirs(); if (!baseDir.isDirectory()) - throw new IllegalArgumentException("File is not a directory: " + baseDir.toString()); + throw new IllegalArgumentException("File is not a directory: " + baseDir); String baseDirStr = baseDir.toString(); List fileContents = sequences.map(new SequenceToStringFunction(delimiter)).collect(); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkUtils.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkUtils.java index 73ff3617a..e1d038495 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkUtils.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkUtils.java @@ -274,7 +274,7 @@ public class SparkUtils { * Register the DataVec writable classes for Kryo */ public static void registerKryoClasses(SparkConf conf) { - List> classes = Arrays.>asList(BooleanWritable.class, ByteWritable.class, + List> classes = Arrays.asList(BooleanWritable.class, ByteWritable.class, DoubleWritable.class, FloatWritable.class, IntWritable.class, LongWritable.class, NullWritable.class, Text.class); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/util/SerializableHadoopConfig.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/util/SerializableHadoopConfig.java index 237e62b5b..280339237 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/util/SerializableHadoopConfig.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/util/SerializableHadoopConfig.java @@ -30,7 +30,7 @@ import java.util.Map; public class SerializableHadoopConfig implements Serializable { - private Map content; + private final Map content; private transient Configuration configuration; public SerializableHadoopConfig(@NonNull Configuration configuration){ diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/TestKryoSerialization.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/TestKryoSerialization.java index a684fb61d..781de0e5d 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/TestKryoSerialization.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/TestKryoSerialization.java @@ -63,6 +63,6 @@ public class TestKryoSerialization extends BaseSparkTest { private T serDe(T in, SerializerInstance si){ ByteBuffer bb = si.serialize(in, null); - return (T)si.deserialize(bb, null); + return si.deserialize(bb, null); } } diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java index 4990cfe03..5143b01eb 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -49,7 +50,7 @@ public class TestNDArrayToWritablesFunction { @Test public void testNDArrayToWritablesArray() throws Exception { INDArray arr = Nd4j.arange(5); - List expected = Arrays.asList((Writable) new NDArrayWritable(arr)); + List expected = Collections.singletonList(new NDArrayWritable(arr)); List actual = new NDArrayToWritablesFunction(true).call(arr); assertEquals(expected, actual); } diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java index b96041e3f..0e4df00cc 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java @@ -78,7 +78,7 @@ public class TestPairSequenceRecordReaderBytesFunction extends BaseSparkTest { Path p = Files.createTempDirectory("dl4j_rrbytesPairOut"); p.toFile().deleteOnExit(); - String outPath = p.toString() + "/out"; + String outPath = p + "/out"; new File(outPath).deleteOnExit(); toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java index aa6e5f76d..9afc645ad 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java @@ -77,7 +77,7 @@ public class TestRecordReaderBytesFunction extends BaseSparkTest { //Write the sequence file: Path p = Files.createTempDirectory("dl4j_rrbytesTest"); p.toFile().deleteOnExit(); - String outPath = p.toString() + "/out"; + String outPath = p + "/out"; filesAsBytes.saveAsNewAPIHadoopFile(outPath, Text.class, BytesWritable.class, SequenceFileOutputFormat.class); //Load data from sequence file, parse via RecordReader: diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java index 2f9bc4410..212b9bb64 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java @@ -73,7 +73,7 @@ public class TestSequenceRecordReaderBytesFunction extends BaseSparkTest { //Write the sequence file: Path p = Files.createTempDirectory("dl4j_rrbytesTest"); p.toFile().deleteOnExit(); - String outPath = p.toString() + "/out"; + String outPath = p + "/out"; filesAsBytes.saveAsNewAPIHadoopFile(outPath, Text.class, BytesWritable.class, SequenceFileOutputFormat.class); //Load data from sequence file, parse via SequenceRecordReader: diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java index 070bda4ed..19847cec0 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java @@ -77,7 +77,7 @@ public class TestWritablesToStringFunctions extends BaseSparkTest { @Test public void testWritablesToString() throws Exception { - List l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue")); + List l = Arrays.asList(new DoubleWritable(1.5), new Text("someValue")); String expected = l.get(0).toString() + "," + l.get(1).toString(); assertEquals(expected, new WritablesToStringFunction(",").call(l)); @@ -86,8 +86,8 @@ public class TestWritablesToStringFunctions extends BaseSparkTest { @Test public void testSequenceWritablesToString() throws Exception { - List> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")), - Arrays.asList(new DoubleWritable(2.5), new Text("otherValue"))); + List> l = Arrays.asList(Arrays.asList(new DoubleWritable(1.5), new Text("someValue")), + Arrays.asList(new DoubleWritable(2.5), new Text("otherValue"))); String expected = l.get(0).get(0).toString() + "," + l.get(0).get(1).toString() + "\n" + l.get(1).get(0).toString() + "," + l.get(1).get(1).toString(); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java index a0a10e876..eaafa1d14 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java @@ -46,11 +46,11 @@ public class TestSparkStorageUtils extends BaseSparkTest { return; } List> l = new ArrayList<>(); - l.add(Arrays.asList(new Text("zero"), new IntWritable(0), + l.add(Arrays.asList(new Text("zero"), new IntWritable(0), new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0)))); - l.add(Arrays.asList(new Text("one"), new IntWritable(11), + l.add(Arrays.asList(new Text("one"), new IntWritable(11), new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0)))); - l.add(Arrays.asList(new Text("two"), new IntWritable(22), + l.add(Arrays.asList(new Text("two"), new IntWritable(22), new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0)))); JavaRDD> rdd = sc.parallelize(l); @@ -92,27 +92,27 @@ public class TestSparkStorageUtils extends BaseSparkTest { } List>> l = new ArrayList<>(); l.add(Arrays.asList( - Arrays.asList(new Text("zero"), new IntWritable(0), + Arrays.asList(new Text("zero"), new IntWritable(0), new DoubleWritable(0), new NDArrayWritable(Nd4j.valueArrayOf(10, 0.0))), - Arrays.asList(new Text("one"), new IntWritable(1), + Arrays.asList(new Text("one"), new IntWritable(1), new DoubleWritable(1.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 1.0))), - Arrays.asList(new Text("two"), new IntWritable(2), + Arrays.asList(new Text("two"), new IntWritable(2), new DoubleWritable(2.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 2.0))))); l.add(Arrays.asList( - Arrays.asList(new Text("Bzero"), new IntWritable(10), + Arrays.asList(new Text("Bzero"), new IntWritable(10), new DoubleWritable(10), new NDArrayWritable(Nd4j.valueArrayOf(10, 10.0))), - Arrays.asList(new Text("Bone"), new IntWritable(11), + Arrays.asList(new Text("Bone"), new IntWritable(11), new DoubleWritable(11.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 11.0))), - Arrays.asList(new Text("Btwo"), new IntWritable(12), + Arrays.asList(new Text("Btwo"), new IntWritable(12), new DoubleWritable(12.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 12.0))))); l.add(Arrays.asList( - Arrays.asList(new Text("Czero"), new IntWritable(20), + Arrays.asList(new Text("Czero"), new IntWritable(20), new DoubleWritable(20), new NDArrayWritable(Nd4j.valueArrayOf(10, 20.0))), - Arrays.asList(new Text("Cone"), new IntWritable(21), + Arrays.asList(new Text("Cone"), new IntWritable(21), new DoubleWritable(21.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 21.0))), - Arrays.asList(new Text("Ctwo"), new IntWritable(22), + Arrays.asList(new Text("Ctwo"), new IntWritable(22), new DoubleWritable(22.0), new NDArrayWritable(Nd4j.valueArrayOf(10, 22.0))))); JavaRDD>> rdd = sc.parallelize(l); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/DataFramesTests.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/DataFramesTests.java index 62237f0b4..4a6bca8dc 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/DataFramesTests.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/DataFramesTests.java @@ -110,15 +110,15 @@ public class DataFramesTests extends BaseSparkTest { public void testNormalize() { List> data = new ArrayList<>(); - data.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10))); - data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20))); - data.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30))); + data.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10))); + data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20))); + data.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30))); List> expMinMax = new ArrayList<>(); - expMinMax.add(Arrays.asList(new DoubleWritable(0.0), new DoubleWritable(0.0))); - expMinMax.add(Arrays.asList(new DoubleWritable(0.5), new DoubleWritable(0.5))); - expMinMax.add(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(1.0))); + expMinMax.add(Arrays.asList(new DoubleWritable(0.0), new DoubleWritable(0.0))); + expMinMax.add(Arrays.asList(new DoubleWritable(0.5), new DoubleWritable(0.5))); + expMinMax.add(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(1.0))); double m1 = (1 + 2 + 3) / 3.0; double s1 = new StandardDeviation().evaluate(new double[] {1, 2, 3}); @@ -127,11 +127,11 @@ public class DataFramesTests extends BaseSparkTest { List> expStandardize = new ArrayList<>(); expStandardize.add( - Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2))); + Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2))); expStandardize.add( - Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2))); + Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2))); expStandardize.add( - Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2))); + Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2))); JavaRDD> rdd = sc.parallelize(data); @@ -178,13 +178,13 @@ public class DataFramesTests extends BaseSparkTest { List>> sequences = new ArrayList<>(); List> seq1 = new ArrayList<>(); - seq1.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100))); - seq1.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200))); - seq1.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300))); + seq1.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(10), new DoubleWritable(100))); + seq1.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(20), new DoubleWritable(200))); + seq1.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(30), new DoubleWritable(300))); List> seq2 = new ArrayList<>(); - seq2.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400))); - seq2.add(Arrays.asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500))); + seq2.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(40), new DoubleWritable(400))); + seq2.add(Arrays.asList(new DoubleWritable(5), new DoubleWritable(50), new DoubleWritable(500))); sequences.add(seq1); sequences.add(seq2); @@ -199,21 +199,21 @@ public class DataFramesTests extends BaseSparkTest { //Min/max normalization: List> expSeq1MinMax = new ArrayList<>(); - expSeq1MinMax.add(Arrays.asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)), + expSeq1MinMax.add(Arrays.asList(new DoubleWritable((1 - 1.0) / (5.0 - 1.0)), new DoubleWritable((10 - 10.0) / (50.0 - 10.0)), new DoubleWritable((100 - 100.0) / (500.0 - 100.0)))); - expSeq1MinMax.add(Arrays.asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)), + expSeq1MinMax.add(Arrays.asList(new DoubleWritable((2 - 1.0) / (5.0 - 1.0)), new DoubleWritable((20 - 10.0) / (50.0 - 10.0)), new DoubleWritable((200 - 100.0) / (500.0 - 100.0)))); - expSeq1MinMax.add(Arrays.asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)), + expSeq1MinMax.add(Arrays.asList(new DoubleWritable((3 - 1.0) / (5.0 - 1.0)), new DoubleWritable((30 - 10.0) / (50.0 - 10.0)), new DoubleWritable((300 - 100.0) / (500.0 - 100.0)))); List> expSeq2MinMax = new ArrayList<>(); - expSeq2MinMax.add(Arrays.asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)), + expSeq2MinMax.add(Arrays.asList(new DoubleWritable((4 - 1.0) / (5.0 - 1.0)), new DoubleWritable((40 - 10.0) / (50.0 - 10.0)), new DoubleWritable((400 - 100.0) / (500.0 - 100.0)))); - expSeq2MinMax.add(Arrays.asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)), + expSeq2MinMax.add(Arrays.asList(new DoubleWritable((5 - 1.0) / (5.0 - 1.0)), new DoubleWritable((50 - 10.0) / (50.0 - 10.0)), new DoubleWritable((500 - 100.0) / (500.0 - 100.0)))); @@ -246,17 +246,17 @@ public class DataFramesTests extends BaseSparkTest { double s3 = new StandardDeviation().evaluate(new double[] {100, 200, 300, 400, 500}); List> expSeq1Std = new ArrayList<>(); - expSeq1Std.add(Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2), + expSeq1Std.add(Arrays.asList(new DoubleWritable((1 - m1) / s1), new DoubleWritable((10 - m2) / s2), new DoubleWritable((100 - m3) / s3))); - expSeq1Std.add(Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2), + expSeq1Std.add(Arrays.asList(new DoubleWritable((2 - m1) / s1), new DoubleWritable((20 - m2) / s2), new DoubleWritable((200 - m3) / s3))); - expSeq1Std.add(Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2), + expSeq1Std.add(Arrays.asList(new DoubleWritable((3 - m1) / s1), new DoubleWritable((30 - m2) / s2), new DoubleWritable((300 - m3) / s3))); List> expSeq2Std = new ArrayList<>(); - expSeq2Std.add(Arrays.asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2), + expSeq2Std.add(Arrays.asList(new DoubleWritable((4 - m1) / s1), new DoubleWritable((40 - m2) / s2), new DoubleWritable((400 - m3) / s3))); - expSeq2Std.add(Arrays.asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2), + expSeq2Std.add(Arrays.asList(new DoubleWritable((5 - m1) / s1), new DoubleWritable((50 - m2) / s2), new DoubleWritable((500 - m3) / s3))); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/ExecutionTest.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/ExecutionTest.java index c863af460..8da6f146b 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/ExecutionTest.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/ExecutionTest.java @@ -57,9 +57,9 @@ public class ExecutionTest extends BaseSparkTest { .doubleMathOp("col2", MathOp.Add, 10.0).build(); List> inputData = new ArrayList<>(); - inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); - inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); - inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); + inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); + inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); + inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); JavaRDD> rdd = sc.parallelize(inputData); @@ -73,9 +73,9 @@ public class ExecutionTest extends BaseSparkTest { }); List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); - expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); - expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); + expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); + expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); + expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); assertEquals(expected, out); } @@ -91,12 +91,12 @@ public class ExecutionTest extends BaseSparkTest { List>> inputSequences = new ArrayList<>(); List> seq1 = new ArrayList<>(); - seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); - seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); - seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); + seq1.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); + seq1.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); + seq1.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); List> seq2 = new ArrayList<>(); - seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); - seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); + seq2.add(Arrays.asList(new IntWritable(3), new Text("state0"), new DoubleWritable(3.1))); + seq2.add(Arrays.asList(new IntWritable(4), new Text("state1"), new DoubleWritable(4.1))); inputSequences.add(seq1); inputSequences.add(seq2); @@ -115,12 +115,12 @@ public class ExecutionTest extends BaseSparkTest { List>> expectedSequence = new ArrayList<>(); List> seq1e = new ArrayList<>(); - seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); - seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); - seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); + seq1e.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); + seq1e.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); + seq1e.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); List> seq2e = new ArrayList<>(); - seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); - seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); + seq2e.add(Arrays.asList(new IntWritable(3), new IntWritable(0), new DoubleWritable(13.1))); + seq2e.add(Arrays.asList(new IntWritable(4), new IntWritable(1), new DoubleWritable(14.1))); expectedSequence.add(seq1e); expectedSequence.add(seq2e); @@ -133,8 +133,8 @@ public class ExecutionTest extends BaseSparkTest { public void testReductionGlobal() { List> in = Arrays.asList( - Arrays.asList(new Text("first"), new DoubleWritable(3.0)), - Arrays.asList(new Text("second"), new DoubleWritable(5.0)) + Arrays.asList(new Text("first"), new DoubleWritable(3.0)), + Arrays.asList(new Text("second"), new DoubleWritable(5.0)) ); JavaRDD> inData = sc.parallelize(in); @@ -154,7 +154,7 @@ public class ExecutionTest extends BaseSparkTest { List> out = outRdd.collect(); - List> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0))); + List> expOut = Collections.singletonList(Arrays.asList(new Text("first"), new DoubleWritable(4.0))); assertEquals(expOut, out); } @@ -163,10 +163,10 @@ public class ExecutionTest extends BaseSparkTest { public void testReductionByKey(){ List> in = Arrays.asList( - Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), - Arrays.asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), - Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), - Arrays.asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)) + Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(3.0)), + Arrays.asList(new IntWritable(0), new Text("second"), new DoubleWritable(5.0)), + Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(30.0)), + Arrays.asList(new IntWritable(1), new Text("s"), new DoubleWritable(50.0)) ); JavaRDD> inData = sc.parallelize(in); @@ -189,8 +189,8 @@ public class ExecutionTest extends BaseSparkTest { List> out = outRdd.collect(); List> expOut = Arrays.asList( - Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), - Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); + Arrays.asList(new IntWritable(0), new Text("first"), new DoubleWritable(4.0)), + Arrays.asList(new IntWritable(1), new Text("f"), new DoubleWritable(40.0))); out = new ArrayList<>(out); Collections.sort( @@ -215,15 +215,15 @@ public class ExecutionTest extends BaseSparkTest { .addColumnDouble("col2").build(); List> inputData = new ArrayList<>(); - inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); - inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); - inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); - inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); - inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); - inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); - inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); - inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); - inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); + inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); + inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); + inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); + inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); + inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); + inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); + inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); + inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); + inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); JavaRDD> rdd = sc.parallelize(inputData); @@ -254,9 +254,9 @@ public class ExecutionTest extends BaseSparkTest { .outputSchema(finalSchema).build() ).build(); List> inputData = new ArrayList<>(); - inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); - inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); - inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); + inputData.add(Arrays.asList(new IntWritable(0), new Text("state2"), new DoubleWritable(0.1))); + inputData.add(Arrays.asList(new IntWritable(1), new Text("state1"), new DoubleWritable(1.1))); + inputData.add(Arrays.asList(new IntWritable(2), new Text("state0"), new DoubleWritable(2.1))); JavaRDD> rdd = sc.parallelize(inputData); @@ -270,9 +270,9 @@ public class ExecutionTest extends BaseSparkTest { }); List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); - expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); - expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); + expected.add(Arrays.asList(new IntWritable(0), new IntWritable(2), new DoubleWritable(10.1))); + expected.add(Arrays.asList(new IntWritable(1), new IntWritable(1), new DoubleWritable(11.1))); + expected.add(Arrays.asList(new IntWritable(2), new IntWritable(0), new DoubleWritable(12.1))); assertEquals(expected, out); } @@ -299,9 +299,9 @@ public class ExecutionTest extends BaseSparkTest { INDArray twos = ones.add(ones); List> inputData = new ArrayList<>(); - inputData.add(Arrays.asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); - inputData.add(Arrays.asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones))); - inputData.add(Arrays.asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones))); + inputData.add(Arrays.asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); + inputData.add(Arrays.asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones))); + inputData.add(Arrays.asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones))); JavaRDD> rdd = sc.parallelize(inputData); @@ -315,9 +315,9 @@ public class ExecutionTest extends BaseSparkTest { }); List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); - expected.add(Arrays.asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones))); - expected.add(Arrays.asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos))); + expected.add(Arrays.asList(new IntWritable(0), new NDArrayWritable(zeros), new NDArrayWritable(zeros), new NDArrayWritable(zeros))); + expected.add(Arrays.asList(new IntWritable(1), new NDArrayWritable(zeros), new NDArrayWritable(ones), new NDArrayWritable(ones))); + expected.add(Arrays.asList(new IntWritable(2), new NDArrayWritable(ones), new NDArrayWritable(ones), new NDArrayWritable(twos))); } @Test @@ -329,14 +329,14 @@ public class ExecutionTest extends BaseSparkTest { .build(); List> in = Arrays.asList( - Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), - Arrays.asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")), - Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")), - Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")), - Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); + Arrays.asList(new Text("a"), new DoubleWritable(3.14159), new Text("8e-4")), + Arrays.asList(new Text("a2"), new DoubleWritable(3.14159), new Text("7e-4")), + Arrays.asList(new Text("b"), new DoubleWritable(2.71828), new Text("7e2")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("6e8")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.0")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.1")), + Arrays.asList(new Text("c"), new DoubleWritable(1.61803), new Text("2.2")), + Arrays.asList(new Text("c"), new DoubleWritable(-2), new Text("non numerical"))); //Test Benfords law use case: TransformProcess tp = new TransformProcess.Builder(s) @@ -354,7 +354,7 @@ public class ExecutionTest extends BaseSparkTest { assertEquals(1, out.size()); List l = out.get(0); - List exp = Arrays.asList( + List exp = Arrays.asList( new IntWritable(0), //0 new IntWritable(0), //1 new IntWritable(3), //2 diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java index 4fc4f3323..8ed68b55e 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java @@ -59,13 +59,13 @@ public class TestAnalysis extends BaseSparkTest { .addColumnNDArray("ndarray", new long[] {1, 10}).build(); List> data = new ArrayList<>(); - data.add(Arrays.asList((Writable) new IntWritable(0), new DoubleWritable(1.0), new LongWritable(1000), + data.add(Arrays.asList(new IntWritable(0), new DoubleWritable(1.0), new LongWritable(1000), new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 100.0)))); - data.add(Arrays.asList((Writable) new IntWritable(5), new DoubleWritable(0.0), new LongWritable(2000), + data.add(Arrays.asList(new IntWritable(5), new DoubleWritable(0.0), new LongWritable(2000), new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 200.0)))); - data.add(Arrays.asList((Writable) new IntWritable(3), new DoubleWritable(10.0), new LongWritable(3000), + data.add(Arrays.asList(new IntWritable(3), new DoubleWritable(10.0), new LongWritable(3000), new Text("A"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 300.0)))); - data.add(Arrays.asList((Writable) new IntWritable(-1), new DoubleWritable(-1.0), new LongWritable(20000), + data.add(Arrays.asList(new IntWritable(-1), new DoubleWritable(-1.0), new LongWritable(20000), new Text("B"), new NDArrayWritable(Nd4j.valueArrayOf(1, 10, 400.0)))); JavaRDD> rdd = sc.parallelize(data); @@ -253,21 +253,21 @@ public class TestAnalysis extends BaseSparkTest { public void testSampleMostFrequent() { List> toParallelize = new ArrayList<>(); - toParallelize.add(Arrays.asList(new Text("a"), new Text("MostCommon"))); - toParallelize.add(Arrays.asList(new Text("b"), new Text("SecondMostCommon"))); - toParallelize.add(Arrays.asList(new Text("c"), new Text("SecondMostCommon"))); - toParallelize.add(Arrays.asList(new Text("d"), new Text("0"))); - toParallelize.add(Arrays.asList(new Text("e"), new Text("MostCommon"))); - toParallelize.add(Arrays.asList(new Text("f"), new Text("ThirdMostCommon"))); - toParallelize.add(Arrays.asList(new Text("c"), new Text("MostCommon"))); - toParallelize.add(Arrays.asList(new Text("h"), new Text("1"))); - toParallelize.add(Arrays.asList(new Text("i"), new Text("SecondMostCommon"))); - toParallelize.add(Arrays.asList(new Text("j"), new Text("2"))); - toParallelize.add(Arrays.asList(new Text("k"), new Text("ThirdMostCommon"))); - toParallelize.add(Arrays.asList(new Text("l"), new Text("MostCommon"))); - toParallelize.add(Arrays.asList(new Text("m"), new Text("3"))); - toParallelize.add(Arrays.asList(new Text("n"), new Text("4"))); - toParallelize.add(Arrays.asList(new Text("o"), new Text("5"))); + toParallelize.add(Arrays.asList(new Text("a"), new Text("MostCommon"))); + toParallelize.add(Arrays.asList(new Text("b"), new Text("SecondMostCommon"))); + toParallelize.add(Arrays.asList(new Text("c"), new Text("SecondMostCommon"))); + toParallelize.add(Arrays.asList(new Text("d"), new Text("0"))); + toParallelize.add(Arrays.asList(new Text("e"), new Text("MostCommon"))); + toParallelize.add(Arrays.asList(new Text("f"), new Text("ThirdMostCommon"))); + toParallelize.add(Arrays.asList(new Text("c"), new Text("MostCommon"))); + toParallelize.add(Arrays.asList(new Text("h"), new Text("1"))); + toParallelize.add(Arrays.asList(new Text("i"), new Text("SecondMostCommon"))); + toParallelize.add(Arrays.asList(new Text("j"), new Text("2"))); + toParallelize.add(Arrays.asList(new Text("k"), new Text("ThirdMostCommon"))); + toParallelize.add(Arrays.asList(new Text("l"), new Text("MostCommon"))); + toParallelize.add(Arrays.asList(new Text("m"), new Text("3"))); + toParallelize.add(Arrays.asList(new Text("n"), new Text("4"))); + toParallelize.add(Arrays.asList(new Text("o"), new Text("5"))); JavaRDD> rdd = sc.parallelize(toParallelize); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/join/TestJoin.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/join/TestJoin.java index 29da7a0a4..853800a03 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/join/TestJoin.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/join/TestJoin.java @@ -45,27 +45,27 @@ public class TestJoin extends BaseSparkTest { .addColumnDouble("amount").build(); List> infoList = new ArrayList<>(); - infoList.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"))); - infoList.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"))); - infoList.add(Arrays.asList(new LongWritable(50000), new Text("Customer50000"))); + infoList.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"))); + infoList.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"))); + infoList.add(Arrays.asList(new LongWritable(50000), new Text("Customer50000"))); List> purchaseList = new ArrayList<>(); - purchaseList.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), + purchaseList.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), new DoubleWritable(10.00))); - purchaseList.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), + purchaseList.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), new DoubleWritable(20.00))); - purchaseList.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), + purchaseList.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), new DoubleWritable(30.00))); Join join = new Join.Builder(Join.JoinType.RightOuter).setJoinColumns("customerID") .setSchemas(customerInfoSchema, purchasesSchema).build(); List> expected = new ArrayList<>(); - expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), + expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), new LongWritable(1000000), new DoubleWritable(10.00))); - expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), + expected.add(Arrays.asList(new LongWritable(12345), new Text("Customer12345"), new LongWritable(1000001), new DoubleWritable(20.00))); - expected.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"), + expected.add(Arrays.asList(new LongWritable(98765), new Text("Customer98765"), new LongWritable(1000002), new DoubleWritable(30.00))); @@ -99,11 +99,11 @@ public class TestJoin extends BaseSparkTest { .setSchemas(purchasesSchema, customerInfoSchema).build(); List> expectedManyToOne = new ArrayList<>(); - expectedManyToOne.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), + expectedManyToOne.add(Arrays.asList(new LongWritable(1000000), new LongWritable(12345), new DoubleWritable(10.00), new Text("Customer12345"))); - expectedManyToOne.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), + expectedManyToOne.add(Arrays.asList(new LongWritable(1000001), new LongWritable(12345), new DoubleWritable(20.00), new Text("Customer12345"))); - expectedManyToOne.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), + expectedManyToOne.add(Arrays.asList(new LongWritable(1000002), new LongWritable(98765), new DoubleWritable(30.00), new Text("Customer98765"))); JavaRDD> joined2 = SparkTransformExecutor.executeJoin(join2, purchases, info); @@ -137,45 +137,45 @@ public class TestJoin extends BaseSparkTest { .addColumnCategorical("otherCategory", Arrays.asList("cat0", "cat1", "cat2")).build(); List> first = new ArrayList<>(); - first.add(Arrays.asList(new LongWritable(0), new Text("cat0"))); - first.add(Arrays.asList(new LongWritable(1), new Text("cat0"))); - first.add(Arrays.asList(new LongWritable(2), new Text("cat1"))); + first.add(Arrays.asList(new LongWritable(0), new Text("cat0"))); + first.add(Arrays.asList(new LongWritable(1), new Text("cat0"))); + first.add(Arrays.asList(new LongWritable(2), new Text("cat1"))); List> second = new ArrayList<>(); - second.add(Arrays.asList(new LongWritable(100), new Text("cat0"))); - second.add(Arrays.asList(new LongWritable(101), new Text("cat0"))); - second.add(Arrays.asList(new LongWritable(102), new Text("cat2"))); + second.add(Arrays.asList(new LongWritable(100), new Text("cat0"))); + second.add(Arrays.asList(new LongWritable(101), new Text("cat0"))); + second.add(Arrays.asList(new LongWritable(102), new Text("cat2"))); List> expOuterJoin = new ArrayList<>(); - expOuterJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); - expOuterJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); - expOuterJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); - expOuterJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); - expOuterJoin.add(Arrays.asList(new LongWritable(2), new Text("cat1"), new NullWritable())); - expOuterJoin.add(Arrays.asList(new NullWritable(), new Text("cat2"), new LongWritable(102))); + expOuterJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); + expOuterJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); + expOuterJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); + expOuterJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); + expOuterJoin.add(Arrays.asList(new LongWritable(2), new Text("cat1"), new NullWritable())); + expOuterJoin.add(Arrays.asList(new NullWritable(), new Text("cat2"), new LongWritable(102))); List> expLeftJoin = new ArrayList<>(); - expLeftJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); - expLeftJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); - expLeftJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); - expLeftJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); - expLeftJoin.add(Arrays.asList(new LongWritable(2), new Text("cat1"), new NullWritable())); + expLeftJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); + expLeftJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); + expLeftJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); + expLeftJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); + expLeftJoin.add(Arrays.asList(new LongWritable(2), new Text("cat1"), new NullWritable())); List> expRightJoin = new ArrayList<>(); - expRightJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); - expRightJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); - expRightJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); - expRightJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); - expRightJoin.add(Arrays.asList(new NullWritable(), new Text("cat2"), new LongWritable(102))); + expRightJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); + expRightJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); + expRightJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); + expRightJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); + expRightJoin.add(Arrays.asList(new NullWritable(), new Text("cat2"), new LongWritable(102))); List> expInnerJoin = new ArrayList<>(); - expInnerJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); - expInnerJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); - expInnerJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); - expInnerJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); + expInnerJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(100))); + expInnerJoin.add(Arrays.asList(new LongWritable(0), new Text("cat0"), new LongWritable(101))); + expInnerJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(100))); + expInnerJoin.add(Arrays.asList(new LongWritable(1), new Text("cat0"), new LongWritable(101))); JavaRDD> firstRDD = sc.parallelize(first); JavaRDD> secondRDD = sc.parallelize(second); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java index 6ff564418..daf2794f2 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java @@ -44,10 +44,10 @@ public class TestCalculateSortedRank extends BaseSparkTest { public void testCalculateSortedRank() { List> data = new ArrayList<>(); - data.add(Arrays.asList((Writable) new Text("0"), new DoubleWritable(0.0))); - data.add(Arrays.asList((Writable) new Text("3"), new DoubleWritable(0.3))); - data.add(Arrays.asList((Writable) new Text("2"), new DoubleWritable(0.2))); - data.add(Arrays.asList((Writable) new Text("1"), new DoubleWritable(0.1))); + data.add(Arrays.asList(new Text("0"), new DoubleWritable(0.0))); + data.add(Arrays.asList(new Text("3"), new DoubleWritable(0.3))); + data.add(Arrays.asList(new Text("2"), new DoubleWritable(0.2))); + data.add(Arrays.asList(new Text("1"), new DoubleWritable(0.1))); JavaRDD> rdd = sc.parallelize(data); diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java index 7faca7235..ad545172c 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java @@ -46,12 +46,12 @@ public class TestConvertToSequence extends BaseSparkTest { Schema s = new Schema.Builder().addColumnsString("key1", "key2").addColumnLong("time").build(); List> allExamples = - Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), - Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), - Arrays.asList(new Text("k1a"), new Text("k2a"), + Arrays.asList(Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10)), + Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10)), + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), - Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), - Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); + Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0))); TransformProcess tp = new TransformProcess.Builder(s) .convertToSequence(Arrays.asList("key1", "key2"), new NumericalColumnComparator("time")) @@ -73,13 +73,13 @@ public class TestConvertToSequence extends BaseSparkTest { } List> expSeq0 = Arrays.asList( - Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), - Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), - Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(-10)), + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(0)), + Arrays.asList(new Text("k1a"), new Text("k2a"), new LongWritable(10))); List> expSeq1 = Arrays.asList( - Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), - Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); + Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(5)), + Arrays.asList(new Text("k1b"), new Text("k2b"), new LongWritable(10))); assertEquals(expSeq0, seq0); assertEquals(expSeq1, seq1); @@ -94,9 +94,9 @@ public class TestConvertToSequence extends BaseSparkTest { .build(); List> allExamples = Arrays.asList( - Arrays.asList(new Text("a"), new LongWritable(0)), - Arrays.asList(new Text("b"), new LongWritable(1)), - Arrays.asList(new Text("c"), new LongWritable(2))); + Arrays.asList(new Text("a"), new LongWritable(0)), + Arrays.asList(new Text("b"), new LongWritable(1)), + Arrays.asList(new Text("c"), new LongWritable(2))); TransformProcess tp = new TransformProcess.Builder(s) .convertToSequence() diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/util/TestSparkUtil.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/util/TestSparkUtil.java index a2dd04ce0..1ed67934b 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/util/TestSparkUtil.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/util/TestSparkUtil.java @@ -46,8 +46,8 @@ public class TestSparkUtil extends BaseSparkTest { return; } List> l = new ArrayList<>(); - l.add(Arrays.asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1))); - l.add(Arrays.asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2))); + l.add(Arrays.asList(new Text("abc"), new DoubleWritable(2.0), new IntWritable(-1))); + l.add(Arrays.asList(new Text("def"), new DoubleWritable(4.0), new IntWritable(-2))); File f = File.createTempFile("testSparkUtil", "txt"); f.deleteOnExit(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java index 0bf569f6d..2ea351b38 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java @@ -27,7 +27,7 @@ import org.nd4j.common.primitives.Pair; import java.util.*; public class Operands { - private Map map = new LinkedHashMap<>(); + private final Map map = new LinkedHashMap<>(); /** * This method allows to pass array to the node identified by its name diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index c7920422f..8a642e7a2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -734,7 +734,7 @@ public abstract class DifferentialFunction { * @return The data types of the outputs */ public List calculateOutputDataTypes(List dataTypes){ - throw new UnsupportedOperationException("Op type of " + getClass().getName() + " and name " + this.toString() + " did not override calculateOutputDataTypes()! This function has not been implemented for " + getClass().getName()); + throw new UnsupportedOperationException("Op type of " + getClass().getName() + " and name " + this + " did not override calculateOutputDataTypes()! This function has not been implemented for " + getClass().getName()); } @@ -746,9 +746,9 @@ public abstract class DifferentialFunction { DifferentialFunction that = (DifferentialFunction) o; if (inPlace != that.inPlace) return false; - if (scalarValue != null ? !scalarValue.equals(that.scalarValue) : that.scalarValue != null) return false; + if (!Objects.equals(scalarValue, that.scalarValue)) return false; if (!Arrays.equals(dimensions, that.dimensions)) return false; - return ownName != null ? ownName.equals(that.ownName) : that.ownName == null; + return Objects.equals(ownName, that.ownName); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/ListenerResponse.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/ListenerResponse.java index 7ab5a262a..81dcee826 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/ListenerResponse.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/ListenerResponse.java @@ -21,5 +21,5 @@ package org.nd4j.autodiff.listeners; public enum ListenerResponse { - CONTINUE, STOP; + CONTINUE, STOP } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java index 43bf09b4d..4461f290d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/Loss.java @@ -96,7 +96,7 @@ public class Loss { public static Loss sum(List losses) { if (losses.isEmpty()) - return new Loss(Collections.emptyList(), new double[0]); + return new Loss(Collections.emptyList(), new double[0]); double[] lossValues = new double[losses.get(0).losses.length]; List lossNames = new ArrayList<>(losses.get(0).lossNames); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java index 7bc8f044e..8c4b12823 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/CheckpointListener.java @@ -44,27 +44,27 @@ import java.util.concurrent.TimeUnit; @Slf4j public class CheckpointListener extends BaseListener implements Serializable { - private enum KeepMode {ALL, LAST, LAST_AND_EVERY}; + private enum KeepMode {ALL, LAST, LAST_AND_EVERY} - private File rootDir; - private String fileNamePrefix; - private KeepMode keepMode; - private int keepLast; - private int keepEvery; - private boolean logSaving; - private boolean deleteExisting; - private boolean saveUpdaterState; + private final File rootDir; + private final String fileNamePrefix; + private final KeepMode keepMode; + private final int keepLast; + private final int keepEvery; + private final boolean logSaving; + private final boolean deleteExisting; + private final boolean saveUpdaterState; - private Integer saveEveryNEpochs; - private Integer saveEveryNIterations; - private boolean saveEveryNIterSinceLast; - private Long saveEveryAmount; - private TimeUnit saveEveryUnit; + private final Integer saveEveryNEpochs; + private final Integer saveEveryNIterations; + private final boolean saveEveryNIterSinceLast; + private final Long saveEveryAmount; + private final TimeUnit saveEveryUnit; private Long saveEveryMs; - private boolean saveEverySinceLast; + private final boolean saveEverySinceLast; private int lastCheckpointNum = -1; - private File checkpointRecordFile; + private final File checkpointRecordFile; private Checkpoint lastCheckpoint; private long startTime = -1; @@ -168,7 +168,6 @@ public class CheckpointListener extends BaseListener implements Serializable { long lastSaveTime = (lastCheckpoint != null ? lastCheckpoint.getTimestamp() : startTime); if((time - lastSaveTime) >= saveEveryMs){ saveCheckpoint(sd, at); - return; } } else { //Save periodically, regardless of when last model was saved @@ -176,7 +175,6 @@ public class CheckpointListener extends BaseListener implements Serializable { if((time - lastSave) > saveEveryMs){ saveCheckpoint(sd, at); lastSaveEveryMsNoSinceLast = time; - return; } } } @@ -215,7 +213,6 @@ public class CheckpointListener extends BaseListener implements Serializable { //Finally: determine if we should delete some old models... if(keepMode == null || keepMode == KeepMode.ALL){ - return; } else if(keepMode == KeepMode.LAST){ List checkpoints = availableCheckpoints(); Iterator iter = checkpoints.iterator(); @@ -423,7 +420,7 @@ public class CheckpointListener extends BaseListener implements Serializable { public static class Builder { - private File rootDir; + private final File rootDir; private String fileNamePrefix = "SameDiff"; private KeepMode keepMode; private int keepLast; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java index a862a2fe9..ab4423020 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java @@ -186,7 +186,7 @@ public class ExecDebuggingListener extends BaseListener { sb.append("Nd4j.exec(op);\n"); } - System.out.print(sb.toString()); + System.out.print(sb); } private static String createString(INDArray arr){ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java index 4e03b6efd..9569122fb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/impl/HistoryListener.java @@ -42,14 +42,14 @@ public class HistoryListener extends BaseEvaluationListener { @Setter private ListenerEvaluations evaluations; - private List trainingHistory = new ArrayList<>(); - private List validationHistory = new ArrayList<>(); + private final List trainingHistory = new ArrayList<>(); + private final List validationHistory = new ArrayList<>(); private LossCurve loss = null; private long startTime; private long endTime; - private List validationTimes = new ArrayList<>(); + private final List validationTimes = new ArrayList<>(); private long validationStartTime; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java index f5bc5c8b6..e9ef1bcf2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/impl/UIListener.java @@ -83,19 +83,19 @@ public class UIListener extends BaseListener { public enum HistogramType {PARAMETERS, PARAMETER_GRADIENTS, PARAMETER_UPDATES, ACTIVATIONS, ACTIVATION_GRADIENTS} - private FileMode fileMode; - private File logFile; - private int lossPlotFreq; - private int performanceStatsFrequency; - private int updateRatioFrequency; - private UpdateRatio updateRatioType; - private int histogramFrequency; - private HistogramType[] histogramTypes; - private int opProfileFrequency; - private Map, List> trainEvalMetrics; - private int trainEvalFrequency; - private TestEvaluation testEvaluation; - private int learningRateFrequency; + private final FileMode fileMode; + private final File logFile; + private final int lossPlotFreq; + private final int performanceStatsFrequency; + private final int updateRatioFrequency; + private final UpdateRatio updateRatioType; + private final int histogramFrequency; + private final HistogramType[] histogramTypes; + private final int opProfileFrequency; + private final Map, List> trainEvalMetrics; + private final int trainEvalFrequency; + private final TestEvaluation testEvaluation; + private final int learningRateFrequency; private MultiDataSet currentIterDataSet; @@ -535,7 +535,7 @@ public class UIListener extends BaseListener { public static class Builder { private FileMode fileMode = FileMode.CREATE_OR_APPEND; - private File logFile; + private final File logFile; private int lossPlotFreq = 1; private int performanceStatsFrequency = -1; //Disabled by default diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java index c2d20756f..0b702e259 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/profiler/ProfilingListener.java @@ -64,8 +64,8 @@ public class ProfilingListener extends BaseListener { private boolean logActive = false; private long opStartNano; - private Writer writer; - private ObjectMapper json; + private final Writer writer; + private final ObjectMapper json; private final Thread fileWritingThread; private final BlockingQueue writeQueue; @@ -209,7 +209,7 @@ public class ProfilingListener extends BaseListener { .pid((int)pid) .tid(tid) .ph(Phase.X) - .args(Collections.singletonMap("name", op.getName())) + .args(Collections.singletonMap("name", op.getName())) .build(); writeQueue.add(event); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java index 68520efe7..3d81b729e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/profiler/comparison/ProfileAnalyzer.java @@ -551,7 +551,7 @@ public class ProfileAnalyzer { } - private static Map TF_PROFILE_ALIASES = new HashMap<>(); + private static final Map TF_PROFILE_ALIASES = new HashMap<>(); static { TF_PROFILE_ALIASES.put("_MklSoftmax", "Softmax"); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java index 02295f330..2fe2cee27 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/EvaluationRecord.java @@ -39,7 +39,7 @@ import org.nd4j.evaluation.IMetric; public class EvaluationRecord { private Map> evaluations; - private Map, IEvaluation> classEvaluations = new HashMap<>(); + private final Map, IEvaluation> classEvaluations = new HashMap<>(); private boolean isEmpty = true; public EvaluationRecord(Map> evaluations) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java index dd642eb0e..9fe7f4814 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/History.java @@ -33,13 +33,13 @@ import org.nd4j.evaluation.IMetric; @Getter public class History { - private List trainingHistory; - private List validationHistory; + private final List trainingHistory; + private final List validationHistory; - private LossCurve lossCurve; + private final LossCurve lossCurve; - private long trainingTimeMillis; - private List validationTimesMillis; + private final long trainingTimeMillis; + private final List validationTimesMillis; public History(List training, List validation, LossCurve loss, long trainingTimeMillis, List validationTimesMillis){ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java index 8c68dbafd..b0f3a78e7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/records/LossCurve.java @@ -34,7 +34,7 @@ import org.nd4j.linalg.factory.Nd4j; public class LossCurve { @Getter - private List lossNames; + private final List lossNames; @Getter private INDArray lossValues; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SDIndex.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SDIndex.java index ea08bf132..5c5f50db2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SDIndex.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SDIndex.java @@ -36,7 +36,7 @@ public class SDIndex { private boolean pointKeepDim; private Long intervalBegin = null; private Long intervalEnd = null; - private Long intervalStrides = 1l; + private Long intervalStrides = 1L; public SDIndex(){} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index 577569ed1..559cf124f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -138,7 +138,7 @@ public class SameDiff extends SDBaseOps { // private DifferentialFunctionFactory functionFactory; // counter for auto-naming variables - private int variableId = 0; + private final int variableId = 0; //////////////////////////////////////// @@ -244,12 +244,12 @@ public class SameDiff extends SDBaseOps { return linalg; } - private Map sameDiffFunctionInstances; + private final Map sameDiffFunctionInstances; - private Table fieldVariableResolutionMapping; + private final Table fieldVariableResolutionMapping; // flag, shows if graph was already registered with libnd4j - private transient AtomicBoolean wasRegistered = new AtomicBoolean(false); + private final transient AtomicBoolean wasRegistered = new AtomicBoolean(false); //debug mode variables @@ -257,11 +257,11 @@ public class SameDiff extends SDBaseOps { private boolean debugMode; @Getter - private Stack argumentInterceptors = new Stack<>(); + private final Stack argumentInterceptors = new Stack<>(); @Getter - private Set pausedArgumentInterceptors = new HashSet<>(); + private final Set pausedArgumentInterceptors = new HashSet<>(); - private Set blockNames = new HashSet<>(); + private final Set blockNames = new HashSet<>(); @Getter @Setter @@ -2159,7 +2159,7 @@ public class SameDiff extends SDBaseOps { MultiDataSet ds = iterator.next(); Map placeholderMap = toPlaceholderMap(ds); - Map m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); + Map m = directExecHelper(placeholderMap, at, ds, Collections.emptyList(), activeListeners, requiredVarsArr); for (Map.Entry> e : variableEvals.entrySet()) { INDArray prediction = m.get(e.getKey()); @@ -5802,8 +5802,10 @@ public class SameDiff extends SDBaseOps { // ensure that there are no variables that look like they are outputs of this op boolean varWithName = false; for (String varName : variables.keySet()) - if (varName.startsWith(name + ":") || varName.equals(name)) + if (varName.startsWith(name + ":") || varName.equals(name)) { varWithName = true; + break; + } if (!ops.containsKey(name) && !varWithName) break; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java index f06cfec9c..442d47f27 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/TrainingConfig.java @@ -73,7 +73,7 @@ public class TrainingConfig { */ public TrainingConfig(IUpdater updater, List regularization, String dataSetFeatureMapping, String dataSetLabelMapping) { this(updater, regularization, true, Collections.singletonList(dataSetFeatureMapping), Collections.singletonList(dataSetLabelMapping), - Collections.emptyList(), Collections.emptyList(), null); + Collections.emptyList(), Collections.emptyList(), null); } /** @@ -154,11 +154,11 @@ public class TrainingConfig { private boolean skipValidation = false; private boolean markLabelsUnused = false; - private Map> trainEvaluations = new HashMap<>(); - private Map trainEvaluationLabels = new HashMap<>(); + private final Map> trainEvaluations = new HashMap<>(); + private final Map trainEvaluationLabels = new HashMap<>(); - private Map> validationEvaluations = new HashMap<>(); - private Map validationEvaluationLabels = new HashMap<>(); + private final Map> validationEvaluations = new HashMap<>(); + private final Map validationEvaluationLabels = new HashMap<>(); /** * Set the updater (such as {@link org.nd4j.linalg.learning.config.Adam}, {@link org.nd4j.linalg.learning.config.Nesterovs} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index 0b0079cb1..d00efcba7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -167,10 +167,7 @@ public abstract class AbstractSession { For example, we might have a label placeholder, and we're doing inference not training */ for (String s : phNames) { - boolean required = false; - if (variables.contains(s)) { - required = true; - } + boolean required = variables.contains(s); if (!required) { Variable v = sameDiff.getVariables().get(s); if (v.getInputsForOp() != null) { @@ -973,8 +970,6 @@ public abstract class AbstractSession { */ protected enum ExecType {OP, VARIABLE, CONSTANT, PLACEHOLDER, SWITCH_L, SWITCH_R, EXEC_START, CONTROL_DEP} - ; - /** * ExecStep represents a single execution step, for a single op (or variable/constant etc) at a specific frame/iteration */ @@ -1022,5 +1017,4 @@ public abstract class AbstractSession { } } - ; } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/FrameIter.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/FrameIter.java index e1a215a36..6cbf7d4bd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/FrameIter.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/FrameIter.java @@ -32,7 +32,7 @@ public class FrameIter { @Override public String toString() { - return "(\"" + frame + "\"," + iteration + (parentFrame == null ? "" : ",parent=" + parentFrame.toString()) + ")"; + return "(\"" + frame + "\"," + iteration + (parentFrame == null ? "" : ",parent=" + parentFrame) + ")"; } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index e2df94563..8d3b414ab 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -74,7 +74,7 @@ public class InferenceSession extends AbstractSession arrayUseTracker = new IdentityDependencyTracker<>(); - private Map opContexts = new HashMap<>(); + private final Map opContexts = new HashMap<>(); public InferenceSession(@NonNull SameDiff sameDiff) { super(sameDiff); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.java index b07f4e094..e3624b70b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/memory/ArrayCacheMemoryMgr.java @@ -42,10 +42,10 @@ public class ArrayCacheMemoryMgr extends AbstractMemoryMgr { private final long totalMemBytes; private long currentCacheSize = 0; - private Map arrayStores = new HashMap<>(); + private final Map arrayStores = new HashMap<>(); - private LinkedHashSet lruCache = new LinkedHashSet<>(); - private Map lruCacheValues = new HashMap<>(); + private final LinkedHashSet lruCache = new LinkedHashSet<>(); + private final Map lruCacheValues = new HashMap<>(); /** * Create an ArrayCacheMemoryMgr with default settings as per {@link ArrayCacheMemoryMgr} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index 44d255345..2d3953ad8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -188,7 +188,7 @@ public class LegacyOpMapper { return FloorDivOp.class; case 23: return TruncateDivOp.class; - case 24:; + case 24: return And.class; case 25: return Or.class; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 8746ef281..0e8c272cd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -451,11 +451,11 @@ public class OpValidation { private static List nonMappedLibnd4jOps; private static Map,CustomOpDescriptor>> dedupedCustomOps; private static int countTotalLibnd4jOps; - private static Map gradCheckCoverageCountPerClass = new LinkedHashMap<>(); - private static Map fwdPassCoverageCountPerClass = new LinkedHashMap<>(); - private static Map singleOpTestCountPerClass = new LinkedHashMap<>(); - private static Map opsWithTFMappingTFImportCounts = new LinkedHashMap<>(); - private static Map tfMappedOpsImportTestCounts = new LinkedHashMap<>(); + private static final Map gradCheckCoverageCountPerClass = new LinkedHashMap<>(); + private static final Map fwdPassCoverageCountPerClass = new LinkedHashMap<>(); + private static final Map singleOpTestCountPerClass = new LinkedHashMap<>(); + private static final Map opsWithTFMappingTFImportCounts = new LinkedHashMap<>(); + private static final Map tfMappedOpsImportTestCounts = new LinkedHashMap<>(); private static void collectCoverageInformation(TestCase testCase) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java index f98b766a0..e755ea0e9 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/TestCase.java @@ -38,7 +38,7 @@ import java.util.*; @Accessors(fluent = true) @Getter public class TestCase { - public enum TestSerialization {BEFORE_EXEC, AFTER_EXEC, BOTH, NONE}; + public enum TestSerialization {BEFORE_EXEC, AFTER_EXEC, BOTH, NONE} public static final boolean GC_DEFAULT_PRINT = false; public static final boolean GC_DEFAULT_EXIT_FIRST_FAILURE = false; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java index f5afa8b21..534340af5 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/validation/listeners/NonInplaceValidationListener.java @@ -41,11 +41,11 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; public class NonInplaceValidationListener extends BaseListener { @Getter - private static AtomicInteger useCounter = new AtomicInteger(); + private static final AtomicInteger useCounter = new AtomicInteger(); @Getter - private static AtomicInteger passCounter = new AtomicInteger(); + private static final AtomicInteger passCounter = new AtomicInteger(); @Getter - private static AtomicInteger failCounter = new AtomicInteger(); + private static final AtomicInteger failCounter = new AtomicInteger(); protected INDArray[] opInputs; protected INDArray[] opInputsOrig; @@ -64,7 +64,6 @@ public class NonInplaceValidationListener extends BaseListener { Op o = (Op)op.getOp(); if(oc.getInputArray(0) == null){ //No input op - return; } else if(oc.getInputArray(1) == null){ opInputsOrig = new INDArray[]{oc.getInputArray(0)}; opInputs = new INDArray[]{oc.getInputArray(0).dup()}; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/context/Nd4jContext.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/context/Nd4jContext.java index 3f1599391..5c13da776 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/context/Nd4jContext.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/context/Nd4jContext.java @@ -30,8 +30,8 @@ import java.util.Properties; @Slf4j public class Nd4jContext implements Serializable { - private Properties conf; - private static Nd4jContext INSTANCE = new Nd4jContext(); + private final Properties conf; + private static final Nd4jContext INSTANCE = new Nd4jContext(); private Nd4jContext() { conf = new Properties(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java index ad8c82573..82e9f7172 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java @@ -28,5 +28,5 @@ public enum ImageResizeMethod { ResizeGaussian, ResizeLanczos3, ResizeLanczos5, - ResizeMitchellcubic; + ResizeMitchellcubic } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java index ec62b7bb3..31c388bb6 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/BaseEvaluation.java @@ -71,7 +71,7 @@ public abstract class BaseEvaluation implements IEvalu } catch (InvalidTypeIdException e) { if (e.getMessage().contains("Could not resolve type id")) { try { - return (T) attempFromLegacyFromJson(json, e); + return attempFromLegacyFromJson(json, e); } catch (Throwable t) { throw new RuntimeException("Cannot deserialize from JSON - JSON is invalid?", t); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/IMetric.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/IMetric.java index 18161d0a1..6372529d5 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/IMetric.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/IMetric.java @@ -25,10 +25,10 @@ public interface IMetric { /** * The {@link IEvaluation} class this metric is for */ - public Class getEvaluationClass(); + Class getEvaluationClass(); /** * Whether this metric should be minimized (aka whether lower values are better). */ - public boolean minimize(); + boolean minimize(); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java index 8e865da20..4f372ae2f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/ConfusionMatrix.java @@ -32,7 +32,7 @@ import java.util.concurrent.ConcurrentHashMap; public class ConfusionMatrix> implements Serializable { @Getter - private volatile Map> matrix; + private final Map> matrix; private List classes; /** diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java index 715e23d37..c3a66aa93 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java @@ -363,7 +363,7 @@ public class EvaluationBinary extends BaseEvaluation { ret += accuracy(i); } - ret /= (double) numLabels(); + ret /= numLabels(); return ret; } @@ -381,7 +381,7 @@ public class EvaluationBinary extends BaseEvaluation { ret += precision(i); } - ret /= (double) numLabels(); + ret /= numLabels(); return ret; } @@ -401,7 +401,7 @@ public class EvaluationBinary extends BaseEvaluation { ret += recall(i); } - ret /= (double) numLabels(); + ret /= numLabels(); return ret; } @@ -420,7 +420,7 @@ public class EvaluationBinary extends BaseEvaluation { ret += f1(i); } - ret /= (double) numLabels(); + ret /= numLabels(); return ret; } @@ -469,7 +469,7 @@ public class EvaluationBinary extends BaseEvaluation { ret += matthewsCorrelation(i); } - ret /= (double) numLabels(); + ret /= numLabels(); return ret; } @@ -496,7 +496,7 @@ public class EvaluationBinary extends BaseEvaluation { ret += gMeasure(i); } - ret /= (double) numLabels(); + ret /= numLabels(); return ret; } @@ -578,7 +578,7 @@ public class EvaluationBinary extends BaseEvaluation { ret += falseAlarmRate(i); } - ret /= (double) numLabels(); + ret /= numLabels(); return ret; } @@ -657,7 +657,7 @@ public class EvaluationBinary extends BaseEvaluation { String label = (labels == null ? String.valueOf(i) : labels.get(i)); - List args = Arrays.asList(label, acc, f1, precision, recall, totalCount, + List args = Arrays.asList(label, acc, f1, precision, recall, totalCount, truePositives(i), trueNegatives(i), falsePositives(i), falseNegatives(i)); if (rocBinary != null) { args = new ArrayList<>(args); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/EvaluationLambda.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/EvaluationLambda.java index f228354b6..7f5a05876 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/EvaluationLambda.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/EvaluationLambda.java @@ -26,6 +26,6 @@ import java.util.List; import org.nd4j.linalg.api.ndarray.INDArray; public interface EvaluationLambda { - public T eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, + T eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List recordMetaData); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java index ee3a6966a..71a3c7962 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/MergeLambda.java @@ -23,5 +23,5 @@ package org.nd4j.evaluation.custom; import java.util.List; public interface MergeLambda { - public List merge(List a, List b); + List merge(List a, List b); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/ResultLambda.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/ResultLambda.java index 722efa346..a9175d953 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/ResultLambda.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/custom/ResultLambda.java @@ -23,5 +23,5 @@ package org.nd4j.evaluation.custom; import java.util.List; public interface ResultLambda { - public double toResult(List data); + double toResult(List data); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java index d34bd5c4e..7ec1d3faa 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/regression/RegressionEvaluation.java @@ -58,10 +58,7 @@ public class RegressionEvaluation extends BaseEvaluation { */ @Override public boolean minimize(){ - if(this == R2 || this == PC){ - return false; - } - return true; + return this != R2 && this != PC; } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixDeserializer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixDeserializer.java index e01f0d51a..3dff297c1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixDeserializer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixDeserializer.java @@ -38,7 +38,7 @@ import java.util.Map; public class ConfusionMatrixDeserializer extends JsonDeserializer> { @Override public ConfusionMatrix deserialize(JsonParser jp, DeserializationContext ctxt) - throws IOException, JsonProcessingException { + throws IOException { JsonNode n = jp.getCodec().readTree(jp); //Get class names/labels diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java index 4234fed16..9c4d0f833 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ConfusionMatrixSerializer.java @@ -35,7 +35,7 @@ import java.util.Map; public class ConfusionMatrixSerializer extends JsonSerializer> { @Override public void serialize(ConfusionMatrix cm, JsonGenerator gen, SerializerProvider provider) - throws IOException, JsonProcessingException { + throws IOException { List classes = cm.getClasses(); Map> matrix = cm.getMatrix(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ROCArraySerializer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ROCArraySerializer.java index 6da654cae..d4c3526ed 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ROCArraySerializer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/serde/ROCArraySerializer.java @@ -33,7 +33,7 @@ public class ROCArraySerializer extends JsonSerializer { @Override public void serialize(ROC[] rocs, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) - throws IOException, JsonProcessingException { + throws IOException { jsonGenerator.writeStartArray(); for (ROC r : rocs) { jsonGenerator.writeStartObject(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatArray.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatArray.java index 2d740a969..23cafa242 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatArray.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatArray.java @@ -37,7 +37,7 @@ public final class FlatArray extends Table { public int shapeLength() { int o = __offset(4); return o != 0 ? __vector_len(o) : 0; } public ByteBuffer shapeAsByteBuffer() { return __vector_as_bytebuffer(4, 8); } public ByteBuffer shapeInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 4, 8); } - public byte buffer(int j) { int o = __offset(6); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; } + public byte buffer(int j) { int o = __offset(6); return o != 0 ? bb.get(__vector(o) + j) : 0; } public int bufferLength() { int o = __offset(6); return o != 0 ? __vector_len(o) : 0; } public ByteBuffer bufferAsByteBuffer() { return __vector_as_bytebuffer(6, 1); } public ByteBuffer bufferInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 6, 1); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatConfiguration.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatConfiguration.java index 8c341212c..28b8d0ff4 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatConfiguration.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatConfiguration.java @@ -37,7 +37,7 @@ public final class FlatConfiguration extends Table { public byte executionMode() { int o = __offset(6); return o != 0 ? bb.get(o + bb_pos) : 0; } public byte profilingMode() { int o = __offset(8); return o != 0 ? bb.get(o + bb_pos) : 0; } public byte outputMode() { int o = __offset(10); return o != 0 ? bb.get(o + bb_pos) : 0; } - public boolean timestats() { int o = __offset(12); return o != 0 ? 0!=bb.get(o + bb_pos) : false; } + public boolean timestats() { int o = __offset(12); return o != 0 && 0 != bb.get(o + bb_pos); } public long footprintForward() { int o = __offset(14); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } public long footprintBackward() { int o = __offset(16); return o != 0 ? bb.getLong(o + bb_pos) : 0L; } public byte direction() { int o = __offset(18); return o != 0 ? bb.get(o + bb_pos) : 0; } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatNode.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatNode.java index 212116196..ee45ab97d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatNode.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatNode.java @@ -61,7 +61,7 @@ public final class FlatNode extends Table { public int extraIntegerLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; } public ByteBuffer extraIntegerAsByteBuffer() { return __vector_as_bytebuffer(22, 8); } public ByteBuffer extraIntegerInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 22, 8); } - public boolean extraBools(int j) { int o = __offset(24); return o != 0 ? 0!=bb.get(__vector(o) + j * 1) : false; } + public boolean extraBools(int j) { int o = __offset(24); return o != 0 && 0 != bb.get(__vector(o) + j); } public int extraBoolsLength() { int o = __offset(24); return o != 0 ? __vector_len(o) : 0; } public ByteBuffer extraBoolsAsByteBuffer() { return __vector_as_bytebuffer(24, 1); } public ByteBuffer extraBoolsInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 24, 1); } @@ -79,7 +79,7 @@ public final class FlatNode extends Table { public String opName() { int o = __offset(36); return o != 0 ? __string(o + bb_pos) : null; } public ByteBuffer opNameAsByteBuffer() { return __vector_as_bytebuffer(36, 1); } public ByteBuffer opNameInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 36, 1); } - public byte outputTypes(int j) { int o = __offset(38); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; } + public byte outputTypes(int j) { int o = __offset(38); return o != 0 ? bb.get(__vector(o) + j) : 0; } public int outputTypesLength() { int o = __offset(38); return o != 0 ? __vector_len(o) : 0; } public ByteBuffer outputTypesAsByteBuffer() { return __vector_as_bytebuffer(38, 1); } public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); } @@ -91,7 +91,7 @@ public final class FlatNode extends Table { public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; } public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; } public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; } - public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j * 1) : 0; } + public byte extraTypes(int j) { int o = __offset(48); return o != 0 ? bb.get(__vector(o) + j) : 0; } public int extraTypesLength() { int o = __offset(48); return o != 0 ? __vector_len(o) : 0; } public ByteBuffer extraTypesAsByteBuffer() { return __vector_as_bytebuffer(48, 1); } public ByteBuffer extraTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 48, 1); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatProperties.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatProperties.java index 96ff452b3..3371e3a8a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatProperties.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/graph/FlatProperties.java @@ -51,7 +51,7 @@ public final class FlatProperties extends Table { public FlatArray a(int j) { return a(new FlatArray(), j); } public FlatArray a(FlatArray obj, int j) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(__vector(o) + j * 4), bb) : null; } public int aLength() { int o = __offset(12); return o != 0 ? __vector_len(o) : 0; } - public boolean b(int j) { int o = __offset(14); return o != 0 ? 0!=bb.get(__vector(o) + j * 1) : false; } + public boolean b(int j) { int o = __offset(14); return o != 0 && 0 != bb.get(__vector(o) + j); } public int bLength() { int o = __offset(14); return o != 0 ? __vector_len(o) : 0; } public ByteBuffer bAsByteBuffer() { return __vector_as_bytebuffer(14, 1); } public ByteBuffer bInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 14, 1); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java index 93e731576..3a1a0527c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/converters/DifferentialFunctionClassHolder.java @@ -41,16 +41,16 @@ import java.util.*; @Slf4j public class DifferentialFunctionClassHolder { - private Map nodeConverters = ImportClassMapping.getOpNameMapping(); - private Map tensorFlowNames = ImportClassMapping.getTFOpMappingFunctions(); - private Map onnxNames = ImportClassMapping.getOnnxOpMappingFunctions(); - private Map> customOpHashToClass = new HashMap<>(); - private Map>> customOpHashToClasses = new HashMap<>(); //Only contains ops with 1 hash to multiple classes - private List missingOps = new ArrayList<>(); + private final Map nodeConverters = ImportClassMapping.getOpNameMapping(); + private final Map tensorFlowNames = ImportClassMapping.getTFOpMappingFunctions(); + private final Map onnxNames = ImportClassMapping.getOnnxOpMappingFunctions(); + private final Map> customOpHashToClass = new HashMap<>(); + private final Map>> customOpHashToClasses = new HashMap<>(); //Only contains ops with 1 hash to multiple classes + private final List missingOps = new ArrayList<>(); - private Map onnxOpDescriptors; - private Map tensorflowOpDescriptors; - private Map> fieldsForFunction; + private final Map onnxOpDescriptors; + private final Map tensorflowOpDescriptors; + private final Map> fieldsForFunction; private static final Set fieldNamesOpsIgnore = new LinkedHashSet(){{ add("extraArgs"); @@ -71,7 +71,7 @@ public class DifferentialFunctionClassHolder { }}; //When determining fields/properties, where should we terminate the search? //We don't wan to include every single field from every single superclass - private static final Set classesToIgnore = new HashSet<>(Arrays.asList( + private static final Set classesToIgnore = new HashSet<>(Collections.singletonList( Object.class // BaseOp.class //Exclude x/y/z, n, numProcessed, extraArgs, etc )); @@ -82,11 +82,11 @@ public class DifferentialFunctionClassHolder { } @Getter - private int countTotalTfOps; + private final int countTotalTfOps; @Getter - private int countTotalMappedOps; + private final int countTotalMappedOps; - private static DifferentialFunctionClassHolder INSTANCE = new DifferentialFunctionClassHolder(); + private static final DifferentialFunctionClassHolder INSTANCE = new DifferentialFunctionClassHolder(); /** * Get the fields for a given {@link DifferentialFunction} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index e6a00e01d..c188fb8c3 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -36,7 +36,7 @@ public class ImportClassMapping { private static final Map TF_OP_NAME_MAP = new HashMap<>(); private static final Map ONNX_OP_NAME_MAP = new HashMap<>(); - private static final List> fnClasses = Arrays.>asList( + private static final List> fnClasses = Arrays.asList( org.nd4j.linalg.api.ops.DynamicCustomOp.class, org.nd4j.linalg.api.ops.NoOp.class, org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater.class, diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/descriptors/properties/adapters/BooleanAdapter.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/descriptors/properties/adapters/BooleanAdapter.java index 602b0d184..f25243100 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/descriptors/properties/adapters/BooleanAdapter.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/descriptors/properties/adapters/BooleanAdapter.java @@ -29,6 +29,6 @@ public class BooleanAdapter implements AttributeAdapter { @Override public void mapAttributeFor(Object inputAttributeValue, Field fieldFor, DifferentialFunction on) { - on.setValueFor(fieldFor, (boolean) inputAttributeValue); + on.setValueFor(fieldFor, inputAttributeValue); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/descriptors/properties/adapters/DataTypeAdapter.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/descriptors/properties/adapters/DataTypeAdapter.java index 2c3cdb373..9552c14b6 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/descriptors/properties/adapters/DataTypeAdapter.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/descriptors/properties/adapters/DataTypeAdapter.java @@ -38,7 +38,7 @@ public class DataTypeAdapter implements AttributeAdapter { val x = dataType.getNumber(); return dtypeConv(x); - }; + } public static org.nd4j.linalg.api.buffer.DataType dtypeConv(int dataType) { @@ -58,5 +58,5 @@ public class DataTypeAdapter implements AttributeAdapter { case DataType.DT_UINT64_VALUE: return org.nd4j.linalg.api.buffer.DataType.UINT64; default: throw new UnsupportedOperationException("DataType isn't supported: " + dataType + " - " + DataType.forNumber(dataType)); } - }; + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java index 7e1e236fb..c293bde56 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/graphmapper/tf/TFGraphMapper.java @@ -786,7 +786,7 @@ public class TFGraphMapper { } else if (!setList.getBList().isEmpty()) { break; } else if (!setList.getFList().isEmpty()) { - val floats = Floats.toArray((Collection) setList.getFList()); + val floats = Floats.toArray(setList.getFList()); if (adapter != null) { adapter.mapAttributeFor(floats, currentField, on); } else diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMapper.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMapper.java index 16f75542e..6204552a3 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMapper.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/graphmapper/tf/tensors/TFTensorMapper.java @@ -32,7 +32,7 @@ import java.nio.ByteBuffer; */ public interface TFTensorMapper { - enum ValueSource {EMPTY, VALUE_COUNT, BINARY}; + enum ValueSource {EMPTY, VALUE_COUNT, BINARY} DataType dataType(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java index 5338f5e5d..45c59221f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/imports/tensorflow/TensorFlowImportValidator.java @@ -268,8 +268,8 @@ public class TensorFlowImportValidator { return new TFImportStatus( Collections.singletonList(path), - unsupportedOpNames.size() > 0 ? Collections.singletonList(path) : Collections.emptyList(), - Collections.emptyList(), + unsupportedOpNames.size() > 0 ? Collections.singletonList(path) : Collections.emptyList(), + Collections.emptyList(), opCount, opNames.size(), opNames, @@ -283,16 +283,16 @@ public class TensorFlowImportValidator { } log.warn("Failed to import model from: " + path + " - not a TensorFlow frozen model in ProtoBuf format?", t); return new TFImportStatus( - Collections.emptyList(), - Collections.emptyList(), + Collections.emptyList(), + Collections.emptyList(), Collections.singletonList(path), 0, 0, - Collections.emptySet(), - Collections.emptyMap(), - Collections.emptySet(), - Collections.emptySet(), - Collections.>emptyMap()); + Collections.emptySet(), + Collections.emptyMap(), + Collections.emptySet(), + Collections.emptySet(), + Collections.emptyMap()); } } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java index 2e7bdf90d..2498b7123 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java @@ -34,7 +34,7 @@ import org.nd4j.common.primitives.Pair; public class ActivationELU extends BaseActivationFunction { public static final double DEFAULT_ALPHA = 1.0; - private double alpha; + private final double alpha; public ActivationELU() { this(DEFAULT_ALPHA); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationGELU.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationGELU.java index f9fe2714b..a66c6e66e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationGELU.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationGELU.java @@ -35,7 +35,7 @@ import org.nd4j.common.primitives.Pair; @Getter public class ActivationGELU extends BaseActivationFunction { - private boolean precise; + private final boolean precise; public ActivationGELU(boolean precise){ this.precise = precise; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java index c6eafce62..e56b654f9 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java @@ -30,9 +30,11 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; /** + * {@code * ⎧ 1, if x > 1 f(x) = ⎨ -1, if x < -1 ⎩ x, otherwise +} */ @EqualsAndHashCode(callSuper = false) @Getter diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java index 9a3d34a65..5459e7aeb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java @@ -39,7 +39,7 @@ import org.nd4j.common.primitives.Pair; public class ActivationLReLU extends BaseActivationFunction { public static final double DEFAULT_ALPHA = 0.01; - private double alpha; + private final double alpha; public ActivationLReLU() { this(DEFAULT_ALPHA); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java index c7bf5778c..23e3b5a4d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationPReLU.java @@ -32,7 +32,7 @@ import org.nd4j.common.primitives.Pair; @Getter public class ActivationPReLU extends BaseActivationFunction { - private INDArray alpha; + private final INDArray alpha; private long[] sharedAxes = null; public ActivationPReLU(INDArray alpha, long[] sharedAxes) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java index 56b0c9b95..88c662670 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java @@ -39,7 +39,8 @@ public class ActivationRReLU extends BaseActivationFunction { public static final double DEFAULT_L = 1.0 / 8; public static final double DEFAULT_U = 1.0 / 3; - private double l, u; + private final double l; + private final double u; private transient INDArray alpha; //don't need to write to json, when streaming public ActivationRReLU() { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java index d96088dc7..4452f28a3 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU.java @@ -36,9 +36,9 @@ import org.nd4j.linalg.factory.Nd4j; @Getter public class ActivationReLU extends BaseActivationFunction { - private Double max; - private Double threshold; - private Double negativeSlope; + private final Double max; + private final Double threshold; + private final Double negativeSlope; public ActivationReLU(){ this(null, null, null); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationThresholdedReLU.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationThresholdedReLU.java index 9ae0df963..6173e6f3b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationThresholdedReLU.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationThresholdedReLU.java @@ -31,7 +31,7 @@ import org.nd4j.common.primitives.Pair; /** * Thresholded RELU * - * f(x) = x for x > theta, f(x) = 0 otherwise. theta defaults to 1.0 + * f(x) = x for x > theta, f(x) = 0 otherwise. theta defaults to 1.0 * * @author Max Pumperla */ @@ -40,7 +40,7 @@ import org.nd4j.common.primitives.Pair; public class ActivationThresholdedReLU extends BaseActivationFunction { public static final double DEFAULT_THETA = 1.0; - private double theta; + private final double theta; public ActivationThresholdedReLU() { this(DEFAULT_THETA); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Blas.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Blas.java index fe5c1ff50..28d89f378 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Blas.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Blas.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.api.blas; public interface Blas { - public enum Vendor { + enum Vendor { UNKNOWN, CUBLAS, OPENBLAS, MKL, } @@ -45,5 +45,5 @@ public interface Blas { * * @return the BLAS library vendor */ - public Vendor getBlasVendor(); + Vendor getBlasVendor(); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/BlasBufferUtil.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/BlasBufferUtil.java index df61d777d..c3ac9321e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/BlasBufferUtil.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/BlasBufferUtil.java @@ -245,8 +245,8 @@ public class BlasBufferUtil { if (toSet.data().allocationMode() == DataBuffer.AllocationMode.HEAP) { Object array = toSet.data().array(); //data is assumed to have already been updated - if (array == data) - return; + if (array == data) { + } else { //copy the data over directly to the underlying array float[] d = (float[]) array; @@ -310,8 +310,8 @@ public class BlasBufferUtil { if (toSet.data().allocationMode() == DataBuffer.AllocationMode.HEAP) { Object array = toSet.data().array(); //data is assumed to have already been updated - if (array == data) - return; + if (array == data) { + } else { //copy the data over directly to the underlying array double[] d = (double[]) array; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java index cf4487de4..3fc641421 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java @@ -36,8 +36,8 @@ public abstract class BaseLapack implements Lapack { if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int m = (int) A.rows(); - int n = (int) A.columns(); + int m = A.rows(); + int n = A.columns(); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {1, 1}, A.dataType()).getFirst()); @@ -88,7 +88,7 @@ public abstract class BaseLapack implements Lapack { throw new ND4JArraySizeException(); byte uplo = (byte) (lower ? 'L' : 'U'); // upper or lower part of the factor desired ? - int n = (int) A.columns(); + int n = A.columns(); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {1, 1}, A.dataType()).getFirst()); @@ -106,7 +106,6 @@ public abstract class BaseLapack implements Lapack { throw new Error("The matrix is not positive definite! (potrf fails @ order " + INFO.getInt(0) + ")"); } - return; } @@ -132,8 +131,8 @@ public abstract class BaseLapack implements Lapack { if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int m = (int) A.rows(); - int n = (int) A.columns(); + int m = A.rows(); + int n = A.columns(); INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new long[] {1, 1}, A.dataType()).getFirst()); @@ -187,9 +186,9 @@ public abstract class BaseLapack implements Lapack { int status = -1; if (A.data().dataType() == DataType.DOUBLE) { - status = dsyev(jobz, uplo, (int) A.rows(), A, V); + status = dsyev(jobz, uplo, A.rows(), A, V); } else if (A.data().dataType() == DataType.FLOAT) { - status = ssyev(jobz, uplo, (int) A.rows(), A, V); + status = ssyev(jobz, uplo, A.rows(), A, V); } else { throw new UnsupportedOperationException(); } @@ -218,8 +217,8 @@ public abstract class BaseLapack implements Lapack { if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int m = (int) A.rows(); - int n = (int) A.columns(); + int m = A.rows(); + int n = A.columns(); byte jobu = (byte) (U == null ? 'N' : 'A'); byte jobvt = (byte) (VT == null ? 'N' : 'A'); @@ -274,8 +273,8 @@ public abstract class BaseLapack implements Lapack { if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int m = (int) A.rows(); - int n = (int) A.columns(); + int m = A.rows(); + int n = A.columns(); INDArray L = Nd4j.create(m, n); for (int r = 0; r < m; r++) { @@ -298,8 +297,8 @@ public abstract class BaseLapack implements Lapack { if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int m = (int) A.rows(); - int n = (int) A.columns(); + int m = A.rows(); + int n = A.columns(); INDArray U = Nd4j.create(n, n); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java index 1743e993e..e331a0d6a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel2.java @@ -119,11 +119,11 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - dgbmv(order, TransA, (int) A.rows(), (int) A.columns(), KL, KU, alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y, + dgbmv(order, TransA, A.rows(), A.columns(), KL, KU, alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y, Y.stride(-1)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y); - sgbmv(order, TransA, (int) A.rows(), (int) A.columns(), KL, KU, (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta, Y, Y.stride(-1)); + sgbmv(order, TransA, A.rows(), A.columns(), KL, KU, (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta, Y, Y.stride(-1)); } OpExecutionerUtil.checkForAny(Y); @@ -148,10 +148,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE || A.size(0) > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - dger(order, (int) A.rows(), (int) A.columns(), alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0)); + dger(order, A.rows(), A.columns(), alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y); - sger(order, (int) A.rows(), (int) A.columns(), (float) alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0)); + sger(order, A.rows(), A.columns(), (float) alpha, X, X.stride(-1), Y, Y.stride(-1), A, (int) A.size(0)); } OpExecutionerUtil.checkForAny(A); @@ -180,11 +180,11 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { } if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X, Y); - dsbmv(order, Uplo, (int) X.length(), (int) A.columns(), alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y, + dsbmv(order, Uplo, (int) X.length(), A.columns(), alpha, A, (int) A.size(0), X, X.stride(-1), beta, Y, Y.stride(-1)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X, Y); - ssbmv(order, Uplo, (int) X.length(), (int) A.columns(), (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta, + ssbmv(order, Uplo, (int) X.length(), A.columns(), (float) alpha, A, (int) A.size(0), X, X.stride(-1), (float) beta, Y, Y.stride(-1)); } @@ -392,10 +392,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X); - dtbmv(order, Uplo, TransA, Diag, (int) X.length(), (int) A.columns(), A, (int) A.size(0), X, X.stride(-1)); + dtbmv(order, Uplo, TransA, Diag, (int) X.length(), A.columns(), A, (int) A.size(0), X, X.stride(-1)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X); - stbmv(order, Uplo, TransA, Diag, (int) X.length(), (int) A.columns(), A, (int) A.size(0), X, X.stride(-1)); + stbmv(order, Uplo, TransA, Diag, (int) X.length(), A.columns(), A, (int) A.size(0), X, X.stride(-1)); } } @@ -420,10 +420,10 @@ public abstract class BaseLevel2 extends BaseLevel implements Level2 { if (X.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, X); - dtbsv(order, Uplo, TransA, Diag, (int) X.length(), (int) A.columns(), A, (int) A.size(0), X, X.stride(-1)); + dtbsv(order, Uplo, TransA, Diag, (int) X.length(), A.columns(), A, (int) A.size(0), X, X.stride(-1)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, X); - stbsv(order, Uplo, TransA, Diag, (int) X.length(), (int) A.columns(), A, (int) A.size(0), X, X.stride(-1)); + stbsv(order, Uplo, TransA, Diag, (int) X.length(), A.columns(), A, (int) A.size(0), X, X.stride(-1)); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java index 958396b81..35eb2af49 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLevel3.java @@ -142,10 +142,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C); - dsymm(Order, Side, Uplo, (int) C.rows(), (int) C.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0), beta, C, (int) C.size(0)); + dsymm(Order, Side, Uplo, C.rows(), C.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0), beta, C, (int) C.size(0)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, B, C); - ssymm(Order, Side, Uplo, (int) C.rows(), (int) C.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0), (float) beta, C, + ssymm(Order, Side, Uplo, C.rows(), C.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0), (float) beta, C, (int) C.size(0)); } @@ -180,10 +180,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, C); - dsyrk(Order, Uplo, Trans, (int) C.rows(), 1, alpha, A, (int) A.size(0), beta, C, (int) C.size(0)); + dsyrk(Order, Uplo, Trans, C.rows(), 1, alpha, A, (int) A.size(0), beta, C, (int) C.size(0)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, C); - ssyrk(Order, Uplo, Trans, (int) C.rows(), 1, (float) alpha, A, (int) A.size(0), (float) beta, C, (int) C.size(0)); + ssyrk(Order, Uplo, Trans, C.rows(), 1, (float) alpha, A, (int) A.size(0), (float) beta, C, (int) C.size(0)); } OpExecutionerUtil.checkForAny(C); @@ -218,10 +218,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C); - dsyr2k(Order, Uplo, Trans, (int) A.rows(), (int) A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0), beta, C, (int) C.size(0)); + dsyr2k(Order, Uplo, Trans, A.rows(), A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0), beta, C, (int) C.size(0)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, B, C); - ssyr2k(Order, Uplo, Trans, (int) A.rows(), (int) A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0), (float) beta, C, (int) C.size(0)); + ssyr2k(Order, Uplo, Trans, A.rows(), A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0), (float) beta, C, (int) C.size(0)); } OpExecutionerUtil.checkForAny(C); @@ -257,10 +257,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B, C); - dtrmm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0)); + dtrmm(Order, Side, Uplo, TransA, Diag, A.rows(), A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, B, C); - strmm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0)); + strmm(Order, Side, Uplo, TransA, Diag, A.rows(), A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0)); } OpExecutionerUtil.checkForAny(C); @@ -295,10 +295,10 @@ public abstract class BaseLevel3 extends BaseLevel implements Level3 { if (A.data().dataType() == DataType.DOUBLE) { DefaultOpExecutioner.validateDataType(DataType.DOUBLE, A, B); - dtrsm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0)); + dtrsm(Order, Side, Uplo, TransA, Diag, A.rows(), A.columns(), alpha, A, (int) A.size(0), B, (int) B.size(0)); } else { DefaultOpExecutioner.validateDataType(DataType.FLOAT, A, B); - strsm(Order, Side, Uplo, TransA, Diag, (int) A.rows(), (int) A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0)); + strsm(Order, Side, Uplo, TransA, Diag, A.rows(), A.columns(), (float) alpha, A, (int) A.size(0), B, (int) B.size(0)); } OpExecutionerUtil.checkForAny(B); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/GemmParams.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/GemmParams.java index 9d436c542..532bbbe53 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/GemmParams.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/GemmParams.java @@ -78,18 +78,18 @@ public @Data class GemmParams { this.b = copyIfNeccessary(b); this.c = c; if (ordering == 'c') { - this.m = (int) c.columns(); - this.n = (int) c.rows(); - this.k = (int) a.columns(); + this.m = c.columns(); + this.n = c.rows(); + this.k = a.columns(); } else { - this.m = (int) c.rows(); - this.n = (int) c.columns(); - this.k = (int) b.columns(); + this.m = c.rows(); + this.n = c.columns(); + this.k = b.columns(); } - this.lda = (int) a.rows(); - this.ldb = (int) b.rows(); - this.ldc = (int) c.rows(); + this.lda = a.rows(); + this.ldb = b.rows(); + this.ldc = c.rows(); this.transA = 'N'; this.transB = 'N'; @@ -101,15 +101,15 @@ public @Data class GemmParams { this.b = b.dup(a.ordering()); this.c = c; - this.m = (int) c.rows(); - this.n = (int) c.columns(); - this.k = (int) a.columns(); + this.m = c.rows(); + this.n = c.columns(); + this.k = a.columns(); this.ordering = a.ordering(); - this.lda = (int) a.rows(); - this.ldb = (int) b.rows(); - this.ldc = (int) c.rows(); + this.lda = a.rows(); + this.ldb = b.rows(); + this.ldc = c.rows(); this.transA = 'N'; this.transB = 'N'; @@ -124,14 +124,14 @@ public @Data class GemmParams { this.b = copyIfNeccessary(b); this.c = c; - this.m = (int) c.rows(); - this.n = (int) c.columns(); - this.k = (int) a.columns(); + this.m = c.rows(); + this.n = c.columns(); + this.k = a.columns(); //always fortran ordering - this.lda = (int) (this.a.ordering() == 'f' ? this.a.rows() : this.a.columns()); //Leading dimension of a, as declared. But swap if 'c' order - this.ldb = (int) (this.b.ordering() == 'f' ? this.b.rows() : this.b.columns()); //Leading dimension of b, as declared. But swap if 'c' order - this.ldc = (int) c.rows(); + this.lda = this.a.ordering() == 'f' ? this.a.rows() : this.a.columns(); //Leading dimension of a, as declared. But swap if 'c' order + this.ldb = this.b.ordering() == 'f' ? this.b.rows() : this.b.columns(); //Leading dimension of b, as declared. But swap if 'c' order + this.ldc = c.rows(); this.transA = (this.a.ordering() == 'c' ? 'T' : 'N'); this.transB = (this.b.ordering() == 'c' ? 'T' : 'N'); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/GemvParameters.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/GemvParameters.java index 1e9822a08..7593f3f2b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/GemvParameters.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/GemvParameters.java @@ -44,19 +44,19 @@ public @Data class GemvParameters { if (a.ordering() == 'f' && a.isMatrix()) { - this.m = (int) a.rows(); - this.n = (int) a.columns(); - this.lda = (int) a.rows(); + this.m = a.rows(); + this.n = a.columns(); + this.lda = a.rows(); } else if (a.ordering() == 'c' && a.isMatrix()) { - this.m = (int) a.columns(); - this.n = (int) a.rows(); - this.lda = (int) a.columns(); + this.m = a.columns(); + this.n = a.rows(); + this.lda = a.columns(); aOrdering = 'T'; } else { - this.m = (int) a.rows(); - this.n = (int) a.columns(); + this.m = a.rows(); + this.n = a.columns(); this.lda = (int) a.size(0); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/MMulTranspose.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/MMulTranspose.java index a97049b58..ef7d0b3be 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/MMulTranspose.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/params/MMulTranspose.java @@ -33,7 +33,7 @@ import java.util.Map; @Getter @EqualsAndHashCode public class MMulTranspose implements Serializable { - private static MMulTranspose allFalse = MMulTranspose.builder().build(); + private static final MMulTranspose allFalse = MMulTranspose.builder().build(); private boolean transposeA; private boolean transposeB; private boolean transposeResult; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index ada11553f..ecb2110c2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -864,17 +864,17 @@ public abstract class BaseDataBuffer implements DataBuffer { case LONG: return ((LongIndexer) indexer).get(i); case UINT32: - return (long) ((UIntIndexer) indexer).get(i); + return ((UIntIndexer) indexer).get(i); case INT: - return (long) ((IntIndexer) indexer).get(i); + return ((IntIndexer) indexer).get(i); case UINT16: - return (long) ((UShortIndexer) indexer).get(i); + return ((UShortIndexer) indexer).get(i); case SHORT: - return (long) ((ShortIndexer) indexer).get(i); + return ((ShortIndexer) indexer).get(i); case BYTE: - return (long) ((ByteIndexer) indexer).get(i); + return ((ByteIndexer) indexer).get(i); case UBYTE: - return (long) ((UByteIndexer) indexer).get(i); + return ((UByteIndexer) indexer).get(i); case BOOL: return ((BooleanIndexer) indexer).get(i) ? 1L : 0L; default: @@ -908,7 +908,7 @@ public abstract class BaseDataBuffer implements DataBuffer { case SHORT: return ((ShortIndexer) indexer).get(i); case BYTE: - return (short) ((ByteIndexer) indexer).get(i); + return ((ByteIndexer) indexer).get(i); case UINT64: case LONG: return (short) ((LongIndexer) indexer).get(i); @@ -945,7 +945,7 @@ public abstract class BaseDataBuffer implements DataBuffer { case UINT16: return ((UShortIndexer) indexer).get(i); case SHORT: - return (float) ((ShortIndexer) indexer).get(i); + return ((ShortIndexer) indexer).get(i); case BFLOAT16: return ((Bfloat16Indexer) indexer).get(i); case HALF: @@ -953,7 +953,7 @@ public abstract class BaseDataBuffer implements DataBuffer { case UBYTE: return (float) ((UByteIndexer) indexer).get(i); case BYTE: - return (float) ((ByteIndexer) indexer).get(i); + return ((ByteIndexer) indexer).get(i); case UINT64: //Fall through case LONG: return (float) ((LongIndexer) indexer).get(i); @@ -1041,7 +1041,7 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(i, element == 0.0 ? false : true); + ((BooleanIndexer) indexer).put(i, element != 0.0); break; case BYTE: ((ByteIndexer) indexer).put(i, (byte) element); @@ -1137,7 +1137,7 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(i, element == 0 ? false : true); + ((BooleanIndexer) indexer).put(i, element != 0); break; case BYTE: ((ByteIndexer) indexer).put(i, (byte) element); @@ -1233,7 +1233,7 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case BOOL: - ((BooleanIndexer) indexer).put(i, element == 0 ? false : true); + ((BooleanIndexer) indexer).put(i, element != 0); break; case BYTE: ((ByteIndexer) indexer).put(i, (byte) element); @@ -1297,7 +1297,7 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() == 0) { return wrappedBuffer().asIntBuffer(); } else - return (IntBuffer) wrappedBuffer().asIntBuffer().position((int) offset()); + return wrappedBuffer().asIntBuffer().position((int) offset()); } @Override @@ -1308,7 +1308,7 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() == 0) { return wrappedBuffer().asLongBuffer(); } else - return (LongBuffer) wrappedBuffer().asLongBuffer().position((int) offset()); + return wrappedBuffer().asLongBuffer().position((int) offset()); } @Override @@ -1319,7 +1319,7 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() == 0) { return wrappedBuffer().asDoubleBuffer(); } else { - return (DoubleBuffer) wrappedBuffer().asDoubleBuffer().position((int) (offset())); + return wrappedBuffer().asDoubleBuffer().position((int) (offset())); } } @@ -1331,7 +1331,7 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() == 0) { return wrappedBuffer().asFloatBuffer(); } else { - return (FloatBuffer) wrappedBuffer().asFloatBuffer().position((int) (offset())); + return wrappedBuffer().asFloatBuffer().position((int) (offset())); } } @@ -1899,10 +1899,7 @@ public abstract class BaseDataBuffer implements DataBuffer { if (released || isAttached() || isConstant()) return false; - if (wrappedDataBuffer != null && wrappedDataBuffer != this) - return false; - - return true; + return wrappedDataBuffer == null || wrappedDataBuffer == this; } protected void markReleased() { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/util/DataTypeUtil.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/util/DataTypeUtil.java index 61f4a6b1d..25e49d37f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/util/DataTypeUtil.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/util/DataTypeUtil.java @@ -28,7 +28,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; public class DataTypeUtil { - private volatile transient static DataType dtype; + private volatile static DataType dtype; private static final ReadWriteLock lock = new ReentrantReadWriteLock(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/FirstAxisIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/FirstAxisIterator.java index 80bf4b574..f37bb7d79 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/FirstAxisIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/FirstAxisIterator.java @@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Iterator; public class FirstAxisIterator implements Iterator { - private INDArray iterateOver; + private final INDArray iterateOver; private int i = 0; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/FlatIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/FlatIterator.java index 7c0d8ab0a..c6be2d059 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/FlatIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/FlatIterator.java @@ -26,10 +26,10 @@ import java.util.Iterator; public class FlatIterator implements Iterator { - private int[] shape; + private final int[] shape; private int runningDimension; - private int[] currentCoord; - private int length; + private final int[] currentCoord; + private final int length; private int current = 0; public FlatIterator(int[] shape) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/INDArrayIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/INDArrayIterator.java index e1eb7a98e..e4740b316 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/INDArrayIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/INDArrayIterator.java @@ -29,7 +29,7 @@ import java.util.Iterator; * @author Adam Gibson */ public class INDArrayIterator implements Iterator { - private INDArray iterateOver; + private final INDArray iterateOver; private int i = 0; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/LinearIndexLookup.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/LinearIndexLookup.java index c3308ff23..865283f3e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/LinearIndexLookup.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/LinearIndexLookup.java @@ -26,11 +26,11 @@ import org.nd4j.common.util.ArrayUtil; import java.io.Serializable; public class LinearIndexLookup implements Serializable { - private char ordering; - private long[][] indexes; - private long[] shape; - private boolean[] exists; - private long numIndexes; + private final char ordering; + private final long[][] indexes; + private final long[] shape; + private final boolean[] exists; + private final long numIndexes; /** * diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/NdIndexIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/NdIndexIterator.java index cd01f89fd..5c6baa572 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/NdIndexIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/iter/NdIndexIterator.java @@ -31,10 +31,10 @@ import java.util.Map; public class NdIndexIterator implements Iterator { private int length = -1; private int i = 0; - private long[] shape; + private final long[] shape; private char order = 'c'; private boolean cache = false; - private static Map, LinearIndexLookup> lookupMap = new HashMap<>(); + private static final Map, LinearIndexLookup> lookupMap = new HashMap<>(); private LinearIndexLookup lookup; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java index 0f450224d..a9b320f96 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java @@ -30,7 +30,7 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class AllocationsTracker { private static final AllocationsTracker INSTANCE = new AllocationsTracker(); - private Map devices = new ConcurrentHashMap<>(); + private final Map devices = new ConcurrentHashMap<>(); protected AllocationsTracker() { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/BasicMemoryManager.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/BasicMemoryManager.java index f718a9907..2141c3172 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/BasicMemoryManager.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/BasicMemoryManager.java @@ -53,9 +53,9 @@ public abstract class BasicMemoryManager implements MemoryManager { protected Queue intervals = new ConcurrentLinkedQueue<>(); - private ThreadLocal workspace = new ThreadLocal<>(); + private final ThreadLocal workspace = new ThreadLocal<>(); - private ThreadLocal tempWorkspace = new ThreadLocal<>(); + private final ThreadLocal tempWorkspace = new ThreadLocal<>(); /** * This method returns diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java index 21a3f1d16..b2a6fc360 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java @@ -31,7 +31,7 @@ import java.util.concurrent.atomic.AtomicLong; @Slf4j public class DeviceAllocationsTracker { - private Map bytesMap = new HashMap<>(); + private final Map bytesMap = new HashMap<>(); public DeviceAllocationsTracker() { for (val e:AllocationKind.values()) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java index a53e539e9..5934a6d14 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java @@ -37,10 +37,10 @@ import java.util.concurrent.atomic.AtomicLong; @Slf4j public class DeallocatorService { - private Thread[] deallocatorThreads; - private ReferenceQueue[] queues; - private Map referenceMap = new ConcurrentHashMap<>(); - private List>> deviceMap = new ArrayList<>(); + private final Thread[] deallocatorThreads; + private final ReferenceQueue[] queues; + private final Map referenceMap = new ConcurrentHashMap<>(); + private final List>> deviceMap = new ArrayList<>(); private final transient AtomicLong counter = new AtomicLong(0); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/pointers/ImmortalFloatPointer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/pointers/ImmortalFloatPointer.java index d6d2a128c..b6980d7e4 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/pointers/ImmortalFloatPointer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/pointers/ImmortalFloatPointer.java @@ -26,7 +26,7 @@ import org.bytedeco.javacpp.Pointer; @Slf4j public class ImmortalFloatPointer extends FloatPointer { - private Pointer pointer; + private final Pointer pointer; public ImmortalFloatPointer(PagedPointer pointer) { this.pointer = pointer; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java index e03ae02d8..0ae612d0e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/provider/BasicWorkspaceManager.java @@ -73,7 +73,7 @@ public abstract class BasicWorkspaceManager implements MemoryWorkspaceManager { */ @Override public String getUUID() { - return "Workspace_" + String.valueOf(counter.incrementAndGet()); + return "Workspace_" + counter.incrementAndGet(); } /** @@ -351,7 +351,7 @@ public abstract class BasicWorkspaceManager implements MemoryWorkspaceManager { log.info("Number of workspaces in current thread: {}", map.size()); log.info("Workspace name: Allocated / external (spilled) / external (pinned)"); for (String key : map.keySet()) { - long current = ((Nd4jWorkspace) map.get(key)).getCurrentSize(); + long current = map.get(key).getCurrentSize(); long spilled = ((Nd4jWorkspace) map.get(key)).getSpilledSize(); long pinned = ((Nd4jWorkspace) map.get(key)).getPinnedSize(); log.info(String.format("%-26s %8s / %8s / %8s (%11d / %11d / %11d)", (key + ":"), diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index c948afee7..9887ddadb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -419,7 +419,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { public BaseNDArray(long newRows, long newColumns, char ordering) { Shape.assertValidOrder(ordering); - this.data = Nd4j.createBuffer((long) newRows * newColumns); + this.data = Nd4j.createBuffer(newRows * newColumns); long[] shape = new long[] {newRows, newColumns}; long[] stride = Nd4j.getStrides(shape, ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, @@ -522,14 +522,14 @@ public abstract class BaseNDArray implements INDArray, Iterable { public BaseNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) { Shape.assertValidOrder(ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data != null && data.length > 0 ? false : true)); + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data == null || data.length <= 0)); if (data != null && data.length > 0) { val perfD = PerformanceTracker.getInstance().helperStartTransaction(); this.data = internalCreateBuffer(data, offset); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, data.length * Nd4j.sizeOfDataType(DataType.FLOAT), MemcpyDirection.HOST_TO_HOST); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, (long) data.length * Nd4j.sizeOfDataType(DataType.FLOAT), MemcpyDirection.HOST_TO_HOST); if (offset >= data.length) throw new IllegalArgumentException("invalid offset: must be < data.length"); @@ -541,7 +541,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { public BaseNDArray(float[] data, long[] shape, long[] stride, long offset, char ordering) { Shape.assertValidOrder(ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data != null && data.length > 0 ? false : true)); + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data == null || data.length <= 0)); if (data != null && data.length > 0) { this.data = Nd4j.createTypedBuffer(data, DataType.FLOAT); if (offset >= data.length) @@ -554,7 +554,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { public BaseNDArray(double[] data, long[] shape, long[] stride, long offset, char ordering) { Shape.assertValidOrder(ordering); setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.DOUBLE, data != null && data.length > 0 ? false : true)); + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.DOUBLE, data == null || data.length <= 0)); if (data != null && data.length > 0) { this.data = Nd4j.createBuffer(data, offset); if (offset >= data.length) @@ -673,7 +673,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } @@ -682,7 +682,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } @@ -691,7 +691,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } @@ -700,7 +700,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data, offset); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } @@ -709,7 +709,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { val perfX = PerformanceTracker.getInstance().helperStartTransaction(); val buffer = Nd4j.createBuffer(data, offset); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); return buffer; } @@ -1086,7 +1086,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { int length = ArrayUtil.prod(tensorShape); int tensorLength = ArrayUtil.prod(tensorShape); - long offset = index * tensorLength / NDArrayMath.lengthPerSlice(ret2); + long offset = (long) index * tensorLength / NDArrayMath.lengthPerSlice(ret2); if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) { if (offset > Integer.MAX_VALUE) @@ -1460,7 +1460,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { long size_2 = jvmShapeInfo.javaShapeInformation[1 + 2]; if (size_0 != 1) - offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 0 + 3]; + offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 3]; if (size_1 != 1) offset += dim1 * jvmShapeInfo.javaShapeInformation[1 + 1 + 3]; if (size_2 != 1) @@ -1900,7 +1900,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public INDArray getWhere(Number comp, Condition condition) { - return BooleanIndexing.chooseFrom(new INDArray[]{this},Arrays.asList(comp.doubleValue()),Collections.emptyList(),condition); + return BooleanIndexing.chooseFrom(new INDArray[]{this}, Collections.singletonList(comp.doubleValue()),Collections.emptyList(),condition); } @Override @@ -1985,7 +1985,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices()); long[] sliceShape = put.shape(); if (Shape.isRowVectorShape(sliceShape)) { - return; } else { long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); @@ -4111,7 +4110,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { if (isVector()) return Nd4j.pullRows(this, 1, rindices); else { - INDArray ret = Nd4j.createUninitialized(this.dataType(), new long[] {rindices.length, columns()}); + INDArray ret = Nd4j.createUninitialized(this.dataType(), rindices.length, columns()); for (int i = 0; i < rindices.length; i++) ret.putRow(i, getRow(rindices[i])); return ret; @@ -4146,8 +4145,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { // Padding remaining dimensions with all() index if too few indices provided if (indexes.length - numNewAxis < this.rank()) { val newIndexes = new INDArrayIndex[this.rank() + numNewAxis]; - for (int e = 0; e < indexes.length; e++) - newIndexes[e] = indexes[e]; + System.arraycopy(indexes, 0, newIndexes, 0, indexes.length); for (int e = indexes.length; e < newIndexes.length; e++) { numAll++; @@ -4312,7 +4310,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { if (isVector()) { return Nd4j.pullRows(this, 0, cindices, this.ordering()); } else { - INDArray ret = Nd4j.createUninitialized(this.dataType(), new long[]{rows(), cindices.length}); + INDArray ret = Nd4j.createUninitialized(this.dataType(), rows(), cindices.length); for (int i = 0; i < cindices.length; i++) ret.putColumn(i, getColumn(cindices[i])); return ret; @@ -5509,8 +5507,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { if (e != 0) { val t = ArrayOptionsHelper.dataType(jvmShapeInfo.javaShapeInformation); - if (t != DataType.UNKNOWN) - return t; + return t; } return DataType.UNKNOWN; @@ -5623,10 +5620,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { @Override public boolean wasClosed() { // data can be null if that's empty array - if (released || (data() != null && data().wasClosed())) - return true; - - return false; + return released || (data() != null && data().wasClosed()); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 0fb2db284..be3a172d1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -44,6 +44,7 @@ import org.tensorflow.framework.NodeDef; import java.nio.Buffer; import java.util.Arrays; import java.util.Map; +import java.util.Objects; @Data public abstract class BaseOp extends DifferentialFunction implements Op { @@ -145,7 +146,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op { if (extraArgs != null) { if (Shape.isZ(dtype) || Shape.isB(dtype)) { - long extraz[] = new long[extraArgs.length]; + long[] extraz = new long[extraArgs.length]; for (int i = 0; i < extraArgs.length; i++) { if (extraArgs[i] instanceof Number) { Number arg = (Number) extraArgs[i]; @@ -156,7 +157,7 @@ public abstract class BaseOp extends DifferentialFunction implements Op { extraArgz = Nd4j.getConstantHandler().getConstantBuffer(extraz, dtype); return extraArgz; } else if (Shape.isR(dtype)) { - double extraz[] = new double[extraArgs.length]; + double[] extraz = new double[extraArgs.length]; for (int i = 0; i < extraArgs.length; i++) { if (!(extraArgs[i] instanceof Number)) continue; @@ -318,12 +319,12 @@ public abstract class BaseOp extends DifferentialFunction implements Op { BaseOp baseOp = (BaseOp) o; - if (x != null ? !x.equals(baseOp.x) : baseOp.x != null) return false; - if (y != null ? !y.equals(baseOp.y) : baseOp.y != null) return false; - if (z != null ? !z.equals(baseOp.z) : baseOp.z != null) return false; + if (!Objects.equals(x, baseOp.x)) return false; + if (!Objects.equals(y, baseOp.y)) return false; + if (!Objects.equals(z, baseOp.z)) return false; // Probably incorrect - comparing Object[] arrays with Arrays.equals if (!Arrays.equals(extraArgs, baseOp.extraArgs)) return false; - return extraArgz != null ? extraArgz.equals(baseOp.extraArgz) : baseOp.extraArgz == null; + return Objects.equals(extraArgz, baseOp.extraArgz); } @Override @@ -369,9 +370,9 @@ public abstract class BaseOp extends DifferentialFunction implements Op { if (z.isR()) return new Double(z.getDouble(0)); else if (z.isZ()) - return new Long(z.getInt(0)); + return Long.valueOf(z.getInt(0)); else if (z.isB()) - return new Integer(z.getInt(0)); + return Integer.valueOf(z.getInt(0)); throw new ND4JIllegalStateException("???"); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java index 8f818b37a..eebc9c8eb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java @@ -96,8 +96,7 @@ public abstract class BaseOpContext implements OpContext { @Override public void setDArguments(DataType... arguments) { fastpath_d.clear(); - for (val v:arguments) - fastpath_d.add(v); + Collections.addAll(fastpath_d, arguments); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java index ef23a963d..d83509801 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformSameOp.java @@ -143,6 +143,6 @@ public abstract class BaseTransformSameOp extends BaseTransformOp implements Tra check = dataType; } } - return Arrays.asList(dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java index 304f40b99..c1fe9da10 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/DynamicCustomOp.java @@ -135,7 +135,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { if(iArguments != null) { for (val a : iArguments) - this.iArguments.add((Long) a.longValue()); + this.iArguments.add(a.longValue()); } bArguments = new ArrayList<>(); dArguments = new ArrayList<>(); @@ -160,7 +160,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { * @param outputs the outputs of the op, may be null */ public DynamicCustomOp(String opName, INDArray[] inputs, INDArray[] outputs) { - this(opName, inputs, outputs, Lists.newArrayList(), Lists.newArrayList()); + this(opName, inputs, outputs, Lists.newArrayList(), Lists.newArrayList()); } /** @@ -313,7 +313,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { private void addIArgument(Integer... arg) { for (val a: arg) - addIArgument((Long) a.longValue()); + addIArgument(a.longValue()); } @Override @@ -690,12 +690,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { protected long opHash; protected List outputShapes = new ArrayList<>(); - private List inputArguments = new ArrayList<>(); - private List outputArguments = new ArrayList<>(); - private List tArguments = new ArrayList<>(); - private List iArguments = new ArrayList<>(); - private List dArguments = new ArrayList<>(); - private List bArguments = new ArrayList<>(); + private final List inputArguments = new ArrayList<>(); + private final List outputArguments = new ArrayList<>(); + private final List tArguments = new ArrayList<>(); + private final List iArguments = new ArrayList<>(); + private final List dArguments = new ArrayList<>(); + private final List bArguments = new ArrayList<>(); protected DynamicCustomOpsBuilder(String opName, long hash, int numInputs, int numOutputs, boolean inplaceAllowed, int numTArguments, int numIArguments) { this.opHash = hash; @@ -727,8 +727,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { throw new ND4JIllegalStateException("CustomOp [" + opName + "] expects at least " + numInputs + " arguments, but " + inputs.length + " was passed to constructor"); } - for (val in : inputs) - inputArguments.add(in); + Collections.addAll(inputArguments, inputs); return this; } @@ -752,8 +751,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { throw new ND4JIllegalStateException("CustomOp [" + opName + "] expects at least " + numOutputs + " arguments, but " + outputs.length + " was passed to constructor"); } - for (val in : outputs) - outputArguments.add(in); + Collections.addAll(outputArguments, outputs); return this; } @@ -873,8 +871,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp { throw new ND4JIllegalStateException("CustomOp [" + opName + "] expects at least " + numTArguments + " integer arguments, but " + targs.length + " was passed to constructor"); } - for (val in : targs) - tArguments.add(in); + Collections.addAll(tArguments, targs); return this; } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/BaseAggregate.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/BaseAggregate.java index d5aaedd5b..1db99da47 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/BaseAggregate.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/BaseAggregate.java @@ -73,11 +73,11 @@ public abstract class BaseAggregate implements Aggregate { @Override public long getRequiredBatchMemorySize() { - long result = maxIntArrays() * maxIntArraySize() * 4; - result += maxArguments() * 8; // pointers - result += maxShapes() * 8; // pointers - result += maxIndexArguments() * 4; - result += maxRealArguments() * (Nd4j.dataType() == DataType.DOUBLE ? 8 + long result = (long) maxIntArrays() * maxIntArraySize() * 4; + result += maxArguments() * 8L; // pointers + result += maxShapes() * 8L; // pointers + result += maxIndexArguments() * 4L; + result += (long) maxRealArguments() * (Nd4j.dataType() == DataType.DOUBLE ? 8 : Nd4j.dataType() == DataType.FLOAT ? 4 : 2); result += 5 * 4; // numArgs diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java index 32801a893..97f67e3b6 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/Batch.java @@ -48,12 +48,12 @@ public class Batch { // all aggregates within this batch @Getter - private List aggregates; + private final List aggregates; @Getter - private T sample; + private final T sample; @Getter - private int numAggregates; + private final int numAggregates; /** * This constructor takes List of Aggregates, and builds Batch instance, usable with Nd4j executioner. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateAxpy.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateAxpy.java index 9f7d5439f..6f61dd546 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateAxpy.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/aggregates/impl/AggregateAxpy.java @@ -27,7 +27,7 @@ import org.nd4j.linalg.factory.Nd4j; @Deprecated public class AggregateAxpy extends BaseAggregate { - private int vectorLength; + private final int vectorLength; public AggregateAxpy(@NonNull INDArray x, @NonNull INDArray y, double alpha) { this.arguments.add(x); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java index 7981803e8..a9cb4a502 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/Flatten.java @@ -32,6 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @Data @@ -43,10 +44,9 @@ public class Flatten extends DynamicCustomOp { public Flatten(char order, INDArray... inputs) { this.order = order; - for (val in:inputs) - inputArguments.add(in); + Collections.addAll(inputArguments, inputs); - iArguments.add(Long.valueOf((int) this.order)); + iArguments.add(Long.valueOf(this.order)); } public Flatten(INDArray output, INDArray... inputs) { @@ -70,6 +70,6 @@ public class Flatten extends DynamicCustomOp { public List calculateOutputDataTypes(List inputDataTypes) { int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); - return Arrays.asList(inputDataTypes.get(0)); + return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java index 751f37ede..8a21451b4 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/FusedBatchNorm.java @@ -78,7 +78,7 @@ public class FusedBatchNorm extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { boolean isNchw = attributesForNode.containsKey("data_format") && attributesForNode.get("data_format").getS().toStringUtf8().equalsIgnoreCase("NCHW"); - boolean training = !attributesForNode.containsKey("is_training") ? true : attributesForNode.get("is_training").getB(); + boolean training = !attributesForNode.containsKey("is_training") || attributesForNode.get("is_training").getB(); addIArgument(isNchw ? 1 : 0); addIArgument(training ? 1 : 0); if(attributesForNode.containsKey("T")){ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java index fe4b9b214..60b19c739 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/custom/LinearSolve.java @@ -67,7 +67,7 @@ public class LinearSolve extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { - boolean adjoint = attributesForNode.containsKey("adjoint") ? attributesForNode.get("adjoint").getB() : false; + boolean adjoint = attributesForNode.containsKey("adjoint") && attributesForNode.get("adjoint").getB(); addBArgument(adjoint); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastEqualTo.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastEqualTo.java index d0b5149dc..c20054198 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastEqualTo.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastEqualTo.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseBroadcastBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class BroadcastEqualTo extends BaseBroadcastBoolOp { @@ -67,7 +68,7 @@ public class BroadcastEqualTo extends BaseBroadcastBoolOp { @Override public List doDiff(List f1) { - return Arrays.asList(outputVariables()[0]); + return Collections.singletonList(outputVariables()[0]); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThan.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThan.java index 23c09413f..e3b6f8d4c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThan.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThan.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseBroadcastBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class BroadcastLessThan extends BaseBroadcastBoolOp { @@ -88,6 +89,6 @@ public class BroadcastLessThan extends BaseBroadcastBoolOp { @Override public List doDiff(List f1) { - return Arrays.asList(outputVariables()[0]); + return Collections.singletonList(outputVariables()[0]); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java index dfa830801..cbfedd015 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/bool/BroadcastLessThanOrEqual.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseBroadcastBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class BroadcastLessThanOrEqual extends BaseBroadcastBoolOp { @@ -88,6 +89,6 @@ public class BroadcastLessThanOrEqual extends BaseBroadcastBoolOp { @Override public List doDiff(List f1) { - return Arrays.asList(outputVariables()[0]); + return Collections.singletonList(outputVariables()[0]); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/grid/BaseGridOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/grid/BaseGridOp.java index 82e11f321..6e38a738b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/grid/BaseGridOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/grid/BaseGridOp.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ops.grid.GridPointers; import org.nd4j.linalg.api.ops.grid.OpDescriptor; import java.util.ArrayList; +import java.util.Collections; import java.util.List; public abstract class BaseGridOp extends BaseOp implements GridOp { @@ -60,9 +61,7 @@ public abstract class BaseGridOp extends BaseOp implements GridOp { } protected BaseGridOp(GridPointers... pointers) { - for (GridPointers ptr : pointers) { - grid.add(ptr); - } + Collections.addAll(grid, pointers); } protected BaseGridOp(List ops) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java index cf5739b6c..2270912bf 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/CropAndResize.java @@ -36,7 +36,8 @@ import java.util.*; @NoArgsConstructor public class CropAndResize extends DynamicCustomOp { - public enum Method {BILINEAR, NEAREST}; + public enum Method {BILINEAR, NEAREST} + protected Method method = Method.BILINEAR; protected double extrapolationValue = 0.0; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeArea.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeArea.java index f37f8c33f..28a2a9488 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeArea.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeArea.java @@ -78,7 +78,7 @@ public class ResizeArea extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); val attrC = attributesForNode.get("align_corners"); - this.alignCorners = attrC != null ? attrC.getB() : false; + this.alignCorners = attrC != null && attrC.getB(); addArgs(); } @@ -86,7 +86,7 @@ public class ResizeArea extends DynamicCustomOp { protected void addArgs() { iArguments.clear(); if(height != null && width != null){ - INDArray size = Nd4j.createFromArray(new int[]{height,width}); + INDArray size = Nd4j.createFromArray(height,width); addInputArgument(size); //iArguments.add(Long.valueOf(height)); //iArguments.add(Long.valueOf(width)); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java index 6f39345fe..094844939 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ResizeBilinear.java @@ -84,8 +84,8 @@ public class ResizeBilinear extends DynamicCustomOp { val attrC = attributesForNode.get("align_corners"); val attrH = attributesForNode.get("half_pixel_centers"); - this.alignCorners = attrC != null ? attrC.getB() : false; - this.halfPixelCenters = attrH != null ? attrH.getB() : false; + this.alignCorners = attrC != null && attrC.getB(); + this.halfPixelCenters = attrH != null && attrH.getB(); addArgs(); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java index 725f8c2d2..302009030 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/AvgPooling2D.java @@ -274,7 +274,7 @@ public class AvgPooling2D extends DynamicCustomOp { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { val paddingVal = !attributesForNode.containsKey("auto_pad") ? "VALID" : attributesForNode.get("auto_pad").getS().toStringUtf8(); val kernelShape = attributesForNode.get("kernel_shape").getIntsList(); - val padding = !attributesForNode.containsKey("pads") ? Arrays.asList(1L) : attributesForNode.get("pads").getIntsList(); + val padding = !attributesForNode.containsKey("pads") ? Collections.singletonList(1L) : attributesForNode.get("pads").getIntsList(); val strides = attributesForNode.get("strides").getIntsList(); Pooling2DConfig pooling2DConfig = Pooling2DConfig.builder() diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java index 14f594611..cd2899948 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv2DTF.java @@ -245,7 +245,7 @@ public class DeConv2DTF extends DynamicCustomOp { int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); if(!dArguments.isEmpty()) { - return Arrays.asList(dArguments.get(0)); + return Collections.singletonList(dArguments.get(0)); } return Collections.singletonList(inputDataTypes.get(2)); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java index d897e6fe5..261240ca7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthToSpace.java @@ -72,7 +72,7 @@ public class DepthToSpace extends DynamicCustomOp { // Gradient to DepthToSpace is just SpaceToDepth of same block size and data format. SDVariable gradient = i_v.get(0); SDVariable ret = new SpaceToDepth(sameDiff, new SDVariable[]{gradient}, blockSize, dataFormat).outputVariable(); - return Arrays.asList(ret); + return Collections.singletonList(ret); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java index b10c53c0a..f967dcf44 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/LocalResponseNormalization.java @@ -124,7 +124,7 @@ public class LocalResponseNormalization extends DynamicCustomOp { .alpha(alpha) .beta(beta) .bias(bias) - .depth((int) depth) + .depth(depth) .build(); this.config = localResponseNormalizationConfig; addArgs(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java index 58437aa57..c772ad44b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/SConv2D.java @@ -79,9 +79,7 @@ public class SConv2D extends Conv2D { inputs.add(arg(0)); inputs.add(f1.get(0)); SDVariable[] args = args(); - for( int i=1; i attributesForNode, GraphDef graph) { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); - boolean isNHWC = dataFormat == null ? true : dataFormat.equals(DataFormat.NHWC); + boolean isNHWC = dataFormat == null || dataFormat.equals(DataFormat.NHWC); addIArgument(blockSize, isNHWC ? 1 : 0); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java index 9524a82ef..40ba2f959 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/GRUCellOutputs.java @@ -32,22 +32,22 @@ public class GRUCellOutputs { /** * Reset gate output [batchSize, numUnits]. */ - private SDVariable r; + private final SDVariable r; /** * Update gate output [batchSize, numUnits]. */ - private SDVariable u; + private final SDVariable u; /** * Cell gate output [batchSize, numUnits]. */ - private SDVariable c; + private final SDVariable c; /** * Current cell output [batchSize, numUnits]. */ - private SDVariable h; + private final SDVariable h; public GRUCellOutputs(SDVariable[] outputs){ Preconditions.checkArgument(outputs.length == 4, diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java index 6949a984e..a64a6ec87 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMCellOutputs.java @@ -32,37 +32,37 @@ public class LSTMCellOutputs { /** * Output - input modulation gate activations [batchSize, numUnits]. */ - private SDVariable i; + private final SDVariable i; /** * Activations, cell state (pre tanh) [batchSize, numUnits]. */ - private SDVariable c; + private final SDVariable c; /** * Output - forget gate activations [batchSize, numUnits]. */ - private SDVariable f; + private final SDVariable f; /** * Output - output gate activations [batchSize, numUnits]. */ - private SDVariable o; + private final SDVariable o; /** * Output - input gate activations [batchSize, numUnits]. */ - private SDVariable z; + private final SDVariable z; /** * Cell state, post tanh [batchSize, numUnits]. */ - private SDVariable h; + private final SDVariable h; /** * Current cell output [batchSize, numUnits]. */ - private SDVariable y; + private final SDVariable y; public LSTMCellOutputs(SDVariable[] outputs){ Preconditions.checkArgument(outputs.length == 7, diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java index e1a24e0d7..f0cacbe4c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/LSTMLayerOutputs.java @@ -37,7 +37,7 @@ public class LSTMLayerOutputs { /** * The LSTM layer data format ({@link LSTMDataFormat}. */ - private LSTMDataFormat dataFormat; + private final LSTMDataFormat dataFormat; /** @@ -51,21 +51,21 @@ public class LSTMLayerOutputs { * [sL, 2, bS, nOut] when directionMode == 4 && dataFormat == 3 * numbers mean index in corresponding enums {@link LSTMDataFormat} and {@link LSTMDirectionMode} */ - private SDVariable timeSeriesOutput; + private final SDVariable timeSeriesOutput; /** * cell state at last step cL: * [bS, nOut] when directionMode FWD or BWD * 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM */ - private SDVariable lastCellStateOutput; + private final SDVariable lastCellStateOutput; /** * output at last step hL: * [bS, nOut] when directionMode FWD or BWD * 2, bS, nOut] when directionMode BIDIR_SUM, BIDIR_CONCAT or BIDIR_EXTRA_DIM */ - private SDVariable lastTimeStepOutput; + private final SDVariable lastTimeStepOutput; public LSTMLayerOutputs(SDVariable[] outputs, LSTMLayerConfig lstmLayerConfig) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java index a6612cc65..603b82ec1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRUCellOutputs.java @@ -33,12 +33,12 @@ public class SRUCellOutputs { /** * Current cell output [batchSize, numUnits]. */ - private SDVariable h; + private final SDVariable h; /** * Current cell state [batchSize, numUnits]. */ - private SDVariable c; + private final SDVariable c; public SRUCellOutputs(SDVariable[] outputs){ Preconditions.checkArgument(outputs.length == 2, diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java index 5052c16d2..0c66ffd51 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/outputs/SRULayerOutputs.java @@ -35,12 +35,12 @@ public class SRULayerOutputs { /** * Current cell output [batchSize, inSize, timeSeriesLength]. */ - private SDVariable h; + private final SDVariable h; /** * Current cell state [batchSize, inSize, timeSeriesLength]. */ - private SDVariable c; + private final SDVariable c; public SRULayerOutputs(SDVariable[] outputs){ Preconditions.checkArgument(outputs.length == 2, @@ -90,7 +90,7 @@ public class SRULayerOutputs { return lastOutput; } - private SDVariable lastState = null; + private final SDVariable lastState = null; /** * Get c, the state of the cell, for the last time step. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index 53ccdb578..6e3a4441c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -89,7 +89,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected 2 input datatypes for %s, got %s", getClass(), inputDataTypes); if(dArguments != null && !dArguments.isEmpty()) - return Arrays.asList(dArguments.get(0)); + return Collections.singletonList(dArguments.get(0)); return Collections.singletonList(inputDataTypes.get(1)); //Same as predictions (logits) } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index be291a9c3..586c39780 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -253,8 +253,8 @@ public class Mmul extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0; - val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0; + val isTransposeA = attributesForNode.containsKey("transA") && attributesForNode.get("transA").getI() > 0; + val isTransposeB = attributesForNode.containsKey("transB") && attributesForNode.get("transB").getI() > 0; MMulTranspose mMulTranspose = MMulTranspose.builder() .transposeA(isTransposeA).transposeB(isTransposeB) .build(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index c80c7cfb6..fb64d7b7d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -308,8 +308,8 @@ public class TensorMmul extends DynamicCustomOp { @Override public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map attributesForNode, Onnx.GraphProto graph) { - val isTransposeA = !attributesForNode.containsKey("transA") ? false : attributesForNode.get("transA").getI() > 0; - val isTransposeB = !attributesForNode.containsKey("transB") ? false : attributesForNode.get("transB").getI() > 0; + val isTransposeA = attributesForNode.containsKey("transA") && attributesForNode.get("transA").getI() > 0; + val isTransposeB = attributesForNode.containsKey("transB") && attributesForNode.get("transB").getI() > 0; MMulTranspose mMulTranspose = MMulTranspose.builder() .transposeA(isTransposeA).transposeB(isTransposeB) .build(); @@ -325,7 +325,7 @@ public class TensorMmul extends DynamicCustomOp { if (addedEdges != that.addedEdges) return false; if (!Arrays.deepEquals(axes, that.axes)) return false; - return mMulTranspose != null ? mMulTranspose.equals(that.mMulTranspose) : that.mMulTranspose == null; + return Objects.equals(mMulTranspose, that.mMulTranspose); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EqualsWithEps.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EqualsWithEps.java index 154a8661a..170c5274d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EqualsWithEps.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EqualsWithEps.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class EqualsWithEps extends BaseReduce3Op { @@ -68,6 +69,6 @@ public class EqualsWithEps extends BaseReduce3Op { @Override public List doDiff(List f1) { - return Arrays.asList(outputVariables()[0]); + return Collections.singletonList(outputVariables()[0]); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarDivision.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarDivision.java index 12c1e7abc..8eb55dec3 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarDivision.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarDivision.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class ScalarDivision extends BaseScalarOp { @@ -72,6 +73,6 @@ public class ScalarDivision extends BaseScalarOp { @Override public List doDiff(List i_v1) { SDVariable ret = i_v1.get(0).div(scalarValue.getDouble(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSubtraction.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSubtraction.java index 6ebcb274f..9303ba72e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSubtraction.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSubtraction.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class ScalarSubtraction extends BaseScalarOp { @@ -68,6 +69,6 @@ public class ScalarSubtraction extends BaseScalarOp { public List doDiff(List i_v1) { SDVariable g = i_v1.get(0); - return Arrays.asList(g); + return Collections.singletonList(g); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarAnd.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarAnd.java index c8434bb1e..8d0c71531 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarAnd.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarAnd.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -69,6 +70,6 @@ public class ScalarAnd extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEps.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEps.java index c3e4cc306..1845fb542 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEps.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEps.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -83,7 +84,7 @@ public class ScalarEps extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java index 095420c5a..320afe6d5 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarEquals.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -82,7 +83,7 @@ public class ScalarEquals extends BaseScalarBoolOp { @Override public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java index 4a19f53cc..587ba1c05 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThan.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -77,6 +78,6 @@ public class ScalarGreaterThan extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java index 5e183ddbc..63ddda993 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarGreaterThanOrEqual.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -76,7 +77,7 @@ public class ScalarGreaterThanOrEqual extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java index 393728d06..fa804b30a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThan.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -73,6 +74,6 @@ public class ScalarLessThan extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java index 3dce29315..ee8ffffde 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarLessThanOrEqual.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -69,7 +70,7 @@ public class ScalarLessThanOrEqual extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNot.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNot.java index f1f0e78f1..e89ed9264 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNot.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNot.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -74,6 +75,6 @@ public class ScalarNot extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java index 21a6e14c4..923cc12cf 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarNotEquals.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -68,6 +69,6 @@ public class ScalarNotEquals extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarOr.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarOr.java index 7a3fb3ab7..5e0c4cacf 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarOr.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarOr.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,6 +76,6 @@ public class ScalarOr extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarSetValue.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarSetValue.java index cc488277e..8d65c1c62 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarSetValue.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarSetValue.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class ScalarSetValue extends BaseScalarOp { @@ -82,6 +83,6 @@ public class ScalarSetValue extends BaseScalarOp { @Override public List doDiff(List f1) { - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarXor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarXor.java index 54d368bed..b92bcd705 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarXor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/comparison/ScalarXor.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarBoolOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -74,6 +75,6 @@ public class ScalarXor extends BaseScalarBoolOp { public List doDiff(List f1) { //Not continuously differentiable, but 0 gradient in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index a5dffd37d..eca904de8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -71,7 +71,7 @@ public class ScatterAdd extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index ae7c12d8b..34e3278fe 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -96,7 +96,7 @@ public class ScatterDiv extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index b035772db..73b5eb688 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -68,7 +68,7 @@ public class ScatterMax extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 77ecd2404..3d138c2f0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -68,7 +68,7 @@ public class ScatterMin extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 1b2a458b0..db22db585 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -71,7 +71,7 @@ public class ScatterMul extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java index e680313b2..2700c096a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNd.java @@ -69,7 +69,7 @@ public class ScatterNd extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java index 781382d2b..96cdbdeb8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdAdd.java @@ -69,7 +69,7 @@ public class ScatterNdAdd extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java index 546cb5055..da50aba66 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdSub.java @@ -69,7 +69,7 @@ public class ScatterNdSub extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java index 825965fb0..157fb7bf9 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterNdUpdate.java @@ -69,7 +69,7 @@ public class ScatterNdUpdate extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index 37cdbb495..db344f0f6 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -87,7 +87,7 @@ public class ScatterSub extends DynamicCustomOp { TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index 59bd0a744..5cf1146be 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -38,7 +38,7 @@ import java.util.Map; public class ScatterUpdate extends DynamicCustomOp { - public static enum UpdateOp { + public enum UpdateOp { ADD, SUBTRACT, MULTIPLY, @@ -76,7 +76,7 @@ public class ScatterUpdate extends DynamicCustomOp { @Override public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { if (nodeDef.containsAttr("use_locking")) { - if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { + if (nodeDef.getAttrOrThrow("use_locking").getB()) { bArguments.add(true); } else { bArguments.add(false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ApplyGradientDescent.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ApplyGradientDescent.java index 62d1878af..1171bf080 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ApplyGradientDescent.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ApplyGradientDescent.java @@ -29,6 +29,7 @@ import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -108,7 +109,7 @@ public class ApplyGradientDescent extends DynamicCustomOp { @Override public List doDiff(List i_v) { SDVariable ret = this.outputVariables()[0]; - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java index d4d177688..f4c2ed15c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Create.java @@ -81,7 +81,7 @@ public class Create extends DynamicCustomOp { protected void addArgs() { addBArgument(initialize); - addIArgument((int) order,outputType.toInt()); + addIArgument(order,outputType.toInt()); } @Override @@ -121,7 +121,7 @@ public class Create extends DynamicCustomOp { @Override public List doDiff(List i_v) { SDVariable ret = sameDiff.zerosLike(outputVariables()[0]); - return Arrays.asList(ret); + return Collections.singletonList(ret); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java index 179565dfe..34dd74ca7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ExpandDims.java @@ -151,7 +151,7 @@ public class ExpandDims extends DynamicCustomOp { public List doDiff(List i_v) { //Simply need a reshape to remove the dimension... SDVariable ret = sameDiff.squeeze(i_v.get(0), jaxis); - return Arrays.asList(ret); + return Collections.singletonList(ret); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java index 8b2b7e6a0..5642ef253 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Eye.java @@ -127,7 +127,7 @@ public class Eye extends DynamicCustomOp { } } - addTArgument((double) dataType.toInt()); + addTArgument(dataType.toInt()); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java index 80ea9ac62..00aa5b59a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/OnesLike.java @@ -110,7 +110,7 @@ public class OnesLike extends DynamicCustomOp { @Override public List doDiff(List i_v) { SDVariable ret = sameDiff.zerosLike(outputVariables()[0]); - return Arrays.asList(ret); + return Collections.singletonList(ret); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java index 9e18a18fa..96200d9ee 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Repeat.java @@ -67,7 +67,7 @@ public class Repeat extends DynamicCustomOp { @Override public Map propertiesForFunction() { - return Collections.singletonMap("axis", axis); + return Collections.singletonMap("axis", axis); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 977241e23..7d6c9f12f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -86,7 +86,6 @@ public class Reshape extends DynamicCustomOp { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { if (!nodeDef.containsAttr("TShape") && nodeDef.getInputCount() == 1) { this.shape = new long[]{}; - return; } else if(nodeDef.getInputCount() == 1){ val shape = nodeDef.getAttrOrThrow("Tshape"); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java index b25762f3a..d319ac8bf 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Squeeze.java @@ -86,14 +86,13 @@ public class Squeeze extends DynamicCustomOp { for (int d : squeezeDims) { ret = sameDiff.expandDims(ret, d); } - ; - return Arrays.asList(ret); + return Collections.singletonList(ret); } @Override public List calculateOutputDataTypes(List dataTypes){ Preconditions.checkState(!dataTypes.isEmpty(), "Expected list with at least 1 datatype for %s, got %s", getClass(), dataTypes); //Output type is same as input type - return Arrays.asList(dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java index eb0a7947e..7fc1953c0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/MaxOut.java @@ -115,10 +115,7 @@ public class MaxOut extends BaseTransformOp { if (y != null && !y().isR()) return false; - if (z != null && z().dataType() != x().dataType()) - return false; - - return true; + return z == null || z().dataType() == x().dataType(); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java index be8c99fef..ff48dfa31 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java @@ -32,7 +32,7 @@ import java.util.Collections; import java.util.List; /** - * [1, 2, 3, 1] -> [0, 0, 1, 0] + * [1, 2, 3, 1] -> [0, 0, 1, 0] * @author Adam Gibson */ public class IsMax extends DynamicCustomOp { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index 431764efd..ad93c34fd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -98,6 +98,6 @@ public class ClipByValue extends DynamicCustomOp { public List calculateOutputDataTypes(List inputDataTypes){ Preconditions.checkState(inputDataTypes != null && !inputDataTypes.isEmpty() , "Expected at least 1 input datatype for %s, got %s", getClass(), inputDataTypes); //get the final data type (sometimes model import passes in 2 dummy data types that aren't relevant) - return Arrays.asList(inputDataTypes.get(inputDataTypes.size() - 1)); + return Collections.singletonList(inputDataTypes.get(inputDataTypes.size() - 1)); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java index de05d6476..3bc42556c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java @@ -89,7 +89,7 @@ public class BatchToSpace extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of batch to space is space to batch with same blocks and padding as crops SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops[0], crops[1])); + return Collections.singletonList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops[0], crops[1])); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java index d9661edd3..df8e50566 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java @@ -71,7 +71,7 @@ public class BatchToSpaceND extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of batch to space is space to batch with same blocks and padding as crops SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops[0], crops[1])); + return Collections.singletonList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops[0], crops[1])); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java index 69665508a..44e8ea079 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Choose.java @@ -69,7 +69,7 @@ public class Choose extends DynamicCustomOp { * @param condition */ public Choose(INDArray[] inputs,Condition condition) { - this(inputs, Collections.emptyList(),Collections.emptyList(),condition); + this(inputs, Collections.emptyList(),Collections.emptyList(),condition); } /** diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java index f00743762..362d16392 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java @@ -80,7 +80,7 @@ public class SpaceToBatch extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of space to batch is batch to space with same blocks and crops as padding SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding[0], padding[1])); + return Collections.singletonList(sameDiff.cnn().batchToSpace(gradient, blocks, padding[0], padding[1])); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java index 327252215..432f69931 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java @@ -71,7 +71,7 @@ public class SpaceToBatchND extends DynamicCustomOp { public List doDiff(List i_v) { // Inverse of space to batch is batch to space with same blocks and crops as padding SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding[0], padding[1])); + return Collections.singletonList(sameDiff.cnn().batchToSpace(gradient, blocks, padding[0], padding[1])); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/StandardizeBp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/StandardizeBp.java index df0679604..6d4d9d79b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/StandardizeBp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/StandardizeBp.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class StandardizeBp extends DynamicCustomOp { @@ -70,6 +71,6 @@ public class StandardizeBp extends DynamicCustomOp { Preconditions.checkState(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatype for %s, got %s", getClass(), dataTypes); Preconditions.checkState(dataTypes.get(0).isFPType(), "Input 0 must be a floating point type, got %s", dataTypes.get(0)); Preconditions.checkState(dataTypes.get(1).isFPType(), "Input 1 must be a floating point type, got %s", dataTypes.get(1)); - return Arrays.asList(dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java index 6df9ff2d9..6894fffb4 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java @@ -47,7 +47,7 @@ public class SquaredDifferenceOp extends BaseDynamicTransformOp { } public SquaredDifferenceOp(INDArray x, INDArray y) { - addInputArgument(new INDArray[]{x,y}); + addInputArgument(x,y); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java index 905263b0d..ce2bd3c52 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java @@ -77,7 +77,7 @@ public class Abs extends BaseTransformSameOp { @Override public List doDiff(List i_v) { SDVariable ret = sameDiff.math.sign(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java index 787545aa0..1591596f8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class Ceil extends BaseTransformSameOp { @@ -73,6 +74,6 @@ public class Ceil extends BaseTransformSameOp { public List doDiff(List f1) { //not continuously differentiable, but dOut/dIn = 0 in most places - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java index bd95aa501..e7d755da7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Floor.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -71,7 +72,7 @@ public class Floor extends BaseTransformSameOp { public List doDiff(List i_v) { //Floor op: non-continuous at integers, but 0 gradient otherwise - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index f5249b4bf..e6b285281 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -77,7 +77,7 @@ public class Identity extends BaseDynamicTransformOp { public List calculateOutputDataTypes(List dataTypes) { Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got input %s", getClass(), dataTypes); if(!dArguments.isEmpty()) - return Arrays.asList(dArguments.get(0)); + return Collections.singletonList(dArguments.get(0)); return dataTypes; } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java index b9ef68cbe..2945feab2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -70,7 +71,7 @@ public class Negative extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - return Arrays.asList(sameDiff.math.neg(i_v.get(0))); + return Collections.singletonList(sameDiff.math.neg(i_v.get(0))); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java index b99f04f91..fe55e245c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -73,6 +74,6 @@ public class Round extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return Arrays.asList(sameDiff.zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java index 12962cdba..945b9f7d1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Sign.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class Sign extends BaseTransformSameOp { @@ -73,7 +74,7 @@ public class Sign extends BaseTransformSameOp { @Override public List doDiff(List i_v) { SDVariable ret = sameDiff.zerosLike(arg()); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java index 92eeb2df4..8ef16b7eb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACosh.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -75,7 +76,7 @@ public class ACosh extends BaseTransformStrictOp { //dacosh(x)/dx = 1/(sqrt(x^2-1)) -- note that domain is x >= 1 SDVariable xSqPlus1 = sameDiff.math().square(arg()).sub(1.0); SDVariable sqrt = sameDiff.math().sqrt(xSqPlus1); - return Arrays.asList(i_v.get(0).div(sqrt)); + return Collections.singletonList(i_v.get(0).div(sqrt)); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java index 2975aabf8..864b04867 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class ASinh extends BaseTransformStrictOp { @@ -74,6 +75,6 @@ public class ASinh extends BaseTransformStrictOp { //dasinh(x)/dx = 1 / sqrt(x^2+1) SDVariable xSqPlus1 = sameDiff.math.square(arg()).add(1.0); SDVariable ret = i_v.get(0).div(sameDiff.math.sqrt(xSqPlus1)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATanh.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATanh.java index 6953a8df1..d63cee3e1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATanh.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATanh.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class ATanh extends BaseTransformStrictOp { @@ -77,7 +78,7 @@ public class ATanh extends BaseTransformStrictOp { SDVariable oneMinusX2 = sameDiff.math().square(arg()).rsub(1.0); SDVariable ret = oneMinusX2.rdiv(1.0).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java index 7a519db06..2a300e5e7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -62,7 +63,7 @@ public class Cos extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { SDVariable ret = sameDiff.math.neg(sameDiff.math.sin(arg())).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java index 8b5b9dbcd..9f9ee4e52 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -72,7 +73,7 @@ public class Cosh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { SDVariable ret = sameDiff.math.sinh(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java index 7974c6197..02d946052 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -71,7 +72,7 @@ public class Exp extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { SDVariable ret = sameDiff.math.mul(sameDiff.math.exp(arg()), i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java index 6733a26b7..cec890024 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -72,7 +73,7 @@ public class Expm1 extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { SDVariable ret = sameDiff.math.mul(sameDiff.math.exp(arg()), i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java index d8e5ce1f6..bb899e850 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -72,7 +73,7 @@ public class Swish extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { SDVariable ret = new SwishDerivative(sameDiff, arg()).outputVariable().mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/performance/PerformanceTracker.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/performance/PerformanceTracker.java index dbed287b3..e49397525 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/performance/PerformanceTracker.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/performance/PerformanceTracker.java @@ -35,8 +35,8 @@ import java.util.Map; public class PerformanceTracker { private static final PerformanceTracker INSTANCE = new PerformanceTracker(); - private Map bandwidth = new HashMap<>(); - private Map operations = new HashMap<>(); + private final Map bandwidth = new HashMap<>(); + private final Map operations = new HashMap<>(); private PerformanceTracker() { // we put in initial holders, one per device @@ -77,7 +77,7 @@ public class PerformanceTracker { */ public long addMemoryTransaction(int deviceId, long timeSpentNanos, long numberOfBytes, @NonNull MemcpyDirection direction) { // we calculate bytes per microsecond now - val bw = (long) (numberOfBytes / (timeSpentNanos / (double) 1000.0)); + val bw = (long) (numberOfBytes / (timeSpentNanos / 1000.0)); // we skip too small values if (bw > 0) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/random/compat/RandomStandardNormal.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/random/compat/RandomStandardNormal.java index e1e7dcdce..baae1a0a6 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/random/compat/RandomStandardNormal.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/random/compat/RandomStandardNormal.java @@ -60,7 +60,7 @@ public class RandomStandardNormal extends DynamicCustomOp { addTArgument(0.0, 1.0); } - public RandomStandardNormal(long shape[]) { + public RandomStandardNormal(long[] shape) { this(Nd4j.create(ArrayUtil.toDouble(shape)), Nd4j.create(shape)); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java index 8f24a5ca4..d73b28ba0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/random/custom/RandomPoisson.java @@ -81,7 +81,7 @@ public class RandomPoisson extends DynamicCustomOp { getClass(), inputDataTypes.size()); if(!dArguments.isEmpty()) - return Arrays.asList(dArguments.get(0)); + return Collections.singletonList(dArguments.get(0)); return Collections.singletonList(outputDataType); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/DefaultRandom.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/DefaultRandom.java index bf49e168d..d7d813176 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/DefaultRandom.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/DefaultRandom.java @@ -52,7 +52,7 @@ public class DefaultRandom implements Random, RandomGenerator { @Override public void setSeed(int seed) { - this.seed = (long) seed; + this.seed = seed; getRandomGenerator().setSeed(seed); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java index 6e929fa8d..b56722c30 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java @@ -43,7 +43,7 @@ public class ConstantDistribution extends BaseDistribution { /** * Mean of this distribution. */ - private double value; + private final double value; public ConstantDistribution(double value) { this.value = value; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java index bd7cff94d..3b1faaf71 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java @@ -48,7 +48,7 @@ public class OrthogonalDistribution extends BaseDistribution { /** * Mean of this distribution. */ - private double gain; + private final double gain; private INDArray gains; public OrthogonalDistribution(double gain) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java index b06896146..07627f05c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java @@ -30,7 +30,8 @@ import org.nd4j.linalg.api.rng.distribution.BaseDistribution; import org.nd4j.linalg.factory.Nd4j; public class UniformDistribution extends BaseDistribution { - private double upper, lower; + private final double upper; + private final double lower; /** * Create a uniform real distribution using the given lower and upper diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/shape/LongShapeDescriptor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/shape/LongShapeDescriptor.java index 0263533b6..bcd6fad49 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/shape/LongShapeDescriptor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/shape/LongShapeDescriptor.java @@ -32,20 +32,20 @@ import java.util.Arrays; public class LongShapeDescriptor { @Getter - private char order; + private final char order; - private long offset; + private final long offset; - private long ews; + private final long ews; - private long hashShape = 0; - private long hashStride = 0; + private final long hashShape = 0; + private final long hashStride = 0; @Getter - private long[] shape; + private final long[] shape; @Getter - private long[] stride; + private final long[] stride; @Getter @Setter private long extras; @@ -107,7 +107,7 @@ public class LongShapeDescriptor { @Override public int hashCode() { - int result = (int) order; + int result = order; result = 31 * result + longHashCode(offset); result = 31 * result + longHashCode(ews); @@ -120,13 +120,11 @@ public class LongShapeDescriptor { @Override public String toString() { - StringBuilder builder = new StringBuilder(); + String builder = shape.length + "," + Arrays.toString(shape) + "," + + Arrays.toString(stride) + "," + extras + "," + ews + "," + + order; - builder.append(shape.length).append(",").append(Arrays.toString(shape)).append(",") - .append(Arrays.toString(stride)).append(",").append(extras).append(",").append(ews).append(",") - .append(order); - - String result = builder.toString().replaceAll("\\]", "").replaceAll("\\[", ""); + String result = builder.replaceAll("\\]", "").replaceAll("\\[", ""); result = "[" + result + "]"; return result; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index cacc677cb..0f2174fd1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -907,7 +907,7 @@ public class Shape { throw new IllegalArgumentException( String.format("J: Index [%d] must not be >= shape[%d]=%d.", i, i, shape[i])); if (shape[i] != 1) { - offset += indices[i] * stride[i]; + offset += (long) indices[i] * stride[i]; } } @@ -995,7 +995,7 @@ public class Shape { throw new IllegalArgumentException( String.format("J: Index [%d] must not be >= shape[%d]=%d.", i, i, size_dimi)); if (size_dimi != 1) { - offset += indices[i] * stride(shapeInformation, i); + offset += (long) indices[i] * stride(shapeInformation, i); } } return offset; @@ -1058,9 +1058,9 @@ public class Shape { + Arrays.toString(shape(shapeInformation)) + " NDArray"); if (size_0 != 1) - offset += row * strideUnsafe(shapeInformation, 0, 2); + offset += (long) row * strideUnsafe(shapeInformation, 0, 2); if (size_1 != 1) - offset += col * strideUnsafe(shapeInformation, 1, 2); + offset += (long) col * strideUnsafe(shapeInformation, 1, 2); return offset; } @@ -1075,9 +1075,9 @@ public class Shape { + Arrays.toString(shape(shapeInformation)) + " NDArray"); if (size_0 != 1) - offset += row * strideUnsafe(shapeInformation, 0, 2); + offset += (long) row * strideUnsafe(shapeInformation, 0, 2); if (size_1 != 1) - offset += col * strideUnsafe(shapeInformation, 1, 2); + offset += (long) col * strideUnsafe(shapeInformation, 1, 2); return offset; } @@ -1118,9 +1118,9 @@ public class Shape { + Arrays.toString(shape(shapeInformation)) + " NDArray"); if (size_0 != 1) - offset += row * stride(shapeInformation, 0); + offset += (long) row * stride(shapeInformation, 0); if (size_1 != 1) - offset += col * stride(shapeInformation, 1); + offset += (long) col * stride(shapeInformation, 1); return offset; } @@ -1147,11 +1147,11 @@ public class Shape { + "] from a " + Arrays.toString(shape(shapeInformation)) + " NDArray"); if (size_0 != 1) - offset += dim0 * stride(shapeInformation, 0); + offset += (long) dim0 * stride(shapeInformation, 0); if (size_1 != 1) - offset += dim1 * stride(shapeInformation, 1); + offset += (long) dim1 * stride(shapeInformation, 1); if (size_2 != 1) - offset += dim2 * stride(shapeInformation, 2); + offset += (long) dim2 * stride(shapeInformation, 2); return offset; } @@ -1185,11 +1185,11 @@ public class Shape { + "] from a " + Arrays.toString(shape(shapeInformation)) + " NDArray"); if (size_0 != 1) - offset += dim0 * strideUnsafe(shapeInformation, 0, 3); + offset += (long) dim0 * strideUnsafe(shapeInformation, 0, 3); if (size_1 != 1) - offset += dim1 * strideUnsafe(shapeInformation, 1, 3); + offset += (long) dim1 * strideUnsafe(shapeInformation, 1, 3); if (size_2 != 1) - offset += dim2 * strideUnsafe(shapeInformation, 2, 3); + offset += (long) dim2 * strideUnsafe(shapeInformation, 2, 3); return offset; } @@ -1237,13 +1237,13 @@ public class Shape { + dim3 + "] from a " + Arrays.toString(shape(shapeInformation)) + " NDArray"); if (size_0 != 1) - offset += dim0 * stride(shapeInformation, 0); + offset += (long) dim0 * stride(shapeInformation, 0); if (size_1 != 1) - offset += dim1 * stride(shapeInformation, 1); + offset += (long) dim1 * stride(shapeInformation, 1); if (size_2 != 1) - offset += dim2 * stride(shapeInformation, 2); + offset += (long) dim2 * stride(shapeInformation, 2); if (size_3 != 1) - offset += dim3 * stride(shapeInformation, 3); + offset += (long) dim3 * stride(shapeInformation, 3); return offset; } @@ -1276,13 +1276,13 @@ public class Shape { + dim3 + "] from a " + Arrays.toString(shape(shapeInformation)) + " NDArray"); if (size_0 != 1) - offset += dim0 * strideUnsafe(shapeInformation, 0, 4); + offset += (long) dim0 * strideUnsafe(shapeInformation, 0, 4); if (size_1 != 1) - offset += dim1 * strideUnsafe(shapeInformation, 1, 4); + offset += (long) dim1 * strideUnsafe(shapeInformation, 1, 4); if (size_2 != 1) - offset += dim2 * strideUnsafe(shapeInformation, 2, 4); + offset += (long) dim2 * strideUnsafe(shapeInformation, 2, 4); if (size_3 != 1) - offset += dim3 * strideUnsafe(shapeInformation, 3, 4); + offset += (long) dim3 * strideUnsafe(shapeInformation, 3, 4); return offset; } @@ -1299,13 +1299,13 @@ public class Shape { + dim3 + "] from a " + Arrays.toString(shape(shapeInformation)) + " NDArray"); if (size_0 != 1) - offset += dim0 * strideUnsafe(shapeInformation, 0, 4); + offset += (long) dim0 * strideUnsafe(shapeInformation, 0, 4); if (size_1 != 1) - offset += dim1 * strideUnsafe(shapeInformation, 1, 4); + offset += (long) dim1 * strideUnsafe(shapeInformation, 1, 4); if (size_2 != 1) - offset += dim2 * strideUnsafe(shapeInformation, 2, 4); + offset += (long) dim2 * strideUnsafe(shapeInformation, 2, 4); if (size_3 != 1) - offset += dim3 * strideUnsafe(shapeInformation, 3, 4); + offset += (long) dim3 * strideUnsafe(shapeInformation, 3, 4); return offset; } @@ -1630,21 +1630,13 @@ public class Shape { public static boolean scalarEquals(int[] shape1, int[] shape2) { if (shape1.length == 0 && shape2.length == 1 && shape2[0] == 1) { return true; - } else if (shape2.length == 0 && shape1.length == 1 && shape1[0] == 1) { - return true; - } - - return false; + } else return shape2.length == 0 && shape1.length == 1 && shape1[0] == 1; } public static boolean scalarEquals(long[] shape1, long[] shape2) { if (shape1.length == 0 && shape2.length == 1 && shape2[0] == 1) { return true; - } else if (shape2.length == 0 && shape1.length == 1 && shape1[0] == 1) { - return true; - } - - return false; + } else return shape2.length == 0 && shape1.length == 1 && shape1[0] == 1; } /** @@ -2310,7 +2302,7 @@ public class Shape { long index = 0; int shift = 1; for (int i = 0; i < shape.length; i++) { - index += shift * indices[i]; + index += (long) shift * indices[i]; shift *= shape[i]; } return index; @@ -2891,13 +2883,13 @@ public class Shape { * @return */ public static IntBuffer shapeOf(IntBuffer buffer) { - Buffer buffer2 = (Buffer) buffer; + Buffer buffer2 = buffer; IntBuffer ret = (IntBuffer) buffer2.position(1); return ret.slice(); } public static LongBuffer shapeOf(LongBuffer buffer) { - Buffer buffer2 = (Buffer) buffer; + Buffer buffer2 = buffer; val ret = (LongBuffer) buffer2.position(1); return ret.slice(); } @@ -3230,7 +3222,7 @@ public class Shape { @Deprecated public static void setOrder(IntBuffer buffer, char order) { int length = Shape.shapeInfoLength(Shape.rank(buffer)); - buffer.put(length - 1, (int) order); + buffer.put(length - 1, order); throw new RuntimeException("setOrder called"); } @@ -3449,7 +3441,7 @@ public class Shape { */ public static boolean contentEquals(int[] arr, IntBuffer other) { for (int i = 0; i < arr.length; i++) { - Buffer buffer2 = (Buffer) other; + Buffer buffer2 = other; buffer2.position(i); if (arr[i] != other.get()) { return false; @@ -3644,7 +3636,7 @@ public class Shape { //Length is simply 1 + the buffer index of the last element long length = 1; for(int i=0; i> getAllTestMatricesWithShape(char ordering, int rows, int cols, int seed, DataType dataType) { List> all = new ArrayList<>(); Nd4j.getRandom().setSeed(seed); - all.add(new Pair<>(Nd4j.linspace(1, rows * cols, rows * cols, dataType).reshape(ordering, rows, cols), + all.add(new Pair<>(Nd4j.linspace(1, (long) rows * cols, (long) rows * cols, dataType).reshape(ordering, rows, cols), "Nd4j..linspace(1,rows * cols,rows * cols).reshape(rows,cols)")); all.add(getTransposedMatrixWithShape(ordering, rows, cols, seed, dataType)); @@ -96,7 +96,7 @@ public class NDArrayCreationUtil { List> all = new ArrayList<>(); if (rank == 0) { //scalar - all.add(new Pair<>(Nd4j.scalar(dataType, Nd4j.rand(dataType, new int[]{1, 1}).getDouble(0)), "{}")); + all.add(new Pair<>(Nd4j.scalar(dataType, Nd4j.rand(dataType, 1, 1).getDouble(0)), "{}")); return all; } //generate all possible combinations with a 1 and a 2 @@ -128,7 +128,7 @@ public class NDArrayCreationUtil { public static Pair getTransposedMatrixWithShape(char ordering, int rows, int cols, int seed, DataType dataType) { Nd4j.getRandom().setSeed(seed); - INDArray out = Nd4j.linspace(1, rows * cols, rows * cols, dataType).reshape(ordering, cols, rows); + INDArray out = Nd4j.linspace(1, (long) rows * cols, (long) rows * cols, dataType).reshape(ordering, cols, rows); return new Pair<>(out.transpose(), "getTransposedMatrixWithShape(" + rows + "," + cols + "," + seed + ")"); } @@ -181,7 +181,7 @@ public class NDArrayCreationUtil { out[1] = temp01.tensorAlongDimension(2, 0, 1).reshape(rows, cols); Nd4j.getRandom().setSeed(seed); - INDArray temp02 = Nd4j.linspace(1, len, len, dataType).reshape(new long[] {cols, 4, rows}); + INDArray temp02 = Nd4j.linspace(1, len, len, dataType).reshape(cols, 4, rows); out[2] = temp02.tensorAlongDimension(0, 0, 2).reshape(rows, cols); temp02 = Nd4j.linspace(1, len, len, dataType).reshape(cols, 4, rows); out[3] = temp02.tensorAlongDimension(2, 0, 2).reshape(rows, cols); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/BasicNDArrayCompressor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/BasicNDArrayCompressor.java index 283694f56..76bb501b8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/BasicNDArrayCompressor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/BasicNDArrayCompressor.java @@ -86,7 +86,7 @@ public class BasicNDArrayCompressor { builder.append("[").append(comp).append("] "); } - System.out.println(builder.toString()); + System.out.println(builder); } /** diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java index cdbd52d3b..29c5ebe90 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java @@ -44,7 +44,7 @@ public class CompressedDataBuffer extends BaseDataBuffer { @Getter @Setter protected CompressionDescriptor compressionDescriptor; - private static Logger logger = LoggerFactory.getLogger(CompressedDataBuffer.class); + private static final Logger logger = LoggerFactory.getLogger(CompressedDataBuffer.class); public CompressedDataBuffer(Pointer pointer, @NonNull CompressionDescriptor descriptor) { this.compressionDescriptor = descriptor; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/CompressionUtils.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/CompressionUtils.java index 03ce9b4b5..ff101063f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/CompressionUtils.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/compression/CompressionUtils.java @@ -27,16 +27,10 @@ public class CompressionUtils { public static boolean goingToDecompress(@NonNull DataTypeEx from, @NonNull DataTypeEx to) { // TODO: eventually we want FLOAT16 here - if (to.equals(DataTypeEx.FLOAT) || to.equals(DataTypeEx.DOUBLE) ) - return true; - - return false; + return to.equals(DataTypeEx.FLOAT) || to.equals(DataTypeEx.DOUBLE); } public static boolean goingToCompress(@NonNull DataTypeEx from, @NonNull DataTypeEx to) { - if (!goingToDecompress(from, to) && goingToDecompress(to, from)) - return true; - - return false; + return !goingToDecompress(from, to) && goingToDecompress(to, from); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java index 4d91be1ba..b2cb74a30 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/convolution/Convolution.java @@ -75,7 +75,7 @@ public class Convolution { if (col.rank() != 6) throw new IllegalArgumentException("col2im input array must be rank 6"); - INDArray output = Nd4j.create(col.dataType(), new long[]{col.size(0), col.size(1), kH, kW}); + INDArray output = Nd4j.create(col.dataType(), col.size(0), col.size(1), kH, kW); val cfg = Conv2DConfig.builder() .sH(sH) @@ -289,8 +289,8 @@ public class Convolution { output = Nd4j.createUninitialized(img.dataType(), new long[]{img.size(0), img.size(1), kh, kw, oH, oW}, 'c'); } else { - long oH = (img.size(2) - (kh + (kh - 1) * (1 - 1)) + 2 * ph) / sy + 1; - long oW = (img.size(3) - (kw + (kw - 1) * (1 - 1)) + 2 * pw) / sx + 1; + long oH = (img.size(2) - (kh + 0) + 2L * ph) / sy + 1; + long oW = (img.size(3) - (kw + 0) + 2L * pw) / sx + 1; output = Nd4j.createUninitialized(img.dataType(), new long[]{img.size(0), img.size(1), kh, kw, oH, oW}, 'c'); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/convolution/OldConvolution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/convolution/OldConvolution.java index 383e807b1..f898598b1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/convolution/OldConvolution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/convolution/OldConvolution.java @@ -71,7 +71,7 @@ public class OldConvolution { //out width long outW = col.size(5); - INDArray img = Nd4j.create(n, c, h + 2 * ph + sy - 1, w + 2 * pw + sx - 1); + INDArray img = Nd4j.create(n, c, h + 2L * ph + sy - 1, w + 2L * pw + sx - 1); for (int i = 0; i < kh; i++) { //iterate over the kernel rows long iLim = i + sy * outH; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java index 842b770d3..e57f29072 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java @@ -117,7 +117,7 @@ public class AsyncDataSetIterator implements DataSetIterator { this.buffer = queue; this.prefetchSize = queueSize; this.backedIterator = iterator; - this.workspaceId = "ADSI_ITER-" + java.util.UUID.randomUUID().toString(); + this.workspaceId = "ADSI_ITER-" + java.util.UUID.randomUUID(); if (iterator.resetSupported() && !iterator.hasNext()) this.backedIterator.reset(); @@ -364,11 +364,11 @@ public class AsyncDataSetIterator implements DataSetIterator { } protected class AsyncPrefetchThread extends Thread implements Runnable { - private BlockingQueue queue; - private DataSetIterator iterator; - private DataSet terminator; + private final BlockingQueue queue; + private final DataSetIterator iterator; + private final DataSet terminator; private boolean isShutdown = false; // locked around `this` - private WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().minSize(10 * 1024L * 1024L) + private final WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().minSize(10 * 1024L * 1024L) .overallocationLimit(prefetchSize + 2).policyReset(ResetPolicy.ENDOFBUFFER_REACHED) .policyLearning(LearningPolicy.FIRST_LOOP).policyAllocation(AllocationPolicy.OVERALLOCATE) .policySpill(SpillPolicy.REALLOCATE).build(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java index 7b32dca06..822fa3ce2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java @@ -103,7 +103,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { this.backedIterator = iterator; this.useWorkspaces = useWorkspace; this.prefetchSize = queueSize; - this.workspaceId = "AMDSI_ITER-" + java.util.UUID.randomUUID().toString(); + this.workspaceId = "AMDSI_ITER-" + java.util.UUID.randomUUID(); this.deviceId = deviceId; if (iterator.resetSupported() && !iterator.hasNext()) @@ -312,11 +312,11 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { } protected class AsyncPrefetchThread extends Thread implements Runnable { - private BlockingQueue queue; - private MultiDataSetIterator iterator; - private MultiDataSet terminator; + private final BlockingQueue queue; + private final MultiDataSetIterator iterator; + private final MultiDataSet terminator; private boolean isShutdown = false; // locked around `this` - private WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().minSize(10 * 1024L * 1024L) + private final WorkspaceConfiguration configuration = WorkspaceConfiguration.builder().minSize(10 * 1024L * 1024L) .overallocationLimit(prefetchSize + 1).policyReset(ResetPolicy.ENDOFBUFFER_REACHED) .policyLearning(LearningPolicy.FIRST_LOOP).policyAllocation(AllocationPolicy.OVERALLOCATE) .policySpill(SpillPolicy.REALLOCATE).build(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/ExistingMiniBatchDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/ExistingMiniBatchDataSetIterator.java index 87f430b19..1447ec78c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/ExistingMiniBatchDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/ExistingMiniBatchDataSetIterator.java @@ -33,7 +33,7 @@ public class ExistingMiniBatchDataSetIterator implements DataSetIterator { public static final String DEFAULT_PATTERN = "dataset-%d.bin"; private int currIdx; - private File rootDir; + private final File rootDir; private int totalBatches = -1; private DataSetPreProcessor dataSetPreProcessor; private final String pattern; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.java index 5a9222865..7ea2e04c1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/MiniBatchFileDataSetIterator.java @@ -35,12 +35,12 @@ import java.util.UUID; @Slf4j public class MiniBatchFileDataSetIterator implements DataSetIterator { - private int batchSize; - private List paths; + private final int batchSize; + private final List paths; private int currIdx; private File rootDir; - private int totalExamples; - private int totalLabels; + private final int totalExamples; + private final int totalLabels; private int totalBatches = -1; private DataSetPreProcessor dataSetPreProcessor; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/ViewIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/ViewIterator.java index b6def37c8..dd4ba90a7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/ViewIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/ViewIterator.java @@ -34,7 +34,7 @@ import java.util.List; public class ViewIterator implements DataSetIterator { private int batchSize = -1; private int cursor = 0; - private DataSet data; + private final DataSet data; private DataSetPreProcessor preProcessor; public ViewIterator(DataSet data, int batchSize) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/adapter/MultiDataSetIteratorAdapter.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/adapter/MultiDataSetIteratorAdapter.java index 7ca5f922a..e100166b8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/adapter/MultiDataSetIteratorAdapter.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/adapter/MultiDataSetIteratorAdapter.java @@ -26,7 +26,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; public class MultiDataSetIteratorAdapter implements MultiDataSetIterator { - private org.nd4j.linalg.dataset.api.iterator.DataSetIterator iter; + private final org.nd4j.linalg.dataset.api.iterator.DataSetIterator iter; private MultiDataSetPreProcessor preProcessor; public MultiDataSetIteratorAdapter(org.nd4j.linalg.dataset.api.iterator.DataSetIterator iter) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSetUtil.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSetUtil.java index 967fe995d..5545569d7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSetUtil.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSetUtil.java @@ -107,7 +107,7 @@ public class DataSetUtil { mask = mask.dup('f'); } - INDArray mask1d = mask.reshape('f', new long[] {mask.length(), 1}); + INDArray mask1d = mask.reshape('f', mask.length(), 1); //Assume masks are 0s and 1s: then sum == number of elements int numElements = mask.sumNumber().intValue(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/CachingDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/CachingDataSetIterator.java index 31a571ba3..b69207188 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/CachingDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/CachingDataSetIterator.java @@ -31,12 +31,12 @@ import java.util.List; public class CachingDataSetIterator implements DataSetIterator { private static final Logger log = LoggerFactory.getLogger(DataSetCache.class); - private DataSetIterator sourceIterator; - private DataSetCache cache; - private String namespace; + private final DataSetIterator sourceIterator; + private final DataSetCache cache; + private final String namespace; private int currentIndex = 0; private boolean usingCache = false; - private boolean allowPrefetching; + private final boolean allowPrefetching; public CachingDataSetIterator(DataSetIterator sourceIterator, DataSetCache cache, String namespace) { this(sourceIterator, cache, namespace, false); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/MultipleEpochsIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/MultipleEpochsIterator.java index 6ebea0682..fe3117040 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/MultipleEpochsIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/MultipleEpochsIterator.java @@ -33,9 +33,9 @@ import java.util.List; @Deprecated public class MultipleEpochsIterator implements DataSetIterator { private static final Logger log = LoggerFactory.getLogger(MultipleEpochsIterator.class); - private int numPasses; + private final int numPasses; private int batch = 0; - private DataSetIterator iter; + private final DataSetIterator iter; private int passes = 0; private DataSetPreProcessor preProcessor; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/SamplingDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/SamplingDataSetIterator.java index 87c6eafae..606700099 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/SamplingDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/SamplingDataSetIterator.java @@ -32,9 +32,9 @@ import java.util.List; * @author Adam Gibson */ public class SamplingDataSetIterator implements DataSetIterator { - private DataSet sampleFrom; - private int batchSize; - private int totalNumberSamples; + private final DataSet sampleFrom; + private final int batchSize; + private final int totalNumberSamples; private int numTimesSampled; private boolean replace = false; private DataSetPreProcessor preProcessor; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/StandardScaler.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/StandardScaler.java index 757f459ac..240e35f20 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/StandardScaler.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/StandardScaler.java @@ -33,7 +33,7 @@ import java.io.IOException; @Deprecated public class StandardScaler { - private static Logger logger = LoggerFactory.getLogger(StandardScaler.class); + private static final Logger logger = LoggerFactory.getLogger(StandardScaler.class); private INDArray mean, std; private long runningTotal = 0; private long batchCount = 0; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java index 982987645..ff82d068c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java @@ -31,7 +31,7 @@ public class TestDataSetIterator implements DataSetIterator { private static final long serialVersionUID = -7569201667767185411L; private int curr = 0; private int batch = 10; - private List list; + private final List list; private DataSetPreProcessor preProcessor; public TestDataSetIterator(DataSet dataset, int batch) { @@ -73,12 +73,12 @@ public class TestDataSetIterator implements DataSetIterator { @Override public int inputColumns() { - return (int)list.get(0).getFeatures().columns(); + return list.get(0).getFeatures().columns(); } @Override public int totalOutcomes() { - return (int) list.get(0).getLabels().columns(); + return list.get(0).getLabels().columns(); } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestMultiDataSetIterator.java index 17ce251e9..b0af492bc 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestMultiDataSetIterator.java @@ -33,7 +33,7 @@ import java.util.List; public class TestMultiDataSetIterator implements MultiDataSetIterator { private int curr = 0; private int batch = 10; - private List list; + private final List list; private MultiDataSetPreProcessor preProcessor; /** diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InFileAndMemoryDataSetCache.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InFileAndMemoryDataSetCache.java index 47e985ef3..d274372eb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InFileAndMemoryDataSetCache.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InFileAndMemoryDataSetCache.java @@ -27,8 +27,8 @@ import java.nio.file.Path; public class InFileAndMemoryDataSetCache implements DataSetCache { - private InFileDataSetCache fileCache; - private InMemoryDataSetCache memoryCache; + private final InFileDataSetCache fileCache; + private final InMemoryDataSetCache memoryCache; public InFileAndMemoryDataSetCache(File cacheDirectory) { this.fileCache = new InFileDataSetCache(cacheDirectory); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InFileDataSetCache.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InFileDataSetCache.java index 2eb7f5dc8..bf415379a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InFileDataSetCache.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InFileDataSetCache.java @@ -27,7 +27,7 @@ import java.io.IOException; import java.nio.file.Path; public class InFileDataSetCache implements DataSetCache { - private File cacheDirectory; + private final File cacheDirectory; public InFileDataSetCache(File cacheDirectory) { if (cacheDirectory.exists() && !cacheDirectory.isDirectory()) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InMemoryDataSetCache.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InMemoryDataSetCache.java index 75a2a9177..df7e10c42 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InMemoryDataSetCache.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/cache/InMemoryDataSetCache.java @@ -35,8 +35,8 @@ public class InMemoryDataSetCache implements DataSetCache { private static final Logger log = LoggerFactory.getLogger(DataSetCache.class); - private Map cache = new HashMap<>(); - private Set completeNamespaces = new HashSet<>(); + private final Map cache = new HashMap<>(); + private final Set completeNamespaces = new HashSet<>(); @Override public boolean isComplete(String namespace) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java index f39e5eed2..dde3a3963 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeDataSetPreProcessor.java @@ -27,7 +27,7 @@ import org.nd4j.linalg.dataset.api.DataSetPreProcessor; public class CompositeDataSetPreProcessor implements DataSetPreProcessor { private final boolean stopOnEmptyDataSet; - private DataSetPreProcessor[] preProcessors; + private final DataSetPreProcessor[] preProcessors; /** * @param preProcessors Preprocessors to apply. They will be applied in this order diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeMultiDataSetPreProcessor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeMultiDataSetPreProcessor.java index 248b66798..5a972d632 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeMultiDataSetPreProcessor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/CompositeMultiDataSetPreProcessor.java @@ -25,7 +25,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; public class CompositeMultiDataSetPreProcessor implements MultiDataSetPreProcessor { - private MultiDataSetPreProcessor[] preProcessors; + private final MultiDataSetPreProcessor[] preProcessors; /** * @param preProcessors Preprocessors to apply. They will be applied in this order diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImageMultiPreProcessingScaler.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImageMultiPreProcessingScaler.java index 8e367f620..8393c9b01 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImageMultiPreProcessingScaler.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/ImageMultiPreProcessingScaler.java @@ -28,9 +28,10 @@ import org.nd4j.linalg.dataset.api.preprocessor.serializer.NormalizerType; public class ImageMultiPreProcessingScaler implements MultiDataNormalization { - private double minRange, maxRange; - private double maxPixelVal; - private int[] featureIndices; + private final double minRange; + private final double maxRange; + private final double maxPixelVal; + private final int[] featureIndices; public ImageMultiPreProcessingScaler(int... featureIndices) { this(0, 1, 8, featureIndices); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategy.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategy.java index 1644bcba1..40c2c413e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategy.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/MinMaxStrategy.java @@ -37,8 +37,8 @@ import java.io.Serializable; @Getter @EqualsAndHashCode public class MinMaxStrategy implements NormalizerStrategy, Serializable { - private double minRange; - private double maxRange; + private final double minRange; + private final double maxRange; public MinMaxStrategy() { this(0, 1); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/BaseUnderSamplingPreProcessor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/BaseUnderSamplingPreProcessor.java index 9812b240e..de942aa04 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/BaseUnderSamplingPreProcessor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/BaseUnderSamplingPreProcessor.java @@ -138,7 +138,7 @@ public abstract class BaseUnderSamplingPreProcessor { INDArray floatMask = labelMask.castTo(label.dataType()); if (!sum1.equals(floatMask)) { throw new IllegalArgumentException("Labels of size minibatchx2xtimesteps are expected to be one hot." - + label.toString() + "\n is not one-hot"); + + label + "\n is not one-hot"); } } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingMultiDataSetPreProcessor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingMultiDataSetPreProcessor.java index d4fe5e792..ae9609c2c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingMultiDataSetPreProcessor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingMultiDataSetPreProcessor.java @@ -31,8 +31,8 @@ import java.util.Map; public class UnderSamplingByMaskingMultiDataSetPreProcessor extends BaseUnderSamplingPreProcessor implements MultiDataSetPreProcessor { - private Map targetMinorityDistMap; - private Map minorityLabelMap = new HashMap<>(); + private final Map targetMinorityDistMap; + private final Map minorityLabelMap = new HashMap<>(); /** * The target distribution to approximate. Valid values are between (0,0.5]. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingPreProcessor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingPreProcessor.java index 565190197..c7a6d2222 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingPreProcessor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingPreProcessor.java @@ -26,7 +26,7 @@ import org.nd4j.linalg.dataset.api.DataSetPreProcessor; public class UnderSamplingByMaskingPreProcessor extends BaseUnderSamplingPreProcessor implements DataSetPreProcessor { - private double targetMinorityDist; + private final double targetMinorityDist; private int minorityLabel = 1; /** diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.java index 5a90aaeb1..2a1ce15a2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/preprocessor/serializer/NormalizerSerializer.java @@ -34,7 +34,7 @@ public class NormalizerSerializer { private static final String HEADER = "NORMALIZER"; private static NormalizerSerializer defaultSerializer; - private List strategies = new ArrayList<>(); + private final List strategies = new ArrayList<>(); /** * Serialize a normalizer to the given file diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/PCA.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/PCA.java index 9d877025c..4c244f014 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/PCA.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/PCA.java @@ -118,7 +118,7 @@ public class PCA { * @return A matrix of size count rows by N columns */ public INDArray generateGaussianSamples(long count) { - INDArray samples = Nd4j.randn(new long[] {count, eigenvalues.columns()}); + INDArray samples = Nd4j.randn(count, eigenvalues.columns()); INDArray factors = Transforms.pow(eigenvalues, -0.5, true); samples.muliRowVector(factors); return Nd4j.tensorMmul(eigenvectors, samples, new int[][] {{1}, {1}}).transposei().addiRowVector(mean); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java index 779365b50..326b9d45f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java @@ -35,9 +35,9 @@ import java.util.List; public class RandomProjection { private int components; - private Random rng; + private final Random rng; private double eps; - private boolean autoMode; + private final boolean autoMode; public RandomProjection(double eps, Random rng){ this.rng = rng; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/impl/WorkspacesDebugAction.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/impl/WorkspacesDebugAction.java index 2bb87f2bc..e6406a5ae 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/impl/WorkspacesDebugAction.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/impl/WorkspacesDebugAction.java @@ -35,8 +35,8 @@ public class WorkspacesDebugAction implements EnvironmentalAction { switch (value.toUpperCase()) { case "SPILL_EVERYTHING": { Nd4j.getWorkspaceManager().setDebugMode(DebugMode.SPILL_EVERYTHING); - }; - break; + } + break; case "BYPASS_EVERYTHING": { Nd4j.getWorkspaceManager().setDebugMode(DebugMode.BYPASS_EVERYTHING); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseBlasWrapper.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseBlasWrapper.java index 62a4cb445..24f92f98e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseBlasWrapper.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseBlasWrapper.java @@ -187,7 +187,7 @@ public abstract class BaseBlasWrapper implements BlasWrapper { LinAlgExceptions.assertMatrix(a); if (a.data().dataType() == DataType.DOUBLE) { - return gemv((double) alpha, a, x, (double) beta, y); + return gemv(alpha, a, x, (double) beta, y); } level2().gemv('N', 'N', alpha, a, x, beta, y); return y; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index 667432f6c..3458ed06b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -749,7 +749,7 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { INDArray ret = Nd4j.createUninitialized(new long[] {indexes.length, vectorLength}, order); for (int cnt = 0; cnt < indexes.length; cnt++) { - ret.putRow(cnt, source.tensorAlongDimension((int) indexes[cnt], sourceDimension)); + ret.putRow(cnt, source.tensorAlongDimension(indexes[cnt], sourceDimension)); } return ret; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java index 6f319e31d..fe162c8e2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java @@ -1374,7 +1374,7 @@ public interface NDArrayFactory { * @param file the file to create the map from * @return Map */ - public Map createFromNpzFile(File file) throws Exception; + Map createFromNpzFile(File file) throws Exception; /** * Convert an {@link INDArray} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 46f538dfc..2dfff5fbd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -259,7 +259,7 @@ public class Nd4j { private static AffinityManager affinityManager; private static MemoryManager memoryManager; - private static AtomicBoolean fallbackMode; + private static final AtomicBoolean fallbackMode; protected static Properties props = new Properties(); @@ -1961,7 +1961,7 @@ public class Nd4j { return Nd4j.getExecutioner().exec(new Linspace((double) lower, num, (double)step, dtype)); } else { - throw new IllegalStateException("Illegal data type for linspace: " + dtype.toString()); + throw new IllegalStateException("Illegal data type for linspace: " + dtype); } } @@ -1999,7 +1999,7 @@ public class Nd4j { return linspace((double) lower, (double)upper, (int) num, dtype); } else { - throw new IllegalStateException("Illegal data type for linspace: " + dtype.toString()); + throw new IllegalStateException("Illegal data type for linspace: " + dtype); } } @@ -4482,7 +4482,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray valueArrayOf(long[] shape, float value) { - return valueArrayOf(shape, (double)value, DataType.FLOAT); + return valueArrayOf(shape, value, DataType.FLOAT); } /** @@ -5214,7 +5214,7 @@ public class Nd4j { try { action.process(value); } catch (Exception e2) { - logger.info("Failed to process env variable [" + e + "], got exception: " + e2.toString()); + logger.info("Failed to process env variable [" + e + "], got exception: " + e2); } } } @@ -5776,7 +5776,7 @@ public class Nd4j { val doubles = new float[prod]; val sb = bb.order(_order).asShortBuffer(); for (int e = 0; e < prod; e++) - doubles[e] = HalfIndexer.toFloat((int) sb.get(e)); + doubles[e] = HalfIndexer.toFloat(sb.get(e)); return Nd4j.create(doubles, shapeOf, stridesOf, ordering, DataType.HALF); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java index cbce0c2f5..b56410cd3 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java @@ -220,15 +220,15 @@ public abstract class Nd4jBackend { java.net.URLClassLoader loader = (URLClassLoader) ND4JClassLoading.getNd4jClassloader(); java.net.URL url = jar.toURI().toURL(); /*Disallow if already loaded*/ - for (java.net.URL it : java.util.Arrays.asList(loader.getURLs())) { + for (java.net.URL it : loader.getURLs()) { if (it.equals(url)) { return; } } java.lang.reflect.Method method = - java.net.URLClassLoader.class.getDeclaredMethod("addURL", new Class[] {java.net.URL.class}); + java.net.URLClassLoader.class.getDeclaredMethod("addURL", java.net.URL.class); method.setAccessible(true); /*promote the method to public access*/ - method.invoke(loader, new Object[] {url}); + method.invoke(loader, url); } catch (final java.lang.NoSuchMethodException | java.lang.IllegalAccessException | java.net.MalformedURLException | java.lang.reflect.InvocationTargetException e) { throw new NoAvailableBackendException(e); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/RandomFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/RandomFactory.java index 93911850d..c97f4261a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/RandomFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/RandomFactory.java @@ -25,8 +25,8 @@ import org.nd4j.linalg.api.rng.Random; import java.lang.reflect.Constructor; public class RandomFactory { - private ThreadLocal threadRandom = new ThreadLocal<>(); - private Class randomClass; + private final ThreadLocal threadRandom = new ThreadLocal<>(); + private final Class randomClass; public RandomFactory(Class randomClass) { this.randomClass = randomClass; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/Heartbeat.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/Heartbeat.java index 32f6c7263..4e73774e2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/Heartbeat.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/Heartbeat.java @@ -30,7 +30,7 @@ import java.util.concurrent.atomic.AtomicBoolean; public class Heartbeat { private static final Heartbeat INSTANCE = new Heartbeat(); private volatile long serialVersionID; - private AtomicBoolean enabled = new AtomicBoolean(true); + private final AtomicBoolean enabled = new AtomicBoolean(true); protected Heartbeat() { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/reports/Environment.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/reports/Environment.java index f58ab7ea7..abccf8706 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/reports/Environment.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/reports/Environment.java @@ -64,7 +64,6 @@ public class Environment implements Serializable { public String toCompactString() { - StringBuilder builder = new StringBuilder(); /* new format is: @@ -79,10 +78,10 @@ public class Environment implements Serializable { builder.append(backendUsed).append(" "); */ - builder.append(backendUsed).append(" (").append(numCores).append(" cores ") - .append(Math.max(availableMemory / 1024 / 1024 / 1024, 1)).append("GB ").append(osName) - .append(" ").append(osArch).append(")"); + String builder = backendUsed + " (" + numCores + " cores " + + Math.max(availableMemory / 1024 / 1024 / 1024, 1) + "GB " + osName + + " " + osArch + ")"; - return builder.toString(); + return builder; } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/reports/Task.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/reports/Task.java index b9df4d545..a0b32ccae 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/reports/Task.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/reports/Task.java @@ -42,12 +42,11 @@ public class Task { private int numSamples; public String toCompactString() { - StringBuilder builder = new StringBuilder(); - builder.append("F: ").append(numFeatures).append("/"); - builder.append("L: ").append(numLabels).append("/"); - builder.append("S: ").append(numSamples).append(" "); + String builder = "F: " + numFeatures + "/" + + "L: " + numLabels + "/" + + "S: " + numSamples + " "; - return builder.toString(); + return builder; } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/utils/EnvironmentUtils.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/utils/EnvironmentUtils.java index e8ecf94cd..a5f2945bf 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/utils/EnvironmentUtils.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/heartbeat/utils/EnvironmentUtils.java @@ -67,12 +67,12 @@ public class EnvironmentUtils { return random.nextLong(); } catch (Exception e) { - ; // do nothing, just skip to next interface + // do nothing, just skip to next interface } } } catch (Exception e) { - ; // do nothing here + // do nothing here } return ret; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java index 8d203ce56..f5fe68632 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/BooleanIndexing.java @@ -56,10 +56,7 @@ public class BooleanIndexing { if (cond instanceof BaseCondition) { long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond)).getDouble(0); - if (val == n.length()) - return true; - else - return false; + return val == n.length(); } else { throw new RuntimeException("Can only execute BaseCondition conditions using this method"); @@ -85,10 +82,7 @@ public class BooleanIndexing { long tadLength = Shape.getTADLength(n.shape(), dimension); for (int i = 0; i < arr.length(); i++) { - if (arr.getDouble(i) == tadLength) - result[i] = true; - else - result[i] = false; + result[i] = arr.getDouble(i) == tadLength; } return result; @@ -113,10 +107,7 @@ public class BooleanIndexing { boolean[] result = new boolean[(int) arr.length()]; for (int i = 0; i < arr.length(); i++) { - if (arr.getDouble(i) > 0) - result[i] = true; - else - result[i] = false; + result[i] = arr.getDouble(i) > 0; } return result; @@ -133,10 +124,7 @@ public class BooleanIndexing { if (cond instanceof BaseCondition) { long val = (long) Nd4j.getExecutioner().exec(new MatchCondition(n, cond)).getDouble(0); - if (val > 0) - return true; - else - return false; + return val > 0; } else { throw new RuntimeException("Can only execute BaseCondition conditions using this method"); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/IndexInfo.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/IndexInfo.java index 8499f9f58..9f7c134bd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/IndexInfo.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/IndexInfo.java @@ -24,7 +24,7 @@ package org.nd4j.linalg.indexing; * @author Adam Gibson */ public class IndexInfo { - private INDArrayIndex[] indexes; + private final INDArrayIndex[] indexes; private boolean[] point; private boolean[] newAxis; private int numNewAxes = 0; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java index cfd583d63..8c856bcae 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/NDArrayIndex.java @@ -39,8 +39,8 @@ import java.util.List; @Slf4j public abstract class NDArrayIndex implements INDArrayIndex { - private long[] indices; - private static NewAxis NEW_AXIS = new NewAxis(); + private final long[] indices; + private static final NewAxis NEW_AXIS = new NewAxis(); /** @@ -655,9 +655,7 @@ public abstract class NDArrayIndex implements INDArrayIndex { NDArrayIndex that = (NDArrayIndex) o; - if (!Arrays.equals(indices, that.indices)) - return false; - return true; + return Arrays.equals(indices, that.indices); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java index acd69d529..796a2a3da 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/PointIndex.java @@ -28,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; */ @EqualsAndHashCode public class PointIndex implements INDArrayIndex { - private long point; + private final long point; /** * diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/SpecifiedIndex.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/SpecifiedIndex.java index f2109e8d2..702f315bc 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/SpecifiedIndex.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/SpecifiedIndex.java @@ -130,7 +130,7 @@ public class SpecifiedIndex implements INDArrayIndex { */ public static class SpecifiedIndexesGenerator implements Generator>> { private int index = 0; - private INDArrayIndex[] indexes; + private final INDArrayIndex[] indexes; /** * The indexes to generate from @@ -166,7 +166,7 @@ public class SpecifiedIndex implements INDArrayIndex { */ public static class SparseSpecifiedIndexesGenerator implements Generator>> { private int index = 0; - private INDArrayIndex[] indexes; + private final INDArrayIndex[] indexes; /** * The indexes to generate from diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/And.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/And.java index 82178d754..cff235d50 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/And.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/And.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.indexing.conditions; public class And implements Condition { - private Condition[] conditions; + private final Condition[] conditions; public And(Condition... conditions) { this.conditions = conditions; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/ConditionEquals.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/ConditionEquals.java index cd8e78b87..aeedb79f2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/ConditionEquals.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/ConditionEquals.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.indexing.conditions; public class ConditionEquals implements Condition { - private Condition[] conditions; + private final Condition[] conditions; public ConditionEquals(Condition... conditions) { this.conditions = conditions; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/Not.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/Not.java index db6cfd6a7..182924c91 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/Not.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/Not.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.indexing.conditions; public class Not implements Condition { - private Condition opposite; + private final Condition opposite; /** * Returns condition ID for native side diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/Or.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/Or.java index 6c1f6adee..5a7e971d8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/Or.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/conditions/Or.java @@ -22,7 +22,7 @@ package org.nd4j.linalg.indexing.conditions; public class Or implements Condition { - private Condition[] conditions; + private final Condition[] conditions; public Or(Condition... conditions) { this.conditions = conditions; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/functions/StableNumber.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/functions/StableNumber.java index 62d1e23c3..b6528b1ec 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/functions/StableNumber.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/functions/StableNumber.java @@ -29,7 +29,7 @@ import org.nd4j.common.function.Function; * or nan */ public class StableNumber implements Function { - private Type type; + private final Type type; public enum Type { DOUBLE, FLOAT diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/functions/Value.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/functions/Value.java index 6464a7db7..94207dfd7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/functions/Value.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/functions/Value.java @@ -23,7 +23,7 @@ package org.nd4j.linalg.indexing.functions; import org.nd4j.common.function.Function; public class Value implements Function { - private Number number; + private final Number number; public Value(Number number) { this.number = number; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java index 70ad03d19..9cf173894 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.java @@ -117,7 +117,7 @@ public class LossBinaryXENT implements ILossFunction { INDArray scoreArr; if (activationFn instanceof ActivationSoftmax) { //TODO Post GPU support for custom ops: Use LogSoftMax op to avoid numerical issues when calculating score - INDArray logsoftmax = Nd4j.exec((CustomOp) new SoftMax(preOutput, preOutput.ulike(), -1))[0]; + INDArray logsoftmax = Nd4j.exec(new SoftMax(preOutput, preOutput.ulike(), -1))[0]; Transforms.log(logsoftmax, false); scoreArr = logsoftmax.muli(labels); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java index 14894362b..84babae8c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMixtureDensity.java @@ -85,7 +85,7 @@ public class LossMixtureDensity implements ILossFunction { // through Nd4j operations in order to increase performance. public MixtureDensityComponents extractComponents(INDArray output) { long outputSize = output.size(1); - if (outputSize != (mLabelWidth + 2) * mMixtures) { + if (outputSize != (long) (mLabelWidth + 2) * mMixtures) { throw new IllegalArgumentException( "Network output size " + outputSize + " must be (labels+2)*mixtures where labels = " + mLabelWidth + " and mixtures = " + mMixtures); @@ -114,7 +114,7 @@ public class LossMixtureDensity implements ILossFunction { // Alpha is a softmax because // the alpha should all sum to 1 for a given gaussian mixture. - mdc.alpha = Nd4j.exec((CustomOp) new SoftMax(mdc.alpha, mdc.alpha, -1))[0]; + mdc.alpha = Nd4j.exec(new SoftMax(mdc.alpha, mdc.alpha, -1))[0]; // Mu comes directly from the network as an unmolested value. // Note that this effectively means that the output layer of @@ -254,10 +254,10 @@ public class LossMixtureDensity implements ILossFunction { INDArray dLdZMu = Nd4j.create(nSamples, mMixtures, mLabelWidth); for (int k = 0; k < mLabelWidth; k++) { dLdZMu.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(k)}, - labelsMinusMu.get(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.point(k)}).muli(pi).divi(variance).negi()); + labelsMinusMu.get(NDArrayIndex.all(), NDArrayIndex.all(), + NDArrayIndex.point(k)).muli(pi).divi(variance).negi()); } - dLdZMu = dLdZMu.reshape(nSamples, mMixtures * mLabelWidth); + dLdZMu = dLdZMu.reshape(nSamples, (long) mMixtures * mLabelWidth); // Place components of gradient into gradient holder. gradient.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, mMixtures)}, dLdZAlpha); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index bffdffdf5..fe44374d9 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -519,7 +519,7 @@ public class Transforms { * @return */ public static INDArray softmax(INDArray in, boolean copy) { - return Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, (copy ? in.ulike() : in), -1))[0]; + return Nd4j.getExecutioner().exec(new SoftMax(in, (copy ? in.ulike() : in), -1))[0]; } /** diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java index 067b25370..65339d4f0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/OpProfiler.java @@ -511,7 +511,7 @@ public class OpProfiler { if (operands[e] == null) buffer.append("null"); else - buffer.append(new String("" + operands[e].ordering()).toUpperCase()); + buffer.append(("" + operands[e].ordering()).toUpperCase()); if (e < operands.length - 1) buffer.append(" x "); @@ -631,8 +631,8 @@ public class OpProfiler { if (x == z || y == z) { return processOperands(x, y); } else { - PenaltyCause causeXY[] = processOperands(x, y); - PenaltyCause causeXZ[] = processOperands(x, z); + PenaltyCause[] causeXY = processOperands(x, y); + PenaltyCause[] causeXZ = processOperands(x, z); if ((causeXY.length == 1 && causeXY[0] == NONE) && (causeXZ.length == 1 && causeXZ[0] == NONE)) { return causeXY; @@ -675,7 +675,7 @@ public class OpProfiler { if (operands[e] == null && operands[e + 1] == null) continue; - PenaltyCause lc[] = processOperands(operands[e], operands[e + 1]); + PenaltyCause[] lc = processOperands(operands[e], operands[e + 1]); for (PenaltyCause cause : lc) { if (cause != NONE && !causes.contains(cause)) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StackAggregator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StackAggregator.java index 5383dead5..5151a745e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StackAggregator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StackAggregator.java @@ -24,7 +24,7 @@ import org.nd4j.linalg.profiler.data.primitives.StackDescriptor; import org.nd4j.linalg.profiler.data.primitives.StackTree; public class StackAggregator { - private StackTree tree = new StackTree(); + private final StackTree tree = new StackTree(); public StackAggregator() { // nothing to do here so far diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StringAggregator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StringAggregator.java index 465bdfa4e..7eb8086da 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StringAggregator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StringAggregator.java @@ -32,8 +32,8 @@ import java.util.concurrent.atomic.AtomicLong; public class StringAggregator { - private Map times = new ConcurrentHashMap<>(); - private Map longCalls = new ConcurrentHashMap<>(); + private final Map times = new ConcurrentHashMap<>(); + private final Map longCalls = new ConcurrentHashMap<>(); private static final long THRESHOLD = 100000; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StringCounter.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StringCounter.java index 5054b680a..ab9784de6 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StringCounter.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/StringCounter.java @@ -28,8 +28,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; public class StringCounter { - private Map counter = new ConcurrentHashMap<>(); - private AtomicLong totals = new AtomicLong(0); + private final Map counter = new ConcurrentHashMap<>(); + private final AtomicLong totals = new AtomicLong(0); public StringCounter() { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackDescriptor.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackDescriptor.java index 60656802c..3121b9785 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackDescriptor.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackDescriptor.java @@ -33,9 +33,9 @@ import java.util.Arrays; @Slf4j public class StackDescriptor { @Getter - protected StackTraceElement stackTrace[]; + protected StackTraceElement[] stackTrace; - public StackDescriptor(@NonNull StackTraceElement stack[]) { + public StackDescriptor(@NonNull StackTraceElement[] stack) { // we cut off X first elements from stack, because they belong to profiler // basically, we just want to make sure, no profiler-related code is mentioned in stack trace int start = 0; @@ -46,7 +46,6 @@ public class StackDescriptor { // in tests it's quite possible to have no DefaultOpExecutioner calls being used if (start == stack.length) { - ; for (start = 0; start < stack.length; start++) { if (!stack[start + 1].getClassName().contains("OpProfiler") && !stack[start + 1].getClassName().contains("StackAggregator")) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackNode.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackNode.java index 316170700..eefe67394 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackNode.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackNode.java @@ -49,12 +49,12 @@ public class StackNode implements Comparable { builder.append(" "); } - builder.append("").append(nodeURI); + builder.append(nodeURI); if (displayCounts) builder.append(" ").append(counter.get()).append(" us"); - System.out.println(builder.toString()); + System.out.println(builder); for (StackNode node : entries.values()) { node.traverse(ownLevel + 1, displayCounts); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackTree.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackTree.java index 1785737f3..60d3a1e4b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackTree.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/StackTree.java @@ -41,14 +41,14 @@ public class StackTree { } public String renderTree(boolean displayCounts) { - StringBuilder builder = new StringBuilder(); + String builder = ""; // we'll always have single entry here, but let's keep loop here for (StackNode cNode : basement.values()) { cNode.traverse(0, displayCounts); } - return builder.toString(); + return builder; } public void consumeStackTrace(@NonNull StackDescriptor descriptor) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/TimeSet.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/TimeSet.java index 97f36f73a..0062b2633 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/TimeSet.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/profiler/data/primitives/TimeSet.java @@ -25,7 +25,7 @@ import java.util.List; public class TimeSet implements Comparable { - private List times = new ArrayList<>(); + private final List times = new ArrayList<>(); private long sum = 0; public void addTime(long time) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/schedule/MapSchedule.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/schedule/MapSchedule.java index 578996bb9..ce28a5901 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/schedule/MapSchedule.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/schedule/MapSchedule.java @@ -91,8 +91,8 @@ public class MapSchedule implements ISchedule { */ public static class Builder { - private ScheduleType scheduleType; - private Map values = new HashMap<>(); + private final ScheduleType scheduleType; + private final Map values = new HashMap<>(); /** * @param scheduleType Schedule opType to use diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java index efcd327df..dae8056c0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/string/NDArrayStrings.java @@ -215,11 +215,10 @@ public class NDArrayStrings { } else if (arr.isRowVector()) { //a slice from a higher dim array if (offset == 0) { - StringBuilder sb = new StringBuilder(); - sb.append("["); - sb.append(vectorToString(arr, summarize)); - sb.append("]"); - return sb.toString(); + String sb = "[" + + vectorToString(arr, summarize) + + "]"; + return sb; } return vectorToString(arr, summarize); } else { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/DataSetUtils.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/DataSetUtils.java index 5ec3fc2de..1e374bab3 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/DataSetUtils.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/DataSetUtils.java @@ -32,7 +32,7 @@ import org.nd4j.common.tools.SIS; public class DataSetUtils { // - private SIS sis; + private final SIS sis; // public DataSetUtils( SIS sis, @@ -179,7 +179,7 @@ public class DataSetUtils { if (in_INDA.rows() > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - int i_CharsCount = BTools.getIndexCharsCount( (int) in_INDA.rows() - 1 ); + int i_CharsCount = BTools.getIndexCharsCount( in_INDA.rows() - 1 ); // oinfo = ""; oinfo += BTools.getMtLvESS( mtLv ); @@ -201,7 +201,7 @@ public class DataSetUtils { // int c_I = 0; // - for ( int j = (int) in_INDA.columns() - 1; j >= 0; j-- ) { + for (int j = in_INDA.columns() - 1; j >= 0; j-- ) { // if ( c_I > c_End_I ) break; // @@ -221,7 +221,7 @@ public class DataSetUtils { if ( ot_INDA != null ) { if (ot_INDA.columns() - 1 > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - for ( int j = (int) ot_INDA.columns() - 1; j >= 0; j-- ) { + for (int j = ot_INDA.columns() - 1; j >= 0; j-- ) { // if ( c_I > c_End_I ) break; // @@ -349,7 +349,7 @@ public class DataSetUtils { double j_Dbl = -1; if (INDA.rows() - 1 > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - int i_CharsCount = BTools.getIndexCharsCount( (int) INDA.rows() - 1 ); + int i_CharsCount = BTools.getIndexCharsCount( INDA.rows() - 1 ); // if ( !turned ) { //= standard oinfo = ""; @@ -370,7 +370,7 @@ public class DataSetUtils { int c_I = 0; if (INDA.columns() - 1 > Integer.MAX_VALUE) throw new ND4JArraySizeException(); - for ( int j = (int) INDA.columns() - 1; j >= 0; j-- ) { + for (int j = INDA.columns() - 1; j >= 0; j-- ) { // if ( c_I > c_End_I ) break; // diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/ND4JTestUtils.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/ND4JTestUtils.java index 41018c288..ad386be2f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/ND4JTestUtils.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/ND4JTestUtils.java @@ -49,7 +49,7 @@ public class ND4JTestUtils { } /** - * A function for use with {@link #validateSerializedArrays(File, File, boolean, BiFunction)} using {@link INDArray#equals(Object)} + * A function for use with {@link #validateSerializedArrays(File, File, boolean, BiFunction)} using {@code INDArray#equals(Object)} */ public static class EqualsFn implements BiFunction { @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/NDArrayMath.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/NDArrayMath.java index e20b79f30..cf55c721a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/NDArrayMath.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/NDArrayMath.java @@ -98,7 +98,7 @@ public class NDArrayMath { */ public static long vectorsPerSlice(INDArray arr) { if (arr.rank() > 2) { - return ArrayUtil.prodLong(new long[] {arr.size(-1), arr.size(-2)}); + return ArrayUtil.prodLong(arr.size(-1), arr.size(-2)); } return arr.slices(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java index af0c04d08..9baf97578 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java @@ -129,9 +129,7 @@ public abstract class BaseWorkspaceMgr> implements WorkspaceMg @Override public void setWorkspace(@NonNull T forEnum, @NonNull String wsName, @NonNull WorkspaceConfiguration configuration) { - if(scopeOutOfWs.contains(forEnum)){ - scopeOutOfWs.remove(forEnum); - } + scopeOutOfWs.remove(forEnum); setWorkspaceName(forEnum, wsName); setConfiguration(forEnum, configuration); } @@ -169,7 +167,7 @@ public abstract class BaseWorkspaceMgr> implements WorkspaceMg throw new ND4JWorkspaceException("Assertion failed: expected current workspace to be \"" + getWorkspaceName(arrayType) + "\" (for array type " + arrayType + ") - actual current workspace is " + (curr == null ? null : curr.getId()) + (msg == null ? "" : ": " + msg)); - }; + } } @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/serde/json/JsonMappers.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/serde/json/JsonMappers.java index 8190f6141..573d0ff81 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/serde/json/JsonMappers.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/serde/json/JsonMappers.java @@ -38,8 +38,8 @@ import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; @Slf4j public class JsonMappers { - private static ObjectMapper jsonMapper = configureMapper(new ObjectMapper()); - private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory())); + private static final ObjectMapper jsonMapper = configureMapper(new ObjectMapper()); + private static final ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory())); /** * @return The default/primary ObjectMapper for deserializing JSON objects diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/systeminfo/SystemInfo.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/systeminfo/SystemInfo.java index f292cf8eb..89175c137 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/systeminfo/SystemInfo.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/systeminfo/SystemInfo.java @@ -305,7 +305,7 @@ public class SystemInfo { sb.append(String.format(wsFormat, ws.getId(), (ws.isScopeActive() ? "OPEN" : "CLOSED"), fBytes(ws.getCurrentSize()), - String.valueOf(numCycles))).append("\n"); + numCycles)).append("\n"); } } sb.append(fBytes("Workspaces total size", totalWsSize)); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/versioncheck/VersionCheck.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/versioncheck/VersionCheck.java index 5ca7116f0..43b868342 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/versioncheck/VersionCheck.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/versioncheck/VersionCheck.java @@ -234,7 +234,7 @@ public class VersionCheck { try { URI uri = u.toURI(); - try (FileSystem fileSystem = (uri.getScheme().equals("jar") ? FileSystems.newFileSystem(uri, Collections.emptyMap()) : null)) { + try (FileSystem fileSystem = (uri.getScheme().equals("jar") ? FileSystems.newFileSystem(uri, Collections.emptyMap()) : null)) { Path myPath = Paths.get(uri); Files.walkFileTree(myPath, new SimpleFileVisitor() { @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/BaseWeightInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/BaseWeightInitScheme.java index be0abbcf5..898e97c5a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/BaseWeightInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/BaseWeightInitScheme.java @@ -29,7 +29,7 @@ import java.util.Arrays; @EqualsAndHashCode public abstract class BaseWeightInitScheme implements WeightInitScheme { - private char order; + private final char order; /** * Initialize with c weight ordering by default diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ConstantInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ConstantInitScheme.java index 74370926c..7177e0302 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ConstantInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ConstantInitScheme.java @@ -32,7 +32,7 @@ import org.nd4j.weightinit.WeightInit; * @author Adam Gibson */ public class ConstantInitScheme extends BaseWeightInitScheme { - private double constant; + private final double constant; @Builder public ConstantInitScheme(char order,double constant) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/DistributionInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/DistributionInitScheme.java index b9cc7469c..ac4c36698 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/DistributionInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/DistributionInitScheme.java @@ -32,7 +32,7 @@ import org.nd4j.weightinit.WeightInit; * @author Adam Gibson */ public class DistributionInitScheme extends BaseWeightInitScheme { - private Distribution distribution; + private final Distribution distribution; @Builder public DistributionInitScheme(char order, Distribution distribution) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java index 91c398424..7272bea70 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java @@ -34,7 +34,7 @@ import org.nd4j.weightinit.WeightInit; */ public class LecunUniformInitScheme extends BaseWeightInitScheme { - private double fanIn; + private final double fanIn; @Builder public LecunUniformInitScheme(char order, double fanIn) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ReluInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ReluInitScheme.java index 4843cd6cf..f12857f55 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ReluInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ReluInitScheme.java @@ -35,7 +35,7 @@ import org.nd4j.weightinit.WeightInit; */ public class ReluInitScheme extends BaseWeightInitScheme { - private double fanIn; + private final double fanIn; @Builder public ReluInitScheme(char order,double fanIn) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java index 0ffa0d6ef..907165d9b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java @@ -35,7 +35,7 @@ import org.nd4j.weightinit.WeightInit; */ public class ReluUniformInitScheme extends BaseWeightInitScheme { - private double fanIn; + private final double fanIn; @Builder public ReluUniformInitScheme(char order, double fanIn) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java index 615070802..43284932d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java @@ -35,8 +35,8 @@ import org.nd4j.weightinit.WeightInit; */ public class SigmoidUniformInitScheme extends BaseWeightInitScheme { - private double fanIn; - private double fanOut; + private final double fanIn; + private final double fanOut; @Builder public SigmoidUniformInitScheme(char order, double fanIn,double fanOut) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java index 680b5ce4a..7aa09e5eb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java @@ -35,7 +35,7 @@ import org.nd4j.weightinit.WeightInit; */ public class UniformInitScheme extends BaseWeightInitScheme { - private double fanIn; + private final double fanIn; @Builder public UniformInitScheme(char order, double fanIn) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanAvgInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanAvgInitScheme.java index b384973a6..cc7530e5a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanAvgInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanAvgInitScheme.java @@ -35,8 +35,8 @@ import org.nd4j.weightinit.WeightInit; */ public class VarScalingNormalFanAvgInitScheme extends BaseWeightInitScheme { - private double fanIn; - private double fanOut; + private final double fanIn; + private final double fanOut; @Builder public VarScalingNormalFanAvgInitScheme(char order, double fanIn, double fanOut) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanInInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanInInitScheme.java index e3839efc0..c0d7057ec 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanInInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanInInitScheme.java @@ -35,7 +35,7 @@ import org.nd4j.weightinit.WeightInit; */ public class VarScalingNormalFanInInitScheme extends BaseWeightInitScheme { - private double fanIn; + private final double fanIn; @Builder public VarScalingNormalFanInInitScheme(char order, double fanIn) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanOutInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanOutInitScheme.java index 24fa4c944..5bdaa2066 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanOutInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalFanOutInitScheme.java @@ -36,7 +36,7 @@ import org.nd4j.weightinit.WeightInit; */ public class VarScalingNormalFanOutInitScheme extends BaseWeightInitScheme { - private double fanOut; + private final double fanOut; @Builder public VarScalingNormalFanOutInitScheme(char order, double fanOut) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java index 4a44b6a36..74ace0a67 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java @@ -35,7 +35,7 @@ import org.nd4j.weightinit.WeightInit; */ public class VarScalingNormalUniformFanInInitScheme extends BaseWeightInitScheme { - private double fanIn; + private final double fanIn; @Builder public VarScalingNormalUniformFanInInitScheme(char order, double fanIn) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java index a2e5c49e0..ffe71b716 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java @@ -34,7 +34,7 @@ import org.nd4j.weightinit.WeightInit; */ public class VarScalingNormalUniformFanOutInitScheme extends BaseWeightInitScheme { - private double fanOut; + private final double fanOut; @Builder public VarScalingNormalUniformFanOutInitScheme(char order, double fanOut) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java index a1aa1a54a..fe7485d81 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java @@ -35,8 +35,8 @@ import org.nd4j.weightinit.WeightInit; */ public class VarScalingUniformFanAvgInitScheme extends BaseWeightInitScheme { - private double fanIn; - private double fanOut; + private final double fanIn; + private final double fanOut; @Builder public VarScalingUniformFanAvgInitScheme(char order, double fanIn, double fanOut) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierFanInInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierFanInInitScheme.java index d361d645d..dd9ec6ab6 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierFanInInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierFanInInitScheme.java @@ -35,7 +35,7 @@ import org.nd4j.weightinit.WeightInit; */ public class XavierFanInInitScheme extends BaseWeightInitScheme { - private double fanIn; + private final double fanIn; @Builder public XavierFanInInitScheme(char order, double fanIn) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierInitScheme.java index c10a40ce5..a10026407 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierInitScheme.java @@ -35,8 +35,8 @@ import org.nd4j.weightinit.WeightInit; */ public class XavierInitScheme extends BaseWeightInitScheme { - private double fanIn; - private double fanOut; + private final double fanIn; + private final double fanOut; @Builder public XavierInitScheme(char order, double fanIn, double fanOut) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java index 2bfebd419..ecf3c928e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java @@ -29,8 +29,8 @@ import org.nd4j.weightinit.WeightInit; public class XavierUniformInitScheme extends BaseWeightInitScheme { - private double fanIn; - private double fanOut; + private final double fanIn; + private final double fanOut; @Builder public XavierUniformInitScheme(char order, double fanIn, double fanOut) { diff --git a/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java b/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java index cfaae7561..f364c49d4 100644 --- a/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java +++ b/cavis-dnn/cavis-dnn-common-tests/src/main/java/org/deeplearning4j/BaseDL4JTest.java @@ -156,7 +156,7 @@ public abstract class BaseDL4JTest { int threadsAfter = ManagementFactory.getThreadMXBean().getThreadCount(); long duration = System.currentTimeMillis() - startTime; - sb.append(getClass().getSimpleName()).append(".").append("") + sb.append(getClass().getSimpleName()).append(".") .append(": ").append(duration).append(" ms") .append(", threadCount: (").append(threadCountBefore).append("->").append(threadsAfter).append(")") .append(", jvmTotal=").append(jvmTotal) diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/base/Preconditions.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/base/Preconditions.java index c8bc3966d..c709d29b8 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/base/Preconditions.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/base/Preconditions.java @@ -687,12 +687,12 @@ public final class Preconditions { } else { if(nextCustom < 0 || (nextIdx > 0 && nextIdx < nextCustom)){ //%s tag - sb.append(message.substring(indexOfStart, nextIdx)) + sb.append(message, indexOfStart, nextIdx) .append(formatArg(args[i])); indexOfStart = nextIdx + 2; } else { //Custom tag - sb.append(message.substring(indexOfStart, nextCustom)); + sb.append(message, indexOfStart, nextCustom); String s = FORMATTERS.get(nextCustomTag).format(nextCustomTag, args[i]); sb.append(s); indexOfStart = nextCustom + nextCustomTag.length(); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/CompactHeapStringList.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/CompactHeapStringList.java index b7a25f248..dec86b34b 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/CompactHeapStringList.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/CompactHeapStringList.java @@ -286,7 +286,7 @@ public class CompactHeapStringList implements List { while (e1.hasNext() && e2.hasNext()) { String o1 = e1.next(); Object o2 = e2.next(); - if (!(o1 == null ? o2 == null : o1.equals(o2))) + if (!(Objects.equals(o1, o2))) return false; } return !(e1.hasNext() || e2.hasNext()); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/IntArrayKeyMap.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/IntArrayKeyMap.java index 17de2f5a1..e4895051b 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/IntArrayKeyMap.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/IntArrayKeyMap.java @@ -28,7 +28,7 @@ import java.util.*; public class IntArrayKeyMap implements Map { - private Map map = new LinkedHashMap<>(); + private final Map map = new LinkedHashMap<>(); @Override public int size() { @@ -120,7 +120,7 @@ public class IntArrayKeyMap implements Map { public static class IntArray implements Comparable { @Getter - private int[] backingArray; + private final int[] backingArray; public IntArray(int[] backingArray) { Preconditions.checkNotNull(backingArray,"Backing array must not be null!"); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/IntArrayKeySet.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/IntArrayKeySet.java index 1a8893cda..b1db74f72 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/IntArrayKeySet.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/IntArrayKeySet.java @@ -23,7 +23,7 @@ package org.nd4j.common.collection; import java.util.*; public class IntArrayKeySet implements Set { - private Set set = new LinkedHashSet<>(); + private final Set set = new LinkedHashSet<>(); @Override public int size() { return set.size(); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java index a88871152..03ec92701 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java @@ -361,7 +361,7 @@ public class MultiDimensionalMap implements Serializable { MultiDimensionalMap that = (MultiDimensionalMap) o; - return !(backedMap != null ? !backedMap.equals(that.backedMap) : that.backedMap != null); + return !(!Objects.equals(backedMap, that.backedMap)); } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java index c5712d3eb..d16c190cb 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java @@ -28,7 +28,7 @@ import java.util.concurrent.ConcurrentSkipListSet; public class MultiDimensionalSet implements Set> { - private Set> backedSet; + private final Set> backedSet; public MultiDimensionalSet(Set> backedSet) { this.backedSet = backedSet; diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/holder/ObjectMapperHolder.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/holder/ObjectMapperHolder.java index 0cd5166a1..9df59f5f7 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/holder/ObjectMapperHolder.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/holder/ObjectMapperHolder.java @@ -26,7 +26,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; public class ObjectMapperHolder { - private static ObjectMapper objectMapper = getMapper(); + private static final ObjectMapper objectMapper = getMapper(); private ObjectMapperHolder() {} diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/AbstractFileResolvingResource.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/AbstractFileResolvingResource.java index eb21e75a6..9774b5cd3 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/AbstractFileResolvingResource.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/AbstractFileResolvingResource.java @@ -124,7 +124,7 @@ public abstract class AbstractFileResolvingResource extends AbstractResource { ((HttpURLConnection) con).setRequestMethod("HEAD"); } - return (long) con.getContentLength(); + return con.getContentLength(); } } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/AbstractResource.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/AbstractResource.java index a6595a0e3..cf7ac3f38 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/AbstractResource.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/AbstractResource.java @@ -79,8 +79,7 @@ public abstract class AbstractResource implements Resource { long size = 0L; int read; - for (byte[] buf = new byte[255]; (read = is.read(buf)) != -1; size += (long) read) { - ; + for (byte[] buf = new byte[255]; (read = is.read(buf)) != -1; size += read) { } long var6 = size; @@ -89,7 +88,6 @@ public abstract class AbstractResource implements Resource { try { is.close(); } catch (IOException var14) { - ; } } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java index 7c5687ef9..b79d395a9 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ClassPathResource.java @@ -358,7 +358,7 @@ public class ClassPathResource extends AbstractFileResolvingResource { private ZipFile zipFile; private ZipEntry entry; private InputStream stream; - private String resourceName; + private final String resourceName; public GetStreamFromZip(URL url, String resourceName) { this.url = url; diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/CollectionUtils.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/CollectionUtils.java index 4212e69e9..bb256403f 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/CollectionUtils.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/CollectionUtils.java @@ -50,10 +50,7 @@ public abstract class CollectionUtils { Object[] arr$ = arr; int len$ = arr.length; - for (int i$ = 0; i$ < len$; ++i$) { - Object elem = arr$[i$]; - collection.add(elem); - } + collection.addAll(Arrays.asList(arr$).subList(0, len$)); } } @@ -157,7 +154,7 @@ public abstract class CollectionUtils { } public static T findValueOfType(Collection collection, Class type) { - if (isEmpty((Collection) collection)) { + if (isEmpty(collection)) { return null; } else { Object value = null; @@ -179,7 +176,7 @@ public abstract class CollectionUtils { } public static Object findValueOfType(Collection collection, Class[] types) { - if (!isEmpty((Collection) collection) && !ObjectUtils.isEmpty(types)) { + if (!isEmpty(collection) && !ObjectUtils.isEmpty(types)) { Class[] arr$ = types; int len$ = types.length; @@ -260,7 +257,7 @@ public abstract class CollectionUtils { } public static MultiValueMap unmodifiableMultiValueMap(MultiValueMap map) { - Assert.notNull(map, "\'map\' must not be null"); + Assert.notNull(map, "'map' must not be null"); LinkedHashMap result = new LinkedHashMap(map.size()); Iterator unmodifiableMap = map.entrySet().iterator(); @@ -278,7 +275,7 @@ public abstract class CollectionUtils { private final Map> map; public MultiValueMapAdapter(Map> map) { - Assert.notNull(map, "\'map\' must not be null"); + Assert.notNull(map, "'map' must not be null"); this.map = map; } @@ -374,7 +371,7 @@ public abstract class CollectionUtils { } public boolean equals(Object other) { - return this == other ? true : this.map.equals(other); + return this == other || this.map.equals(other); } public int hashCode() { @@ -387,7 +384,7 @@ public abstract class CollectionUtils { } private static class EnumerationIterator implements Iterator { - private Enumeration enumeration; + private final Enumeration enumeration; public EnumerationIterator(Enumeration enumeration) { this.enumeration = enumeration; diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ObjectUtils.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ObjectUtils.java index e1dcf32e9..43f6db46b 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ObjectUtils.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ObjectUtils.java @@ -122,7 +122,7 @@ public abstract class ObjectUtils { } throw new IllegalArgumentException(String.format("constant [%s] does not exist in enum opType %s", - new Object[] {constant, enumValues.getClass().getComponentType().getName()})); + constant, enumValues.getClass().getComponentType().getName())); } public static A[] addObjectToArray(A[] array, O obj) { @@ -479,7 +479,7 @@ public abstract class ObjectUtils { sb.append(", "); } - sb.append(String.valueOf(array[i])); + sb.append(array[i]); } sb.append("}"); @@ -557,7 +557,7 @@ public abstract class ObjectUtils { sb.append(", "); } - sb.append("\'").append(array[i]).append("\'"); + sb.append("'").append(array[i]).append("'"); } sb.append("}"); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java index 5c41dba38..f973d1a26 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java @@ -289,7 +289,7 @@ public abstract class ReflectionUtils { mc.doWith(superIfc); } catch (IllegalAccessException var9) { throw new IllegalStateException( - "Shouldn\'t be illegal to access method \'" + superIfc.getName() + "\': " + var9); + "Shouldn't be illegal to access method '" + superIfc.getName() + "': " + var9); } } } @@ -374,7 +374,7 @@ public abstract class ReflectionUtils { fc.doWith(field); } catch (IllegalAccessException var10) { throw new IllegalStateException( - "Shouldn\'t be illegal to access field \'" + field.getName() + "\': " + var10); + "Shouldn't be illegal to access field '" + field.getName() + "': " + var10); } } } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/StringUtils.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/StringUtils.java index 9f4fecbec..264f76cf5 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/StringUtils.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/StringUtils.java @@ -242,7 +242,7 @@ public abstract class StringUtils { int index = inString.indexOf(oldPattern); for (int patLen = oldPattern.length(); index >= 0; index = inString.indexOf(oldPattern, pos)) { - sb.append(inString.substring(pos, index)); + sb.append(inString, pos, index); sb.append(newPattern); pos = index + patLen; } @@ -276,7 +276,7 @@ public abstract class StringUtils { } public static String quote(String str) { - return str != null ? "\'" + str + "\'" : null; + return str != null ? "'" + str + "'" : null; } public static Object quoteIfString(Object obj) { @@ -536,10 +536,7 @@ public abstract class StringUtils { String[] arr$ = array; int len$ = array.length; - for (int i$ = 0; i$ < len$; ++i$) { - String element = arr$[i$]; - set.add(element); - } + set.addAll(Arrays.asList(arr$).subList(0, len$)); return toStringArray(set); } @@ -656,10 +653,7 @@ public abstract class StringUtils { String[] arr$ = tokens; int len$ = tokens.length; - for (int i$ = 0; i$ < len$; ++i$) { - String token = arr$[i$]; - set.add(token); - } + set.addAll(Arrays.asList(arr$).subList(0, len$)); return set; } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/VfsUtils.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/VfsUtils.java index 93b11b937..cb4cc04b1 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/VfsUtils.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/io/VfsUtils.java @@ -43,14 +43,14 @@ public abstract class VfsUtils { private static Method VFS_METHOD_GET_ROOT_URL = null; private static Method VFS_METHOD_GET_ROOT_URI = null; private static Method VIRTUAL_FILE_METHOD_EXISTS = null; - private static Method VIRTUAL_FILE_METHOD_GET_INPUT_STREAM; - private static Method VIRTUAL_FILE_METHOD_GET_SIZE; - private static Method VIRTUAL_FILE_METHOD_GET_LAST_MODIFIED; - private static Method VIRTUAL_FILE_METHOD_TO_URL; - private static Method VIRTUAL_FILE_METHOD_TO_URI; - private static Method VIRTUAL_FILE_METHOD_GET_NAME; - private static Method VIRTUAL_FILE_METHOD_GET_PATH_NAME; - private static Method VIRTUAL_FILE_METHOD_GET_CHILD; + private static final Method VIRTUAL_FILE_METHOD_GET_INPUT_STREAM; + private static final Method VIRTUAL_FILE_METHOD_GET_SIZE; + private static final Method VIRTUAL_FILE_METHOD_GET_LAST_MODIFIED; + private static final Method VIRTUAL_FILE_METHOD_TO_URL; + private static final Method VIRTUAL_FILE_METHOD_TO_URI; + private static final Method VIRTUAL_FILE_METHOD_GET_NAME; + private static final Method VIRTUAL_FILE_METHOD_GET_PATH_NAME; + private static final Method VIRTUAL_FILE_METHOD_GET_CHILD; protected static Class VIRTUAL_FILE_VISITOR_INTERFACE; protected static Method VIRTUAL_FILE_METHOD_VISIT; private static Method VFS_UTILS_METHOD_IS_NESTED_FILE = null; @@ -122,11 +122,11 @@ public abstract class VfsUtils { } static Object getRelative(URL url) throws IOException { - return invokeVfsMethod(VFS_METHOD_GET_ROOT_URL, null, new Object[] {url}); + return invokeVfsMethod(VFS_METHOD_GET_ROOT_URL, null, url); } static Object getChild(Object vfsResource, String path) throws IOException { - return invokeVfsMethod(VIRTUAL_FILE_METHOD_GET_CHILD, vfsResource, new Object[] {path}); + return invokeVfsMethod(VIRTUAL_FILE_METHOD_GET_CHILD, vfsResource, path); } static File getFile(Object vfsResource) throws IOException { @@ -148,11 +148,11 @@ public abstract class VfsUtils { } static Object getRoot(URI url) throws IOException { - return invokeVfsMethod(VFS_METHOD_GET_ROOT_URI, null, new Object[] {url}); + return invokeVfsMethod(VFS_METHOD_GET_ROOT_URI, null, url); } protected static Object getRoot(URL url) throws IOException { - return invokeVfsMethod(VFS_METHOD_GET_ROOT_URL, null, new Object[] {url}); + return invokeVfsMethod(VFS_METHOD_GET_ROOT_URL, null, url); } protected static Object doGetVisitorAttribute() { @@ -195,8 +195,8 @@ public abstract class VfsUtils { try { String ex = VFS_VER.V3.equals(version) ? "getChild" : "getRoot"; - VFS_METHOD_GET_ROOT_URL = ReflectionUtils.findMethod(vfsClass, ex, new Class[] {URL.class}); - VFS_METHOD_GET_ROOT_URI = ReflectionUtils.findMethod(vfsClass, ex, new Class[] {URI.class}); + VFS_METHOD_GET_ROOT_URL = ReflectionUtils.findMethod(vfsClass, ex, URL.class); + VFS_METHOD_GET_ROOT_URI = ReflectionUtils.findMethod(vfsClass, ex, URI.class); Class virtualFile = loader.loadClass(pkg + "VirtualFile"); VIRTUAL_FILE_METHOD_EXISTS = ReflectionUtils.findMethod(virtualFile, "exists"); VIRTUAL_FILE_METHOD_GET_INPUT_STREAM = ReflectionUtils.findMethod(virtualFile, "openStream"); @@ -208,15 +208,15 @@ public abstract class VfsUtils { VIRTUAL_FILE_METHOD_GET_PATH_NAME = ReflectionUtils.findMethod(virtualFile, "getPathName"); GET_PHYSICAL_FILE = ReflectionUtils.findMethod(virtualFile, "getPhysicalFile"); ex = VFS_VER.V3.equals(version) ? "getChild" : "findChild"; - VIRTUAL_FILE_METHOD_GET_CHILD = ReflectionUtils.findMethod(virtualFile, ex, new Class[] {String.class}); + VIRTUAL_FILE_METHOD_GET_CHILD = ReflectionUtils.findMethod(virtualFile, ex, String.class); Class utilsClass = loader.loadClass(pkg + "VFSUtils"); VFS_UTILS_METHOD_GET_COMPATIBLE_URI = - ReflectionUtils.findMethod(utilsClass, "getCompatibleURI", new Class[] {virtualFile}); + ReflectionUtils.findMethod(utilsClass, "getCompatibleURI", virtualFile); VFS_UTILS_METHOD_IS_NESTED_FILE = - ReflectionUtils.findMethod(utilsClass, "isNestedFile", new Class[] {virtualFile}); + ReflectionUtils.findMethod(utilsClass, "isNestedFile", virtualFile); VIRTUAL_FILE_VISITOR_INTERFACE = loader.loadClass(pkg + "VirtualFileVisitor"); VIRTUAL_FILE_METHOD_VISIT = ReflectionUtils.findMethod(virtualFile, "visit", - new Class[] {VIRTUAL_FILE_VISITOR_INTERFACE}); + VIRTUAL_FILE_VISITOR_INTERFACE); Class visitorAttributesClass = loader.loadClass(pkg + "VisitorAttributes"); VISITOR_ATTRIBUTES_FIELD_RECURSE = ReflectionUtils.findField(visitorAttributesClass, "RECURSE"); } catch (ClassNotFoundException var7) { @@ -224,9 +224,9 @@ public abstract class VfsUtils { } } - private static enum VFS_VER { + private enum VFS_VER { V2, V3; - private VFS_VER() {} + VFS_VER() {} } } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/CounterMap.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/CounterMap.java index 1cc6758e6..597513300 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/CounterMap.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/CounterMap.java @@ -192,7 +192,7 @@ public class CounterMap implements Serializable{ public Iterator> getIterator() { return new Iterator>() { - Iterator outerIt; + final Iterator outerIt; Iterator innerIt; F curKey; diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicBoolean.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicBoolean.java index 6c807feea..13a9e523a 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicBoolean.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicBoolean.java @@ -31,7 +31,7 @@ import java.io.IOException; public class JsonDeserializerAtomicBoolean extends JsonDeserializer { @Override - public AtomicBoolean deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + public AtomicBoolean deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { JsonNode node = jsonParser.getCodec().readTree(jsonParser); boolean value = node.asBoolean(); return new AtomicBoolean(value); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicDouble.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicDouble.java index d777b0072..2b152e750 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicDouble.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicDouble.java @@ -31,7 +31,7 @@ import java.io.IOException; public class JsonDeserializerAtomicDouble extends JsonDeserializer { @Override - public AtomicDouble deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + public AtomicDouble deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { JsonNode node = jsonParser.getCodec().readTree(jsonParser); double value = node.asDouble(); return new AtomicDouble(value); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicBoolean.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicBoolean.java index c10f1bc95..e2d51b105 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicBoolean.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicBoolean.java @@ -30,7 +30,7 @@ import java.io.IOException; public class JsonSerializerAtomicBoolean extends JsonSerializer { @Override - public void serialize(AtomicBoolean atomicDouble, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException, JsonProcessingException { + public void serialize(AtomicBoolean atomicDouble, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { jsonGenerator.writeBoolean(atomicDouble.get()); } } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicDouble.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicDouble.java index 1f9041ccd..9e00819d4 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicDouble.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicDouble.java @@ -30,7 +30,7 @@ import java.io.IOException; public class JsonSerializerAtomicDouble extends JsonSerializer { @Override - public void serialize(AtomicDouble atomicDouble, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException, JsonProcessingException { + public void serialize(AtomicDouble atomicDouble, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { jsonGenerator.writeNumber(atomicDouble.doubleValue()); } } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/Resources.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/Resources.java index f8fa974f4..aec97ba3e 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/Resources.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/Resources.java @@ -31,7 +31,7 @@ import java.util.*; @Slf4j public class Resources { - private static Resources INSTANCE = new Resources(); + private static final Resources INSTANCE = new Resources(); protected final List resolvers; @@ -123,7 +123,7 @@ public class Resources { } throw new IllegalStateException("Cannot resolve resource (not found): none of " + resolvers.size() + - " resolvers can resolve resource \"" + resourcePath + "\" - available resolvers: " + resolvers.toString()); + " resolvers can resolve resource \"" + resourcePath + "\" - available resolvers: " + resolvers); } public InputStream getAsStream(String resourcePath) { @@ -135,7 +135,7 @@ public class Resources { } throw new IllegalStateException("Cannot resolve resource (not found): none of " + resolvers.size() + - " resolvers can resolve resource \"" + resourcePath + "\" - available resolvers: " + resolvers.toString()); + " resolvers can resolve resource \"" + resourcePath + "\" - available resolvers: " + resolvers); } public void copyDir(String directoryPath, File destinationDir) { diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/strumpf/ResourceFile.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/strumpf/ResourceFile.java index 0141be02f..8bdeae89c 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/strumpf/ResourceFile.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/strumpf/ResourceFile.java @@ -118,10 +118,7 @@ public class ResourceFile { Preconditions.checkState(expSha256 != null, "Expected JSON property %s was not found in resource reference file %s", sha256Property, filePath); String actualSha256 = sha256(file); - if (!expSha256.equals(actualSha256)) { - return false; - } - return true; + return expSha256.equals(actualSha256); } /** diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/strumpf/StrumpfResolver.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/strumpf/StrumpfResolver.java index 54ff89459..ba879f740 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/strumpf/StrumpfResolver.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/resources/strumpf/StrumpfResolver.java @@ -94,11 +94,7 @@ public class StrumpfResolver implements Resolver { } cpr = new ClassPathResource(resourcePath); - if (cpr.exists()) { - return true; - } - - return false; + return cpr.exists(); } @Override @@ -116,11 +112,7 @@ public class StrumpfResolver implements Resolver { //Second: Check classpath ClassPathResource cpr = new ClassPathResource(dirPath); - if (cpr.exists()) { - return true; - } - - return false; + return cpr.exists(); } @Override diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/tools/BTools.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/tools/BTools.java index 7e4d06b49..d22b22998 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/tools/BTools.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/tools/BTools.java @@ -272,10 +272,10 @@ public class BTools { // String FormatS = ""; if ( LeadingChar == '0' ) { - FormatS = "%" + LeadingChar + Integer.toString( CharsCount ) + "d"; + FormatS = "%" + LeadingChar + CharsCount + "d"; } else { - FormatS = "%" + Integer.toString( CharsCount ) + "d"; + FormatS = "%" + CharsCount + "d"; } // Result = String.format( FormatS, Value ); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/tools/SIS.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/tools/SIS.java index b10296fcc..a2ee4f925 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/tools/SIS.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/tools/SIS.java @@ -33,7 +33,7 @@ import java.time.format.DateTimeFormatter; public class SIS { // System Informations Saving // - private String baseModuleCode = "SIS"; + private final String baseModuleCode = "SIS"; private String moduleCode = "?"; // private PrintStream out; diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArchiveUtils.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArchiveUtils.java index 317c5a23d..cd682f3b2 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArchiveUtils.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArchiveUtils.java @@ -80,7 +80,7 @@ public class ArchiveUtils { new File(dest).mkdirs(); FileInputStream fin = new FileInputStream(target); int BUFFER = 2048; - byte data[] = new byte[BUFFER]; + byte[] data = new byte[BUFFER]; if (file.endsWith(".zip") || file.endsWith(".jar")) { try(ZipInputStream zis = new ZipInputStream(fin)) { @@ -152,7 +152,7 @@ public class ArchiveUtils { else { int count; try(FileOutputStream fos = new FileOutputStream(dest + File.separator + entry.getName()); - BufferedOutputStream destStream = new BufferedOutputStream(fos, BUFFER);) { + BufferedOutputStream destStream = new BufferedOutputStream(fos, BUFFER)) { while ((count = tarIn.read(data, 0, BUFFER)) != -1) { destStream.write(data, 0, count); } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java index 8a30f0e48..13780f3a6 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java @@ -295,7 +295,7 @@ public class ArrayUtil { public static long[] toLongs(byte[] data) { val ret = new long[data.length]; for (int i = 0; i < ret.length; i++) { - ret[i] = (long) data[i]; + ret[i] = data[i]; } return ret; } @@ -311,7 +311,7 @@ public class ArrayUtil { public static long[] toLongs(short[] data) { val ret = new long[data.length]; for (int i = 0; i < ret.length; i++) { - ret[i] = (long) data[i]; + ret[i] = data[i]; } return ret; } @@ -319,7 +319,7 @@ public class ArrayUtil { public static long[] toLongs(int[] data) { val ret = new long[data.length]; for (int i = 0; i < ret.length; i++) { - ret[i] = (long) data[i]; + ret[i] = data[i]; } return ret; } @@ -1105,7 +1105,7 @@ public class ArrayUtil { public static double[] toDoubles(int[] ints) { double[] ret = new double[ints.length]; for (int i = 0; i < ints.length; i++) - ret[i] = (double) ints[i]; + ret[i] = ints[i]; return ret; } @@ -1119,7 +1119,7 @@ public class ArrayUtil { public static double[] toDoubles(float[] ints) { double[] ret = new double[ints.length]; for (int i = 0; i < ints.length; i++) - ret[i] = (double) ints[i]; + ret[i] = ints[i]; return ret; } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/Index.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/Index.java index cc64e145d..ff91a9a4e 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/Index.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/Index.java @@ -23,14 +23,15 @@ package org.nd4j.common.util; import java.io.Serializable; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; @SuppressWarnings({"rawtypes", "unchecked"}) public class Index implements Serializable { private static final long serialVersionUID = 1160629777026141078L; - private Map objects = new ConcurrentHashMap<>(); - private Map indexes = new ConcurrentHashMap<>(); + private final Map objects = new ConcurrentHashMap<>(); + private final Map indexes = new ConcurrentHashMap<>(); public synchronized boolean add(Object o, int idx) { if (o instanceof String && o.toString().isEmpty()) { @@ -103,9 +104,9 @@ public class Index implements Serializable { Index index = (Index) o; - if (objects != null ? !objects.equals(index.objects) : index.objects != null) + if (!Objects.equals(objects, index.objects)) return false; - return !(indexes != null ? !indexes.equals(index.indexes) : index.indexes != null); + return !(!Objects.equals(indexes, index.indexes)); } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/MathUtils.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/MathUtils.java index 58d72eace..6e249ffbd 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/MathUtils.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/MathUtils.java @@ -163,7 +163,7 @@ public class MathUtils { * @param targetAttribute target attribute vector * @return the correlation coefficient or r */ - public static double correlation(double[] residuals, double targetAttribute[]) { + public static double correlation(double[] residuals, double[] targetAttribute) { double[] predictedValues = new double[residuals.length]; for (int i = 0; i < predictedValues.length; i++) { predictedValues[i] = targetAttribute[i] - residuals[i]; @@ -1042,7 +1042,7 @@ public class MathUtils { */ public static /*@pure@*/ double roundDouble(double value, int afterDecimalPoint) { - double mask = Math.pow(10.0, (double) afterDecimalPoint); + double mask = Math.pow(10.0, afterDecimalPoint); return (double) (Math.round(value * mask)) / mask; }//end roundDouble diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/Rational.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/Rational.java index 404874016..e9914479c 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/Rational.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/Rational.java @@ -234,10 +234,10 @@ class Rational implements Cloneable { public Rational pow(BigInteger exponent) throws NumberFormatException { /* test for overflow */ if (exponent.compareTo(MAX_INT) == 1) { - throw new NumberFormatException("Exponent " + exponent.toString() + " too large."); + throw new NumberFormatException("Exponent " + exponent + " too large."); } if (exponent.compareTo(MIN_INT) == -1) { - throw new NumberFormatException("Exponent " + exponent.toString() + " too small."); + throw new NumberFormatException("Exponent " + exponent + " too small."); } /* promote to the simpler interface above */ return pow(exponent.intValue()); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/SynchronizedTable.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/SynchronizedTable.java index ace0bf5f1..37c16114e 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/SynchronizedTable.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/SynchronizedTable.java @@ -27,7 +27,7 @@ import java.util.Map; import java.util.Set; public class SynchronizedTable implements Table { - private Table wrapped; + private final Table wrapped; public SynchronizedTable(Table wrapped) { this.wrapped = wrapped; diff --git a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java index 88a4ba98d..e75d8638a 100644 --- a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java +++ b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java @@ -44,9 +44,9 @@ public class FunctionalUtilsTest { //[(fish,([],[alex])), (dog,([adam],[steve])), (cat,([adam],[alice]))] Map,List>> assertion = new HashMap<>(); - assertion.put("cat",Pair.of(Arrays.asList("adam"),Arrays.asList("alice"))); - assertion.put("dog",Pair.of(Arrays.asList("adam"),Arrays.asList("steve"))); - assertion.put("fish",Pair.of(Collections.emptyList(),Arrays.asList("alex"))); + assertion.put("cat",Pair.of(Collections.singletonList("adam"), Collections.singletonList("alice"))); + assertion.put("dog",Pair.of(Collections.singletonList("adam"), Collections.singletonList("steve"))); + assertion.put("fish",Pair.of(Collections.emptyList(), Collections.singletonList("alex"))); Map, List>> cogroup = FunctionalUtils.cogroup(leftMap, rightMap); assertEquals(assertion,cogroup); diff --git a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java index 8e7a22e6d..1d49a22ad 100644 --- a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java +++ b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/io/ClassPathResourceTest.java @@ -39,7 +39,7 @@ public class ClassPathResourceTest { ClassPathResource cpr = new ClassPathResource("somedir"); - File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID().toString()); + File f = new File(FileUtils.getTempDirectoryPath()+File.separatorChar+ UUID.randomUUID()); FileUtils.forceMkdir(f); cpr.copyDirectory(f); diff --git a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java index bd3f9c569..201dcf5f5 100644 --- a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java +++ b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java @@ -81,7 +81,7 @@ public class TestFileBatch { //Check that it is indeed a valid zip file: - File f = new File(FileUtils.getTempDirectoryPath()+"/"+UUID.randomUUID().toString()); + File f = new File(FileUtils.getTempDirectoryPath()+"/"+ UUID.randomUUID()); f.delete(); fb.writeAsZip(f); diff --git a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java index 1dc5860d7..19e5dfadc 100644 --- a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java +++ b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java @@ -26,9 +26,9 @@ import static org.junit.jupiter.api.Assertions.*; public class InfoValuesTest { // - private String[] t1_titleA = { "T0", "T1", "T2", "T3", "T4", "T5" }; + private final String[] t1_titleA = { "T0", "T1", "T2", "T3", "T4", "T5" }; // - private String[] t2_titleA = { "", "T1", "T2" }; + private final String[] t2_titleA = { "", "T1", "T2" }; // @Test diff --git a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/tools/SISTest.java b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/tools/SISTest.java index 34953554c..d40d32406 100644 --- a/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/tools/SISTest.java +++ b/cavis-dnn/cavis-dnn-common/src/test/java/org/nd4j/common/tools/SISTest.java @@ -48,7 +48,7 @@ public class SISTest { // assertEquals( 33, fFName.length() ); assertEquals( "Z", fFName.substring( 0, 1 ) ); - assertEquals( "_Test_ABC.txt", fFName.substring( fFName.length() - 13, fFName.length() ) ); + assertEquals( "_Test_ABC.txt", fFName.substring( fFName.length() - 13) ); // assertEquals( "", fFName ); // assertEquals( "", tmpFld.getRoot().getAbsolutePath() ); // diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/datasets/test/TestDataSetIterator.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/datasets/test/TestDataSetIterator.java index 2bda23111..076756b71 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/datasets/test/TestDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/datasets/test/TestDataSetIterator.java @@ -32,7 +32,7 @@ public class TestDataSetIterator implements DataSetIterator { * */ private static final long serialVersionUID = -3042802726018263331L; - private DataSetIterator wrapped; + private final DataSetIterator wrapped; private int numDataSets = 0; @Getter private DataSetPreProcessor preProcessor; diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/evaluation/EvaluationTools.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/evaluation/EvaluationTools.java index 6d4295e1f..8f0712eb9 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/evaluation/EvaluationTools.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/evaluation/EvaluationTools.java @@ -173,7 +173,7 @@ public class EvaluationTools { if (classNames != null && classNames.size() > i) { headerText += " (" + classNames.get(i) + ")"; } - headerText += " vs. All";; + headerText += " vs. All"; Component headerDivPad = new ComponentDiv(HEADER_DIV_PAD_STYLE); components.add(headerDivPad); diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemPolling.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemPolling.java index 954be2065..b74a2382b 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemPolling.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemPolling.java @@ -36,10 +36,10 @@ import java.util.concurrent.TimeUnit; public class SystemPolling { private ScheduledExecutorService scheduledExecutorService; - private long pollEveryMillis; - private File outputDirectory; - private NameProvider nameProvider; - private ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory()); + private final long pollEveryMillis; + private final File outputDirectory; + private final NameProvider nameProvider; + private final ObjectMapper objectMapper = new ObjectMapper(new YAMLFactory()); private SystemPolling(long pollEveryMillis,File outputDirectory,NameProvider nameProvider) { this.pollEveryMillis = pollEveryMillis; diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/parallelism/AsyncIterator.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/parallelism/AsyncIterator.java index 0ecab4f12..9fb604e89 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/parallelism/AsyncIterator.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/parallelism/AsyncIterator.java @@ -102,9 +102,9 @@ public class AsyncIterator implements Iterator { } private class ReaderThread extends Thread implements Runnable { - private BlockingQueue buffer; - private Iterator iterator; - private T terminator; + private final BlockingQueue buffer; + private final Iterator iterator; + private final T terminator; public ReaderThread(Iterator iterator, BlockingQueue buffer, T terminator) { this.buffer = buffer; @@ -133,8 +133,6 @@ public class AsyncIterator implements Iterator { } catch (Exception e) { // TODO: pass that forward throw new RuntimeException(e); - } finally { - //log.info("AsyncReader [{}] stopped", Thread.currentThread().getId()); } } } diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/storage/impl/RemoteUIStatsStorageRouter.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/storage/impl/RemoteUIStatsStorageRouter.java index b67798056..df21ac930 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/storage/impl/RemoteUIStatsStorageRouter.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/storage/impl/RemoteUIStatsStorageRouter.java @@ -74,8 +74,8 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa private transient Thread postThread; - private AtomicBoolean shutdown = new AtomicBoolean(false); - private AtomicLong shutdownWarnCount = new AtomicLong(0); + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final AtomicLong shutdownWarnCount = new AtomicLong(0); private static final ObjectMapper objectMapper = new ObjectMapper(); @@ -368,7 +368,7 @@ public class RemoteUIStatsStorageRouter implements StatsStorageRouter, Serializa in.close(); log.warn("Error posting to remote UI - received response code {}\tContent: {}", response, - response.toString()); + response); return false; } diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/ui/UiConnectionInfo.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/ui/UiConnectionInfo.java index a2c81576c..47689316e 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/ui/UiConnectionInfo.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/ui/UiConnectionInfo.java @@ -49,11 +49,8 @@ public class UiConnectionInfo { * @return */ public String getFirstPart() { - StringBuilder builder = new StringBuilder(); - builder.append(useHttps ? "https" : "http").append("://").append(address).append(":").append(port).append(""); - - return builder.toString(); + return (useHttps ? "https" : "http") + "://" + address + ":" + port; } public String getSecondPart() { @@ -89,7 +86,7 @@ public class UiConnectionInfo { } public static class Builder { - private UiConnectionInfo info = new UiConnectionInfo(); + private final UiConnectionInfo info = new UiConnectionInfo(); /** * This method allows you to specify sessionId for this UiConnectionInfo instance diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/ModelGuesser.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/ModelGuesser.java index 8aabe3fe4..70b250978 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/ModelGuesser.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/ModelGuesser.java @@ -122,7 +122,7 @@ public class ModelGuesser { */ public static Object loadConfigGuess(InputStream stream) throws Exception { String p = System.getProperty(DL4JSystemProperties.DL4J_TEMP_DIR_PROPERTY); - File tmp = DL4JFileUtils.createTempFile("model-" + UUID.randomUUID().toString(), "bin"); + File tmp = DL4JFileUtils.createTempFile("model-" + UUID.randomUUID(), "bin"); BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(tmp)); IOUtils.copy(stream, bufferedOutputStream); bufferedOutputStream.flush(); diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/MovingWindowMatrix.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/MovingWindowMatrix.java index 853a2bd98..d25fdb1fb 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/MovingWindowMatrix.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/MovingWindowMatrix.java @@ -43,7 +43,7 @@ public class MovingWindowMatrix { private int windowRowSize = 28; private int windowColumnSize = 28; - private INDArray toSlice; + private final INDArray toSlice; private boolean addRotate = false; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index 4a037de0d..c7f354937 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -46,7 +46,7 @@ import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException import org.deeplearning4j.datasets.datavec.tools.SpecialImageRecordReader; import org.junit.jupiter.api.Timeout; import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.dataset.AsyncDataSetIterator;; +import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; @@ -98,11 +98,11 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { recordReader.initialize(csv); DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 10, -1, -1, 2); DataSet ds = iter.next(); - assertFalse(ds == null); + assertNotNull(ds); assertEquals(10, ds.numExamples()); iter.hasNext(); iter.next(); - assertEquals(false, iter.hasNext()); + assertFalse(iter.hasNext()); } @Test @@ -841,14 +841,14 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { public void testSeqRRDSIArrayWritableOneReader() { List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), + sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, 1,3)), new IntWritable(0))); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), + sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, 1,3)), new IntWritable(1))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), + sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, 1,3)), new IntWritable(2))); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), + sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, 1,3)), new IntWritable(3))); @@ -874,15 +874,15 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { public void testSeqRRDSIArrayWritableOneReaderRegression() { //Regression, where the output is an array writable List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, new long[]{1,3})))); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, new long[]{1,3})))); + sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, 1,3)), + new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, 1,3)))); + sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, 1,3)), + new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, 1,3)))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, new long[]{1,3})))); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, new long[]{1,3})))); + sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, 1,3)), + new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, 1,3)))); + sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, 1,3)), + new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, 1,3)))); SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); @@ -910,15 +910,15 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { //Input with multiple array writables: List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, new long[]{1,3})), new IntWritable(0))); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, new long[]{1,3})), new IntWritable(1))); + sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, 1,3)), + new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, 1,3)), new IntWritable(0))); + sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, 1,3)), + new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, 1,3)), new IntWritable(1))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, new long[]{1,3})), new IntWritable(2))); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), - new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, new long[]{1,3})), new IntWritable(3))); + sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, 1,3)), + new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, 1,3)), new IntWritable(2))); + sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, 1,3)), + new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, 1,3)), new IntWritable(3))); SequenceRecordReader rr = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); @@ -944,26 +944,26 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { @Test public void testSeqRRDSIArrayWritableTwoReaders() { List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, new long[]{1,3})), + sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {1, 2, 3}, 1,3)), new IntWritable(100))); - sequence1.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, new long[]{1,3})), + sequence1.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {4, 5, 6}, 1,3)), new IntWritable(200))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, new long[]{1,3})), + sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {7, 8, 9}, 1,3)), new IntWritable(300))); - sequence2.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, new long[]{1,3})), + sequence2.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {10, 11, 12}, 1,3)), new IntWritable(400))); SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); List> sequence1L = new ArrayList<>(); - sequence1L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, new long[]{1,3})), + sequence1L.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {100, 200, 300}, 1,3)), new IntWritable(101))); - sequence1L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, new long[]{1,3})), + sequence1L.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {400, 500, 600}, 1,3)), new IntWritable(201))); List> sequence2L = new ArrayList<>(); - sequence2L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, new long[]{1,3})), + sequence2L.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {700, 800, 900}, 1,3)), new IntWritable(301))); - sequence2L.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, new long[]{1,3})), + sequence2L.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {1000, 1100, 1200}, 1,3)), new IntWritable(401))); SequenceRecordReader rrLabels = new CollectionSequenceRecordReader(Arrays.asList(sequence1L, sequence2L)); @@ -1050,12 +1050,12 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { Collection> data = new ArrayList<>(); - data.add(Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), - new NDArrayWritable(Nd4j.create(new double[] {1.1, 2.1, 3.1}, new long[]{1,3})))); - data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(3), - new NDArrayWritable(Nd4j.create(new double[] {4.1, 5.1, 6.1}, new long[]{1,3})))); - data.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(5), - new NDArrayWritable(Nd4j.create(new double[] {7.1, 8.1, 9.1}, new long[]{1,3})))); + data.add(Arrays.asList(new DoubleWritable(0), new DoubleWritable(1), + new NDArrayWritable(Nd4j.create(new double[] {1.1, 2.1, 3.1}, 1,3)))); + data.add(Arrays.asList(new DoubleWritable(2), new DoubleWritable(3), + new NDArrayWritable(Nd4j.create(new double[] {4.1, 5.1, 6.1}, 1,3)))); + data.add(Arrays.asList(new DoubleWritable(4), new DoubleWritable(5), + new NDArrayWritable(Nd4j.create(new double[] {7.1, 8.1, 9.1}, 1,3)))); RecordReader rr = new CollectionRecordReader(data); int batchSize = 3; @@ -1075,12 +1075,12 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { //ALSO: test if we have NDArrayWritables for BOTH the features and the labels data = new ArrayList<>(); - data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {0, 1}, new long[]{1,2})), - new NDArrayWritable(Nd4j.create(new double[] {1.1, 2.1, 3.1}, new long[]{1,3})))); - data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {2, 3}, new long[]{1,2})), - new NDArrayWritable(Nd4j.create(new double[] {4.1, 5.1, 6.1}, new long[]{1,3})))); - data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {4, 5}, new long[]{1,2})), - new NDArrayWritable(Nd4j.create(new double[] {7.1, 8.1, 9.1}, new long[]{1,3})))); + data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {0, 1}, 1,2)), + new NDArrayWritable(Nd4j.create(new double[] {1.1, 2.1, 3.1}, 1,3)))); + data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {2, 3}, 1,2)), + new NDArrayWritable(Nd4j.create(new double[] {4.1, 5.1, 6.1}, 1,3)))); + data.add(Arrays.asList(new NDArrayWritable(Nd4j.create(new double[] {4, 5}, 1,2)), + new NDArrayWritable(Nd4j.create(new double[] {7.1, 8.1, 9.1}, 1,3)))); labelIndexFrom = 1; labelIndexTo = 1; @@ -1203,7 +1203,7 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { //[DoubleWritable, DoubleWritable, NDArrayWritable([1,10]), IntWritable] -> concatenate to a [1,13] feature vector automatically. - List l = Arrays.asList(new DoubleWritable(1), + List l = Arrays.asList(new DoubleWritable(1), new NDArrayWritable(Nd4j.create(new double[] {2, 3, 4})), new DoubleWritable(5), new NDArrayWritable(Nd4j.create(new double[] {6, 7, 8})), new IntWritable(9), new IntWritable(1)); @@ -1241,12 +1241,12 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { //Idea: input vector is like [f,f,f,f,l,l,f,f] or similar - i.e., label writables aren't start/end - List l = Arrays.asList(new DoubleWritable(1), - new NDArrayWritable(Nd4j.create(new float[] {2, 3, 4}, new long[]{1,3})), new DoubleWritable(5), - new NDArrayWritable(Nd4j.create(new float[] {6, 7, 8}, new long[]{1,3}))); + List l = Arrays.asList(new DoubleWritable(1), + new NDArrayWritable(Nd4j.create(new float[] {2, 3, 4}, 1,3)), new DoubleWritable(5), + new NDArrayWritable(Nd4j.create(new float[] {6, 7, 8}, 1,3))); - INDArray expF = Nd4j.create(new float[] {1, 6, 7, 8}, new long[]{1,4}); - INDArray expL = Nd4j.create(new float[] {2, 3, 4, 5}, new long[]{1,4}); + INDArray expF = Nd4j.create(new float[] {1, 6, 7, 8}, 1,4); + INDArray expL = Nd4j.create(new float[] {2, 3, 4, 5}, 1,4); RecordReader rr = new CollectionRecordReader(Collections.singletonList(l)); @@ -1368,12 +1368,12 @@ public class RecordReaderDataSetiteratorTest extends BaseDL4JTest { @Test public void testSeqRRDSINoLabels(){ List> sequence1 = new ArrayList<>(); - sequence1.add(Arrays.asList((Writable) new DoubleWritable(1), new DoubleWritable(2))); - sequence1.add(Arrays.asList((Writable) new DoubleWritable(3), new DoubleWritable(4))); - sequence1.add(Arrays.asList((Writable) new DoubleWritable(5), new DoubleWritable(6))); + sequence1.add(Arrays.asList(new DoubleWritable(1), new DoubleWritable(2))); + sequence1.add(Arrays.asList(new DoubleWritable(3), new DoubleWritable(4))); + sequence1.add(Arrays.asList(new DoubleWritable(5), new DoubleWritable(6))); List> sequence2 = new ArrayList<>(); - sequence2.add(Arrays.asList((Writable) new DoubleWritable(10), new DoubleWritable(20))); - sequence2.add(Arrays.asList((Writable) new DoubleWritable(30), new DoubleWritable(40))); + sequence2.add(Arrays.asList(new DoubleWritable(10), new DoubleWritable(20))); + sequence2.add(Arrays.asList(new DoubleWritable(30), new DoubleWritable(40))); SequenceRecordReader rrFeatures = new CollectionSequenceRecordReader(Arrays.asList(sequence1, sequence2)); SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(rrFeatures, 2, -1, -1); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 7b163892e..0341ac846 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -657,17 +657,17 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { //2 in, 2 out, 3 total sequences of length [1,3,5] List> seq1 = - Arrays.asList(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(2.0))); + Collections.singletonList(Arrays.asList(new DoubleWritable(1.0), new DoubleWritable(2.0))); List> seq2 = - Arrays.asList(Arrays.asList(new DoubleWritable(10.0), new DoubleWritable(11.0)), - Arrays.asList(new DoubleWritable(20.0), new DoubleWritable(21.0)), - Arrays.asList(new DoubleWritable(30.0), new DoubleWritable(31.0))); + Arrays.asList(Arrays.asList(new DoubleWritable(10.0), new DoubleWritable(11.0)), + Arrays.asList(new DoubleWritable(20.0), new DoubleWritable(21.0)), + Arrays.asList(new DoubleWritable(30.0), new DoubleWritable(31.0))); List> seq3 = - Arrays.asList(Arrays.asList(new DoubleWritable(100.0), new DoubleWritable(101.0)), - Arrays.asList(new DoubleWritable(200.0), new DoubleWritable(201.0)), - Arrays.asList(new DoubleWritable(300.0), new DoubleWritable(301.0)), - Arrays.asList(new DoubleWritable(400.0), new DoubleWritable(401.0)), - Arrays.asList(new DoubleWritable(500.0), new DoubleWritable(501.0))); + Arrays.asList(Arrays.asList(new DoubleWritable(100.0), new DoubleWritable(101.0)), + Arrays.asList(new DoubleWritable(200.0), new DoubleWritable(201.0)), + Arrays.asList(new DoubleWritable(300.0), new DoubleWritable(301.0)), + Arrays.asList(new DoubleWritable(400.0), new DoubleWritable(401.0)), + Arrays.asList(new DoubleWritable(500.0), new DoubleWritable(501.0))); Collection>> seqs = Arrays.asList(seq1, seq2, seq3); @@ -732,8 +732,8 @@ public class RecordReaderMultiDataSetIteratorTest extends BaseDL4JTest { features.add(Arrays.asList(l(new DoubleWritable(1)), l(new DoubleWritable(2)), l(new DoubleWritable(3)))); features.add(Arrays.asList(l(new DoubleWritable(4)), l(new DoubleWritable(5)))); - labels.add(Arrays.asList(l(new IntWritable(0)))); - labels.add(Arrays.asList(l(new IntWritable(1)))); + labels.add(Collections.singletonList(l(new IntWritable(0)))); + labels.add(Collections.singletonList(l(new IntWritable(1)))); CollectionSequenceRecordReader fR = new CollectionSequenceRecordReader(features); CollectionSequenceRecordReader lR = new CollectionSequenceRecordReader(labels); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/tools/SpecialImageRecordReader.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/tools/SpecialImageRecordReader.java index e1ec7d90b..545be93e3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/tools/SpecialImageRecordReader.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/tools/SpecialImageRecordReader.java @@ -39,11 +39,15 @@ import java.util.concurrent.atomic.AtomicInteger; @Slf4j public class SpecialImageRecordReader extends ImageRecordReader { - private AtomicInteger counter = new AtomicInteger(0); - private AtomicInteger labelsCounter = new AtomicInteger(0); - private int limit, channels, width, height, numClasses; - private List labels = new ArrayList<>(); - private INDArray zFeatures; + private final AtomicInteger counter = new AtomicInteger(0); + private final AtomicInteger labelsCounter = new AtomicInteger(0); + private final int limit; + private final int channels; + private final int width; + private final int height; + private final int numClasses; + private final List labels = new ArrayList<>(); + private final INDArray zFeatures; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIteratorTest.java index a67078e8a..471515e30 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIteratorTest.java @@ -63,7 +63,7 @@ public class AbstractDataSetIteratorTest extends BaseDL4JTest { @Override public Iterator> iterator() { return new Iterator>() { - private AtomicInteger cnt = new AtomicInteger(0); + private final AtomicInteger cnt = new AtomicInteger(0); @Override public boolean hasNext() { @@ -72,8 +72,8 @@ public class AbstractDataSetIteratorTest extends BaseDL4JTest { @Override public Pair next() { - float features[] = new float[numColumns]; - float labels[] = new float[numColumns]; + float[] features = new float[numColumns]; + float[] labels = new float[numColumns]; for (int i = 0; i < numColumns; i++) { features[i] = (float) i; labels[i] = RandomUtils.nextFloat(0, 5); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java index e999aee23..c4157cddd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncDataSetIteratorTest.java @@ -188,7 +188,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { DataSet ds = adsi.next(); //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); - assertEquals((double) cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10, "Failed on epoch " + e + "; iteration: " + cnt + ";"); + assertEquals(cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10, "Failed on epoch " + e + "; iteration: " + cnt + ";"); assertEquals( (double) cnt + 0.25, ds.getLabels().meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); assertEquals((double) cnt + 0.5, @@ -219,7 +219,7 @@ public class AsyncDataSetIteratorTest extends BaseDL4JTest { ds.detach(); //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); - assertEquals((double) cnt, + assertEquals(cnt, ds.getFeatures().meanNumber().doubleValue(), 1e-10, "Failed on epoch " + e + "; iteration: " + cnt + ";"); assertEquals((double) cnt + 0.25, ds.getLabels().meanNumber().doubleValue(), 1e-10, "Failed on epoch " + e + "; iteration: " + cnt + ";"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java index 2952382b7..5887bfe90 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/AsyncMultiDataSetIteratorTest.java @@ -57,7 +57,7 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); - assertEquals( (double) cnt, + assertEquals(cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10,"Failed on epoch " + e + "; iteration: " + cnt + ";"); assertEquals((double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10, "Failed on epoch " + e + "; iteration: " + cnt + ";"); @@ -96,7 +96,7 @@ public class AsyncMultiDataSetIteratorTest extends BaseDL4JTest { //log.info("Features ptr: {}", AtomicAllocator.getInstance().getPointer(mds.getFeatures()[0].data()).address()); - assertEquals((double) cnt, + assertEquals(cnt, mds.getFeatures()[0].meanNumber().doubleValue(), 1e-10, "Failed on epoch " + e + "; iteration: " + cnt + ";"); assertEquals((double) cnt + 0.25, mds.getLabels()[0].meanNumber().doubleValue(), 1e-10, "Failed on epoch " + e + "; iteration: " + cnt + ";"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java index 138298e89..dc9b3ffcf 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java @@ -70,7 +70,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { while (iris.hasNext()) { irisC++; DataSet ds = iris.next(); - assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0); + assertEquals(1.0, ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0)); } assertEquals(5, irisC); } @@ -84,7 +84,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { while (mnist.hasNext()) { mnistC++; DataSet ds = mnist.next(); - assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0); + assertEquals(1.0, ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0)); } assertEquals(5, mnistC); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java index 26302914e..b98c31929 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetSplitterTests.java @@ -245,7 +245,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { trained = true; val ds = trainIter.next(); assertNotNull(ds); - assertEquals( (double) globalIter, ds.getFeatures().getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); globalIter++; } assertTrue(trained, "Failed at epoch [" + e + "]"); @@ -260,7 +260,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { val ds = testIter.next(); assertNotNull(ds); - assertEquals((double) globalIter, ds.getFeatures().getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); globalIter++; } assertTrue(tested, "Failed at epoch [" + e + "]"); @@ -275,7 +275,7 @@ public class DataSetSplitterTests extends BaseDL4JTest { val ds = validationIter.next(); assertNotNull(ds); - assertEquals((double) globalIter, ds.getFeatures().getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures().getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); globalIter++; } assertTrue(validated, "Failed at epoch [" + e + "]"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java index d95c63ce7..559865b22 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIteratorTest.java @@ -51,7 +51,7 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { List seenData = new ArrayList<>(); while (earlyEndIter.hasNext()) { DataSet path = earlyEndIter.next(); - assertFalse(path == null); + assertNotNull(path); seenData.add(path); batchesSeen++; } @@ -76,10 +76,10 @@ public class EarlyTerminationDataSetIteratorTest extends BaseDL4JTest { EarlyTerminationDataSetIterator earlyEndIter = new EarlyTerminationDataSetIterator(iter, terminateAfter); earlyEndIter.next(10); - assertEquals(false, earlyEndIter.hasNext()); + assertFalse(earlyEndIter.hasNext()); earlyEndIter.reset(); - assertEquals(true, earlyEndIter.hasNext()); + assertTrue(earlyEndIter.hasNext()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java index b05240ac7..f5e956653 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIteratorTest.java @@ -32,8 +32,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { @@ -90,10 +89,10 @@ public class EarlyTerminationMultiDataSetIteratorTest extends BaseDL4JTest { new EarlyTerminationMultiDataSetIterator(iter, terminateAfter); earlyEndIter.next(10); - assertEquals(false, earlyEndIter.hasNext()); + assertFalse(earlyEndIter.hasNext()); earlyEndIter.reset(); - assertEquals(true, earlyEndIter.hasNext()); + assertTrue(earlyEndIter.hasNext()); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/JointParallelDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/JointParallelDataSetIteratorTest.java index 6705f6430..6bc8d8a6c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/JointParallelDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/JointParallelDataSetIteratorTest.java @@ -59,7 +59,7 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest { // ds.detach(); //ds.migrate(); - assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); + assertEquals(example, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); cnt++; @@ -96,7 +96,7 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest { nulls++; if (cnt % 2 == 2) { - assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); + assertEquals(example, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); } @@ -130,7 +130,7 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest { DataSet ds = jpdsi.next(); assertNotNull( ds, "Failed on iteration " + cnt); - assertEquals((double) example, ds.getFeatures().meanNumber().doubleValue(),0.001, "Failed on iteration " + cnt); + assertEquals(example, ds.getFeatures().meanNumber().doubleValue(),0.001, "Failed on iteration " + cnt); assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); @@ -169,14 +169,14 @@ public class JointParallelDataSetIteratorTest extends BaseDL4JTest { assertNotNull( ds, "Failed on iteration " + cnt); if (cnt % 2 == 0) { - assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); + assertEquals(example, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); assertEquals((double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); } else { if (cnt <= 200) { - assertEquals( (double) example, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); + assertEquals(example, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); assertEquals( (double) example + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt); } else { - assertEquals( (double) example_sec, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt + ", second iteration " + cnt_sec); + assertEquals(example_sec, ds.getFeatures().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt + ", second iteration " + cnt_sec); assertEquals((double) example_sec + 0.5, ds.getLabels().meanNumber().doubleValue(), 0.001, "Failed on iteration " + cnt + ", second iteration " + cnt_sec); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java index 27ffe5bba..26dc7803a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/MultiDataSetSplitterTests.java @@ -185,7 +185,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); } globalIter++; } @@ -202,7 +202,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); } globalIter++; } @@ -219,7 +219,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); } globalIter++; } @@ -298,7 +298,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); } globalIter++; } @@ -314,7 +314,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { val ds = testIter.next(); assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); } globalIter++; } @@ -331,7 +331,7 @@ public class MultiDataSetSplitterTests extends BaseDL4JTest { assertNotNull(ds); for (int i = 0; i < ds.getFeatures().length; ++i) { - assertEquals((double) globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); + assertEquals(globalIter, ds.getFeatures()[i].getDouble(0), 1e-5f, "Failed at iteration [" + globalIter + "]"); } globalIter++; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java index 3b221afd9..69a33aa75 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIteratorTest.java @@ -54,7 +54,7 @@ public class MultipleEpochsIteratorTest extends BaseDL4JTest { assertTrue(multiIter.hasNext()); while (multiIter.hasNext()) { DataSet path = multiIter.next(); - assertFalse(path == null); + assertNotNull(path); } assertEquals(epochs, multiIter.epochs); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/TestAsyncIterator.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/TestAsyncIterator.java index 199953dbc..6d2097a4b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/TestAsyncIterator.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/TestAsyncIterator.java @@ -175,9 +175,9 @@ public class TestAsyncIterator extends BaseDL4JTest { private static class TestIterator implements DataSetIterator { - private int size; + private final int size; private int cursor; - private long delayMSOnNext; + private final long delayMSOnNext; private TestIterator(int size, long delayMSOnNext) { this.size = size; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/SimpleVariableGenerator.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/SimpleVariableGenerator.java index 172a80167..cf0f578af 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/SimpleVariableGenerator.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/tools/SimpleVariableGenerator.java @@ -30,13 +30,13 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; public class SimpleVariableGenerator implements DataSetIterator { - private long seed; - private int numBatches; - private int batchSize; - private int numFeatures; - private int numLabels; + private final long seed; + private final int numBatches; + private final int batchSize; + private final int numFeatures; + private final int numLabels; - private AtomicInteger counter = new AtomicInteger(0); + private final AtomicInteger counter = new AtomicInteger(0); public SimpleVariableGenerator(long seed, int numBatches, int batchSize, int numFeatures, int numLabels) { this.seed = seed; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java index 2774e9961..13ae46efb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java @@ -537,7 +537,7 @@ public class TestEarlyStopping extends BaseDL4JTest { private static class LoggingEarlyStoppingListener implements EarlyStoppingListener { - private static Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); + private static final Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); private int onStartCallCount = 0; private int onEpochCallCount = 0; private int onCompletionCallCount = 0; @@ -852,7 +852,7 @@ public class TestEarlyStopping extends BaseDL4JTest { int outputs = 2; DataSet ds = new DataSet( - Nd4j.rand(new int[]{3, 10, 50}), + Nd4j.rand(3, 10, 50), TestUtils.randomOneHotTimeSeries(3, outputs, 50, 12345)); DataSetIterator train = new ExistingDataSetIterator( Arrays.asList(ds, ds, ds, ds, ds, ds, ds, ds, ds, ds)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java index 1a02ffd7f..4209f8dd3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java @@ -263,7 +263,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { private static class LoggingEarlyStoppingListener implements EarlyStoppingListener { - private static Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); + private static final Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); private int onStartCallCount = 0; private int onEpochCallCount = 0; private int onCompletionCallCount = 0; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index c33a69c87..024804c0c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -328,10 +328,10 @@ public class EvalTest extends BaseDL4JTest { for(boolean useMask : new boolean[]{false, true}) { - INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength}); + INDArray in1 = Nd4j.rand(3, nIn, tsLength); INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength); - INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength}); + INDArray in2 = Nd4j.rand(5, nIn, tsLength); INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength); INDArray lMask1 = null; @@ -409,10 +409,10 @@ public class EvalTest extends BaseDL4JTest { for (boolean useMask : new boolean[]{false, true}) { - INDArray in1 = Nd4j.rand(new int[]{3, nIn, tsLength}); + INDArray in1 = Nd4j.rand(3, nIn, tsLength); INDArray out1 = TestUtils.randomOneHotTimeSeries(3, nOut, tsLength); - INDArray in2 = Nd4j.rand(new int[]{5, nIn, tsLength}); + INDArray in2 = Nd4j.rand(5, nIn, tsLength); INDArray out2 = TestUtils.randomOneHotTimeSeries(5, nOut, tsLength); INDArray lMask1 = null; @@ -442,11 +442,11 @@ public class EvalTest extends BaseDL4JTest { @Test public void testEvalSplitting2(){ List> seqFeatures = new ArrayList<>(); - List step = Arrays.asList(new FloatWritable(0), new FloatWritable(0), new FloatWritable(0)); + List step = Arrays.asList(new FloatWritable(0), new FloatWritable(0), new FloatWritable(0)); for( int i=0; i<30; i++ ){ seqFeatures.add(step); } - List> seqLabels = Collections.singletonList(Collections.singletonList(new FloatWritable(0))); + List> seqLabels = Collections.singletonList(Collections.singletonList(new FloatWritable(0))); SequenceRecordReader fsr = new CollectionSequenceRecordReader(Collections.singletonList(seqFeatures)); SequenceRecordReader lsr = new CollectionSequenceRecordReader(Collections.singletonList(seqLabels)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/ROCTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/ROCTest.java index 629ce0d9b..5684a76d6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/ROCTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/ROCTest.java @@ -47,8 +47,8 @@ import static org.junit.jupiter.api.Assertions.*; public class ROCTest extends BaseDL4JTest { - private static Map expTPR; - private static Map expFPR; + private static final Map expTPR; + private static final Map expFPR; static { expTPR = new HashMap<>(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java index 23e69502c..b5e2b994e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java @@ -102,7 +102,7 @@ public class RegressionEvalTest extends BaseDL4JTest { re.eval(l, predictions, mask); - double[] mse = new double[] {(10 * 10) / 1.0, (2 * 2 + 20 * 20 + 10 * 10) / 3, (3 * 3) / 1.0}; + double[] mse = new double[] {(10 * 10), (2 * 2 + 20 * 20 + 10 * 10) / 3, (3 * 3)}; double[] mae = new double[] {10.0, (2 + 20 + 10) / 3.0, 3.0}; @@ -118,11 +118,11 @@ public class RegressionEvalTest extends BaseDL4JTest { @Test public void testRegressionEvalTimeSeriesSplit(){ - INDArray out1 = Nd4j.rand(new int[]{3, 5, 20}); + INDArray out1 = Nd4j.rand(3, 5, 20); INDArray outSub1 = out1.get(all(), all(), interval(0,10)); INDArray outSub2 = out1.get(all(), all(), interval(10, 20)); - INDArray label1 = Nd4j.rand(new int[]{3, 5, 20}); + INDArray label1 = Nd4j.rand(3, 5, 20); INDArray labelSub1 = label1.get(all(), all(), interval(0,10)); INDArray labelSub2 = label1.get(all(), all(), interval(10, 20)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java index 868ec0809..3ec50df59 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java @@ -37,6 +37,7 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -45,8 +46,8 @@ public class TestRecordReaders extends BaseDL4JTest { @Test public void testClassIndexOutsideOfRangeRRDSI() { Collection> c = new ArrayList<>(); - c.add(Arrays.asList(new DoubleWritable(0.5), new IntWritable(0))); - c.add(Arrays.asList(new DoubleWritable(1.0), new IntWritable(2))); + c.add(Arrays.asList(new DoubleWritable(0.5), new IntWritable(0))); + c.add(Arrays.asList(new DoubleWritable(1.0), new IntWritable(2))); CollectionRecordReader crr = new CollectionRecordReader(c); @@ -67,13 +68,13 @@ public class TestRecordReaders extends BaseDL4JTest { Collection>> c = new ArrayList<>(); Collection> seq1 = new ArrayList<>(); - seq1.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(0))); - seq1.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(1))); + seq1.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(0))); + seq1.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(1))); c.add(seq1); Collection> seq2 = new ArrayList<>(); - seq2.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(0))); - seq2.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(2))); + seq2.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(0))); + seq2.add(Arrays.asList(new DoubleWritable(0.0), new IntWritable(2))); c.add(seq2); CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c); @@ -94,24 +95,24 @@ public class TestRecordReaders extends BaseDL4JTest { Collection>> c1 = new ArrayList<>(); Collection> seq1 = new ArrayList<>(); - seq1.add(Arrays.asList(new DoubleWritable(0.0))); - seq1.add(Arrays.asList(new DoubleWritable(0.0))); + seq1.add(Collections.singletonList(new DoubleWritable(0.0))); + seq1.add(Collections.singletonList(new DoubleWritable(0.0))); c1.add(seq1); Collection> seq2 = new ArrayList<>(); - seq2.add(Arrays.asList(new DoubleWritable(0.0))); - seq2.add(Arrays.asList(new DoubleWritable(0.0))); + seq2.add(Collections.singletonList(new DoubleWritable(0.0))); + seq2.add(Collections.singletonList(new DoubleWritable(0.0))); c1.add(seq2); Collection>> c2 = new ArrayList<>(); Collection> seq1a = new ArrayList<>(); - seq1a.add(Arrays.asList(new IntWritable(0))); - seq1a.add(Arrays.asList(new IntWritable(1))); + seq1a.add(Collections.singletonList(new IntWritable(0))); + seq1a.add(Collections.singletonList(new IntWritable(1))); c2.add(seq1a); Collection> seq2a = new ArrayList<>(); - seq2a.add(Arrays.asList(new IntWritable(0))); - seq2a.add(Arrays.asList(new IntWritable(2))); + seq2a.add(Collections.singletonList(new IntWritable(0))); + seq2a.add(Collections.singletonList(new IntWritable(2))); c2.add(seq2a); CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java index 739168b31..e375aa180 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java @@ -62,7 +62,7 @@ public class AttentionLayerTest extends BaseDL4JTest { for (int mb : new int[]{1, 3}) { for (boolean inputMask : new boolean[]{false, true}) { for (boolean projectInput : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(DataType.DOUBLE, mb, nIn, tsLength); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); @@ -123,7 +123,7 @@ public class AttentionLayerTest extends BaseDL4JTest { for (boolean inputMask : new boolean[]{false, true}) { for (int mb : new int[]{3, 1}) { for (boolean projectInput : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(DataType.DOUBLE, mb, nIn, tsLength); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); @@ -205,7 +205,7 @@ public class AttentionLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); for (int mb : new int[]{3, 1}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(DataType.DOUBLE, mb, nIn, tsLength); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); @@ -257,11 +257,11 @@ public class AttentionLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - final INDArray initialInput = Nd4j.rand(new int[]{8, nIn, 7}); - final INDArray goodNextInput = Nd4j.rand(new int[]{8, nIn, 7}); - final INDArray badNextInput = Nd4j.rand(new int[]{8, nIn, 12}); + final INDArray initialInput = Nd4j.rand(8, nIn, 7); + final INDArray goodNextInput = Nd4j.rand(8, nIn, 7); + final INDArray badNextInput = Nd4j.rand(8, nIn, 12); - final INDArray labels = Nd4j.rand(new int[]{8, nOut}); + final INDArray labels = Nd4j.rand(8, nOut); net.fit(initialInput, labels); net.fit(goodNextInput, labels); @@ -281,7 +281,7 @@ public class AttentionLayerTest extends BaseDL4JTest { for (int mb : new int[]{3, 1}) { for (boolean inputMask : new boolean[]{true, false}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(DataType.DOUBLE, mb, nIn, tsLength); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); @@ -339,7 +339,7 @@ public class AttentionLayerTest extends BaseDL4JTest { for (boolean inputMask : new boolean[]{false, true}) { for (int mb : new int[]{3, 1}) { for (boolean projectInput : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(DataType.DOUBLE, mb, nIn, tsLength); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); @@ -403,7 +403,7 @@ public class AttentionLayerTest extends BaseDL4JTest { for (boolean inputMask : new boolean[]{false, true}) { for (int mb : new int[]{3, 1}) { for (boolean projectInput : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(mb, nIn, tsLength); INDArray labels = TestUtils.randomOneHot(mb, nOut); String maskType = (inputMask ? "inputMask" : "none"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index 3d945b27e..65f8787d8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -111,7 +111,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { int depth = 1; int hw = 4; int nOut = 4; - INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); + INDArray input = Nd4j.rand(minibatch, depth, hw, hw); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { @@ -171,7 +171,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { int depth = 2; int hw = 5; int nOut = 2; - INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}).muli(5).subi(2.5); + INDArray input = Nd4j.rand(minibatch, depth, hw, hw).muli(5).subi(2.5); INDArray labels = TestUtils.randomOneHot(minibatch, nOut); DataSet ds = new DataSet(input, labels); @@ -277,7 +277,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { int minibatch = 10; int nIn = 5; int nOut = 3; - INDArray input = Nd4j.rand(new int[]{minibatch, nIn}); + INDArray input = Nd4j.rand(minibatch, nIn); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { @@ -406,7 +406,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { int depth = 1; int hw = 4; int nOut = 4; - INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); + INDArray input = Nd4j.rand(minibatch, depth, hw, hw); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { @@ -470,7 +470,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { net.init(); Random r = new Random(12345); - INDArray input = Nd4j.rand(new int[]{minibatchSize, channels, height, width}); //Order: examples, channels, height, width + INDArray input = Nd4j.rand(minibatchSize, channels, height, width); //Order: examples, channels, height, width INDArray labels = Nd4j.zeros(minibatchSize, numClasses); for (int i = 0; i < minibatchSize; i++) { labels.putScalar(new int[]{i, r.nextInt(numClasses)}, 1.0); @@ -510,7 +510,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { int depth = 2; int hw = 5; int nOut = 3; - INDArray input = Nd4j.rand(new int[]{minibatch, depth, hw, hw}); + INDArray input = Nd4j.rand(minibatch, depth, hw, hw); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index 094034320..b61c1fe24 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -82,7 +82,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { for (Activation afn : activations) { for (int minibatchSize : minibatchSizes) { for (int kernel : kernels) { - INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); + INDArray input = Nd4j.rand(minibatchSize, convNIn, length); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length); for (int i = 0; i < minibatchSize; i++) { for (int j = 0; j < length; j++) { @@ -162,7 +162,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { for (int kernel : kernels) { - INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); + INDArray input = Nd4j.rand(minibatchSize, convNIn, length); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, croppedLength); for (int i = 0; i < minibatchSize; i++) { for (int j = 0; j < croppedLength; j++) { @@ -243,7 +243,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { for (int kernel : kernels) { - INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); + INDArray input = Nd4j.rand(minibatchSize, convNIn, length); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, paddedLength); for (int i = 0; i < minibatchSize; i++) { for (int j = 0; j < paddedLength; j++) { @@ -322,7 +322,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { for (SubsamplingLayer.PoolingType poolingType : poolingTypes) { for (int minibatchSize : minibatchSizes) { for (int kernel : kernels) { - INDArray input = Nd4j.rand(new int[]{minibatchSize, convNIn, length}); + INDArray input = Nd4j.rand(minibatchSize, convNIn, length); INDArray labels = Nd4j.zeros(minibatchSize, finalNOut, length); for (int i = 0; i < minibatchSize; i++) { for (int j = 0; j < length; j++) { @@ -418,7 +418,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray f = Nd4j.rand(new int[]{2, convNIn, length}); + INDArray f = Nd4j.rand(2, convNIn, length); INDArray fm = Nd4j.create(2, length); fm.get(NDArrayIndex.point(0), NDArrayIndex.all()).assign(1); fm.get(NDArrayIndex.point(1), NDArrayIndex.interval(0,6)).assign(1); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index 2c8f4dead..4d3de0bfb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -103,9 +103,9 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { INDArray input; if(df == Convolution3D.DataFormat.NDHWC){ - input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn}); + input = Nd4j.rand(miniBatchSize, depth, height, width, convNIn); } else { - input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + input = Nd4j.rand(miniBatchSize, convNIn, depth, height, width); } INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int i = 0; i < miniBatchSize; i++) { @@ -142,7 +142,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", stride = " - + Arrays.toString(stride) + ", mode = " + mode.toString() + + Arrays.toString(stride) + ", mode = " + mode + ", input depth " + depth + ", input height " + height + ", input width " + width; @@ -209,7 +209,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { outHeight += zeroPadding[2] + zeroPadding[3]; outWidth += zeroPadding[4] + zeroPadding[5]; - INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + INDArray input = Nd4j.rand(miniBatchSize, convNIn, depth, height, width); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int i = 0; i < miniBatchSize; i++) { labels.putScalar(new int[]{i, i % finalNOut}, 1.0); @@ -245,7 +245,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { net.init(); String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode + ", input depth " + depth + ", input height " + height + ", input width " + width; @@ -337,7 +337,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { net.init(); String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode + ", input depth " + depth + ", input height " + height + ", input width " + width + ", dataFormat=" + df; @@ -424,7 +424,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { net.init(); String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(upsamplingSize) + ", mode = " + mode.toString() + + ", kernel = " + Arrays.toString(upsamplingSize) + ", mode = " + mode + ", input depth " + depth + ", input height " + height + ", input width " + width; @@ -487,7 +487,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { outHeight -= cropping[2] + cropping[3]; outWidth -= cropping[4] + cropping[5]; - INDArray input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + INDArray input = Nd4j.rand(miniBatchSize, convNIn, depth, height, width); INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int i = 0; i < miniBatchSize; i++) { labels.putScalar(new int[]{i, i % finalNOut}, 1.0); @@ -523,7 +523,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { net.init(); String msg = "Minibatch size = " + miniBatchSize + ", activationFn=" + afn - + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode.toString() + + ", kernel = " + Arrays.toString(kernel) + ", mode = " + mode + ", input depth " + depth + ", input height " + height + ", input width " + width; @@ -583,9 +583,9 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { INDArray input; if (df == Convolution3D.DataFormat.NDHWC) { - input = Nd4j.rand(new int[]{miniBatchSize, depth, height, width, convNIn}); + input = Nd4j.rand(miniBatchSize, depth, height, width, convNIn); } else { - input = Nd4j.rand(new int[]{miniBatchSize, convNIn, depth, height, width}); + input = Nd4j.rand(miniBatchSize, convNIn, depth, height, width); } INDArray labels = Nd4j.zeros(miniBatchSize, finalNOut); for (int j = 0; j < miniBatchSize; j++) { @@ -618,7 +618,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { String msg = "DataFormat = " + df + ", minibatch size = " + miniBatchSize + ", activationFn=" + afn + ", kernel = " + Arrays.toString(kernel) + ", stride = " - + Arrays.toString(stride) + ", mode = " + mode.toString() + + Arrays.toString(stride) + ", mode = " + mode + ", input depth " + depth + ", input height " + height + ", input width " + width; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index 3772741d5..b9536ee41 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -63,7 +63,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - private CNN2DFormat format; + private final CNN2DFormat format; public CNNGradientCheckTest(CNN2DFormat format){ this.format = format; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java index 193ede7ac..9aafd297c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java @@ -125,7 +125,7 @@ public class DropoutGradientCheck extends BaseDL4JTest { INDArray f; if(cnn){ - f = Nd4j.rand(new int[]{minibatch, 2, 6, 6}).muli(10).subi(5); + f = Nd4j.rand(minibatch, 2, 6, 6).muli(10).subi(5); } else { f = Nd4j.rand(minibatch, 6).muli(10).subi(5); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java index f4b9d4dc5..7cb10f83b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java @@ -276,7 +276,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { mln.init(); Random r = new Random(12345L); - INDArray input = Nd4j.rand(new int[] {miniBatchSize, inputDepth, inputH, inputW}).subi(0.5); + INDArray input = Nd4j.rand(miniBatchSize, inputDepth, inputH, inputW).subi(0.5); INDArray inputMask; if (miniBatchSize == 1) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index be641898e..ec99f3852 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -472,7 +472,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf); graph.init(); - INDArray input = Nd4j.rand(new int[] {batchSize, inLength, timeSeriesLength}); + INDArray input = Nd4j.rand(batchSize, inLength, timeSeriesLength); INDArray labels = TestUtils.randomOneHotTimeSeries(batchSize, 2, timeSeriesLength); if (PRINT_RESULTS) { @@ -509,7 +509,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { graph.init(); Random r = new Random(12345); - INDArray input = Nd4j.rand(new int[] {2, 3, 4}); + INDArray input = Nd4j.rand(2, 3, 4); INDArray labels = TestUtils.randomOneHot(2, 2); //Here: labels are 2d (due to LastTimeStepVertex) if (PRINT_RESULTS) { @@ -572,8 +572,8 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { graph.init(); Random r = new Random(12345); - INDArray input1 = Nd4j.rand(new int[] {batchSize, 3, 4}); - INDArray input2 = Nd4j.rand(new int[] {batchSize, 2, 4}); + INDArray input1 = Nd4j.rand(batchSize, 3, 4); + INDArray input2 = Nd4j.rand(batchSize, 2, 4); INDArray labels = TestUtils.randomOneHotTimeSeries(batchSize, outSize, timeSeriesLength); if (PRINT_RESULTS) { @@ -622,7 +622,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { graph.init(); Random r = new Random(12345); - INDArray input = Nd4j.rand(new int[] {2, 2, 4}); + INDArray input = Nd4j.rand(2, 2, 4); INDArray labels = TestUtils.randomOneHotTimeSeries(2, 2, 4); if (PRINT_RESULTS) { @@ -813,7 +813,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int[] minibatchSizes = {1, 3}; for (int mb : minibatchSizes) { - INDArray input = Nd4j.rand(new int[] {mb, 2, inH, inW}).muli(4); //Order: examples, channels, height, width + INDArray input = Nd4j.rand(mb, 2, inH, inW).muli(4); //Order: examples, channels, height, width INDArray out = Nd4j.rand(mb, 2); String msg = "testMultipleOutputsMergeVertex() - minibatchSize = " + mb; @@ -991,7 +991,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray example = Nd4j.rand(new int[] {150, inputDepth, inputH, inputW}); + INDArray example = Nd4j.rand(150, inputDepth, inputH, inputW); INDArray labels = Nd4j.zeros(150, numLabels); Random r = new Random(12345); @@ -1001,7 +1001,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (train) { for (int i = 0; i < 10; i++) { - INDArray f = Nd4j.rand(new int[] {10, inputDepth, inputH, inputW}); + INDArray f = Nd4j.rand(10, inputDepth, inputH, inputW); INDArray l = Nd4j.zeros(10, numLabels); for (int j = 0; j < 10; j++) { l.putScalar(j, r.nextInt(numLabels), 1.0); @@ -1227,15 +1227,15 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int[] mbSizes = new int[] {1, 2, 3}; for (int minibatch : mbSizes) { - INDArray in1 = Nd4j.rand(new int[] {minibatch, layerSizes, 4}); - INDArray in2 = Nd4j.rand(new int[] {minibatch, layerSizes, 5}); + INDArray in1 = Nd4j.rand(minibatch, layerSizes, 4); + INDArray in2 = Nd4j.rand(minibatch, layerSizes, 5); INDArray inMask1 = Nd4j.zeros(minibatch, 4); inMask1.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 3)).assign(1); INDArray inMask2 = Nd4j.zeros(minibatch, 5); inMask2.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 4)).assign(1); - INDArray labels1 = Nd4j.rand(new int[] {minibatch, 2}); - INDArray labels2 = Nd4j.rand(new int[] {minibatch, 2}); + INDArray labels1 = Nd4j.rand(minibatch, 2); + INDArray labels2 = Nd4j.rand(minibatch, 2); String testName = "testBasicStackUnstackVariableLengthTS() - minibatch = " + minibatch; @@ -1389,7 +1389,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int[] mbSizes = new int[] {1, 3, 10}; for (int minibatch : mbSizes) { - INDArray in1 = Nd4j.rand(new int[] {minibatch, dIn, h, w}); + INDArray in1 = Nd4j.rand(minibatch, dIn, h, w); INDArray labels1 = Nd4j.rand(minibatch, 2); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index a444e1146..4efd20ee7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -411,7 +411,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray f = Nd4j.rand(new int[]{mb, 3, tsLength}); + INDArray f = Nd4j.rand(mb, 3, tsLength); INDArray l = TestUtils.randomOneHot(mb, 3); INDArray lm = TestUtils.randomBernoulli(mb, 1); @@ -468,7 +468,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); net.init(); - INDArray f = Nd4j.rand(new int[]{mb, 3, tsLength}); + INDArray f = Nd4j.rand(mb, 3, tsLength); INDArray l = TestUtils.randomOneHot(mb, 3); INDArray lm = TestUtils.randomBernoulli(mb, 1); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java index ad1b564db..9d982818a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java @@ -61,7 +61,7 @@ public class LRNGradientCheckTests extends BaseDL4JTest { int depth = 6; int hw = 5; int nOut = 4; - INDArray input = Nd4j.rand(new int[] {minibatch, depth, hw, hw}); + INDArray input = Nd4j.rand(minibatch, depth, hw, hw); INDArray labels = Nd4j.zeros(minibatch, nOut); Random r = new Random(12345); for (int i = 0; i < minibatch; i++) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java index 452742f10..c1e20d858 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java @@ -418,7 +418,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { //Generate Nd4j.getRandom().setSeed(12345); - INDArray input = Nd4j.rand(new int[] {miniBatchSize, inputSize, timeSeriesLength}); + INDArray input = Nd4j.rand(miniBatchSize, inputSize, timeSeriesLength); INDArray labels = Nd4j.zeros(miniBatchSize, nClasses, timeSeriesLength); Random r = new Random(12345); for (int i = 0; i < miniBatchSize; i++) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index fe4c1eb3b..74b142845 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -343,7 +343,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { lossFunctions[i] = lf2; } catch(IOException ex) { ex.printStackTrace(); - assertTrue(false, "Tests failed: serialization of " + lossFunctions[i]); + fail("Tests failed: serialization of " + lossFunctions[i]); } Nd4j.getRandom().setSeed(12345); @@ -362,8 +362,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertTrue(((LossLayer) net.getLayer(1).conf().getLayer()).getLossFn().getClass() == lossFunctions[i] - .getClass()); + assertSame(((LossLayer) net.getLayer(1).conf().getLayer()).getLossFn().getClass(), lossFunctions[i] + .getClass()); INDArray[] inOut = getFeaturesAndLabels(lossFunctions[i], minibatchSizes[j], 4, nOut[i], 12345); INDArray input = inOut[0]; @@ -421,22 +421,22 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { labels = Nd4j.diag(Nd4j.ones(3)); gradientAndScore = lossMultiLabel.computeGradientAndScore(labels, preOutput, activationFn, null, true); - assertTrue(!gradientAndScore.getFirst().isNaN()); - assertTrue(!gradientAndScore.getFirst().isInfinite()); + assertFalse(gradientAndScore.getFirst().isNaN()); + assertFalse(gradientAndScore.getFirst().isInfinite()); // Edge Case: Labels are all 1 labels = Nd4j.ones(3, 3); gradientAndScore = lossMultiLabel.computeGradientAndScore(labels, preOutput, activationFn, null, true); - assertTrue(!gradientAndScore.getFirst().isNaN()); - assertTrue(!gradientAndScore.getFirst().isInfinite()); + assertFalse(gradientAndScore.getFirst().isNaN()); + assertFalse(gradientAndScore.getFirst().isInfinite()); // Edge Case: Labels are all 0 labels = Nd4j.zeros(3, 3); gradientAndScore = lossMultiLabel.computeGradientAndScore(labels, preOutput, activationFn, null, true); - assertTrue(!gradientAndScore.getFirst().isNaN()); - assertTrue(!gradientAndScore.getFirst().isInfinite()); + assertFalse(gradientAndScore.getFirst().isNaN()); + assertFalse(gradientAndScore.getFirst().isInfinite()); } public static INDArray[] getFeaturesAndLabels(ILossFunction l, long minibatch, long nIn, long nOut, long seed) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java index 8acbf157e..5cfec0631 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java @@ -135,7 +135,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { int layerSize = 6; for (int minibatch : new int[]{1, 4}) { - INDArray input = Nd4j.rand(new int[]{minibatch, nIn, tsLength}); + INDArray input = Nd4j.rand(minibatch, nIn, tsLength); INDArray labels = TestUtils.randomOneHotTimeSeries(minibatch, nOut, tsLength); for (boolean rnnOutHasBias : new boolean[]{true, false}) { @@ -292,9 +292,9 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { net.init(); if(cnnHasBias){ - assertEquals(3 * 2 * kernel[0] * kernel[1] + 2, net.getLayer(2).numParams()); + assertEquals(3L * 2 * kernel[0] * kernel[1] + 2, net.getLayer(2).numParams()); } else { - assertEquals(3 * 2 * kernel[0] * kernel[1], net.getLayer(2).numParams()); + assertEquals(3L * 2 * kernel[0] * kernel[1], net.getLayer(2).numParams()); } String msg = "testCnnWithSubsamplingNoBias(), minibatch = " + minibatchSize + ", cnnHasBias = " + cnnHasBias; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index f11daf9ec..1c1da4cee 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -67,7 +67,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { for (int maskType = 0; maskType < 3; maskType++) { Random r = new Random(12345L); - INDArray input = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}); + INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray labelMask; String mt; @@ -172,7 +172,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { for (int maskType = 0; maskType < 4; maskType++) { Random r = new Random(12345L); - INDArray input = Nd4j.rand(new int[]{miniBatchSize, dIn, h, w}); + INDArray input = Nd4j.rand(miniBatchSize, dIn, h, w); INDArray labelMask; String mt; @@ -190,13 +190,13 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { break; case 2: //Per x/y masking (3d mask, shape [minibatch, h, w]) - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, h, w}); + labelMask = Nd4j.createUninitialized(miniBatchSize, h, w); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerXY"; break; case 3: //Per output masking (4d mask, same shape as output [minibatch, c, h, w]) - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, dOut, h, w}); + labelMask = Nd4j.createUninitialized(miniBatchSize, dOut, h, w); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerOutput"; break; @@ -208,7 +208,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { INDArray labels; if (lf instanceof LossMSE) { - labels = Nd4j.rand(new int[]{miniBatchSize, dOut, h, w}); + labels = Nd4j.rand(miniBatchSize, dOut, h, w); } else { labels = Nd4j.zeros(miniBatchSize, dOut, h, w); for (int mb = 0; mb < miniBatchSize; mb++) { @@ -283,9 +283,9 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { Random r = new Random(12345L); INDArray input; if(dataFormat == Convolution3D.DataFormat.NCDHW) { - input = Nd4j.rand(new int[]{miniBatchSize, chIn, d, h, w}); + input = Nd4j.rand(miniBatchSize, chIn, d, h, w); } else { - input = Nd4j.rand(new int[]{miniBatchSize, d, h, w, chIn}); + input = Nd4j.rand(miniBatchSize, d, h, w, chIn); } INDArray labelMask; @@ -298,16 +298,16 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { break; case 1: //Per example masking (shape [minibatch, 1, 1, 1, 1] - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, 1, 1, 1, 1}); + labelMask = Nd4j.createUninitialized(miniBatchSize, 1, 1, 1, 1); Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerExample"; break; case 2: //Per channel masking (5d mask, shape [minibatch, d, 1, 1, 1] or [minibatch, 1, 1, 1, d]) if(dataFormat == Convolution3D.DataFormat.NCDHW) { - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, chOut, 1, 1, 1}); + labelMask = Nd4j.createUninitialized(miniBatchSize, chOut, 1, 1, 1); } else { - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, 1, 1, 1, chOut}); + labelMask = Nd4j.createUninitialized(miniBatchSize, 1, 1, 1, chOut); } Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerChannel"; @@ -315,9 +315,9 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { case 3: //Per output masking (5d mask, same shape as output [minibatch, c, h, w]) if(dataFormat == Convolution3D.DataFormat.NCDHW) { - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, chOut, d, h, w}); + labelMask = Nd4j.createUninitialized(miniBatchSize, chOut, d, h, w); } else { - labelMask = Nd4j.createUninitialized(new int[]{miniBatchSize, d, h, w, chOut}); + labelMask = Nd4j.createUninitialized(miniBatchSize, d, h, w, chOut); } Nd4j.getExecutioner().exec(new BernoulliDistribution(labelMask, 0.5)); mt = "PerOutput"; @@ -336,9 +336,9 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { INDArray labels; if (lf instanceof LossMSE) { if(dataFormat == Convolution3D.DataFormat.NCDHW) { - labels = Nd4j.rand(new int[]{miniBatchSize, chOut, d, h, w}); + labels = Nd4j.rand(miniBatchSize, chOut, d, h, w); } else { - labels = Nd4j.rand(new int[]{miniBatchSize, d, h, w, chOut}); + labels = Nd4j.rand(miniBatchSize, d, h, w, chOut); } } else { if(dataFormat == Convolution3D.DataFormat.NCDHW) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index 4555904ca..87a42e4e0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -74,7 +74,7 @@ public class RnnGradientChecks extends BaseDL4JTest { if(!simple && hasLayerNorm) continue; - INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(mb, nIn, tsLength); INDArray labels = Nd4j.create(mb, nOut, tsLength); for (int i = 0; i < mb; i++) { for (int j = 0; j < tsLength; j++) { @@ -159,7 +159,7 @@ public class RnnGradientChecks extends BaseDL4JTest { if(r.nextInt(5) != 0) continue; - INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(mb, nIn, tsLength); INDArray labels = Nd4j.create(mb, nOut, tsLength); for (int i = 0; i < mb; i++) { for (int j = 0; j < tsLength; j++) { @@ -236,7 +236,7 @@ public class RnnGradientChecks extends BaseDL4JTest { continue; - INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(mb, nIn, tsLength); INDArray labels = Nd4j.create(mb, nOut); for (int i = 0; i < mb; i++) { labels.putScalar(i, r.nextInt(nOut), 1.0); @@ -306,7 +306,7 @@ public class RnnGradientChecks extends BaseDL4JTest { for (boolean inputMask : new boolean[]{false, true}) { - INDArray in = Nd4j.rand(new int[]{mb, nIn, tsLength}); + INDArray in = Nd4j.rand(mb, nIn, tsLength); INDArray labels = TestUtils.randomOneHotTimeSeries(mb, nOut, tsLength); String maskType = (inputMask ? "inputMask" : "none"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 105fcb284..9ae3e598a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -60,7 +60,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } - private CNN2DFormat format; + private final CNN2DFormat format; public YoloGradientCheckTests(CNN2DFormat format){ this.format = format; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java index 33d8856cd..a10a9a3c7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java @@ -159,7 +159,7 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { .setInputType(InputType.convolutional(32, 32, 1)).build(); String str = conf.toJson(); - MultiLayerConfiguration fromJson = conf.fromJson(str); + MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(str); assertEquals(conf, fromJson); } @@ -253,7 +253,7 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { private static MultiLayerConfiguration getConf() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345l).list() + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list() .layer(0, new DenseLayer.Builder().nIn(2).nOut(2) .dist(new NormalDistribution(0, 1)).build()) .layer(1, new OutputLayer.Builder().nIn(2).nOut(1) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiNeuralNetConfLayerBuilderTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiNeuralNetConfLayerBuilderTest.java index 76fb090ea..08e162b7a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiNeuralNetConfLayerBuilderTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiNeuralNetConfLayerBuilderTest.java @@ -37,8 +37,7 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.*; /** * @author Jeffrey Tang. @@ -84,6 +83,6 @@ public class MultiNeuralNetConfLayerBuilderTest extends BaseDL4JTest { NeuralNetConfiguration firstLayer = multiConf1.getConf(0); NeuralNetConfiguration secondLayer = multiConf1.getConf(1); - assertFalse(firstLayer.equals(secondLayer)); + assertNotEquals(firstLayer, secondLayer); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java index fda02a451..37260087d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java @@ -452,9 +452,9 @@ public class TestConstraints extends BaseDL4JTest { for( int i=0; i<100; i++ ){ - INDArray in1 = Nd4j.rand(new int[]{1, nIn, 5}); - INDArray in2 = Nd4j.rand(new int[]{1, 1}); - INDArray label = Nd4j.rand(new int[]{1, 1}); + INDArray in1 = Nd4j.rand(1, nIn, 5); + INDArray in2 = Nd4j.rand(1, 1); + INDArray label = Nd4j.rand(1, 1); g.fit(new INDArray[]{in1, in2}, new INDArray[]{label}); for(Map.Entry e : g.paramTable().entrySet()){ diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java index 941309304..046cf0f63 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java @@ -56,7 +56,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { * from @agibsonccc: check for the basics: like 0 numParams */ - ElementWiseVertex.Op ops[] = new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add, + ElementWiseVertex.Op[] ops = new ElementWiseVertex.Op[] {ElementWiseVertex.Op.Add, ElementWiseVertex.Op.Subtract, ElementWiseVertex.Op.Product}; for (ElementWiseVertex.Op op : ops) { @@ -706,5 +706,5 @@ public class ElementWiseVertexTest extends BaseDL4JTest { return clean; } - private double epsilon = 1e-10; + private final double epsilon = 1e-10; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java index 766854407..acab33814 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java @@ -258,5 +258,5 @@ public class ShiftVertexTest extends BaseDL4JTest { return clean; } - private double epsilon = 1e-10; + private final double epsilon = 1e-10; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java index 96a2bc739..484da1ff9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java @@ -184,7 +184,7 @@ public class LayerBuilderTest extends BaseDL4JTest { assertEquals(numIn, bN.nIn); assertEquals(numOut, bN.nOut); - assertEquals(true, bN.isLockGammaBeta()); + assertTrue(bN.isLockGammaBeta()); assertEquals(0.5, bN.decay, 1e-4); assertEquals(2, bN.gamma, 1e-4); assertEquals(1, bN.beta, 1e-4); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java index 635926f7c..60b549714 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java @@ -290,11 +290,11 @@ public class LayerConfigTest extends BaseDL4JTest { net.init(); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization()); + conf.getConf(0).getLayer().getGradientNormalization()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization()); - assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0); - assertEquals(10, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0); + conf.getConf(1).getLayer().getGradientNormalization()); + assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0); + assertEquals(10, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0); //With: conf = new NeuralNetConfiguration.Builder() @@ -310,10 +310,10 @@ public class LayerConfigTest extends BaseDL4JTest { net.init(); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalization()); - assertEquals(GradientNormalization.None, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalization()); - assertEquals(10, ((BaseLayer) conf.getConf(0).getLayer()).getGradientNormalizationThreshold(), 0.0); - assertEquals(2.5, ((BaseLayer) conf.getConf(1).getLayer()).getGradientNormalizationThreshold(), 0.0); + conf.getConf(0).getLayer().getGradientNormalization()); + assertEquals(GradientNormalization.None, conf.getConf(1).getLayer().getGradientNormalization()); + assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0); + assertEquals(2.5, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java index 79878cd4c..48112c682 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java @@ -43,11 +43,11 @@ import static org.junit.jupiter.api.Assertions.*; **/ public class CNNProcessorTest extends BaseDL4JTest { - private static int rows = 28; - private static int cols = 28; - private static INDArray in2D = Nd4j.create(DataType.FLOAT, 1, 784); - private static INDArray in3D = Nd4j.create(DataType.FLOAT, 20, 784, 7); - private static INDArray in4D = Nd4j.create(DataType.FLOAT, 20, 1, 28, 28); + private static final int rows = 28; + private static final int cols = 28; + private static final INDArray in2D = Nd4j.create(DataType.FLOAT, 1, 784); + private static final INDArray in3D = Nd4j.create(DataType.FLOAT, 20, 784, 7); + private static final INDArray in4D = Nd4j.create(DataType.FLOAT, 20, 1, 28, 28); @Test @@ -56,12 +56,12 @@ public class CNNProcessorTest extends BaseDL4JTest { INDArray check2to4 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); int val2to4 = check2to4.shape().length; - assertTrue(val2to4 == 4); + assertEquals(4, val2to4); assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4); INDArray check4to4 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); int val4to4 = check4to4.shape().length; - assertTrue(val4to4 == 4); + assertEquals(4, val4to4); assertEquals(Nd4j.create(DataType.FLOAT, 20, 1, 28, 28), check4to4); } @@ -134,7 +134,7 @@ public class CNNProcessorTest extends BaseDL4JTest { INDArray check2to2 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); int val2to2 = check2to2.shape().length; - assertTrue(val2to2 == 2); + assertEquals(2, val2to2); assertEquals(Nd4j.create(DataType.FLOAT, 1, 784), check2to2); } @@ -144,12 +144,12 @@ public class CNNProcessorTest extends BaseDL4JTest { INDArray check2to4 = convProcessor.backprop(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); int val2to4 = check2to4.shape().length; - assertTrue(val2to4 == 4); + assertEquals(4, val2to4); assertEquals(Nd4j.create(DataType.FLOAT, 1, 1, 28, 28), check2to4); INDArray check4to4 = convProcessor.backprop(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); int val4to4 = check4to4.shape().length; - assertTrue(val4to4 == 4); + assertEquals(4, val4to4); assertEquals(Nd4j.create(DataType.FLOAT, 20, 1, 28, 28), check4to4); } @@ -160,12 +160,12 @@ public class CNNProcessorTest extends BaseDL4JTest { INDArray check2to2 = convProcessor.preProcess(in2D, -1, LayerWorkspaceMgr.noWorkspaces()); int val2to2 = check2to2.shape().length; - assertTrue(val2to2 == 2); + assertEquals(2, val2to2); assertEquals(Nd4j.create(DataType.FLOAT, 1, 784), check2to2); INDArray check4to2 = convProcessor.preProcess(in4D, -1, LayerWorkspaceMgr.noWorkspaces()); int val4to2 = check4to2.shape().length; - assertTrue(val4to2 == 2); + assertEquals(2, val4to2); assertEquals(Nd4j.create(DataType.FLOAT, 20, 784), check4to2); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java index 3f6741b89..56c6cfb1d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java @@ -239,8 +239,8 @@ public class TestPreProcessors extends BaseDL4JTest { (ConvolutionLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); - INDArray activationsCnn = Nd4j.rand(new int[] {miniBatchSize * timeSeriesLength, nChannels, - inputHeight, inputWidth}); + INDArray activationsCnn = Nd4j.rand(miniBatchSize * timeSeriesLength, nChannels, + inputHeight, inputWidth); //Check shape of outputs: val prod = nChannels * inputHeight * inputWidth; @@ -262,8 +262,8 @@ public class TestPreProcessors extends BaseDL4JTest { INDArray activationsRnnComp = compProc.preProcess(activationsCnn, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); assertEquals(activationsRnnComp, activationsRnn, msg); - INDArray epsilonsRnn = Nd4j.rand(new int[] {miniBatchSize, - nChannels * inputHeight * inputWidth, timeSeriesLength}); + INDArray epsilonsRnn = Nd4j.rand(miniBatchSize, + nChannels * inputHeight * inputWidth, timeSeriesLength); INDArray epsilonsCnnComp = compProc.backprop(epsilonsRnn, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); INDArray epsilonsCnn = proc.backprop(epsilonsRnn, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); if (!epsilonsCnn.equals(epsilonsCnnComp)) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java index f495e8fb0..d4bae91a6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java @@ -259,7 +259,7 @@ public class TestWeightNoise extends BaseDL4JTest { DropConnect d = new DropConnect(0.5); INDArray outTest = d.getParameter(l, "W", 0, 0, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(l.getParam("W") == outTest); //Should be same object + assertSame(l.getParam("W"), outTest); //Should be same object INDArray outTrain = d.getParameter(l, "W", 0, 0, true, LayerWorkspaceMgr.noWorkspaces()); assertNotEquals(l.getParam("W"), outTrain); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java index 2bddca70a..eb8c1cbcc 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java @@ -89,7 +89,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf); graph.init(); - INDArray input = Nd4j.rand(new int[] {3, 5, timeSeriesLength}); + INDArray input = Nd4j.rand(3, 5, timeSeriesLength); Map allOutputActivations = graph.feedForward(input, true); INDArray fullOutL0 = allOutputActivations.get("0"); @@ -117,7 +117,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { INDArray inputSubset = input.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeRange, endTimeRange)); if (inLength > 1) - assertTrue(inputSubset.size(2) == inLength); + assertEquals(inputSubset.size(2), inLength); INDArray[] outArr = graph.rnnTimeStep(inputSubset); assertEquals(1, outArr.length); @@ -173,7 +173,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf); graph.init(); - INDArray input3d = Nd4j.rand(new int[] {3, 5, timeSeriesLength}); + INDArray input3d = Nd4j.rand(3, 5, timeSeriesLength); INDArray out3d = graph.rnnTimeStep(input3d)[0]; assertArrayEquals(out3d.shape(), new long[] {3, 4, timeSeriesLength}); @@ -191,12 +191,12 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { //Check same but for input of size [3,5,1]. Expect [3,4,1] out graph.rnnClearPreviousState(); for (int i = 0; i < timeSeriesLength; i++) { - INDArray temp = Nd4j.create(new int[] {3, 5, 1}); + INDArray temp = Nd4j.create(3, 5, 1); temp.tensorAlongDimension(0, 1, 0).assign(input3d.tensorAlongDimension(i, 1, 0)); INDArray out3dSlice = graph.rnnTimeStep(temp)[0]; assertArrayEquals(out3dSlice.shape(), new long[] {3, 4, 1}); - assertTrue(out3dSlice.tensorAlongDimension(0, 1, 0).equals(out3d.tensorAlongDimension(i, 1, 0))); + assertEquals(out3dSlice.tensorAlongDimension(0, 1, 0), out3d.tensorAlongDimension(i, 1, 0)); } } @@ -245,8 +245,8 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf); graph.init(); - INDArray input0 = Nd4j.rand(new int[] {3, 5, timeSeriesLength}); - INDArray input1 = Nd4j.rand(new int[] {3, 4, timeSeriesLength}); + INDArray input0 = Nd4j.rand(3, 5, timeSeriesLength); + INDArray input1 = Nd4j.rand(3, 4, timeSeriesLength); Map allOutputActivations = graph.feedForward(new INDArray[] {input0, input1}, true); INDArray fullActLSTM0 = allOutputActivations.get("lstm0"); @@ -276,12 +276,12 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { INDArray inputSubset0 = input0.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeRange, endTimeRange)); if (inLength > 1) - assertTrue(inputSubset0.size(2) == inLength); + assertEquals(inputSubset0.size(2), inLength); INDArray inputSubset1 = input1.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(startTimeRange, endTimeRange)); if (inLength > 1) - assertTrue(inputSubset1.size(2) == inLength); + assertEquals(inputSubset1.size(2), inLength); INDArray[] outArr = graph.rnnTimeStep(inputSubset0, inputSubset1); assertEquals(2, outArr.length); @@ -395,8 +395,8 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { assertEquals(timeSeriesLength, graphTBPTT.getConfiguration().getTbpttFwdLength()); assertEquals(timeSeriesLength, graphTBPTT.getConfiguration().getTbpttBackLength()); - INDArray inputData = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); - INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength}); + INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); + INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); graph.setInput(0, inputData); graph.setLabel(0, labels); @@ -479,8 +479,8 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf); graph.init(); - INDArray inputLong = Nd4j.rand(new int[] {miniBatchSize, nIn, nTimeSlices * timeSeriesLength}); - INDArray labelsLong = Nd4j.rand(new int[] {miniBatchSize, nOut, nTimeSlices * timeSeriesLength}); + INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, nTimeSlices * timeSeriesLength); + INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, nTimeSlices * timeSeriesLength); graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong}); } @@ -517,8 +517,8 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf); graph.init(); - INDArray inputLong = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); - INDArray labelsLong = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength}); + INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); + INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); INDArray initialParams = graph.params().dup(); graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong}); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java index 96532aa69..95691fed6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java @@ -93,7 +93,7 @@ public class TestCompGraphCNN extends BaseDL4JTest { } protected static int getNumParams() { - return 2 * (3 * 1 * 4 * 4 * 3 + 3) + (7 * 14 * 14 * 6 + 7) + (7 * 10 + 10); + return 2 * (3 * 4 * 4 * 3 + 3) + (7 * 14 * 14 * 6 + 7) + (7 * 10 + 10); } @BeforeEach diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 6b1191a51..7a918a674 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -1205,7 +1205,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { NeuralNetConfiguration nnc = new NeuralNetConfiguration(); nnc.setLayer(new DenseLayer.Builder().build()); GraphVertex[] singleInputVertices = new GraphVertex[]{new L2NormalizeVertex(), new LayerVertex(nnc, null), - new PoolHelperVertex(), new PreprocessorVertex(), new ReshapeVertex(new int[]{1, 1}), + new PoolHelperVertex(), new PreprocessorVertex(), new ReshapeVertex(1, 1), new ScaleVertex(1.0), new ShiftVertex(1.0), new SubsetVertex(1, 1), new UnstackVertex(0, 2), new DuplicateToTimeSeriesVertex("in1"), new LastTimeStepVertex("in1")}; @@ -1971,7 +1971,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { IDropout d1 = model.getLayer(0).conf().getLayer().getIDropout(); IDropout d2 = cg2.getLayer(0).conf().getLayer().getIDropout(); - assertFalse(d1 == d2); //Should not be same object! + assertNotSame(d1, d2); //Should not be same object! assertEquals(d1, d2); //But should be equal } @@ -1988,9 +1988,9 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .addInputs("x_emb") .addLayer("agg_lstm", new Bidirectional(CONCAT, new LSTM.Builder().nOut(hiddenSize/2).build()), "x_emb") .addLayer("agg_att", new DenseLayer.Builder().nIn(100).nOut(1).activation(Activation.SOFTMAX).build(), "agg_lstm") - .addVertex("att", new PreprocessorVertex(new ComposableInputPreProcessor(new FeedForwardToRnnPreProcessor(), new PermutePreprocessor(new int[] {0,2,1}), new RnnToFeedForwardPreProcessor())), "agg_att") + .addVertex("att", new PreprocessorVertex(new ComposableInputPreProcessor(new FeedForwardToRnnPreProcessor(), new PermutePreprocessor(0,2,1), new RnnToFeedForwardPreProcessor())), "agg_att") .addLayer("att_repeat", new RepeatVector.Builder(hiddenSize).build(),"att") - .addVertex("att_trans", new PreprocessorVertex(new PermutePreprocessor(new int[] {0, 2, 1})), "att_repeat") + .addVertex("att_trans", new PreprocessorVertex(new PermutePreprocessor(0, 2, 1)), "att_repeat") .addVertex("mult", new ElementWiseVertex(ElementWiseVertex.Op.Product), "agg_lstm", "att_trans") .addLayer("sum", new GlobalPoolingLayer.Builder().build(), "mult") .addLayer("agg_out", new DenseLayer.Builder().nIn(100).nOut(6).activation(Activation.TANH).build(), "sum") @@ -2003,8 +2003,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { net.init(); - INDArray features = Nd4j.rand(new int[] {dataSize, inputSize, seqLen}); - INDArray labels = Nd4j.rand(new int[] {dataSize, 6}); + INDArray features = Nd4j.rand(dataSize, inputSize, seqLen); + INDArray labels = Nd4j.rand(dataSize, 6); INDArray featuresMask = Nd4j.ones(dataSize, seqLen); INDArray labelsMask = Nd4j.ones(dataSize, 6); @@ -2056,8 +2056,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { soFar += 3*2; INDArray m1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(3); //m1b soFar += 2; - INDArray m2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2*1)).assign(4); //m2w - soFar += 2*1; + INDArray m2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+ 2)).assign(4); //m2w + soFar += 2; INDArray m2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(5); //m2b soFar += 1; @@ -2069,8 +2069,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { soFar += 3*2; INDArray v1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(9); //v1b soFar += 2; - INDArray v2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2*1)).assign(10); //v2w - soFar += 2*1; + INDArray v2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+ 2)).assign(10); //v2w + soFar += 2; INDArray v2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(11); //v2b soFar += 1; @@ -2140,8 +2140,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int dataSize = 11; - INDArray features = Nd4j.rand(new int[] {dataSize, inputSize}); - INDArray labels = Nd4j.rand(new int[] {dataSize, outputSize}); + INDArray features = Nd4j.rand(dataSize, inputSize); + INDArray labels = Nd4j.rand(dataSize, outputSize); boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(net).inputs(new INDArray[]{features}) .labels(new INDArray[]{labels})); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java index ec5c47894..0c17238db 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java @@ -68,8 +68,8 @@ public class TestSetGetParameters extends BaseDL4JTest { assertEquals(params, net2.params()); assertEquals(params, net3.params()); - assertFalse(params == net2.params()); //Different objects due to clone - assertTrue(params == net3.params()); //Same object due to clone + assertNotSame(params, net2.params()); //Different objects due to clone + assertSame(params, net3.params()); //Same object due to clone Map paramsMap = net.paramTable(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java index a39ac53b5..96e1dcf12 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java @@ -81,14 +81,14 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); net.init(); - INDArray in1 = Nd4j.rand(new int[] {nExamples, 2, 4}); - INDArray in2 = Nd4j.rand(new int[] {nExamples, 2, 5}); + INDArray in1 = Nd4j.rand(nExamples, 2, 4); + INDArray in2 = Nd4j.rand(nExamples, 2, 5); in2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)}, in1); assertEquals(in1, in2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4))); - INDArray labels1 = Nd4j.rand(new int[] {nExamples, 1, 4}); + INDArray labels1 = Nd4j.rand(nExamples, 1, 4); INDArray labels2 = Nd4j.create(nExamples, 1, 5); labels2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)}, labels1); @@ -178,14 +178,14 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); net.init(); - INDArray in1 = Nd4j.rand(new int[] {nExamples, 2, 4}); - INDArray in2 = Nd4j.rand(new int[] {nExamples, 2, 5}); + INDArray in1 = Nd4j.rand(nExamples, 2, 4); + INDArray in2 = Nd4j.rand(nExamples, 2, 5); in2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)}, in1); assertEquals(in1, in2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4))); - INDArray labels1 = Nd4j.rand(new int[] {nExamples, 1, 4}); + INDArray labels1 = Nd4j.rand(nExamples, 1, 4); INDArray labels2 = Nd4j.create(nExamples, 1, 5); labels2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)}, labels1); @@ -296,7 +296,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { } } - INDArray input = Nd4j.rand(new int[] {miniBatch, nIn, tsLength}); + INDArray input = Nd4j.rand(miniBatch, nIn, tsLength); INDArray labels = Nd4j.ones(miniBatch, nOut, tsLength); ComputationGraphConfiguration conf = diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java index de4010554..ba3eb90bb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java @@ -168,7 +168,7 @@ public class TestGraphNodes extends BaseDL4JTest { assertEquals(Nd4j.zeros(5, 2), backward.get(NDArrayIndex.all(), NDArrayIndex.interval(8, 9, true))); //Test same for CNNs: - in = Nd4j.rand(new int[] {5, 10, 3, 3}); + in = Nd4j.rand(5, 10, 3, 3); subset.setInputs(in); out = subset.doForward(false, LayerWorkspaceMgr.noWorkspaces()); assertEquals(in.get(NDArrayIndex.all(), NDArrayIndex.interval(4, 7, true), NDArrayIndex.all(), @@ -198,7 +198,7 @@ public class TestGraphNodes extends BaseDL4JTest { //First: test without input mask array Nd4j.getRandom().setSeed(12345); - INDArray in = Nd4j.rand(new int[] {3, 5, 6}); + INDArray in = Nd4j.rand(3, 5, 6); INDArray expOut = in.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(5)); GraphVertex gv = graph.getVertex("lastTS"); @@ -250,7 +250,7 @@ public class TestGraphNodes extends BaseDL4JTest { graph.init(); INDArray in2d = Nd4j.rand(3, 5); - INDArray in3d = Nd4j.rand(new int[] {3, 2, 7}); + INDArray in3d = Nd4j.rand(3, 2, 7); graph.setInputs(in2d, in3d); @@ -339,9 +339,9 @@ public class TestGraphNodes extends BaseDL4JTest { GraphVertex stack = new StackVertex(null, "", -1, Nd4j.dataType()); //Test stack with variable length + mask arrays - INDArray in0 = Nd4j.rand(new int[] {5, 2, 5}); - INDArray in1 = Nd4j.rand(new int[] {5, 2, 6}); - INDArray in2 = Nd4j.rand(new int[] {5, 2, 7}); + INDArray in0 = Nd4j.rand(5, 2, 5); + INDArray in1 = Nd4j.rand(5, 2, 6); + INDArray in2 = Nd4j.rand(5, 2, 7); INDArray mask0 = Nd4j.ones(5, 5); INDArray mask1 = Nd4j.ones(5, 6); @@ -434,7 +434,7 @@ public class TestGraphNodes extends BaseDL4JTest { //Test same for CNNs: - in = Nd4j.rand(new int[] {15, 10, 3, 3}); + in = Nd4j.rand(15, 10, 3, 3); unstack0.setInputs(in); unstack1.setInputs(in); unstack2.setInputs(in); @@ -533,7 +533,7 @@ public class TestGraphNodes extends BaseDL4JTest { reshapeVertex.setEpsilon(out); INDArray[] backward = reshapeVertex.doBackward(false, LayerWorkspaceMgr.noWorkspaces()).getSecond(); - assertTrue(Arrays.equals(backward[0].shape(), inputShape)); + assertArrayEquals(backward[0].shape(), inputShape); } @Test @@ -591,7 +591,7 @@ public class TestGraphNodes extends BaseDL4JTest { .build(); - INDArray input = Nd4j.rand(new int[]{10, numInputs, 16}); + INDArray input = Nd4j.rand(10, numInputs, 16); INDArray[] out = updatedModel.output(input); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java index 0c7375cfd..14e169767 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java @@ -86,7 +86,7 @@ public class ActivationLayerTest extends BaseDL4JTest { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU) + .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) @@ -102,7 +102,7 @@ public class ActivationLayerTest extends BaseDL4JTest { MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.IDENTITY) + .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).activation(Activation.IDENTITY) .weightInit(WeightInit.XAVIER).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.ActivationLayer.Builder() .activation(Activation.RELU).build()) @@ -144,7 +144,7 @@ public class ActivationLayerTest extends BaseDL4JTest { int layerSize = 5; int nOut = 3; - INDArray next = Nd4j.rand(new int[] {minibatch, nIn}); + INDArray next = Nd4j.rand(minibatch, nIn); INDArray labels = Nd4j.zeros(minibatch, nOut); for (int i = 0; i < minibatch; i++) { labels.putScalar(i, i % nOut, 1.0); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java index d4eea7a49..7b55a4641 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java @@ -92,7 +92,7 @@ public class CacheModeTest extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - INDArray in = Nd4j.rand(new int[]{3, 3, 10}); + INDArray in = Nd4j.rand(3, 3, 10); INDArray labels = TestUtils.randomOneHotTimeSeries(3, 10, 10); INDArray out1 = net1.output(in); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java index cee20827c..3aa7e37dd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java @@ -159,7 +159,7 @@ public class DropoutLayerTest extends BaseDL4JTest { MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10) + .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10) .activation(Activation.RELU).weightInit( WeightInit.XAVIER) .build()) @@ -176,7 +176,7 @@ public class DropoutLayerTest extends BaseDL4JTest { MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28 * 1).nOut(10).activation(Activation.RELU) + .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build()) .layer(1, new DropoutLayer.Builder(0.25).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java index fcd509494..232a9a46e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java @@ -68,7 +68,7 @@ public class OutputLayerTest extends BaseDL4JTest { long numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf, - Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); + Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); params = l.params(); l.setParams(params); assertEquals(params, l.params()); @@ -221,8 +221,8 @@ public class OutputLayerTest extends BaseDL4JTest { double score = mln.score() * timeSeriesLength; double scoreRNN = mlnRnn.score(); - assertTrue(!Double.isNaN(score)); - assertTrue(!Double.isNaN(scoreRNN)); + assertFalse(Double.isNaN(score)); + assertFalse(Double.isNaN(scoreRNN)); double relError = Math.abs(score - scoreRNN) / (Math.abs(score) + Math.abs(scoreRNN)); System.out.println(relError); @@ -306,7 +306,7 @@ public class OutputLayerTest extends BaseDL4JTest { mln2.setParams(mln.params()); - INDArray in = Nd4j.rand(new int[]{miniBatchSize, nIn, timeSeriesLength}); + INDArray in = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray out1 = mln.output(in); INDArray out2 = mln.output(in); @@ -390,7 +390,7 @@ public class OutputLayerTest extends BaseDL4JTest { mln2.setParams(mln.params()); - INDArray in = Nd4j.rand(new int[]{3, 3, 5, 5}); + INDArray in = Nd4j.rand(3, 3, 5, 5); INDArray out1 = mln.output(in); INDArray out2 = mln2.output(in); @@ -412,8 +412,8 @@ public class OutputLayerTest extends BaseDL4JTest { assertEquals(mln.gradient().gradient(), mln2.gradient().gradient()); //Also check computeScoreForExamples - INDArray in2a = Nd4j.rand(new int[]{1, 3, 5, 5}); - INDArray labels2a = Nd4j.rand(new int[]{1, 4, 5, 5}); + INDArray in2a = Nd4j.rand(1, 3, 5, 5); + INDArray labels2a = Nd4j.rand(1, 4, 5, 5); INDArray in2 = Nd4j.concat(0, in2a, in2a); INDArray labels2 = Nd4j.concat(0, labels2a, labels2a); @@ -483,7 +483,7 @@ public class OutputLayerTest extends BaseDL4JTest { graph2.setParams(graph.params()); - INDArray in = Nd4j.rand(new int[]{3, 3, 5, 5}); + INDArray in = Nd4j.rand(3, 3, 5, 5); INDArray out1 = graph.outputSingle(in); INDArray out2 = graph2.outputSingle(in); @@ -505,8 +505,8 @@ public class OutputLayerTest extends BaseDL4JTest { assertEquals(graph.gradient().gradient(), graph2.gradient().gradient()); //Also check computeScoreForExamples - INDArray in2a = Nd4j.rand(new int[]{1, 3, 5, 5}); - INDArray labels2a = Nd4j.rand(new int[]{1, 4, 5, 5}); + INDArray in2a = Nd4j.rand(1, 3, 5, 5); + INDArray labels2a = Nd4j.rand(1, 4, 5, 5); INDArray in2 = Nd4j.concat(0, in2a, in2a); INDArray labels2 = Nd4j.concat(0, labels2a, labels2a); @@ -540,7 +540,7 @@ public class OutputLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in = Nd4j.rand(new int[]{2,3,4,5}); + INDArray in = Nd4j.rand(2,3,4,5); INDArray out = net.output(in); double min = out.minNumber().doubleValue(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java index 2aa98575e..3e526e774 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java @@ -34,12 +34,11 @@ import org.nd4j.common.primitives.Pair; import java.util.Arrays; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class RepeatVectorTest extends BaseDL4JTest { - private int REPEAT = 4; + private final int REPEAT = 4; private Layer getRepeatVectorLayer() { @@ -55,18 +54,18 @@ public class RepeatVectorTest extends BaseDL4JTest { double[] arr = new double[] {1., 2., 3., 1., 2., 3., 1., 2., 3., 1., 2., 3.}; INDArray expectedOut = Nd4j.create(arr, new long[] {1, 3, REPEAT}, 'f'); - INDArray input = Nd4j.create(new double[] {1., 2., 3.}, new long[] {1, 3}); + INDArray input = Nd4j.create(new double[] {1., 2., 3.}, 1, 3); Layer layer = getRepeatVectorLayer(); INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(expectedOut.shape(), output.shape())); + assertArrayEquals(expectedOut.shape(), output.shape()); assertEquals(expectedOut, output); INDArray epsilon = Nd4j.ones(1,3,4); Pair out = layer.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); INDArray outEpsilon = out.getSecond(); - INDArray expectedEpsilon = Nd4j.create(new double[] {4., 4., 4.}, new long[] {1, 3}); + INDArray expectedEpsilon = Nd4j.create(new double[] {4., 4., 4.}, 1, 3); assertEquals(expectedEpsilon, outEpsilon); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java index a9b3ee532..4d46d5066 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java @@ -40,8 +40,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class SeedTest extends BaseDL4JTest { - private DataSetIterator irisIter = new IrisDataSetIterator(50, 50); - private DataSet data = irisIter.next(); + private final DataSetIterator irisIter = new IrisDataSetIterator(50, 50); + private final DataSet data = irisIter.next(); @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java index 596906d10..d282690bb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java @@ -36,8 +36,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; /** * @author Max Pumperla @@ -45,18 +44,18 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class Convolution3DTest extends BaseDL4JTest { private int nExamples = 1; - private int nChannelsOut = 1; - private int nChannelsIn = 1; - private int inputDepth = 2 * 2; - private int inputWidth = 28 / 2; - private int inputHeight = 28 / 2; + private final int nChannelsOut = 1; + private final int nChannelsIn = 1; + private final int inputDepth = 2 * 2; + private final int inputWidth = 28 / 2; + private final int inputHeight = 28 / 2; - private int[] kernelSize = new int[]{2, 2, 2}; - private int outputDepth = inputDepth - kernelSize[0] + 1; - private int outputHeight = inputHeight - kernelSize[1] + 1; - private int outputWidth = inputWidth - kernelSize[2] + 1; + private final int[] kernelSize = new int[]{2, 2, 2}; + private final int outputDepth = inputDepth - kernelSize[0] + 1; + private final int outputHeight = inputHeight - kernelSize[1] + 1; + private final int outputWidth = inputWidth - kernelSize[2] + 1; - private INDArray epsilon = Nd4j.ones(nExamples, nChannelsOut, outputDepth, outputHeight, outputWidth); + private final INDArray epsilon = Nd4j.ones(nExamples, nChannelsOut, outputDepth, outputHeight, outputWidth); @Test @@ -65,11 +64,11 @@ public class Convolution3DTest extends BaseDL4JTest { INDArray containedInput = getContainedData(); Convolution3DLayer layer = (Convolution3DLayer) getConvolution3DLayer(ConvolutionMode.Same); - assertTrue(layer.convolutionMode == ConvolutionMode.Same); + assertSame(layer.convolutionMode, ConvolutionMode.Same); INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(containedInput.shape(), containedOutput.shape())); + assertArrayEquals(containedInput.shape(), containedOutput.shape()); } @@ -78,13 +77,12 @@ public class Convolution3DTest extends BaseDL4JTest { Convolution3DLayer layer = (Convolution3DLayer) getConvolution3DLayer(ConvolutionMode.Strict); - assertTrue(layer.convolutionMode == ConvolutionMode.Strict); + assertSame(layer.convolutionMode, ConvolutionMode.Strict); INDArray input = getData(); INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[]{nExamples, nChannelsOut, outputDepth, outputWidth, outputHeight}, - output.shape())); + assertArrayEquals(new long[]{nExamples, nChannelsOut, outputDepth, outputWidth, outputHeight}, output.shape()); } private Layer getConvolution3DLayer(ConvolutionMode mode) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java index 26fc3a4e3..246dfee5b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java @@ -108,7 +108,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); - DataSet d = new DataSet(Nd4j.rand(new int[]{10, nChannels, numRows, numColumns}), + DataSet d = new DataSet(Nd4j.rand(10, nChannels, numRows, numColumns), FeatureUtil.toOutcomeMatrix(new int[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 6)); MultiLayerNetwork network = new MultiLayerNetwork(builder.build()); network.init(); @@ -137,7 +137,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { @Test public void testMultiChannel() throws Exception { - INDArray in = Nd4j.rand(new int[] {10, 3, 28, 28}); + INDArray in = Nd4j.rand(10, 3, 28, 28); INDArray labels = Nd4j.rand(10, 2); DataSet next = new DataSet(in, labels); @@ -288,7 +288,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) .build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nIn(5 * 5 * 1 * 6) //216 + .nIn(5 * 5 * 6) //216 .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) .build()) .inputPreProcessor(0, new FeedForwardToCnnPreProcessor(numRows, numColumns, nChannels)) @@ -440,8 +440,8 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { network.fit(next); INDArray actualGammaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.GAMMA); INDArray actualBetaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.BETA); - assertTrue(actualGammaParam != null); - assertTrue(actualBetaParam != null); + assertNotNull(actualGammaParam); + assertNotNull(actualBetaParam); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java index fa8c88493..e4921b555 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java @@ -111,7 +111,7 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest { network.init(); INDArray input = Nd4j.ones(10, 3, 8); - INDArray output = network.output(input, false);; + INDArray output = network.output(input, false); for (int i = 0; i < 100; i++) { // TODO: this falls flat for 1000 iterations on my machine output = network.output(input, false); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java index f69b0041e..0ee4e322f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java @@ -33,22 +33,21 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.Arrays; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class SpaceToDepthTest extends BaseDL4JTest { - private int mb = 1; - private int inDepth = 2; - private int inputWidth = 2; - private int inputHeight = 2; + private final int mb = 1; + private final int inDepth = 2; + private final int inputWidth = 2; + private final int inputHeight = 2; - private int blockSize = 2; - private SpaceToDepthLayer.DataFormat dataFormat = SpaceToDepthLayer.DataFormat.NCHW; + private final int blockSize = 2; + private final SpaceToDepthLayer.DataFormat dataFormat = SpaceToDepthLayer.DataFormat.NCHW; - private int outDepth = inDepth * blockSize * blockSize; - private int outputHeight = inputHeight / blockSize; - private int outputWidth = inputWidth / blockSize; + private final int outDepth = inDepth * blockSize * blockSize; + private final int outputHeight = inputHeight / blockSize; + private final int outputWidth = inputWidth / blockSize; private INDArray getContainedData() { @@ -75,7 +74,7 @@ public class SpaceToDepthTest extends BaseDL4JTest { Layer std = getSpaceToDepthLayer(); INDArray containedOutput = std.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); + assertArrayEquals(containedExpectedOut.shape(), containedOutput.shape()); assertEquals(containedExpectedOut, containedOutput); } @@ -89,7 +88,7 @@ public class SpaceToDepthTest extends BaseDL4JTest { std.setInput(getContainedData(), LayerWorkspaceMgr.noWorkspaces()); INDArray containedOutput = std.backpropGradient(containedInputEpsilon, LayerWorkspaceMgr.noWorkspaces()).getRight(); - assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); + assertArrayEquals(containedExpectedOut.shape(), containedOutput.shape()); assertEquals(containedExpectedOut, containedOutput); } } \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java index 2fca7643a..75434a4c3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java @@ -54,12 +54,12 @@ import static org.junit.jupiter.api.Assertions.*; public class SubsamplingLayerTest extends BaseDL4JTest { private int nExamples = 1; - private int depth = 20; //channels & nOut - private int nChannelsIn = 1; - private int inputWidth = 28; - private int inputHeight = 28; - private int[] kernelSize = new int[] {2, 2}; - private int[] stride = new int[] {2, 2}; + private final int depth = 20; //channels & nOut + private final int nChannelsIn = 1; + private final int inputWidth = 28; + private final int inputHeight = 28; + private final int[] kernelSize = new int[] {2, 2}; + private final int[] stride = new int[] {2, 2}; int featureMapWidth = (inputWidth - kernelSize[0]) / stride[0] + 1; int featureMapHeight = (inputHeight - kernelSize[1]) / stride[0] + 1; @@ -73,18 +73,17 @@ public class SubsamplingLayerTest extends BaseDL4JTest { @Test public void testSubSampleMaxActivate() throws Exception { INDArray containedExpectedOut = - Nd4j.create(new double[] {5., 7., 6., 8., 4., 7., 5., 9.}, new long[] {1, 2, 2, 2}).castTo(Nd4j.defaultFloatingPointType()); + Nd4j.create(new double[] {5., 7., 6., 8., 4., 7., 5., 9.}, 1, 2, 2, 2).castTo(Nd4j.defaultFloatingPointType()); INDArray containedInput = getContainedData(); INDArray input = getData(); Layer layer = getSubsamplingLayer(SubsamplingLayer.PoolingType.MAX); INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); + assertArrayEquals(containedExpectedOut.shape(), containedOutput.shape()); assertEquals(containedExpectedOut, containedOutput); INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[] {nExamples, nChannelsIn, featureMapWidth, featureMapHeight}, - output.shape())); + assertArrayEquals(new long[]{nExamples, nChannelsIn, featureMapWidth, featureMapHeight}, output.shape()); assertEquals(nChannelsIn, output.size(1), 1e-4); // channels retained } @@ -97,12 +96,11 @@ public class SubsamplingLayerTest extends BaseDL4JTest { Layer layer = getSubsamplingLayer(SubsamplingLayer.PoolingType.AVG); INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); + assertArrayEquals(containedExpectedOut.shape(), containedOutput.shape()); assertEquals(containedExpectedOut, containedOutput); INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[] {nExamples, nChannelsIn, featureMapWidth, featureMapHeight}, - output.shape())); + assertArrayEquals(new long[]{nExamples, nChannelsIn, featureMapWidth, featureMapHeight}, output.shape()); assertEquals(nChannelsIn, output.size(1), 1e-4); // channels retained } @@ -124,7 +122,7 @@ public class SubsamplingLayerTest extends BaseDL4JTest { Pair containedOutput = layer.backpropGradient(expectedContainedEpsilonInput, LayerWorkspaceMgr.noWorkspaces()); assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); - assertEquals(null, containedOutput.getFirst().getGradientFor("W")); + assertNull(containedOutput.getFirst().getGradientFor("W")); assertEquals(expectedContainedEpsilonResult.shape().length, containedOutput.getSecond().shape().length); INDArray input2 = getData(); @@ -153,7 +151,7 @@ public class SubsamplingLayerTest extends BaseDL4JTest { Pair containedOutput = layer.backpropGradient(expectedContainedEpsilonInput, LayerWorkspaceMgr.noWorkspaces()); assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); - assertEquals(null, containedOutput.getFirst().getGradientFor("W")); + assertNull(containedOutput.getFirst().getGradientFor("W")); assertArrayEquals(expectedContainedEpsilonResult.shape(), containedOutput.getSecond().shape()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java index 6cc561ceb..35ba6d924 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java @@ -68,14 +68,14 @@ public class TestConvolutionModes extends BaseDL4JTest { for (int minibatch : minibatches) { for (int inDepth : inDepths) { - INDArray origData = Nd4j.rand(new int[] {minibatch, inDepth, 9, 9}); + INDArray origData = Nd4j.rand(minibatch, inDepth, 9, 9); for (int inSize : inSizes) { for (ConvolutionMode cm : new ConvolutionMode[] {ConvolutionMode.Strict, ConvolutionMode.Truncate}) { - INDArray inputData = Nd4j.rand(new int[] {minibatch, inDepth, inSize, inSize}); + INDArray inputData = Nd4j.rand(minibatch, inDepth, inSize, inSize); inputData.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 9), NDArrayIndex.interval(0, 9)).assign(origData); @@ -147,14 +147,14 @@ public class TestConvolutionModes extends BaseDL4JTest { for (int minibatch : minibatches) { for (int inDepth : inDepths) { - INDArray origData = Nd4j.rand(new int[] {minibatch, inDepth, 9, 9}); + INDArray origData = Nd4j.rand(minibatch, inDepth, 9, 9); for (int inSize : inSizes) { for (ConvolutionMode cm : new ConvolutionMode[] {ConvolutionMode.Strict, ConvolutionMode.Truncate}) { - INDArray inputData = Nd4j.rand(new int[] {minibatch, inDepth, inSize, inSize}); + INDArray inputData = Nd4j.rand(minibatch, inDepth, inSize, inSize); inputData.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 9), NDArrayIndex.interval(0, 9)).assign(origData); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java index 0504c4fac..277b43c31 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java @@ -38,8 +38,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.Arrays; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; /** * @author Max Pumperla @@ -47,11 +46,11 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class Upsampling1DTest extends BaseDL4JTest { private int nExamples = 1; - private int depth = 20; - private int nChannelsIn = 1; - private int inputLength = 28; - private int size = 2; - private int outputLength = inputLength * size; + private final int depth = 20; + private final int nChannelsIn = 1; + private final int inputLength = 28; + private final int size = 2; + private final int outputLength = inputLength * size; private INDArray epsilon = Nd4j.ones(nExamples, depth, outputLength); @@ -65,12 +64,11 @@ public class Upsampling1DTest extends BaseDL4JTest { Layer layer = getUpsampling1DLayer(); INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); + assertArrayEquals(containedExpectedOut.shape(), containedOutput.shape()); assertEquals(containedExpectedOut, containedOutput); INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[] {nExamples, nChannelsIn, outputLength}, - output.shape())); + assertArrayEquals(new long[]{nExamples, nChannelsIn, outputLength}, output.shape()); assertEquals(nChannelsIn, output.size(1), 1e-4); } @@ -92,7 +90,7 @@ public class Upsampling1DTest extends BaseDL4JTest { Pair containedOutput = layer.backpropGradient(expectedContainedEpsilonInput, LayerWorkspaceMgr.noWorkspaces()); assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); - assertEquals(null, containedOutput.getFirst().getGradientFor("W")); + assertNull(containedOutput.getFirst().getGradientFor("W")); assertEquals(expectedContainedEpsilonResult.shape().length, containedOutput.getSecond().shape().length); INDArray input2 = getData(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java index a0ee3de55..e1d46f911 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java @@ -46,14 +46,14 @@ import static org.junit.jupiter.api.Assertions.*; public class Upsampling2DTest extends BaseDL4JTest { private int nExamples = 1; - private int depth = 20; - private int nChannelsIn = 1; - private int inputWidth = 28; - private int inputHeight = 28; + private final int depth = 20; + private final int nChannelsIn = 1; + private final int inputWidth = 28; + private final int inputHeight = 28; - private int size = 2; - private int outputWidth = inputWidth * size; - private int outputHeight = inputHeight * size; + private final int size = 2; + private final int outputWidth = inputWidth * size; + private final int outputHeight = inputHeight * size; private INDArray epsilon = Nd4j.ones(nExamples, depth, outputHeight, outputWidth); @@ -68,12 +68,11 @@ public class Upsampling2DTest extends BaseDL4JTest { Layer layer = getUpsamplingLayer(); INDArray containedOutput = layer.activate(containedInput, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(containedExpectedOut.shape(), containedOutput.shape())); + assertArrayEquals(containedExpectedOut.shape(), containedOutput.shape()); assertEquals(containedExpectedOut, containedOutput); INDArray output = layer.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(new long[] {nExamples, nChannelsIn, outputWidth, outputHeight}, - output.shape())); + assertArrayEquals(new long[]{nExamples, nChannelsIn, outputWidth, outputHeight}, output.shape()); assertEquals(nChannelsIn, output.size(1), 1e-4); } @@ -95,7 +94,7 @@ public class Upsampling2DTest extends BaseDL4JTest { Pair containedOutput = layer.backpropGradient(expectedContainedEpsilonInput, LayerWorkspaceMgr.noWorkspaces()); assertEquals(expectedContainedEpsilonResult, containedOutput.getSecond()); - assertEquals(null, containedOutput.getFirst().getGradientFor("W")); + assertNull(containedOutput.getFirst().getGradientFor("W")); assertEquals(expectedContainedEpsilonResult.shape().length, containedOutput.getSecond().shape().length); INDArray input2 = getData(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java index 2c4968e52..25c8074a8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java @@ -43,9 +43,9 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class DenseTest extends BaseDL4JTest { - private int numSamples = 150; - private int batchSize = 150; - private DataSetIterator iter = new IrisDataSetIterator(batchSize, numSamples); + private final int numSamples = 150; + private final int batchSize = 150; + private final DataSetIterator iter = new IrisDataSetIterator(batchSize, numSamples); private DataSet data; @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index 30e221c1a..259a38382 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -399,7 +399,6 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net2.setParams(net.params().dup()); - ; INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength); INDArray outLabels = Nd4j.create(batchSize, 4, timeSeriesLength); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java index 10ca617fe..c4950d3c4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java @@ -149,7 +149,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { int nIn = 4; int minibatch = 2; Nd4j.getRandom().setSeed(12345); - INDArray input = Nd4j.rand('c', new int[]{minibatch, nIn}); + INDArray input = Nd4j.rand('c', minibatch, nIn); //TODO: other values for gamma/beta INDArray gamma = Nd4j.ones(1, nIn); @@ -207,7 +207,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { int hw = 15; Nd4j.getRandom().setSeed(12345); - INDArray randInput = Nd4j.rand(new int[]{100, nOut, hw, hw}); + INDArray randInput = Nd4j.rand(100, nOut, hw, hw); INDArray output = l.activate(randInput, true, LayerWorkspaceMgr.noWorkspaces()); assertEquals(4, output.rank()); @@ -288,7 +288,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { int hw = 3; int minibatch = 2; Nd4j.getRandom().setSeed(12345); - INDArray input = Nd4j.rand('c', new int[]{minibatch, nIn, hw, hw}); + INDArray input = Nd4j.rand('c', minibatch, nIn, hw, hw); //TODO: other values for gamma/beta INDArray gamma = Nd4j.ones(1, nIn); @@ -313,7 +313,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { //------------------------------------------------------------- //Check backprop - INDArray epsilon = Nd4j.rand('c', new int[]{minibatch, nIn, hw, hw}); //dL/dy + INDArray epsilon = Nd4j.rand('c', minibatch, nIn, hw, hw); //dL/dy int effectiveMinibatch = minibatch * hw * hw; @@ -388,8 +388,8 @@ public class BatchNormalizationTest extends BaseDL4JTest { network.fit(next); INDArray actualGammaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.GAMMA); INDArray actualBetaParam = network.getLayer(1).getParam(BatchNormalizationParamInitializer.BETA); - assertTrue(actualGammaParam != null); - assertTrue(actualBetaParam != null); + assertNotNull(actualGammaParam); + assertNotNull(actualBetaParam); } @Test @@ -599,7 +599,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { int minibatch = 32; List list = new ArrayList<>(); for (int i = 0; i < 100; i++) { - list.add(new DataSet(Nd4j.rand(new int[]{minibatch, 3, 5, 5}), Nd4j.rand(minibatch, 10))); + list.add(new DataSet(Nd4j.rand(minibatch, 3, 5, 5), Nd4j.rand(minibatch, 10))); } DataSetIterator iter = new ListDataSetIterator(list); @@ -672,7 +672,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { int minibatch = 32; for (int i = 0; i < 10; i++) { - DataSet ds = new DataSet(Nd4j.rand(new int[]{minibatch, 3, 5, 5}), Nd4j.rand(minibatch, 10)); + DataSet ds = new DataSet(Nd4j.rand(minibatch, 3, 5, 5), Nd4j.rand(minibatch, 10)); net.fit(ds); net2.fit(ds); @@ -743,8 +743,8 @@ public class BatchNormalizationTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in = Nd4j.rand(new int[]{1, 3, 5}); - INDArray label = Nd4j.rand(new int[]{1, 3, 5}); + INDArray in = Nd4j.rand(1, 3, 5); + INDArray label = Nd4j.rand(1, 3, 5); INDArray out = net.output(in); assertArrayEquals(new long[]{1, 3, 5}, out.shape()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java index 99fc1e5a3..e876b736b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java @@ -46,15 +46,14 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.*; /** * */ public class LocalResponseTest extends BaseDL4JTest { - private INDArray x = Nd4j.create(new double[] {0.88128096, -0.96666986, -0.61832994, 0.26418415, 0.05694608, + private final INDArray x = Nd4j.create(new double[] {0.88128096, -0.96666986, -0.61832994, 0.26418415, 0.05694608, 0.2950289, 0.99222249, 0.24541704, 0.4219842, 0.96430975, 0.19299535, -0.06658337, -0.27603117, 0.24216647, 0.21834095, 0.03863283, -0.82313406, -0.37236378, -0.77667993, 0.66295379, -0.34406275, -0.25924176, 0.26652309, -0.58964926, -0.46907067, 0.34666502, 0.81208313, -0.17042427, -0.22470538, @@ -67,7 +66,7 @@ public class LocalResponseTest extends BaseDL4JTest { -0.31666604, 0.19781154, 0.09908111, 0.64796048, -0.99037546, 0.67919868, 0.43810204}, new int[] {2, 7, 3, 2}); - private INDArray activationsExpected = Nd4j.create(new double[] {0.52397668, -0.57476264, -0.3676528, 0.15707894, + private final INDArray activationsExpected = Nd4j.create(new double[] {0.52397668, -0.57476264, -0.3676528, 0.15707894, 0.03385943, 0.17542371, 0.58992499, 0.14591768, 0.25090647, 0.57335907, 0.11475233, -0.03958985, -0.16411273, 0.14398433, 0.12981956, 0.02297027, -0.48942304, -0.22139823, -0.46177959, 0.39418164, -0.20457059, -0.15413573, 0.15846729, -0.3505919, -0.27889356, 0.20611978, 0.48284137, -0.10133155, @@ -80,7 +79,7 @@ public class LocalResponseTest extends BaseDL4JTest { 0.57277, -0.18827969, 0.1176173, 0.05891332, 0.38526815, -0.58884346, 0.40383074, 0.26048511}, new int[] {2, 7, 3, 2}); - private INDArray epsilon = Nd4j.create(new double[] {-0.13515499, 0.96470547, -0.62253004, 0.80172491, -0.97510445, + private final INDArray epsilon = Nd4j.create(new double[] {-0.13515499, 0.96470547, -0.62253004, 0.80172491, -0.97510445, -0.41198033, -0.4790071, 0.07551047, -0.01383764, -0.05797465, 0.21242172, 0.7145375, -0.17809176, -0.11465316, -0.2066526, 0.21950938, 0.4627091, 0.30275798, 0.61443841, 0.75912178, -0.132248, -0.82923287, 0.74962652, -0.88993639, 0.04406403, 0.32096064, -0.46400586, 0.1603231, 0.63007826, @@ -93,7 +92,7 @@ public class LocalResponseTest extends BaseDL4JTest { 0.04847952, -0.82953823, 0.8089835, 0.50185651, -0.88619858, -0.78598201, 0.27489874, 0.63673472}, new int[] {2, 7, 3, 2}); - private INDArray newEpsilonExpected = Nd4j.create(new double[] {-0.08033668, 0.57355404, -0.37014094, 0.47668865, + private final INDArray newEpsilonExpected = Nd4j.create(new double[] {-0.08033668, 0.57355404, -0.37014094, 0.47668865, -0.57978398, -0.24495915, -0.28474802, 0.04490108, -0.00823483, -0.03448687, 0.12630466, 0.42485803, -0.10589627, -0.06816553, -0.12287001, 0.13051508, 0.27510744, 0.18001786, 0.36528736, 0.45133191, -0.07863599, -0.49303374, 0.44571424, -0.52912313, 0.02620371, 0.19082049, -0.27585581, 0.09532529, @@ -133,7 +132,7 @@ public class LocalResponseTest extends BaseDL4JTest { assertEquals(newEpsilonExpected.getDouble(8), containedOutput.getSecond().getDouble(8), 1e-4); assertEquals(newEpsilonExpected.getDouble(20), containedOutput.getSecond().getDouble(20), 1e-4); - assertEquals(null, containedOutput.getFirst().getGradientFor("W")); + assertNull(containedOutput.getFirst().getGradientFor("W")); assertArrayEquals(newEpsilonExpected.shape(), containedOutput.getSecond().shape()); } @@ -182,7 +181,7 @@ public class LocalResponseTest extends BaseDL4JTest { double alpha = 1e-4; double beta = 0.75; - INDArray in = Nd4j.rand(new int[] {minibatch, depth, wh, wh}); + INDArray in = Nd4j.rand(minibatch, depth, wh, wh); INDArray outExp = Nd4j.zeros(minibatch, depth, wh, wh); for (int m = 0; m < minibatch; m++) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java index 7fb8dc8af..c732ab366 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java @@ -102,7 +102,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer y2impl = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) net.getLayer(1); - INDArray input = Nd4j.rand(new int[]{mb, depth, h, w}); + INDArray input = Nd4j.rand(mb, depth, h, w); INDArray out = y2impl.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(out); @@ -115,7 +115,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { INDArray labels = Nd4j.zeros(mb, labelDepth, h, w); //put 1 object per minibatch, at positions (0,0), (1,1) etc. //Positions for label boxes: (1,1) to (2,2), (2,2) to (4,4) etc - labels.putScalar(0, 4 + 0, 0, 0, 1); + labels.putScalar(0, 4, 0, 0, 1); labels.putScalar(1, 4 + 1, 1, 1, 1); labels.putScalar(2, 4 + 2, 2, 2, 1); @@ -190,7 +190,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer y2impl = (org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer) net.getLayer(1); - INDArray input = Nd4j.rand(new int[]{mb, depth, h, w}); + INDArray input = Nd4j.rand(mb, depth, h, w); INDArray out = y2impl.activate(input, false, LayerWorkspaceMgr.noWorkspaces()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java index 1c9da8933..e9f76dfc2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java @@ -147,7 +147,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { System.out.println("Normal probabilities " + normalProbs); System.out.println("Normal raw output " + outputForNormalSamples); - File tmpFile = new File(testDir.getAbsoluteFile(),"tmp-file-" + UUID.randomUUID().toString()); + File tmpFile = new File(testDir.getAbsoluteFile(),"tmp-file-" + UUID.randomUUID()); ModelSerializer.writeModel(network,tmpFile,true); tmpFile.deleteOnExit(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java index f6ef09732..a7f3d1867 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java @@ -74,7 +74,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { net.init(); Random r = new Random(12345L); - INDArray input = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}).subi(0.5); + INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength).subi(0.5); INDArray mask; if (miniBatchSize == 1) { @@ -136,7 +136,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray inToBeMasked = Nd4j.rand(new int[] {minibatch, depthIn, height, width}); + INDArray inToBeMasked = Nd4j.rand(minibatch, depthIn, height, width); //Shape for mask: [minibatch, 1, 1, width] INDArray maskArray = Nd4j.create(new double[] {1, 1, 1, 1, 1, 0}, new int[]{1,1,1,width}); @@ -164,7 +164,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { //Finally: check gradient calc for exceptions net.setLayerMaskArrays(maskArray, null); net.setInput(inToBeMasked); - INDArray labels = Nd4j.create(new double[] {0, 1}, new long[]{1,2}); + INDArray labels = Nd4j.create(new double[] {0, 1}, 1,2); net.setLabels(labels); net.computeGradientAndScore(); @@ -199,7 +199,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray inToBeMasked = Nd4j.rand(new int[] {minibatch, depthIn, height, width}); + INDArray inToBeMasked = Nd4j.rand(minibatch, depthIn, height, width); //Shape for mask: [minibatch, width] INDArray maskArray = Nd4j.create(new double[] {1, 1, 1, 1, 1, 0}, new int[]{1,1,height,1}); @@ -227,7 +227,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { //Finally: check gradient calc for exceptions net.setLayerMaskArrays(maskArray, null); net.setInput(inToBeMasked); - INDArray labels = Nd4j.create(new double[] {0, 1}, new long[]{1,2}); + INDArray labels = Nd4j.create(new double[] {0, 1}, 1,2); net.setLabels(labels); net.computeGradientAndScore(); @@ -263,7 +263,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray inToBeMasked = Nd4j.rand(new int[] {minibatch, depthIn, height, width}); + INDArray inToBeMasked = Nd4j.rand(minibatch, depthIn, height, width); //Shape for mask: [minibatch, width] INDArray maskArray = Nd4j.create(new double[][] {{1, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 1, 0}, {1, 1, 1, 1, 0, 0}}) @@ -322,7 +322,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray inToBeMasked = Nd4j.rand(new int[] {minibatch, depthIn, height, width}); + INDArray inToBeMasked = Nd4j.rand(minibatch, depthIn, height, width); //Shape for mask: [minibatch, 1, height, 1] -> broadcast INDArray maskArray = Nd4j.create(new double[][] {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 0}, {1, 1, 1, 0, 0}}) @@ -381,7 +381,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray inToBeMasked = Nd4j.rand(new int[] {minibatch, depthIn, height, width}); + INDArray inToBeMasked = Nd4j.rand(minibatch, depthIn, height, width); //Second example in minibatch: size [3,2] inToBeMasked.get(point(1), NDArrayIndex.all(), NDArrayIndex.interval(3,height), NDArrayIndex.all()).assign(0); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index 2c9f0886e..e785b36e5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -68,7 +68,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class BidirectionalTest extends BaseDL4JTest { - private RNNFormat rnnDataFormat; + private final RNNFormat rnnDataFormat; public BidirectionalTest(RNNFormat rnnDataFormat){ this.rnnDataFormat = rnnDataFormat; @@ -128,9 +128,9 @@ public class BidirectionalTest extends BaseDL4JTest { INDArray in; if (rnnDataFormat == NCW){ - in = Nd4j.rand(new int[]{3, 10, 5}); + in = Nd4j.rand(3, 10, 5); }else{ - in = Nd4j.rand(new int[]{3, 5, 10}); + in = Nd4j.rand(3, 5, 10); } INDArray out1 = net1.output(in); @@ -140,9 +140,9 @@ public class BidirectionalTest extends BaseDL4JTest { INDArray labels; if (rnnDataFormat == NCW){ - labels = Nd4j.rand(new int[]{3, 10, 5}); + labels = Nd4j.rand(3, 10, 5); }else{ - labels = Nd4j.rand(new int[]{3, 5, 10}); + labels = Nd4j.rand(3, 5, 10); } net1.setInput(in); net1.setLabels(labels); @@ -234,14 +234,14 @@ public class BidirectionalTest extends BaseDL4JTest { net2.setParams(net1.params()); //Assuming exact same layout here... - INDArray in = Nd4j.rand(new int[]{3, 10, 5}); + INDArray in = Nd4j.rand(3, 10, 5); INDArray out1 = net1.outputSingle(in); INDArray out2 = net2.outputSingle(in); assertEquals(out1, out2); - INDArray labels = Nd4j.rand(new int[]{3, 10, 5}); + INDArray labels = Nd4j.rand(3, 10, 5); net1.setInput(0,in); net1.setLabels(labels); @@ -261,8 +261,8 @@ public class BidirectionalTest extends BaseDL4JTest { assertEquals(g1.gradient(), g2.gradient()); //Ensure updates are equal: - ComputationGraphUpdater u1 = (ComputationGraphUpdater) net1.getUpdater(); - ComputationGraphUpdater u2 = (ComputationGraphUpdater) net2.getUpdater(); + ComputationGraphUpdater u1 = net1.getUpdater(); + ComputationGraphUpdater u2 = net2.getUpdater(); assertEquals(u1.getUpdaterStateViewArray(), u2.getUpdaterStateViewArray()); u1.update(g1, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); u2.update(g2, 0, 0, 3, LayerWorkspaceMgr.noWorkspaces()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index b6f3e7a58..bd1291216 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -47,8 +47,8 @@ import org.nd4j.common.primitives.Pair; import static org.junit.jupiter.api.Assertions.*; public class GravesBidirectionalLSTMTest extends BaseDL4JTest { - private double score = 0.0; - private RNNFormat rnnDataFormat; + private final double score = 0.0; + private final RNNFormat rnnDataFormat; public GravesBidirectionalLSTMTest(RNNFormat rnnDataFormat){ this.rnnDataFormat = rnnDataFormat; @@ -170,13 +170,13 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { assertNotNull(inWeightGradientB); assertNotNull(recurrentWeightGradientB); - assertArrayEquals(biasGradientF.shape(), new long[] {1, 4 * lstmNHiddenUnits}); - assertArrayEquals(inWeightGradientF.shape(), new long[] {nIn, 4 * lstmNHiddenUnits}); - assertArrayEquals(recurrentWeightGradientF.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); + assertArrayEquals(biasGradientF.shape(), new long[] {1, 4L * lstmNHiddenUnits}); + assertArrayEquals(inWeightGradientF.shape(), new long[] {nIn, 4L * lstmNHiddenUnits}); + assertArrayEquals(recurrentWeightGradientF.shape(), new long[] {lstmNHiddenUnits, 4L * lstmNHiddenUnits + 3}); - assertArrayEquals(biasGradientB.shape(), new long[] {1, 4 * lstmNHiddenUnits}); - assertArrayEquals(inWeightGradientB.shape(), new long[] {nIn, 4 * lstmNHiddenUnits}); - assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); + assertArrayEquals(biasGradientB.shape(), new long[] {1, 4L * lstmNHiddenUnits}); + assertArrayEquals(inWeightGradientB.shape(), new long[] {nIn, 4L * lstmNHiddenUnits}); + assertArrayEquals(recurrentWeightGradientB.shape(), new long[] {lstmNHiddenUnits, 4L * lstmNHiddenUnits + 3}); assertNotNull(nextEpsilon); if (rnnDataFormat == RNNFormat.NCW) { @@ -212,7 +212,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { INDArray params = Nd4j.create(1, numParams); final GravesBidirectionalLSTM lstm = (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - final INDArray input = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); + final INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); @@ -236,7 +236,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { for (int i = 0; i < timeSeriesLength; i++) { final INDArray sliceFalse = fwdPassFalse.tensorAlongDimension(i, 1, 0); final INDArray sliceTrue = fwdPassTrue[i]; - assertTrue(sliceFalse.equals(sliceTrue)); + assertEquals(sliceFalse, sliceTrue); } } @@ -273,8 +273,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .instantiate(confBidirectional, null, 0, params, true, params.dataType()); - final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}): - Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn}); + final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(miniBatchSize, nIn, timeSeriesLength): + Nd4j.rand(miniBatchSize, timeSeriesLength, nIn); final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); @@ -327,8 +327,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards))); - final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}): - Nd4j.rand(new int[] {miniBatchSize, timeSeriesLength, nIn}); + final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(miniBatchSize, nIn, timeSeriesLength): + Nd4j.rand(miniBatchSize, timeSeriesLength, nIn); final INDArray sigb = sig.dup(); if (rnnDataFormat == RNNFormat.NCW) { @@ -389,8 +389,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { assertArrayEquals(activation1.data().asFloat(), activation2.data().asFloat(), 1e-5f); - final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(new int[] {1, layerSize, timeSeriesLength}): - Nd4j.rand(new int[] {1, timeSeriesLength, layerSize}); + final INDArray randSig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(1, layerSize, timeSeriesLength): + Nd4j.rand(1, timeSeriesLength, layerSize); INDArray randSigBackwards = randSig.dup(); if (rnnDataFormat == RNNFormat.NCW){ reverseColumnsInPlace(randSigBackwards.slice(0)); @@ -549,8 +549,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) net.getLayer(0).conf() .getLayer()).getGateActivationFn().toString()); - INDArray in = Nd4j.rand(new int[] {3, 2, 5}); - INDArray labels = Nd4j.rand(new int[] {3, 2, 5}); + INDArray in = Nd4j.rand(3, 2, 5); + INDArray labels = Nd4j.rand(3, 2, 5); if (rnnDataFormat == RNNFormat.NWC){ in = in.permute(0, 2, 1); labels = labels.permute(0, 2, 1); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java index 80d3af6fe..679066755 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java @@ -131,9 +131,9 @@ public class GravesLSTMTest extends BaseDL4JTest { assertNotNull(inWeightGradient); assertNotNull(recurrentWeightGradient); - assertArrayEquals(biasGradient.shape(), new long[] {1, 4 * lstmNHiddenUnits}); - assertArrayEquals(inWeightGradient.shape(), new long[] {nIn, 4 * lstmNHiddenUnits}); - assertArrayEquals(recurrentWeightGradient.shape(), new long[] {lstmNHiddenUnits, 4 * lstmNHiddenUnits + 3}); + assertArrayEquals(biasGradient.shape(), new long[] {1, 4L * lstmNHiddenUnits}); + assertArrayEquals(inWeightGradient.shape(), new long[] {nIn, 4L * lstmNHiddenUnits}); + assertArrayEquals(recurrentWeightGradient.shape(), new long[] {lstmNHiddenUnits, 4L * lstmNHiddenUnits + 3}); assertNotNull(nextEpsilon); assertArrayEquals(nextEpsilon.shape(), new long[] {miniBatchSize, nIn, timeSeriesLength}); @@ -164,7 +164,7 @@ public class GravesLSTMTest extends BaseDL4JTest { val numParams = conf.getLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - INDArray input = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); + INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); Method actHelper = GravesLSTM.class.getDeclaredMethod("activateHelper", boolean.class, INDArray.class, @@ -189,7 +189,7 @@ public class GravesLSTMTest extends BaseDL4JTest { for (int i = 0; i < timeSeriesLength; i++) { INDArray sliceFalse = fwdPassFalse.tensorAlongDimension(i, 1, 0); INDArray sliceTrue = fwdPassTrue[i]; - assertTrue(sliceFalse.equals(sliceTrue)); + assertEquals(sliceFalse, sliceTrue); } } @@ -210,13 +210,13 @@ public class GravesLSTMTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in1 = Nd4j.rand(new int[] {1, 2, 4}); - INDArray in2 = Nd4j.rand(new int[] {1, 2, 5}); + INDArray in1 = Nd4j.rand(1, 2, 4); + INDArray in2 = Nd4j.rand(1, 2, 5); in2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, in1); assertEquals(in1, in2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4))); - INDArray labels1 = Nd4j.rand(new int[] {1, 1, 4}); + INDArray labels1 = Nd4j.rand(1, 1, 4); INDArray labels2 = Nd4j.create(1, 1, 5); labels2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4)}, labels1); assertEquals(labels1, labels2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4))); @@ -271,8 +271,8 @@ public class GravesLSTMTest extends BaseDL4JTest { assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesLSTM) net.getLayer(0).conf().getLayer()) .getGateActivationFn().toString()); - INDArray in = Nd4j.rand(new int[] {3, 2, 5}); - INDArray labels = Nd4j.rand(new int[] {3, 2, 5}); + INDArray in = Nd4j.rand(3, 2, 5); + INDArray labels = Nd4j.rand(3, 2, 5); net.fit(in, labels); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java index 1508d4b62..f1fa71ab2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java @@ -42,7 +42,7 @@ import java.util.Collections; import static org.junit.jupiter.api.Assertions.assertEquals; public class MaskZeroLayerTest extends BaseDL4JTest { - private RNNFormat rnnDataFormat; + private final RNNFormat rnnDataFormat; public MaskZeroLayerTest(RNNFormat rnnDataFormat){ this.rnnDataFormat = rnnDataFormat; @@ -73,17 +73,17 @@ public class MaskZeroLayerTest extends BaseDL4JTest { .build(); NeuralNetConfiguration conf = new NeuralNetConfiguration(); conf.setLayer(underlying); - INDArray params = Nd4j.zeros(new int[]{1, 16}); + INDArray params = Nd4j.zeros(1, 16); //Set the biases to 1. for (int i = 12; i < 16; i++) { params.putScalar(i, 1.0); } - Layer lstm = underlying.instantiate(conf, Collections.emptyList(), 0, params, false, params.dataType()); + Layer lstm = underlying.instantiate(conf, Collections.emptyList(), 0, params, false, params.dataType()); double maskingValue = 0.0; MaskZeroLayer l = new MaskZeroLayer(lstm, maskingValue); - INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), new int[]{2, 2, 3}); + INDArray input = Nd4j.create(Arrays.asList(ex1, ex2), 2, 2, 3); if (rnnDataFormat == RNNFormat.NWC){ input = input.permute(0, 2, 1); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java index 170ab285f..4abcfa768 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java @@ -49,7 +49,7 @@ import static org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction.MSE; public class TestLastTimeStepLayer extends BaseDL4JTest { - private RNNFormat rnnDataFormat; + private final RNNFormat rnnDataFormat; public TestLastTimeStepLayer(RNNFormat rnnDataFormat){ this.rnnDataFormat = rnnDataFormat; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index b9f850453..b5fd0ac57 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -48,13 +48,11 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class TestRnnLayers extends BaseDL4JTest { - private RNNFormat rnnDataFormat; + private final RNNFormat rnnDataFormat; public TestRnnLayers(RNNFormat rnnDataFormat){ this.rnnDataFormat = rnnDataFormat; @@ -87,13 +85,13 @@ public class TestRnnLayers extends BaseDL4JTest { INDArray rnnInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10,12, 1):Nd4j.create(10, 1, 12); INDArray simpleOut = simpleRnn.rnnTimeStep(rnnInput3d, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(simpleOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 3, 1}:new long[]{10, 1, 3})); + assertArrayEquals(simpleOut.shape(), (rnnDataFormat == RNNFormat.NCW) ? new long[]{10, 3, 1} : new long[]{10, 1, 3}); INDArray rnnInput2d = Nd4j.create(10, 12); try { simpleRnn.rnnTimeStep(rnnInput2d, LayerWorkspaceMgr.noWorkspaces()); } catch (IllegalStateException e) { - assertTrue(e.getMessage().equals("3D input expected to RNN layer expected, got 2")); + assertEquals("3D input expected to RNN layer expected, got 2", e.getMessage()); } org.deeplearning4j.nn.layers.recurrent.LSTM lstm = @@ -101,13 +99,13 @@ public class TestRnnLayers extends BaseDL4JTest { INDArray lstmInput3d = (rnnDataFormat==RNNFormat.NCW)?Nd4j.create(10, 3, 1):Nd4j.create(10, 1, 3); INDArray lstmOut = lstm.rnnTimeStep(lstmInput3d, LayerWorkspaceMgr.noWorkspaces()); - assertTrue(Arrays.equals(lstmOut.shape(), (rnnDataFormat==RNNFormat.NCW)?new long[] {10, 5, 1}:new long[]{10, 1, 5})); + assertArrayEquals(lstmOut.shape(), (rnnDataFormat == RNNFormat.NCW) ? new long[]{10, 5, 1} : new long[]{10, 1, 5}); INDArray lstmInput2d = Nd4j.create(10, 3); try { lstm.rnnTimeStep(lstmInput2d, LayerWorkspaceMgr.noWorkspaces()); } catch (IllegalStateException e) { - assertTrue(e.getMessage().equals("3D input expected to RNN layer expected, got 2")); + assertEquals("3D input expected to RNN layer expected, got 2", e.getMessage()); } @@ -178,7 +176,7 @@ public class TestRnnLayers extends BaseDL4JTest { assertEquals(net.params(), netD.params(), s); assertEquals(net.params(), netD2.params(), s); - INDArray f = Nd4j.rand(DataType.FLOAT, new int[]{3, 10, 10}); + INDArray f = Nd4j.rand(DataType.FLOAT, 3, 10, 10); //Output: test mode -> no dropout INDArray out1 = net.output(f); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index 5fc4e8bb1..9d77537c8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -42,7 +42,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.point; public class TestSimpleRnn extends BaseDL4JTest { - private RNNFormat rnnDataFormat; + private final RNNFormat rnnDataFormat; public TestSimpleRnn(RNNFormat rnnDataFormat){ this.rnnDataFormat = rnnDataFormat; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java index 6c9a55ed2..90a05de95 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java @@ -49,7 +49,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class TestTimeDistributed extends BaseDL4JTest { - private RNNFormat rnnDataFormat; + private final RNNFormat rnnDataFormat; public TestTimeDistributed(RNNFormat rnnDataFormat){ this.rnnDataFormat = rnnDataFormat; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java index d9a331d0b..f0d5d16ce 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java @@ -206,7 +206,7 @@ public class TestSameDiffConv extends BaseDL4JTest { } } - INDArray in = Nd4j.rand(new int[]{minibatch, nIn, imgH, imgW}); + INDArray in = Nd4j.rand(minibatch, nIn, imgH, imgW); INDArray out = net.output(in); INDArray outExp = net2.output(in); @@ -306,7 +306,7 @@ public class TestSameDiffConv extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray f = Nd4j.rand(new int[]{minibatch, nIn, imgH, imgW}); + INDArray f = Nd4j.rand(minibatch, nIn, imgH, imgW); INDArray l = TestUtils.randomOneHot(minibatch, nOut); log.info("Starting: " + msg); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java index e49e6aca6..e84390916 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java @@ -85,8 +85,8 @@ public class SameDiffDense extends SameDiffLayer { @Override public void defineParameters(SDLayerParams params) { params.clear(); - params.addWeightParam(DefaultParamInitializer.WEIGHT_KEY, new long[]{nIn, nOut}); - params.addBiasParam(DefaultParamInitializer.BIAS_KEY, new long[]{1, nOut}); + params.addWeightParam(DefaultParamInitializer.WEIGHT_KEY, nIn, nOut); + params.addBiasParam(DefaultParamInitializer.BIAS_KEY, 1, nOut); } @Override diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java index 2e60b8461..41d149b3b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java @@ -35,10 +35,10 @@ import java.util.Map; public class SameDiffMSEOutputLayer extends SameDiffOutputLayer { - private int nIn; - private int nOut; - private Activation activation; - private WeightInit weightInit; + private final int nIn; + private final int nOut; + private final Activation activation; + private final WeightInit weightInit; public SameDiffMSEOutputLayer(int nIn, int nOut, Activation activation, WeightInit weightInit){ this.nIn = nIn; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java index 934ba63a8..7138a2a42 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestReconstructionDistributions.java @@ -60,7 +60,7 @@ public class TestReconstructionDistributions extends BaseDL4JTest { INDArray mean = Nd4j.randn(minibatch, inputSize); INDArray logStdevSquared = Nd4j.rand(minibatch, inputSize).subi(0.5); - INDArray distributionParams = Nd4j.createUninitialized(new int[] {minibatch, 2 * inputSize}); + INDArray distributionParams = Nd4j.createUninitialized(minibatch, 2 * inputSize); distributionParams.get(NDArrayIndex.all(), NDArrayIndex.interval(0, inputSize)).assign(mean); distributionParams.get(NDArrayIndex.all(), NDArrayIndex.interval(inputSize, 2 * inputSize)) .assign(logStdevSquared); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java index e61614a1b..f535c81fa 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java @@ -206,7 +206,7 @@ public class TestVAE extends BaseDL4JTest { INDArray gArr = grads.get(p); assertArrayEquals(pArr.shape(), gvArr.shape()); - assertTrue(gvArr == gArr); //Should be the exact same object due to view mechanics + assertSame(gvArr, gArr); //Should be the exact same object due to view mechanics } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java index cd1ca1a28..fc8312630 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java @@ -59,8 +59,8 @@ public class TestNetConversion extends BaseDL4JTest { default: throw new RuntimeException(); } - INDArray in = (i <= 1 ? Nd4j.rand(new int[]{8, 3, 10, 10}) : Nd4j.rand(new int[]{8, 5, 10})); - INDArray labels = (i <= 1 ? Nd4j.rand(new int[]{8, 10}) : Nd4j.rand(new int[]{8, 10, 10})); + INDArray in = (i <= 1 ? Nd4j.rand(8, 3, 10, 10) : Nd4j.rand(8, 5, 10)); + INDArray labels = (i <= 1 ? Nd4j.rand(8, 10) : Nd4j.rand(8, 10, 10)); ComputationGraph cg = n.toComputationGraph(); @@ -109,7 +109,7 @@ public class TestNetConversion extends BaseDL4JTest { if(train) { for (int i = 0; i < 3; i++) { - INDArray f = Nd4j.rand(new int[]{8, 3, 10, 10}); + INDArray f = Nd4j.rand(8, 3, 10, 10); INDArray l = Nd4j.rand(8, 10); net.fit(f, l); @@ -137,8 +137,8 @@ public class TestNetConversion extends BaseDL4JTest { net.init(); for (int i = 0; i < 3; i++) { - INDArray f = Nd4j.rand(new int[]{8, 5, 10}); - INDArray l = Nd4j.rand(new int[]{8, 10, 10}); + INDArray f = Nd4j.rand(8, 5, 10); + INDArray l = Nd4j.rand(8, 10, 10); net.fit(f, l); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java index ad57a4688..cf7d31bd5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java @@ -76,7 +76,7 @@ public class WorkspaceTests extends BaseDL4JTest { c.getConfiguration().setTrainingWorkspaceMode(wm); c.getConfiguration().setInferenceWorkspaceMode(wm); - INDArray f = Nd4j.rand(new int[]{8, 1, 28, 28}); + INDArray f = Nd4j.rand(8, 1, 28, 28); INDArray l = Nd4j.rand(8, 10); c.setInputs(f); c.setLabels(l); @@ -112,7 +112,7 @@ public class WorkspaceTests extends BaseDL4JTest { net2.getLayerWiseConfigurations().setInferenceWorkspaceMode(WorkspaceMode.NONE); net2.getLayerWiseConfigurations().setTrainingWorkspaceMode(WorkspaceMode.NONE); - INDArray in = Nd4j.rand(new int[]{1, 2, 5, 5}); + INDArray in = Nd4j.rand(1, 2, 5, 5); net.output(in); net2.output(in); //Op [add_scalar] X argument uses leaked workspace pointer from workspace [LOOP_EXTERNAL] @@ -175,7 +175,7 @@ public class WorkspaceTests extends BaseDL4JTest { } cg.setInputs(input); - cg.setLabels(Nd4j.rand(new int[]{1, 3, 5})); + cg.setLabels(Nd4j.rand(1, 3, 5)); cg.computeGradientAndScore(); } } @@ -207,7 +207,7 @@ public class WorkspaceTests extends BaseDL4JTest { } net.setInput(input); - net.setLabels(Nd4j.rand(new int[]{1, 3, 5})); + net.setLabels(Nd4j.rand(1, 3, 5)); net.computeGradientAndScore(); } } @@ -303,11 +303,11 @@ public class WorkspaceTests extends BaseDL4JTest { net2.init(); for (int j = 0; j < 3; j++) { - net.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5})); + net.rnnTimeStep(Nd4j.rand(3, 10, 5)); } for (int j = 0; j < 3; j++) { - net2.rnnTimeStep(Nd4j.rand(new int[]{3, 10, 5})); + net2.rnnTimeStep(Nd4j.rand(3, 10, 5)); } } } @@ -384,11 +384,11 @@ public class WorkspaceTests extends BaseDL4JTest { net2.init(); for (int j = 0; j < 3; j++) { - net.fit(Nd4j.rand(new int[]{3, 10, 20}), Nd4j.rand(new int[]{3, 10, 20})); + net.fit(Nd4j.rand(3, 10, 20), Nd4j.rand(3, 10, 20)); } for (int j = 0; j < 3; j++) { - net2.fit(new DataSet(Nd4j.rand(new int[]{3, 10, 20}), Nd4j.rand(new int[]{3, 10, 20}))); + net2.fit(new DataSet(Nd4j.rand(3, 10, 20), Nd4j.rand(3, 10, 20))); } } } @@ -625,7 +625,7 @@ public class WorkspaceTests extends BaseDL4JTest { mlc.setTrainingWorkspaceMode(wm); mlc.setInferenceWorkspaceMode(wm); - INDArray f = Nd4j.rand(new int[]{1, 1, 5, 5}); + INDArray f = Nd4j.rand(1, 1, 5, 5); INDArray l = Nd4j.rand(1, 10); DataSet ds = new DataSet(f,l); @@ -669,7 +669,7 @@ public class WorkspaceTests extends BaseDL4JTest { c.getConfiguration().setTrainingWorkspaceMode(wm); c.getConfiguration().setInferenceWorkspaceMode(wm); - INDArray f = Nd4j.rand(new int[]{8, 1, 28, 28}); + INDArray f = Nd4j.rand(8, 1, 28, 28); INDArray l = Nd4j.rand(8, 10); DataSet ds = new DataSet(f,l); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java index 0f6337502..695fdb70d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java @@ -110,7 +110,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { String name = pt + ", mb=" + minibatch + ", cm=" + cm + ", kernel=" + Arrays.toString(kernel) + ", stride=" + Arrays.toString(stride); LayerHelperValidationUtil.TestCase tc = LayerHelperValidationUtil.TestCase.builder() .testName(name) - .allowHelpersForClasses(Arrays.>asList(org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer.class, + .allowHelpersForClasses(Arrays.asList(org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer.class, org.deeplearning4j.nn.layers.convolution.ConvolutionLayer.class)) .testForward(true) .testScore(true) @@ -179,7 +179,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { netWithout.init(); LayerHelperValidationUtil.TestCase tc = LayerHelperValidationUtil.TestCase.builder() - .allowHelpersForClasses(Collections.>singletonList(org.deeplearning4j.nn.layers.normalization.BatchNormalization.class)) + .allowHelpersForClasses(Collections.singletonList(org.deeplearning4j.nn.layers.normalization.BatchNormalization.class)) .testForward(true) .testScore(true) .testBackward(true) @@ -252,7 +252,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { netWithout.init(); LayerHelperValidationUtil.TestCase tc = LayerHelperValidationUtil.TestCase.builder() - .allowHelpersForClasses(Collections.>singletonList(org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization.class)) + .allowHelpersForClasses(Collections.singletonList(org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization.class)) .testForward(true) .testScore(true) .testBackward(true) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index c8e758feb..056f4a43e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -288,9 +288,9 @@ public class MultiLayerTest extends BaseDL4JTest { log.info("Testing full cycle..."); - List comparableResult = model.feedForward(Nd4j.create(trainingData[0], new long[]{1, trainingData[0].length})); + List comparableResult = model.feedForward(Nd4j.create(trainingData[0], 1, trainingData[0].length)); - INDArray encodeResult = model.activateSelectedLayers(0, 4, Nd4j.create(trainingData[0], new long[]{1, trainingData[0].length})); + INDArray encodeResult = model.activateSelectedLayers(0, 4, Nd4j.create(trainingData[0], 1, trainingData[0].length)); log.info("Compare feedForward results with selectedActivation"); @@ -541,8 +541,8 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, new long[]{1, 4}); - INDArray out = Nd4j.create(new double[] {1, 0, 0}, new long[]{1,3}); + INDArray in = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, 1, 4); + INDArray out = Nd4j.create(new double[] {1, 0, 0}, 1,3); double score = net.score(new DataSet(in, out)); } @@ -599,8 +599,8 @@ public class MultiLayerTest extends BaseDL4JTest { testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")); String actualLables = testData.getLabelName(0); List prediction = net.predict(testData); - assertTrue(actualLables != null); - assertTrue(prediction.get(0) != null); + assertNotNull(actualLables); + assertNotNull(prediction.get(0)); } @Test @@ -611,7 +611,7 @@ public class MultiLayerTest extends BaseDL4JTest { Environment environment = EnvironmentUtils.buildEnvironment(); environment.setSerialVersionID(EnvironmentUtils.buildCId()); - Task task = TaskUtils.buildTask(Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, new long[]{1,6})); + Task task = TaskUtils.buildTask(Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, 1,6)); Heartbeat.getInstance().reportEvent(Event.STANDALONE, environment, task); @@ -710,7 +710,7 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray inputWrongDepth = Nd4j.rand(new int[]{miniBatch, 5, height, width}); //Order: examples, channels, height, width + INDArray inputWrongDepth = Nd4j.rand(miniBatch, 5, height, width); //Order: examples, channels, height, width net.feedForward(inputWrongDepth); }); } @@ -1419,7 +1419,7 @@ public class MultiLayerTest extends BaseDL4JTest { INDArray bb1 = ((Yolo2OutputLayer)conf.getConf(1).getLayer()).getBoundingBoxes(); INDArray bb2 = ((Yolo2OutputLayer)conf2.getConf(1).getLayer()).getBoundingBoxes(); - assertFalse(bb1 == bb2); + assertNotSame(bb1, bb2); assertEquals(bb1, bb2); } @@ -1475,8 +1475,8 @@ public class MultiLayerTest extends BaseDL4JTest { soFar += 3*2; INDArray m1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(3); //m1b soFar += 2; - INDArray m2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2*1)).assign(4); //m2w - soFar += 2*1; + INDArray m2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+ 2)).assign(4); //m2w + soFar += 2; INDArray m2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(5); //m2b soFar += 1; @@ -1488,8 +1488,8 @@ public class MultiLayerTest extends BaseDL4JTest { soFar += 3*2; INDArray v1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(9); //v1b soFar += 2; - INDArray v2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2*1)).assign(10); //v2w - soFar += 2*1; + INDArray v2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+ 2)).assign(10); //v2w + soFar += 2; INDArray v2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(11); //v2b soFar += 1; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java index 6f1b3f732..5064e44ab 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java @@ -86,7 +86,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { assertTrue(layer instanceof GravesLSTM); Map paramTable = layer.paramTable(); - assertTrue(paramTable.size() == 3); //2 sets of weights, 1 set of biases + assertEquals(3, paramTable.size()); //2 sets of weights, 1 set of biases INDArray recurrentWeights = paramTable.get(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); assertArrayEquals(recurrentWeights.shape(), new long[] {nHiddenUnits, 4 * nHiddenUnits + 3}); //Should be shape: [layerSize,4*layerSize+3] @@ -104,7 +104,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { assertEquals(nHiddenUnits, count); val nParams = recurrentWeights.length() + inputWeights.length() + biases.length(); - assertTrue(nParams == layer.numParams()); + assertEquals(nParams, layer.numParams()); } @Test @@ -131,7 +131,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { assertTrue(layer instanceof GravesLSTM); Map paramTable = layer.paramTable(); - assertTrue(paramTable.size() == 3); //2 sets of weights, 1 set of biases + assertEquals(3, paramTable.size()); //2 sets of weights, 1 set of biases int layerNIn = (i == 0 ? nIn : nHiddenUnits[i - 1]); @@ -151,7 +151,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { assertEquals(nHiddenUnits[i], (int)count); val nParams = recurrentWeights.length() + inputWeights.length() + biases.length(); - assertTrue(nParams == layer.numParams()); + assertEquals(nParams, layer.numParams()); } } @@ -181,20 +181,20 @@ public class MultiLayerTestRNN extends BaseDL4JTest { .build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); - INDArray input = Nd4j.rand(new int[] {3, 5, timeSeriesLength}); + INDArray input = Nd4j.rand(3, 5, timeSeriesLength); List allOutputActivations = mln.feedForward(input, true); INDArray outAct = allOutputActivations.get(3); INDArray outRnnTimeStep = mln.rnnTimeStep(input); - assertTrue(outAct.equals(outRnnTimeStep)); //Should be identical here + assertEquals(outAct, outRnnTimeStep); //Should be identical here Map currStateL0 = mln.rnnGetPreviousState(0); Map currStateL1 = mln.rnnGetPreviousState(1); - assertTrue(currStateL0.size() == 2); - assertTrue(currStateL1.size() == 2); + assertEquals(2, currStateL0.size()); + assertEquals(2, currStateL1.size()); INDArray lastActL0 = currStateL0.get(GravesLSTM.STATE_KEY_PREV_ACTIVATION); INDArray lastMemL0 = currStateL0.get(GravesLSTM.STATE_KEY_PREV_MEMCELL); @@ -205,10 +205,10 @@ public class MultiLayerTestRNN extends BaseDL4JTest { assertTrue(lastActL1 != null && lastMemL1 != null); INDArray expectedLastActL0 = allOutputActivations.get(1).tensorAlongDimension(timeSeriesLength - 1, 1, 0); - assertTrue(expectedLastActL0.equals(lastActL0)); + assertEquals(expectedLastActL0, lastActL0); INDArray expectedLastActL1 = allOutputActivations.get(2).tensorAlongDimension(timeSeriesLength - 1, 1, 0); - assertTrue(expectedLastActL1.equals(lastActL1)); + assertEquals(expectedLastActL1, lastActL1); //Check clearing and setting of state: mln.rnnClearPreviousState(); @@ -216,9 +216,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest { assertTrue(mln.rnnGetPreviousState(1).isEmpty()); mln.rnnSetPreviousState(0, currStateL0); - assertTrue(mln.rnnGetPreviousState(0).size() == 2); + assertEquals(2, mln.rnnGetPreviousState(0).size()); mln.rnnSetPreviousState(1, currStateL1); - assertTrue(mln.rnnGetPreviousState(1).size() == 2); + assertEquals(2, mln.rnnGetPreviousState(1).size()); } @Test @@ -278,7 +278,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { .inputPreProcessor(3, new FeedForwardToRnnPreProcessor()).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); - INDArray input = Nd4j.rand(new int[]{3, 5, timeSeriesLength}); + INDArray input = Nd4j.rand(3, 5, timeSeriesLength); List allOutputActivations = mln.feedForward(input, true); INDArray fullOutL0 = allOutputActivations.get(1); @@ -311,7 +311,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { NDArrayIndex.interval(startTimeRange, endTimeRange)); } if (inLength > 1) - assertTrue(inputSubset.size(2) == inLength); + assertEquals(inputSubset.size(2), inLength); INDArray out = mln.rnnTimeStep(inputSubset); @@ -389,12 +389,12 @@ public class MultiLayerTestRNN extends BaseDL4JTest { //Check same but for input of size [3,5,1]. Expect [3,4,1] out mln.rnnClearPreviousState(); for (int i = 0; i < timeSeriesLength; i++) { - INDArray temp = Nd4j.create(new int[] {3, 5, 1}); + INDArray temp = Nd4j.create(3, 5, 1); temp.tensorAlongDimension(0, 1, 0).assign(input3d.tensorAlongDimension(i, 1, 0)); INDArray out3dSlice = mln.rnnTimeStep(temp); assertArrayEquals(out3dSlice.shape(), new long[] {3, 4, 1}); - assertTrue(out3dSlice.tensorAlongDimension(0, 1, 0).equals(out3d.tensorAlongDimension(i, 1, 0))); + assertEquals(out3dSlice.tensorAlongDimension(0, 1, 0), out3d.tensorAlongDimension(i, 1, 0)); } } @@ -460,8 +460,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { assertEquals(timeSeriesLength, mlnTBPTT.getLayerWiseConfigurations().getTbpttFwdLength()); assertEquals(timeSeriesLength, mlnTBPTT.getLayerWiseConfigurations().getTbpttBackLength()); - INDArray inputData = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); - INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength}); + INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); + INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); mln.setInput(inputData); mln.setLabels(labels); @@ -542,7 +542,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - INDArray inputLong = Nd4j.rand(new int[] {miniBatchSize, nIn, nTimeSlices * timeSeriesLength}); + INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, nTimeSlices * timeSeriesLength); INDArray input = inputLong.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, timeSeriesLength)); @@ -624,8 +624,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - INDArray inputLong = Nd4j.rand(new int[] {miniBatchSize, nIn, nTimeSlices * timeSeriesLength}); - INDArray labelsLong = Nd4j.rand(new int[] {miniBatchSize, nOut, nTimeSlices * timeSeriesLength}); + INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, nTimeSlices * timeSeriesLength); + INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, nTimeSlices * timeSeriesLength); mln.fit(inputLong, labelsLong); } @@ -661,8 +661,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - INDArray features = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); - INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength}); + INDArray features = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); + INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); INDArray maskArrayInput = Nd4j.ones(miniBatchSize, timeSeriesLength); INDArray maskArrayOutput = Nd4j.ones(miniBatchSize, timeSeriesLength); @@ -743,8 +743,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - INDArray features = Nd4j.rand(new int[] {miniBatchSize, nIn, timeSeriesLength}); - INDArray labels = Nd4j.rand(new int[] {miniBatchSize, nOut, timeSeriesLength}); + INDArray features = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); + INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); INDArray maskArrayInput = Nd4j.ones(miniBatchSize, timeSeriesLength); INDArray maskArrayOutput = Nd4j.ones(miniBatchSize, timeSeriesLength); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java index 420417296..c4c3067a9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java @@ -105,7 +105,7 @@ public class TestMasking extends BaseDL4JTest { int nIn = 6; int layerSize = 4; - INDArray mask1 = Nd4j.create(new double[] {1, 0, 0, 1, 0}, new long[]{1,5}); + INDArray mask1 = Nd4j.create(new double[] {1, 0, 0, 1, 0}, 1,5); INDArray mask3 = Nd4j.create(new double[][] {{1, 1, 1, 1, 1}, {0, 1, 0, 1, 0}, {1, 0, 0, 1, 1}}); INDArray[] labelMasks = new INDArray[] {mask1, mask3}; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java index ff5efa35a..cb9536e3d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java @@ -63,7 +63,7 @@ public class TestSetGetParameters extends BaseDL4JTest { Map initParams2After = net.paramTable(); for (String s : initParams2.keySet()) { - assertTrue(initParams2.get(s).equals(initParams2After.get(s)), "Params differ: " + s); + assertEquals(initParams2.get(s), initParams2After.get(s), "Params differ: " + s); } assertEquals(initParams, initParamsAfter); @@ -100,7 +100,7 @@ public class TestSetGetParameters extends BaseDL4JTest { Map initParams2After = net.paramTable(); for (String s : initParams2.keySet()) { - assertTrue( initParams2.get(s).equals(initParams2After.get(s)), "Params differ: " + s); + assertEquals(initParams2.get(s), initParams2After.get(s), "Params differ: " + s); } assertEquals(initParams, initParamsAfter); @@ -141,8 +141,8 @@ public class TestSetGetParameters extends BaseDL4JTest { assertEquals(params, net2.params()); assertEquals(params, net3.params()); - assertFalse(params == net2.params()); //Different objects due to clone - assertTrue(params == net3.params()); //Same object due to clone + assertNotSame(params, net2.params()); //Different objects due to clone + assertSame(params, net3.params()); //Same object due to clone Map paramsMap = net.paramTable(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java index 5212865f6..5d5daed14 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java @@ -83,14 +83,14 @@ public class TestVariableLengthTS extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in1 = Nd4j.rand(new int[] {nExamples, 2, 4}); - INDArray in2 = Nd4j.rand(new int[] {nExamples, 2, 5}); + INDArray in1 = Nd4j.rand(nExamples, 2, 4); + INDArray in2 = Nd4j.rand(nExamples, 2, 5); in2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)}, in1); assertEquals(in1, in2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4))); - INDArray labels1 = Nd4j.rand(new int[] {nExamples, 1, 4}); + INDArray labels1 = Nd4j.rand(nExamples, 1, 4); INDArray labels2 = Nd4j.create(nExamples, 1, 5); labels2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)}, labels1); @@ -176,14 +176,14 @@ public class TestVariableLengthTS extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray in1 = Nd4j.rand(new int[] {nExamples, 2, 4}); - INDArray in2 = Nd4j.rand(new int[] {nExamples, 2, 5}); + INDArray in1 = Nd4j.rand(nExamples, 2, 4); + INDArray in2 = Nd4j.rand(nExamples, 2, 5); in2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)}, in1); assertEquals(in1, in2.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 4))); - INDArray labels1 = Nd4j.rand(new int[] {nExamples, 1, 4}); + INDArray labels1 = Nd4j.rand(nExamples, 1, 4); INDArray labels2 = Nd4j.create(nExamples, 1, 5); labels2.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3, true)}, labels1); @@ -302,7 +302,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { } } - INDArray input = Nd4j.rand(new int[] {miniBatch, nIn, tsLength}); + INDArray input = Nd4j.rand(miniBatch, nIn, tsLength); INDArray labels = Nd4j.ones(miniBatch, nOut, tsLength); MultiLayerConfiguration conf = @@ -366,7 +366,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { } } - INDArray input = Nd4j.rand(new int[] {miniBatch, nIn, tsLength}); + INDArray input = Nd4j.rand(miniBatch, nIn, tsLength); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list() @@ -455,8 +455,8 @@ public class TestVariableLengthTS extends BaseDL4JTest { int tsLength = 5; int minibatch = 3; - INDArray input = Nd4j.rand(new int[] {minibatch, nIn, tsLength}); - INDArray labels = Nd4j.rand(new int[] {minibatch, nOut, tsLength}); + INDArray input = Nd4j.rand(minibatch, nIn, tsLength); + INDArray labels = Nd4j.rand(minibatch, nOut, tsLength); INDArray featuresMask = Nd4j.create(new double[][] {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 0}, {1, 1, 1, 0, 0}}); INDArray labelsMask = featuresMask.dup(); @@ -537,8 +537,8 @@ public class TestVariableLengthTS extends BaseDL4JTest { int tsLength = 5; int minibatch = 3; - INDArray input = Nd4j.rand(new int[] {minibatch, nIn, tsLength}); - INDArray labels = Nd4j.rand(new int[] {minibatch, nOut}); + INDArray input = Nd4j.rand(minibatch, nIn, tsLength); + INDArray labels = Nd4j.rand(minibatch, nOut); INDArray featuresMask = Nd4j.create(new double[][] {{1, 1, 1, 1, 1}, {1, 1, 1, 1, 0}, {1, 1, 1, 0, 0}}); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java index e98680c51..ecda6b48a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java @@ -81,8 +81,8 @@ public class TestFrozenLayers extends BaseDL4JTest { } for( int i=0; i<20; i++ ){ - INDArray f = Nd4j.rand(new int[]{16,1,28,28}); - INDArray l = Nd4j.rand(new int[]{16,10}); + INDArray f = Nd4j.rand(16,1,28,28); + INDArray l = Nd4j.rand(16,10); transfer.fit(f,l); } @@ -133,8 +133,8 @@ public class TestFrozenLayers extends BaseDL4JTest { } for( int i=0; i<20; i++ ){ - INDArray f = Nd4j.rand(new int[]{16,1,28,28}); - INDArray l = Nd4j.rand(new int[]{16,10}); + INDArray f = Nd4j.rand(16,1,28,28); + INDArray l = Nd4j.rand(16,10); transfer.fit(new INDArray[]{f},new INDArray[]{l}); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java index d9735fb89..462143897 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java @@ -340,7 +340,7 @@ public class TestUpdaters extends BaseDL4JTest { actualM[i] = Math.round(actualM[i] * 1e2) / 1e2; } - assertTrue( Arrays.equals(expectedM, actualM), "Wrong weight gradient after first iteration's update"); + assertArrayEquals(expectedM, actualM, "Wrong weight gradient after first iteration's update"); } @@ -592,7 +592,7 @@ public class TestUpdaters extends BaseDL4JTest { Updater updater = net.getUpdater(); assertNotNull(updater); - assertTrue(updater.getClass() == MultiLayerUpdater.class); + assertSame(updater.getClass(), MultiLayerUpdater.class); MultiLayerUpdater mlu = (MultiLayerUpdater) updater; @@ -695,7 +695,7 @@ public class TestUpdaters extends BaseDL4JTest { Updater newUpdater = UpdaterCreator.getUpdater(net); net.setUpdater(newUpdater); - assertTrue(newUpdater == net.getUpdater()); //Should be identical object + assertSame(newUpdater, net.getUpdater()); //Should be identical object } @Test @@ -722,7 +722,7 @@ public class TestUpdaters extends BaseDL4JTest { Updater newUpdater = UpdaterCreator.getUpdater(net); net.setUpdater(newUpdater); - assertTrue(newUpdater == net.getUpdater()); //Should be identical object + assertSame(newUpdater, net.getUpdater()); //Should be identical object } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java index 703d56eb2..170c6bdc1 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java @@ -103,7 +103,7 @@ public class TestCustomUpdater extends BaseDL4JTest { net2.setLabels(labels); net1.computeGradientAndScore(); - net2.computeGradientAndScore();; + net2.computeGradientAndScore(); assertEquals(net1.getFlattenedGradients(), net2.getFlattenedGradients()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/util/TestDataSetConsumer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/util/TestDataSetConsumer.java index cbc94f1f2..54ea33fe7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/util/TestDataSetConsumer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/util/TestDataSetConsumer.java @@ -30,8 +30,8 @@ import java.util.concurrent.atomic.AtomicLong; public class TestDataSetConsumer { private DataSetIterator iterator; - private long delay; - private AtomicLong count = new AtomicLong(0); + private final long delay; + private final AtomicLong count = new AtomicLong(0); protected static final Logger logger = LoggerFactory.getLogger(TestDataSetConsumer.class); public TestDataSetConsumer(long delay) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index b17032fdd..5b7bec134 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -65,6 +65,7 @@ import java.util.Collection; import java.util.Collections; import java.util.Map; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; public class TestOptimizers extends BaseDL4JTest { @@ -123,7 +124,7 @@ public class TestOptimizers extends BaseDL4JTest { } double scoreAfter = network.score(ds); scores[i + 1] = scoreAfter; - assertTrue( !Double.isNaN(scoreAfter), "Score is NaN after optimization"); + assertFalse(Double.isNaN(scoreAfter), "Score is NaN after optimization"); assertTrue( scoreAfter <= score, "OA= " + oa + ", before= " + score + ", after= " + scoreAfter); score = scoreAfter; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java index 78dbb6d14..d9b7e86e5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java @@ -50,7 +50,7 @@ public class SmartFancyBlockingQueueTest extends BaseDL4JTest { for (int e = 0; e < 6; e++) { queue.put(Nd4j.create(5, 5).assign(e)); - }; + } assertEquals(6, queue.size()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java index 8d3b3751e..a7ac0905b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java @@ -34,18 +34,18 @@ public class ScoreStatTest extends BaseDL4JTest { public void testScoreStatSmall() { CollectScoresIterationListener.ScoreStat statTest = new CollectScoresIterationListener.ScoreStat(); for (int i = 0; i < CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH; ++i) { - double score = (double)i; + double score = i; statTest.addScore(i, score); } List indexes = statTest.getIndexes(); List scores = statTest.getScores(); - assertTrue(indexes.size() == 1); - assertTrue(scores.size() == 1); + assertEquals(1, indexes.size()); + assertEquals(1, scores.size()); - assertTrue(indexes.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH); - assertTrue(scores.get(0).length == CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH); + assertEquals(CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH, indexes.get(0).length); + assertEquals(CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH, scores.get(0).length); assertEquals(indexes.get(0)[indexes.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1); assertEquals(scores.get(0)[scores.get(0).length-1], CollectScoresIterationListener.ScoreStat.BUCKET_LENGTH-1, 1e-4); } @@ -109,12 +109,12 @@ public class ScoreStatTest extends BaseDL4JTest { List indexes = statTest.getIndexes(); List scores = statTest.getScores(); - assertTrue(indexes.size() == 2); - assertTrue(scores.size() == 2); + assertEquals(2, indexes.size()); + assertEquals(2, scores.size()); for (int i = 0; i < 5; ++i) { - assertTrue(indexes.get(1)[i] == Integer.MAX_VALUE + i); - assertTrue(scores.get(1)[i] == Integer.MAX_VALUE + i); + assertEquals(indexes.get(1)[i], Integer.MAX_VALUE + i); + assertEquals(scores.get(1)[i], Integer.MAX_VALUE + i); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java index 47430c8c3..55b1d39c8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java @@ -268,7 +268,7 @@ public class TestListeners extends BaseDL4JTest { assertEquals(exp, tl.getCalls()); } - private static enum Call { + private enum Call { ITER_DONE, EPOCH_START, EPOCH_END, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ArrayUtilTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ArrayUtilTest.java index 696faf3f9..edd398223 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ArrayUtilTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ArrayUtilTest.java @@ -27,6 +27,7 @@ import org.nd4j.common.util.ArrayUtil; import java.util.Arrays; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; /** * @@ -37,11 +38,11 @@ public class ArrayUtilTest extends BaseDL4JTest { public void testRange() { int[] range = ArrayUtil.range(0, 2); int[] test = {0, 1}; - assertEquals(true, Arrays.equals(test, range)); + assertTrue(Arrays.equals(test, range)); int[] test2 = {-1, 0}; int[] range2 = ArrayUtil.range(-1, 1); - assertEquals(true, Arrays.equals(test2, range2)); + assertTrue(Arrays.equals(test2, range2)); } @@ -52,16 +53,16 @@ public class ArrayUtilTest extends BaseDL4JTest { int[] fortranStyleStride = {1, 5, 20}; int[] fortranStyleTest = ArrayUtil.calcStridesFortran(shape); int[] cStyleTest = ArrayUtil.calcStrides(shape); - assertEquals(true, Arrays.equals(cStyleStride, cStyleTest)); - assertEquals(true, Arrays.equals(fortranStyleStride, fortranStyleTest)); + assertTrue(Arrays.equals(cStyleStride, cStyleTest)); + assertTrue(Arrays.equals(fortranStyleStride, fortranStyleTest)); int[] shape2 = {2, 2}; int[] cStyleStride2 = {2, 1}; int[] fortranStyleStride2 = {1, 2}; int[] cStyleTest2 = ArrayUtil.calcStrides(shape2); int[] fortranStyleTest2 = ArrayUtil.calcStridesFortran(shape2); - assertEquals(true, Arrays.equals(cStyleStride2, cStyleTest2)); - assertEquals(true, Arrays.equals(fortranStyleStride2, fortranStyleTest2)); + assertTrue(Arrays.equals(cStyleStride2, cStyleTest2)); + assertTrue(Arrays.equals(fortranStyleStride2, fortranStyleTest2)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java index 74fbd476a..2ff1c481d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java @@ -99,7 +99,7 @@ public class ModelGuesserTest extends BaseDL4JTest { ModelSerializer.writeModel(net, tempFile, true); NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); - normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2}))); + normalizer.fit(new DataSet(Nd4j.rand(2, 2), Nd4j.rand(2, 2))); ModelSerializer.addNormalizerToModel(tempFile, normalizer); Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); Normalizer normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath()); @@ -116,7 +116,7 @@ public class ModelGuesserTest extends BaseDL4JTest { File tempFile = new File(testDir, "testNormalizerInPlace.bin"); NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); - normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2}))); + normalizer.fit(new DataSet(Nd4j.rand(2, 2), Nd4j.rand(2, 2))); ModelSerializer.writeModel(net, tempFile, true,normalizer); Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); @@ -135,7 +135,7 @@ public class ModelGuesserTest extends BaseDL4JTest { ModelSerializer.writeModel(net, tempFile, true); NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); - normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2}))); + normalizer.fit(new DataSet(Nd4j.rand(2, 2), Nd4j.rand(2, 2))); ModelSerializer.addNormalizerToModel(tempFile, normalizer); Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); try (InputStream inputStream = new FileInputStream(tempFile)) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java index d1b1c3e02..610cb0961 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java @@ -248,7 +248,7 @@ public class ModelSerializerTest extends BaseDL4JTest { NormalizerStandardize restored = ModelSerializer.restoreNormalizerFromInputStream(fis); - assertEquals(null, restored); + assertNull(restored); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/TestUIDProvider.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/TestUIDProvider.java index a23f3d513..363b90dd6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/TestUIDProvider.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/TestUIDProvider.java @@ -24,10 +24,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.util.UIDProvider; import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class TestUIDProvider extends BaseDL4JTest { @@ -40,8 +37,8 @@ public class TestUIDProvider extends BaseDL4JTest { assertNotNull(jvmUID); assertNotNull(hardwareUID); - assertTrue(!jvmUID.isEmpty()); - assertTrue(!hardwareUID.isEmpty()); + assertFalse(jvmUID.isEmpty()); + assertFalse(hardwareUID.isEmpty()); assertEquals(jvmUID, UIDProvider.getJVMUID()); assertEquals(hardwareUID, UIDProvider.getHardwareUID()); diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/base/IrisUtils.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/base/IrisUtils.java index 74c2208f0..d48d30586 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/base/IrisUtils.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/base/IrisUtils.java @@ -74,7 +74,7 @@ public class IrisUtils { } for (int i = 0; i < ret.rows(); i++) { - DataSet add = new DataSet(ret.getRow(i, true), Nd4j.create(outcomes[from + i], new long[]{1,3})); + DataSet add = new DataSet(ret.getRow(i, true), Nd4j.create(outcomes[from + i], 1,3)); list.add(add); } return list; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/EmnistDataFetcher.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/EmnistDataFetcher.java index 70d974e99..d02749095 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/EmnistDataFetcher.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/EmnistDataFetcher.java @@ -86,11 +86,7 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche //For some inexplicable reason, EMNIST LETTERS set is indexed 1 to 26 (i.e., 1 to nClasses), while everything else // is indexed (0 to nClasses-1) :/ - if (dataSet == EmnistDataSetIterator.Set.LETTERS) { - oneIndexed = true; - } else { - oneIndexed = false; - } + oneIndexed = dataSet == EmnistDataSetIterator.Set.LETTERS; this.fOrder = true; //MNIST is C order, EMNIST is F order } @@ -107,8 +103,6 @@ public class EmnistDataFetcher extends MnistDataFetcher implements DataSetFetche if (!f.exists()) return false; f = new File(EMNIST_ROOT, e.getTestFileLabelsFilename_unzipped()); - if (!f.exists()) - return false; - return true; + return f.exists(); } } diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java index be1dd952e..59deb860f 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/fetchers/MnistDataFetcher.java @@ -142,9 +142,7 @@ public class MnistDataFetcher extends BaseDataFetcher { if (!f.exists()) return false; f = new File(MNIST_ROOT, MnistFetcher.TEST_FILE_LABELS_FILENAME_UNZIPPED); - if (!f.exists()) - return false; - return true; + return f.exists(); } private void validateFiles(String[] files, long[] checksums){ diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/mnist/MnistDbFile.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/mnist/MnistDbFile.java index 5fd53de81..30791e129 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/mnist/MnistDbFile.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/mnist/MnistDbFile.java @@ -26,7 +26,7 @@ import java.io.IOException; import java.io.RandomAccessFile; public abstract class MnistDbFile extends RandomAccessFile { - private int count; + private final int count; /** diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/mnist/MnistImageFile.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/mnist/MnistImageFile.java index 196352e84..2a1a7a2b8 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/mnist/MnistImageFile.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/mnist/MnistImageFile.java @@ -26,8 +26,8 @@ import java.io.IOException; public class MnistImageFile extends MnistDbFile { - private int rows; - private int cols; + private final int rows; + private final int cols; /** * Creates new MNIST database image file ready for reading. diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java index c0c89f00b..34731a8ac 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java @@ -63,13 +63,13 @@ public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, S EQUAL_LENGTH, ALIGN_START, ALIGN_END } - private int batchSize; - private AlignmentMode alignmentMode; + private final int batchSize; + private final AlignmentMode alignmentMode; private Map recordReaders = new HashMap<>(); private Map sequenceRecordReaders = new HashMap<>(); - private List inputs = new ArrayList<>(); - private List outputs = new ArrayList<>(); + private final List inputs = new ArrayList<>(); + private final List outputs = new ArrayList<>(); @Getter @Setter @@ -775,13 +775,13 @@ public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, S public static class Builder { - private int batchSize; + private final int batchSize; private AlignmentMode alignmentMode = AlignmentMode.ALIGN_START; - private Map recordReaders = new HashMap<>(); - private Map sequenceRecordReaders = new HashMap<>(); + private final Map recordReaders = new HashMap<>(); + private final Map sequenceRecordReaders = new HashMap<>(); - private List inputs = new ArrayList<>(); - private List outputs = new ArrayList<>(); + private final List inputs = new ArrayList<>(); + private final List outputs = new ArrayList<>(); private boolean timeSeriesRandomOffset = false; private long timeSeriesRandomOffsetSeed = System.currentTimeMillis(); diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java index 4ad4ffb48..dbebd0642 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/SequenceRecordReaderDataSetIterator.java @@ -64,8 +64,8 @@ public class SequenceRecordReaderDataSetIterator implements DataSetIterator { private static final String READER_KEY = "reader"; private static final String READER_KEY_LABEL = "reader_labels"; - private SequenceRecordReader recordReader; - private SequenceRecordReader labelsReader; + private final SequenceRecordReader recordReader; + private final SequenceRecordReader labelsReader; private int miniBatchSize = 10; private final boolean regression; private int labelIndex = -1; @@ -288,7 +288,7 @@ public class SequenceRecordReaderDataSetIterator implements DataSetIterator { fm = RecordReaderDataSetIterator.getOrNull(mds.getFeaturesMaskArrays(), 0); //Per-example masking only on the input -> same for both //Can assume 3d features here - f = Nd4j.createUninitialized(new long[] {f1.size(0), f1.size(1) + f2.size(1), f1.size(2)}); + f = Nd4j.createUninitialized(f1.size(0), f1.size(1) + f2.size(1), f1.size(2)); f.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(0, f1.size(1)), NDArrayIndex.all()}, f1); f.put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.interval(f1.size(1), f1.size(1) + f2.size(1)), diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIterator.java index 64959c6f3..fbcb248a0 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AbstractDataSetIterator.java @@ -36,7 +36,7 @@ import java.util.concurrent.LinkedBlockingQueue; public abstract class AbstractDataSetIterator implements DataSetIterator { private DataSetPreProcessor preProcessor; - private transient Iterable> iterable; + private final transient Iterable> iterable; private transient Iterator> iterator; private final int batchSize; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AsyncShieldDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AsyncShieldDataSetIterator.java index a8cef412e..d2bbb9ebb 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AsyncShieldDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AsyncShieldDataSetIterator.java @@ -29,7 +29,7 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import java.util.List; public class AsyncShieldDataSetIterator implements DataSetIterator { - private DataSetIterator backingIterator; + private final DataSetIterator backingIterator; /** * @param iterator Iterator to wrop, to disable asynchronous prefetching for diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AsyncShieldMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AsyncShieldMultiDataSetIterator.java index 947cc58ac..4129113f9 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AsyncShieldMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/AsyncShieldMultiDataSetIterator.java @@ -26,7 +26,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; public class AsyncShieldMultiDataSetIterator implements MultiDataSetIterator { - private MultiDataSetIterator backingIterator; + private final MultiDataSetIterator backingIterator; public AsyncShieldMultiDataSetIterator(@NonNull MultiDataSetIterator iterator) { this.backingIterator = iterator; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/CombinedMultiDataSetPreProcessor.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/CombinedMultiDataSetPreProcessor.java index 84e748430..d1d1a07b3 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/CombinedMultiDataSetPreProcessor.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/CombinedMultiDataSetPreProcessor.java @@ -43,7 +43,7 @@ public class CombinedMultiDataSetPreProcessor implements MultiDataSetPreProcesso } public static class Builder { - private List preProcessors = new ArrayList<>(); + private final List preProcessors = new ArrayList<>(); public Builder() { diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/CombinedPreProcessor.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/CombinedPreProcessor.java index ec225b248..d3f1d47a0 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/CombinedPreProcessor.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/CombinedPreProcessor.java @@ -47,7 +47,7 @@ public class CombinedPreProcessor implements DataSetPreProcessor { } public static class Builder { - private List preProcessors = new ArrayList<>(); + private final List preProcessors = new ArrayList<>(); public Builder() { diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.java index a1b183565..1c20f1bf5 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/DataSetIteratorSplitter.java @@ -231,10 +231,7 @@ public class DataSetIteratorSplitter { } val state = backedIterator.hasNext(); - if (state && counter.get() < numTrain) - return true; - else - return false; + return state && counter.get() < numTrain; } @Override @@ -325,10 +322,7 @@ public class DataSetIteratorSplitter { @Override public boolean hasNext() { val state = backedIterator.hasNext(); - if (state && counter.get() < numTrain + numTest) - return true; - else - return false; + return state && counter.get() < numTrain + numTest; } @Override diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIterator.java index c3575fe40..c22f3b0e3 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/EarlyTerminationDataSetIterator.java @@ -28,8 +28,8 @@ import java.util.List; public class EarlyTerminationDataSetIterator implements DataSetIterator { - private DataSetIterator underlyingIterator; - private int terminationPoint; + private final DataSetIterator underlyingIterator; + private final int terminationPoint; private int minibatchCount = 0; /** diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIterator.java index 9284a26d2..18814a644 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/EarlyTerminationMultiDataSetIterator.java @@ -26,8 +26,8 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; public class EarlyTerminationMultiDataSetIterator implements MultiDataSetIterator { - private MultiDataSetIterator underlyingIterator; - private int terminationPoint; + private final MultiDataSetIterator underlyingIterator; + private final int terminationPoint; private int minibatchCount = 0; /** diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/FileSplitDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/FileSplitDataSetIterator.java index 54f576439..ca059e0b3 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/FileSplitDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/FileSplitDataSetIterator.java @@ -35,10 +35,10 @@ import java.util.concurrent.atomic.AtomicInteger; public class FileSplitDataSetIterator implements DataSetIterator { private DataSetPreProcessor preProcessor; - private List files; - private int numFiles; - private AtomicInteger counter = new AtomicInteger(0); - private FileCallback callback; + private final List files; + private final int numFiles; + private final AtomicInteger counter = new AtomicInteger(0); + private final FileCallback callback; /** * @param files List of files to iterate over diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.java index 128113c9d..3d0769bac 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultiDataSetIteratorSplitter.java @@ -197,10 +197,7 @@ public class MultiDataSetIteratorSplitter { } val state = backedIterator.hasNext(); - if (state && counter.get() < numTrain) - return true; - else - return false; + return state && counter.get() < numTrain; } @Override @@ -272,10 +269,7 @@ public class MultiDataSetIteratorSplitter { @Override public boolean hasNext() { val state = backedIterator.hasNext(); - if (state && counter.get() < numTrain + numTest) - return true; - else - return false; + return state && counter.get() < numTrain + numTest; } @Override diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java index 28f69c7bf..85b0c2dd2 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/MultipleEpochsIterator.java @@ -231,7 +231,7 @@ public class MultipleEpochsIterator implements DataSetIterator { newEpoch = false; } if (iter == null) - return (epochs < numEpochs) && ((!batchedDS.isEmpty() && batchedDS.size() > batch) || batchedDS.isEmpty()); + return (epochs < numEpochs) && (batchedDS.isEmpty() || batchedDS.size() > batch); else // either there are still epochs to complete or its the first epoch return (epochs < numEpochs) || (iter.hasNext() && (epochs == 0 || epochs == numEpochs)); diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/RandomDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/RandomDataSetIterator.java index 78479e2c2..35cbc5b43 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/RandomDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/RandomDataSetIterator.java @@ -29,7 +29,7 @@ public class RandomDataSetIterator extends MultiDataSetWrapperIterator { public RandomMultiDataSetIterator.Values toMdsValues(){ return RandomMultiDataSetIterator.Values.valueOf(this.toString()); } - }; + } /** * @param numMiniBatches Number of minibatches per epoch diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/RandomMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/RandomMultiDataSetIterator.java index 8ce205878..a4a5bbf44 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/RandomMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/RandomMultiDataSetIterator.java @@ -119,9 +119,9 @@ public class RandomMultiDataSetIterator implements MultiDataSetIterator { public static class Builder { - private int numMiniBatches; - private List> features = new ArrayList<>(); - private List> labels = new ArrayList<>(); + private final int numMiniBatches; + private final List> features = new ArrayList<>(); + private final List> labels = new ArrayList<>(); /** * @param numMiniBatches Number of minibatches per epoch diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ReconstructionDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ReconstructionDataSetIterator.java index 8fb538d7f..b41008b97 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ReconstructionDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ReconstructionDataSetIterator.java @@ -35,7 +35,7 @@ import java.util.List; */ public class ReconstructionDataSetIterator implements DataSetIterator { - private DataSetIterator iter; + private final DataSetIterator iter; @Getter private DataSetPreProcessor preProcessor; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableDataSetIterator.java index 91d9579c9..074297e77 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableDataSetIterator.java @@ -43,7 +43,7 @@ public class ScrollableDataSetIterator implements DataSetIterator { protected MultiDataSet firstMultiTrain = null; private double ratio; private long totalExamples; - private long itemsPerPart; + private final long itemsPerPart; private long current; @@ -152,10 +152,7 @@ public class ScrollableDataSetIterator implements DataSetIterator { state = backedIterator.hasNext(); if (!state) return false; - if (state && counter.get() < itemsPerPart) - return true; - else - return false; + return state && counter.get() < itemsPerPart; } diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableMultiDataSetIterator.java index a2ba36f36..b5b4377c1 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/ScrollableMultiDataSetIterator.java @@ -45,7 +45,7 @@ public class ScrollableMultiDataSetIterator implements MultiDataSetIterator { protected MultiDataSet firstMultiTrain = null; private double ratio; private long totalExamples; - private long itemsPerPart; + private final long itemsPerPart; private long current; public ScrollableMultiDataSetIterator(int num, MultiDataSetIterator backedIterator, AtomicLong counter, @@ -110,10 +110,7 @@ public class ScrollableMultiDataSetIterator implements MultiDataSetIterator { state = backedIterator.hasNext(); if (!state) return false; - if (state && counter.get() < itemsPerPart) - return true; - else - return false; + return state && counter.get() < itemsPerPart; } diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/callbacks/InterleavedDataSetCallback.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/callbacks/InterleavedDataSetCallback.java index 39be1e600..d9e5d2f42 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/callbacks/InterleavedDataSetCallback.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/callbacks/InterleavedDataSetCallback.java @@ -37,13 +37,13 @@ import java.util.concurrent.atomic.AtomicLong; @Slf4j public class InterleavedDataSetCallback implements DataSetCallback { - private List workspaces = new ArrayList<>(); - private int bufferSize; + private final List workspaces = new ArrayList<>(); + private final int bufferSize; private int numWorkspaces; private boolean isInitialized = false; - private AtomicLong counterInput = new AtomicLong(0); + private final AtomicLong counterInput = new AtomicLong(0); public InterleavedDataSetCallback(int bufferSize) { this.bufferSize = bufferSize; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/BenchmarkDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/BenchmarkDataSetIterator.java index 7f4808fd2..03b9c1080 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/BenchmarkDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/BenchmarkDataSetIterator.java @@ -33,10 +33,10 @@ import java.util.concurrent.atomic.AtomicLong; @Slf4j public class BenchmarkDataSetIterator implements DataSetIterator { - private INDArray baseFeatures; - private INDArray baseLabels; - private long limit; - private AtomicLong counter = new AtomicLong(0); + private final INDArray baseFeatures; + private final INDArray baseLabels; + private final long limit; + private final AtomicLong counter = new AtomicLong(0); /** * @param featuresShape Shape of the features data to randomly generate diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/BenchmarkMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/BenchmarkMultiDataSetIterator.java index 60457c297..025309a31 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/BenchmarkMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/BenchmarkMultiDataSetIterator.java @@ -31,10 +31,10 @@ import java.util.concurrent.atomic.AtomicLong; @Slf4j public class BenchmarkMultiDataSetIterator implements MultiDataSetIterator { - private INDArray[] baseFeatures; - private INDArray[] baseLabels; - private long limit; - private AtomicLong counter = new AtomicLong(0); + private final INDArray[] baseFeatures; + private final INDArray[] baseLabels; + private final long limit; + private final AtomicLong counter = new AtomicLong(0); public BenchmarkMultiDataSetIterator(int[][] featuresShape, int[] numLabels, int totalIterations) { if (featuresShape.length != numLabels.length) diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/ListDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/ListDataSetIterator.java index ca95b334d..67ebc3a06 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/ListDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/impl/ListDataSetIterator.java @@ -34,7 +34,7 @@ public class ListDataSetIterator implements DataSetIterator { private static final long serialVersionUID = -7569201667767185411L; private int curr = 0; private int batch = 10; - private List list; + private final List list; @Getter private DataSetPreProcessor preProcessor; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/BaseParallelDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/BaseParallelDataSetIterator.java index 38bda75ad..1f1568909 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/BaseParallelDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/BaseParallelDataSetIterator.java @@ -105,10 +105,7 @@ public abstract class BaseParallelDataSetIterator implements ParallelDataSetIter return true; } case STOP_EVERYONE: { - if (!states.allTrue()) - return false; - - return true; + return states.allTrue(); } default: throw new ND4JIllegalStateException( diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java index 40dd67594..de6fac884 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/FileSplitParallelDataSetIterator.java @@ -26,7 +26,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.filefilter.IOFileFilter; import org.apache.commons.io.filefilter.RegexFileFilter; -import org.nd4j.linalg.dataset.AsyncDataSetIterator;; +import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.deeplearning4j.datasets.iterator.FileSplitDataSetIterator; import org.deeplearning4j.datasets.iterator.callbacks.FileCallback; import org.nd4j.linalg.dataset.DataSet; @@ -43,8 +43,8 @@ import java.util.List; public class FileSplitParallelDataSetIterator extends BaseParallelDataSetIterator { public static final String DEFAULT_PATTERN = "dataset-%d.bin"; - private String pattern; - private int buffer; + private final String pattern; + private final int buffer; protected List asyncIterators = new ArrayList<>(); diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/JointParallelDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/JointParallelDataSetIterator.java index 64cd04b89..ef255beb1 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/JointParallelDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/JointParallelDataSetIterator.java @@ -23,7 +23,7 @@ package org.deeplearning4j.datasets.iterator.parallel; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.nd4j.linalg.dataset.AsyncDataSetIterator;; +import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.enums.InequalityHandling; @@ -94,10 +94,10 @@ public class JointParallelDataSetIterator extends BaseParallelDataSetIterator { public static class Builder { - private List iterators = new ArrayList<>(); + private final List iterators = new ArrayList<>(); private boolean enforceSingleDevice = true; private int bufferSize = 4; - private InequalityHandling inequalityHandling; + private final InequalityHandling inequalityHandling; public Builder(@NonNull InequalityHandling inequalityHandling) { this.inequalityHandling = inequalityHandling; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/MultiBoolean.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/MultiBoolean.java index 61b58bd56..48dc20752 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/MultiBoolean.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-utility-iterators/src/main/java/org/deeplearning4j/datasets/iterator/parallel/MultiBoolean.java @@ -28,7 +28,7 @@ public class MultiBoolean { private final int numEntries; private int holder = 0; private int max = 0; - private boolean oneTime; + private final boolean oneTime; private MultiBoolean timeTracker; public MultiBoolean(int numEntries) { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java index bb8cd203f..29dbb8d88 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/Hdf5Archive.java @@ -59,8 +59,8 @@ public class Hdf5Archive implements Closeable { } } - private H5File file; - private static DataType dataType = new DataType(PredType.NATIVE_FLOAT()); + private final H5File file; + private static final DataType dataType = new DataType(PredType.NATIVE_FLOAT()); public Hdf5Archive(String archiveFilename) { synchronized (LOCK_OBJECT) { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java index 1c001c1fd..d4bf6ba92 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java @@ -276,7 +276,7 @@ public class KerasModel { sameDiffLambdaLayer.setLayerName(newName); KerasLambda kerasLambda = new KerasLambda(configCopy,sameDiffLambdaLayer); kerasLambda.layerName = newName; - kerasLambda.setInboundLayerNames(new ArrayList<>(Arrays.asList(input))); + kerasLambda.setInboundLayerNames(new ArrayList<>(Collections.singletonList(input))); layers.put(newName,kerasLambda); int indexOfNewLayer = names.indexOf(input) + 1; updatedOrders.put(indexOfNewLayer,kerasLambda); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java index 5d2d04923..8e30f72f2 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java @@ -43,8 +43,8 @@ import java.util.Map; public class TFOpLayer extends Layer { - private Map nodeDef; - private Map constants; + private final Map nodeDef; + private final Map constants; public TFOpLayer(Map nodeDef, Map constants){ super(); this.nodeDef = nodeDef; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java index ab88f2aa5..ba2b98db4 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java @@ -83,7 +83,7 @@ public class TFOpLayerImpl extends AbstractLayer { Map inputDataTypes = new HashMap<>(); Map constArrays = new HashMap(); this.inputNames = new ArrayList<>(); - List outputNames = Arrays.asList(nodeDef.getName()); + List outputNames = Collections.singletonList(nodeDef.getName()); Map attrMap = nodeDef.getAttrMap(); for (int i = 0; i < nodeDef.getInputCount(); i++){ String inputName = nodeDef.getInput(i); @@ -104,7 +104,7 @@ public class TFOpLayerImpl extends AbstractLayer { this.inputNames.add(nodeDef.getInput(i)); } } - String graph = "node{\n" + nodeDef.toString() + "\n}\nversions {\n producer: 22\n}"; + String graph = "node{\n" + nodeDef + "\n}\nversions {\n producer: 22\n}"; for (int i = 0; i < allInputNames.size(); i++){ String inpName = allInputNames.get(i); String dtype = inputDataTypes.get(inpName); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasPReLU.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasPReLU.java index 6853ba203..5a4bb1c55 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasPReLU.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activations/KerasPReLU.java @@ -102,7 +102,7 @@ public class KerasPReLU extends KerasLayer { int[] intAxes = ArrayUtil.toArray(axesList); axes = new long[intAxes.length]; for (int i = 0; i < intAxes.length; i++) { - axes[i] = (long) intAxes[i]; + axes[i] = intAxes[i]; } } catch (Exception e) { // no shared axes diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.java index e14d97bb1..264200686 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasDepthwiseConvolution2D.java @@ -73,7 +73,7 @@ public class KerasDepthwiseConvolution2D extends KerasConvolution { */ public KerasDepthwiseConvolution2D(Map layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - this(layerConfig, Collections.emptyMap(), true); + this(layerConfig, Collections.emptyMap(), true); } /** diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermute.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermute.java index bafdcc98e..6c0a9c7c7 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermute.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermute.java @@ -109,7 +109,7 @@ public class KerasPermute extends KerasLayer { case TENSORFLOW: // account for channels last permutationIndices = new int[] {permutationIndices[2], permutationIndices[0], permutationIndices[1]}; - preprocessor = new PermutePreprocessor(new int[]{1, 3, 2}); + preprocessor = new PermutePreprocessor(1, 3, 2); } } else if (inputType[0] instanceof InputType.InputTypeRecurrent) { if (Arrays.equals(permutationIndices, new int[] {2, 1})) diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index a2b30b92b..4e35a6867 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -123,7 +123,7 @@ public class KerasLSTM extends KerasLayer { */ public KerasLSTM(Map layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - this(layerConfig, enforceTrainingConfig, Collections.emptyMap()); + this(layerConfig, enforceTrainingConfig, Collections.emptyMap()); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index 60c25fe47..ac2d4c234 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -87,7 +87,7 @@ public class KerasSimpleRnn extends KerasLayer { */ public KerasSimpleRnn(Map layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - this(layerConfig, true, Collections.emptyMap()); + this(layerConfig, true, Collections.emptyMap()); } /** @@ -113,7 +113,7 @@ public class KerasSimpleRnn extends KerasLayer { */ public KerasSimpleRnn(Map layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - this(layerConfig, enforceTrainingConfig, Collections.emptyMap()); + this(layerConfig, enforceTrainingConfig, Collections.emptyMap()); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java index 1042ca244..fa5f5b508 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java @@ -63,7 +63,7 @@ public class KerasBidirectional extends KerasLayer { */ public KerasBidirectional(Map layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - this(layerConfig, true, Collections.emptyMap()); + this(layerConfig, true, Collections.emptyMap()); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java index 002ce6b57..085db4e37 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java @@ -89,9 +89,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor { int shapeLength = shape.length; val miniBatchShape = new long[shapeLength + 1]; miniBatchShape[0] = miniBatchSize; - for (int i = 1; i < miniBatchShape.length; i++) { - miniBatchShape[i] = shape[i - 1]; - } + System.arraycopy(shape, 0, miniBatchShape, 1, miniBatchShape.length - 1); return miniBatchShape; } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelBuilder.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelBuilder.java index 92691ddf6..c0392c66a 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelBuilder.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelBuilder.java @@ -118,7 +118,7 @@ public class KerasModelBuilder implements Cloneable, Closeable { public KerasModelBuilder modelJsonInputStream(InputStream modelJsonInputStream) throws IOException { ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); IOUtils.copy(modelJsonInputStream, byteArrayOutputStream); - this.modelJson = new String(byteArrayOutputStream.toByteArray()); + this.modelJson = byteArrayOutputStream.toString(); return this; } @@ -132,7 +132,7 @@ public class KerasModelBuilder implements Cloneable, Closeable { public KerasModelBuilder modelYamlInputStream(InputStream modelYamlInputStream) throws IOException { ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); IOUtils.copy(modelYamlInputStream, byteArrayOutputStream); - this.modelJson = new String(byteArrayOutputStream.toByteArray()); + this.modelJson = byteArrayOutputStream.toString(); return this; } @@ -197,7 +197,7 @@ public class KerasModelBuilder implements Cloneable, Closeable { public KerasModelBuilder trainingJsonInputStream(InputStream trainingJsonInputStream) throws IOException { ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); IOUtils.copy(trainingJsonInputStream, byteArrayOutputStream); - this.trainingJson = new String(byteArrayOutputStream.toByteArray()); + this.trainingJson = byteArrayOutputStream.toString(); return this; } @@ -210,7 +210,7 @@ public class KerasModelBuilder implements Cloneable, Closeable { public KerasModelBuilder trainingYamlInputStream(InputStream trainingYamlInputStream) throws IOException { ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); IOUtils.copy(trainingYamlInputStream, byteArrayOutputStream); - this.trainingYaml = new String(byteArrayOutputStream.toByteArray()); + this.trainingYaml = byteArrayOutputStream.toString(); return this; } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java index ad11282e5..43f3b244f 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java @@ -170,8 +170,10 @@ public class KerasModelUtils { // check to ensure naming scheme doesn't include forward slash boolean includesSlash = false; for (String layerName : layers.keySet()) { - if (layerName.contains("/")) + if (layerName.contains("/")) { includesSlash = true; + break; + } } synchronized (KerasModelUtils.class) { List layerGroups; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java index e29503c5f..815b2cf68 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/MiscTests.java @@ -111,7 +111,7 @@ public class MiscTests extends BaseDL4JTest { assertFalse(vr0.isValid()); assertEquals("Keras Sequential Model HDF5", vr0.getFormatType()); assertTrue(vr0.getIssues().get(0).contains("exist"), vr0.getIssues().get(0)); - System.out.println(vr0.toString()); + System.out.println(vr0); //Test empty file: File fEmpty = new File(f, "empty.h5"); @@ -121,7 +121,7 @@ public class MiscTests extends BaseDL4JTest { assertEquals("Keras Sequential Model HDF5", vr1.getFormatType()); assertFalse(vr1.isValid()); assertTrue(vr1.getIssues().get(0).contains("empty"), vr1.getIssues().get(0)); - System.out.println(vr1.toString()); + System.out.println(vr1); //Test directory (not zip file) File directory = new File(f, "dir"); @@ -131,7 +131,7 @@ public class MiscTests extends BaseDL4JTest { assertEquals("Keras Sequential Model HDF5", vr2.getFormatType()); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0).contains("directory"), vr2.getIssues().get(0)); - System.out.println(vr2.toString()); + System.out.println(vr2); //Test Keras HDF5 format: File fText = new File(f, "text.txt"); @@ -141,7 +141,7 @@ public class MiscTests extends BaseDL4JTest { assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); assertTrue( s.contains("Keras") && s.contains("Sequential") && s.contains("corrupt"), s); - System.out.println(vr3.toString()); + System.out.println(vr3); //Test corrupted npy format: File fValid = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); @@ -157,7 +157,7 @@ public class MiscTests extends BaseDL4JTest { assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); assertTrue(s.contains("Keras") && s.contains("Sequential") && s.contains("corrupt"), s); - System.out.println(vr4.toString()); + System.out.println(vr4); //Test valid npy format: @@ -166,7 +166,7 @@ public class MiscTests extends BaseDL4JTest { assertTrue(vr5.isValid()); assertNull(vr5.getIssues()); assertNull(vr5.getException()); - System.out.println(vr4.toString()); + System.out.println(vr4); } @Test @@ -180,7 +180,7 @@ public class MiscTests extends BaseDL4JTest { assertFalse(vr0.isValid()); assertEquals("Keras Functional Model HDF5", vr0.getFormatType()); assertTrue( vr0.getIssues().get(0).contains("exist"), vr0.getIssues().get(0)); - System.out.println(vr0.toString()); + System.out.println(vr0); //Test empty file: File fEmpty = new File(f, "empty.h5"); @@ -190,7 +190,7 @@ public class MiscTests extends BaseDL4JTest { assertEquals("Keras Functional Model HDF5", vr1.getFormatType()); assertFalse(vr1.isValid()); assertTrue( vr1.getIssues().get(0).contains("empty"), vr1.getIssues().get(0)); - System.out.println(vr1.toString()); + System.out.println(vr1); //Test directory (not zip file) File directory = new File(f, "dir"); @@ -200,7 +200,7 @@ public class MiscTests extends BaseDL4JTest { assertEquals("Keras Functional Model HDF5", vr2.getFormatType()); assertFalse(vr2.isValid()); assertTrue(vr2.getIssues().get(0).contains("directory"), vr2.getIssues().get(0)); - System.out.println(vr2.toString()); + System.out.println(vr2); //Test Keras HDF5 format: File fText = new File(f, "text.txt"); @@ -210,7 +210,7 @@ public class MiscTests extends BaseDL4JTest { assertFalse(vr3.isValid()); String s = vr3.getIssues().get(0); assertTrue(s.contains("Keras") && s.contains("Functional") && s.contains("corrupt"),s); - System.out.println(vr3.toString()); + System.out.println(vr3); //Test corrupted npy format: File fValid = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); @@ -226,7 +226,7 @@ public class MiscTests extends BaseDL4JTest { assertFalse(vr4.isValid()); s = vr4.getIssues().get(0); assertTrue(s.contains("Keras") && s.contains("Functional") && s.contains("corrupt"),s); - System.out.println(vr4.toString()); + System.out.println(vr4); //Test valid npy format: @@ -235,6 +235,6 @@ public class MiscTests extends BaseDL4JTest { assertTrue(vr5.isValid()); assertNull(vr5.getIssues()); assertNull(vr5.getException()); - System.out.println(vr4.toString()); + System.out.println(vr4); } } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/DeepCTRLambdaTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/DeepCTRLambdaTest.java index 5b3fdc77a..1eca7ab48 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/DeepCTRLambdaTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/DeepCTRLambdaTest.java @@ -40,7 +40,7 @@ public class DeepCTRLambdaTest { @Override public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) { - return layerInput.sum("tensors_sum-" + UUID.randomUUID().toString(),false,1); + return layerInput.sum("tensors_sum-" + UUID.randomUUID(),false,1); } @Override @@ -53,7 +53,7 @@ public class DeepCTRLambdaTest { @Override public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) { - return layerInput.mul("tensor_square-" + UUID.randomUUID().toString(),layerInput); + return layerInput.mul("tensor_square-" + UUID.randomUUID(),layerInput); } @Override @@ -66,7 +66,7 @@ public class DeepCTRLambdaTest { @Override public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) { - return layerInput.mul("lambda1-" + UUID.randomUUID().toString(),0.5); + return layerInput.mul("lambda1-" + UUID.randomUUID(),0.5); } @Override @@ -80,9 +80,9 @@ public class DeepCTRLambdaTest { @Override public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput) { if(this.layerName.equals("concat_embed_2d") || this.layerName.equals("cat_embed_2d_genure_mean")) - return layerInput.mean("mean_pooling-" + UUID.randomUUID().toString(),true,1); + return layerInput.mean("mean_pooling-" + UUID.randomUUID(),true,1); else - return layerInput.mean("mean_pooling-" + UUID.randomUUID().toString(),false,1); + return layerInput.mean("mean_pooling-" + UUID.randomUUID(),false,1); } @Override diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java index 207a8b82f..1120dfbb8 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java @@ -52,7 +52,7 @@ import java.util.Arrays; import java.util.LinkedList; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; public class FullModelComparisons extends BaseDL4JTest { @@ -90,7 +90,7 @@ public class FullModelComparisons extends BaseDL4JTest { org.deeplearning4j.nn.conf.layers.LSTM firstConf = (org.deeplearning4j.nn.conf.layers.LSTM) firstLstm.conf().getLayer(); // "unit_forget_bias": true - assertTrue(firstConf.getForgetGateBiasInit() == 1.0); + assertEquals(1.0, firstConf.getForgetGateBiasInit()); assertTrue(firstConf.getGateActivationFn() instanceof ActivationHardSigmoid); assertTrue(firstConf.getActivationFn() instanceof ActivationTanH); @@ -101,7 +101,7 @@ public class FullModelComparisons extends BaseDL4JTest { // Need to convert from IFCO to CFOI order // INDArray W = firstLstm.getParam("W"); - Assertions.assertTrue(Arrays.equals(W.shape(), new long[]{nIn, 4 * nOut})); + assertArrayEquals(W.shape(), new long[]{nIn, 4 * nOut}); Assertions.assertEquals(W.getDouble(0, 288), -0.30737767, 1e-7); Assertions.assertEquals(W.getDouble(0, 289), -0.5845409, 1e-7); Assertions.assertEquals(W.getDouble(1, 288), -0.44083247, 1e-7); @@ -112,12 +112,12 @@ public class FullModelComparisons extends BaseDL4JTest { INDArray RW = firstLstm.getParam("RW"); - assertTrue(Arrays.equals(RW.shape(), new long[]{nOut, 4 * nOut})); + assertArrayEquals(RW.shape(), new long[]{nOut, 4 * nOut}); Assertions.assertEquals(RW.getDouble(0, 288), 0.15112677, 1e-7); INDArray b = firstLstm.getParam("b"); - assertTrue(Arrays.equals(b.shape(), new long[]{1, 4 * nOut})); + assertArrayEquals(b.shape(), new long[]{1, 4 * nOut}); Assertions.assertEquals(b.getDouble(0, 288), -0.36940336, 1e-7); // Keras I Assertions.assertEquals(b.getDouble(0, 96), 0.6031118, 1e-7); // Keras F Assertions.assertEquals(b.getDouble(0, 192), -0.13569744, 1e-7); // Keras O @@ -128,7 +128,7 @@ public class FullModelComparisons extends BaseDL4JTest { org.deeplearning4j.nn.conf.layers.LSTM secondConf = (org.deeplearning4j.nn.conf.layers.LSTM) secondLstm.conf().getLayer(); // "unit_forget_bias": true - assertTrue(secondConf.getForgetGateBiasInit() == 1.0); + assertEquals(1.0, secondConf.getForgetGateBiasInit()); assertTrue(firstConf.getGateActivationFn() instanceof ActivationHardSigmoid); assertTrue(firstConf.getActivationFn() instanceof ActivationTanH); @@ -137,16 +137,16 @@ public class FullModelComparisons extends BaseDL4JTest { nOut = 96; W = secondLstm.getParam("W"); - assertTrue(Arrays.equals(W.shape(), new long[]{nIn, 4 * nOut})); + assertArrayEquals(W.shape(), new long[]{nIn, 4 * nOut}); Assertions.assertEquals(W.getDouble(0, 288), -0.7559755, 1e-7); RW = secondLstm.getParam("RW"); - assertTrue(Arrays.equals(RW.shape(), new long[]{nOut, 4 * nOut})); + assertArrayEquals(RW.shape(), new long[]{nOut, 4 * nOut}); Assertions.assertEquals(RW.getDouble(0, 288), -0.33184892, 1e-7); b = secondLstm.getParam("b"); - assertTrue(Arrays.equals(b.shape(), new long[]{1, 4 * nOut})); + assertArrayEquals(b.shape(), new long[]{1, 4 * nOut}); Assertions.assertEquals(b.getDouble(0, 288), -0.2223678, 1e-7); Assertions.assertEquals(b.getDouble(0, 96), 0.73556226, 1e-7); Assertions.assertEquals(b.getDouble(0, 192), -0.63227624, 1e-7); @@ -167,7 +167,7 @@ public class FullModelComparisons extends BaseDL4JTest { INDArray sequence = dataSet.getFeatures().get(NDArrayIndex.point(0)).transpose(); INDArray bsSequence = sequence.reshape(1, 4, 12); // one batch INDArray pred = model.output(bsSequence); - assertTrue(Arrays.equals(pred.shape(), new long[]{1, 1})); + assertArrayEquals(pred.shape(), new long[]{1, 1}); preds.add(pred.getDouble(0, 0)); } INDArray dl4jPredictions = Nd4j.create(preds); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java index 71065896c..4558eccc7 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/JsonTest.java @@ -36,7 +36,7 @@ public class JsonTest extends BaseDL4JTest { public void testJsonPreprocessors() throws Exception { InputPreProcessor[] pp = new InputPreProcessor[] { new KerasFlattenRnnPreprocessor(10, 5), - new PermutePreprocessor(new int[]{0,1,2}), + new PermutePreprocessor(0,1,2), new ReshapePreprocessor(new long[]{10,10}, new long[]{100,1}, true, null) }; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java index 7c8d8f73d..fc48183e2 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java @@ -36,7 +36,7 @@ import java.io.InputStream; @Slf4j public class Keras1ModelConfigurationTest extends BaseDL4JTest { - private ClassLoader classLoader = getClass().getClassLoader(); + private final ClassLoader classLoader = getClass().getClassLoader(); @Test public void imdbLstmTfSequentialConfigTest() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index a8eab6be6..05f6162f3 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -267,7 +267,7 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { model.init(); INDArray input = Nd4j.create(DataType.FLOAT, 50, 1500, 500); //NWC format - [Minibatch, seqLength, channels] INDArray out = model.output(input); - assertTrue(Arrays.equals(out.shape(), new long[]{50, 64})); + assertArrayEquals(out.shape(), new long[]{50, 64}); } } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java index 31fb10f09..e97a1685e 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java @@ -40,15 +40,15 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class KerasInitilizationTest extends BaseDL4JTest { - private double minValue = -0.2; - private double maxValue = 0.2; - private double mean = 0.0; - private double stdDev = 0.2; - private double value = 42.0; - private double gain = 0.2; + private final double minValue = -0.2; + private final double maxValue = 0.2; + private final double mean = 0.0; + private final double stdDev = 0.2; + private final double value = 42.0; + private final double gain = 0.2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testInitializers() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java index 06683cd07..67caf1e3b 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java @@ -56,7 +56,7 @@ public class KerasCustomLayerTest extends BaseDL4JTest { // download file if (!cachedKerasFile.exists()) { - log.info("Downloading model to " + cachedKerasFile.toString()); + log.info("Downloading model to " + cachedKerasFile); FileUtils.copyURLToFile(new URL(kerasWeightsAndConfigUrl), cachedKerasFile); cachedKerasFile.deleteOnExit(); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java index 5d394351a..7a13ab908 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLossTest.java @@ -64,7 +64,7 @@ public class KerasCustomLossTest extends BaseDL4JTest { .enforceTrainingConfig(true).buildSequential().getMultiLayerNetwork(); System.out.println(model.summary()); - INDArray input = Nd4j.create(new int[]{10, 3}); + INDArray input = Nd4j.create(10, 3); model.output(input); } finally { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java index 726de2e1f..592ad2d9c 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasLambdaTest.java @@ -83,7 +83,7 @@ public class KerasLambdaTest extends BaseDL4JTest { .enforceTrainingConfig(false).buildSequential().getMultiLayerNetwork(); System.out.println(model.summary()); - INDArray input = Nd4j.create(new int[]{10, 100}); + INDArray input = Nd4j.create(10, 100); model.output(input); } finally { @@ -105,7 +105,7 @@ public class KerasLambdaTest extends BaseDL4JTest { .enforceTrainingConfig(false).buildModel().getComputationGraph(); System.out.println(model.summary()); - INDArray input = Nd4j.create(new int[]{10, 784}); + INDArray input = Nd4j.create(10, 784); model.output(input); } finally { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java index 782923365..a5ab1f512 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasYolo9000PredictTest.java @@ -40,7 +40,7 @@ import java.io.File; public class KerasYolo9000PredictTest extends BaseDL4JTest { private static final String DL4J_MODEL_FILE_NAME = "."; - private static ImagePreProcessingScaler IMAGE_PREPROCESSING_SCALER = new ImagePreProcessingScaler(0, 1); + private static final ImagePreProcessingScaler IMAGE_PREPROCESSING_SCALER = new ImagePreProcessingScaler(0, 1); @Test ////@Ignore("Need to manually download file for ylo.") diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java index 91e890cba..b2417dad8 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasLeakyReLUTest.java @@ -38,8 +38,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; */ public class KerasLeakyReLUTest extends BaseDL4JTest { - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testLeakyReLULayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java index 7405a6007..202e06426 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java @@ -42,8 +42,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; */ public class KerasPReLUTest extends BaseDL4JTest { - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); private final String INIT_KERAS = "glorot_normal"; private final IWeightInit INIT_DL4J = new WeightInitXavier(); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java index 822a140a6..834886ef7 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasThresholdedReLUTest.java @@ -38,8 +38,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; */ public class KerasThresholdedReLUTest extends BaseDL4JTest { - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testThresholdedReLULayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java index 828d1c4c2..f5e25ea9f 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java @@ -58,8 +58,8 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras1 = 1; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Integer keras1 = 1; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); @Test public void testAtrousConvolution1DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java index a29be581c..f2eebb8f2 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java @@ -61,7 +61,7 @@ public class KerasAtrousConvolution2DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); @Test public void testAtrousConvolution2DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java index 22a51b1a7..994d3affe 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java @@ -60,10 +60,10 @@ public class KerasConvolution1DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testConvolution1DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java index f449c2cae..b92ab0432 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java @@ -62,10 +62,10 @@ public class KerasConvolution2DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java index a6c9af9c4..c36b0351d 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java @@ -61,10 +61,10 @@ public class KerasConvolution3DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java index 9519aa4ac..b3159e54b 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping1DTest.java @@ -41,8 +41,8 @@ public class KerasCropping1DTest extends BaseDL4JTest { private final String LAYER_NAME = "cropping_1D_layer"; private final int CROPPING = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testCropping1DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java index 966690847..e65a59438 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping2DTest.java @@ -42,8 +42,8 @@ public class KerasCropping2DTest extends BaseDL4JTest { private final String LAYER_NAME = "cropping_2D_layer"; private final int[] CROPPING = new int[]{2, 3}; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testCropping2DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java index 7c8f45579..5fe65127a 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasCropping3DTest.java @@ -42,8 +42,8 @@ public class KerasCropping3DTest extends BaseDL4JTest { private final String LAYER_NAME = "cropping_3D_layer"; private final int[] CROPPING = new int[]{2, 3, 5}; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testCropping3DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java index 37fafc785..c0db1c47b 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java @@ -62,10 +62,10 @@ public class KerasDeconvolution2DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java index 1b4b8e7c7..4dc4856c0 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java @@ -62,8 +62,8 @@ public class KerasDepthwiseConvolution2DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras2 = 2; - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras2 = 2; + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java index fb1df4525..54f50a478 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java @@ -63,10 +63,10 @@ public class KerasSeparableConvolution2DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java index 4985681cd..394f768c9 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling1DTest.java @@ -39,12 +39,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class KerasUpsampling1DTest extends BaseDL4JTest { private final String LAYER_NAME = "upsampling_1D_layer"; - private int size = 4; + private final int size = 4; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testUpsampling1DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java index eb38f4ec0..f75958315 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling2DTest.java @@ -41,12 +41,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class KerasUpsampling2DTest extends BaseDL4JTest { private final String LAYER_NAME = "upsampling_2D_layer"; - private int[] size = new int[]{2, 2}; + private final int[] size = new int[]{2, 2}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testUpsampling2DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java index 7741785d1..7c82f4907 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasUpsampling3DTest.java @@ -41,12 +41,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class KerasUpsampling3DTest extends BaseDL4JTest { private final String LAYER_NAME = "upsampling_3D_layer"; - private int[] size = new int[]{2, 2, 2}; + private final int[] size = new int[]{2, 2, 2}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testUpsampling3DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java index 9cfe0bdab..64bc6563f 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding1DTest.java @@ -38,8 +38,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; */ public class KerasZeroPadding1DTest extends BaseDL4JTest { - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testZeroPadding1DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java index 809cb5f0a..203c4b887 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding2DTest.java @@ -42,8 +42,8 @@ public class KerasZeroPadding2DTest extends BaseDL4JTest { private final String LAYER_NAME = "zero_padding_2D_layer"; private final int[] ZERO_PADDING = new int[]{2, 3}; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testZeroPadding2DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java index 6ae93473b..cc2c44968 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasZeroPadding3DTest.java @@ -42,8 +42,8 @@ public class KerasZeroPadding3DTest extends BaseDL4JTest { private final String LAYER_NAME = "zero_padding_3D_layer"; private final int[] ZERO_PADDING = new int[]{2, 3, 4}; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testZeroPadding3DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java index ad73a4c00..1f2496c16 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasActivationLayer.java @@ -38,10 +38,10 @@ public class KerasActivationLayer extends BaseDL4JTest { private final String ACTIVATION_DL4J = "identity"; private final String LAYER_NAME = "test_layer"; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testActivationLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java index 2d5c4f864..c9c70e5ff 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java @@ -41,10 +41,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; */ public class KerasDenseTest extends BaseDL4JTest { - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); private final String ACTIVATION_KERAS = "linear"; private final String ACTIVATION_DL4J = "identity"; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java index 322955813..afc7506e2 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDropoutTest.java @@ -42,10 +42,10 @@ public class KerasDropoutTest extends BaseDL4JTest { private final double DROPOUT_KERAS = 0.3; private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java index f898209ce..8734351f4 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasMaskingTest.java @@ -39,8 +39,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class KerasMaskingTest extends BaseDL4JTest { - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java index 50efe158a..93df0c5d1 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasPermuteTest.java @@ -42,10 +42,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; */ public class KerasPermuteTest extends BaseDL4JTest { - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java index 7390c8bc5..958e2baad 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasRepeatVectorTest.java @@ -38,12 +38,12 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class KerasRepeatVectorTest extends BaseDL4JTest { String LAYER_NAME = "repeat"; - private int REPEAT = 4; + private final int REPEAT = 4; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java index 6e57fa561..68ca9b55c 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshapeTest.java @@ -43,10 +43,10 @@ import static org.junit.jupiter.api.Assertions.assertEquals; */ public class KerasReshapeTest extends BaseDL4JTest { - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java index 01d225c19..8234c29b2 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasSpatialDropout2DTest.java @@ -42,10 +42,10 @@ public class KerasSpatialDropout2DTest extends BaseDL4JTest { private final double RATE_KERAS = 0.3; private final double RATE_DL4J = 1 - RATE_KERAS; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java index d358bd61e..010f890b7 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/embeddings/KerasEmbeddingTest.java @@ -43,10 +43,10 @@ public class KerasEmbeddingTest extends BaseDL4JTest { private final String INIT_KERAS = "glorot_normal"; private final int[] INPUT_SHAPE = new int[]{100, 20}; private static final boolean[] MASK_ZERO = new boolean[]{false, true}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testEmbeddingLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java index 1b2d9dfd7..42afecf32 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected1DTest.java @@ -58,10 +58,10 @@ public class KerasLocallyConnected1DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int VALID_PADDING = 0; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java index b703a482b..42981f1b6 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/local/KerasLocallyConnected2DTest.java @@ -61,10 +61,10 @@ public class KerasLocallyConnected2DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java index fa3a2feae..1f35515bb 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasAlphaDropoutTest.java @@ -42,10 +42,10 @@ public class KerasAlphaDropoutTest extends BaseDL4JTest { private final double RATE_KERAS = 0.3; private final double RATE_DL4J = 1 - RATE_KERAS; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java index e23356da2..eee0f1c8a 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianDropoutTest.java @@ -42,10 +42,10 @@ public class KerasGaussianDropoutTest extends BaseDL4JTest { private final double RATE_KERAS = 0.3; private final double RATE_DL4J = 1 - RATE_KERAS; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java index 4eb0042b5..6d8eb994c 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/noise/KerasGaussianNoiseTest.java @@ -41,10 +41,10 @@ public class KerasGaussianNoiseTest extends BaseDL4JTest { String LAYER_NAME = "gaussian_noise"; private final double STDDEV = 0.3; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java index d8341de8f..ca84e5244 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/normalization/KerasBatchNormalizationTest.java @@ -41,10 +41,10 @@ public class KerasBatchNormalizationTest extends BaseDL4JTest { public static final String PARAM_NAME_BETA = "beta"; private final String LAYER_NAME = "batch_norm_layer"; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java index 25557e595..d504e626f 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling1DTest.java @@ -47,10 +47,10 @@ public class KerasPooling1DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testPooling1DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java index 189cea1da..76aed15c1 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling2DTest.java @@ -49,10 +49,10 @@ public class KerasPooling2DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testPooling2DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java index eefba12b4..44ed404eb 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/pooling/KerasPooling3DTest.java @@ -49,10 +49,10 @@ public class KerasPooling3DTest extends BaseDL4JTest { private final String BORDER_MODE_VALID = "valid"; private final int[] VALID_PADDING = new int[]{0, 0, 0}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testPooling3DLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java index b88f3c94c..7ce6bf0b3 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java @@ -61,12 +61,12 @@ public class KerasLSTMTest extends BaseDL4JTest { private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; private final int N_OUT = 13; - private Boolean[] returnSequences = new Boolean[]{true, false}; - private Boolean[] maskZero = new Boolean[]{true, false}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Boolean[] returnSequences = new Boolean[]{true, false}; + private final Boolean[] maskZero = new Boolean[]{true, false}; + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testLstmLayer() throws Exception { @@ -177,8 +177,8 @@ public class KerasLSTMTest extends BaseDL4JTest { layerConfig.put(conf.getLAYER_FIELD_CONFIG(), config); layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); layerConfig.put(conf.getLAYER_FIELD_INBOUND_NODES(), - Arrays.asList(Arrays.asList( - Arrays.asList("embedding")))); + Collections.singletonList(Collections.singletonList( + Collections.singletonList("embedding")))); KerasEmbedding embedding = getEmbedding(maskZero); Map previousLayers = Collections.singletonMap("embedding", embedding); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java index e68627e3e..c8e8287fb 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java @@ -52,11 +52,11 @@ public class KerasSimpleRnnTest extends BaseDL4JTest { private final double DROPOUT_DL4J = 1 - DROPOUT_KERAS; private final int N_OUT = 13; - private Boolean[] returnSequences = new Boolean[]{true, false}; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Boolean[] returnSequences = new Boolean[]{true, false}; + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testSimpleRnnLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java index 7073a6cba..1aa8b0a81 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectionalTest.java @@ -52,10 +52,10 @@ public class KerasBidirectionalTest extends BaseDL4JTest { private final int N_OUT = 13; private final String mode = "sum"; - private Integer keras1 = 1; - private Integer keras2 = 2; - private Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); - private Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); + private final Integer keras1 = 1; + private final Integer keras2 = 2; + private final Keras1LayerConfiguration conf1 = new Keras1LayerConfiguration(); + private final Keras2LayerConfiguration conf2 = new Keras2LayerConfiguration(); @Test public void testLstmLayer() throws Exception { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java index b40eb37c1..a5eb7f3a2 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/weights/KerasWeightSettingTests.java @@ -211,7 +211,6 @@ public class KerasWeightSettingTests extends BaseDL4JTest { int nOut = 12; int mb = 10; - ; int[] inShape = new int[]{5, 5, 5}; INDArray input = Nd4j.zeros(mb, inShape[0], inShape[1], inShape[2]); INDArray output = model.output(input); @@ -259,7 +258,7 @@ public class KerasWeightSettingTests extends BaseDL4JTest { ComputationGraph model = loadComputationalGraph(modelPath, false); // INDArray input[] = new INDArray[]{Nd4j.zeros(10, 4, 6, 6), Nd4j.zeros(10, 16, 3, 3)}; - INDArray input[] = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)}; + INDArray[] input = new INDArray[]{Nd4j.zeros(10, 6, 6, 4), Nd4j.zeros(10, 3, 3, 16)}; INDArray[] output = model.output(input); log.info(Arrays.toString(output[0].shape())); assertArrayEquals(new long[]{10, 3, 3, 32}, output[0].shape()); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizer.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizer.java index 0928ee429..a83386d01 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizer.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizer.java @@ -42,6 +42,7 @@ import java.io.BufferedReader; import java.io.File; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -62,7 +63,7 @@ public class BagOfWordsVectorizer extends BaseTextVectorizer { @Override public DataSet vectorize(InputStream is, String label) { try { - BufferedReader reader = new BufferedReader(new InputStreamReader(is, "UTF-8")); + BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8)); String line = ""; StringBuilder builder = new StringBuilder(); while ((line = reader.readLine()) != null) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/DefaultInputStreamCreator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/DefaultInputStreamCreator.java index b05583f49..a03ebb1a8 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/DefaultInputStreamCreator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/DefaultInputStreamCreator.java @@ -26,7 +26,7 @@ import org.deeplearning4j.text.documentiterator.DocumentIterator; import java.io.InputStream; public class DefaultInputStreamCreator implements InputStreamCreator { - private DocumentIterator iter; + private final DocumentIterator iter; public DefaultInputStreamCreator(DocumentIterator iter) { this.iter = iter; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.java index b74e6b953..ba344994b 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/bagofwords/vectorizer/TfidfVectorizer.java @@ -45,6 +45,7 @@ import java.io.BufferedReader; import java.io.File; import java.io.InputStream; import java.io.InputStreamReader; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.concurrent.atomic.AtomicLong; @@ -60,7 +61,7 @@ public class TfidfVectorizer extends BaseTextVectorizer { @Override public DataSet vectorize(InputStream is, String label) { try { - BufferedReader reader = new BufferedReader(new InputStreamReader(is, "UTF-8")); + BufferedReader reader = new BufferedReader(new InputStreamReader(is, StandardCharsets.UTF_8)); String line = ""; StringBuilder builder = new StringBuilder(); while ((line = reader.readLine()) != null) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java index a004236e6..0478f1e14 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/BertIterator.java @@ -671,9 +671,9 @@ public class BertIterator implements MultiDataSetIterator { private int listLength = 0; @Getter - private long[] segIdOnesFrom; + private final long[] segIdOnesFrom; private int cursor = 0; - private SentenceListProcessed sentenceListProcessed; + private final SentenceListProcessed sentenceListProcessed; private SentencePairListProcessed(int listLength) { this.listLength = listLength; @@ -701,14 +701,14 @@ public class BertIterator implements MultiDataSetIterator { } private static class SentenceListProcessed { - private int listLength; + private final int listLength; @Getter @Setter private int maxL; @Getter - private List, String>> tokensAndLabelList; + private final List, String>> tokensAndLabelList; private SentenceListProcessed(int listLength) { this.listLength = listLength; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/CnnSentenceDataSetIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/CnnSentenceDataSetIterator.java index fd5fafe94..ccf669dc5 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/CnnSentenceDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/CnnSentenceDataSetIterator.java @@ -450,7 +450,7 @@ public class CnnSentenceDataSetIterator implements DataSetIterator { public static class Builder { - private Format format; + private final Format format; private LabeledSentenceProvider sentenceProvider = null; private WordVectors wordVectors; private TokenizerFactory tokenizerFactory = new DefaultTokenizerFactory(); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/provider/LabelAwareConverter.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/provider/LabelAwareConverter.java index 9c0aa2018..555b15bda 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/provider/LabelAwareConverter.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/iterator/provider/LabelAwareConverter.java @@ -29,8 +29,8 @@ import org.nd4j.common.primitives.Pair; import java.util.List; public class LabelAwareConverter implements LabeledSentenceProvider { - private LabelAwareIterator backingIterator; - private List labels; + private final LabelAwareIterator backingIterator; + private final List labels; public LabelAwareConverter(@NonNull LabelAwareIterator iterator, @NonNull List labels) { this.backingIterator = iterator; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/BatchItem.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/BatchItem.java index 590e02e88..f775dbf86 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/BatchItem.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/BatchItem.java @@ -29,7 +29,7 @@ public class BatchItem { private int[] windowWords; // CBOW only private boolean[] wordStatuses; - private long randomValue; + private final long randomValue; private double alpha; private int windowWordsLength; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/BatchSequences.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/BatchSequences.java index 93c10bd5e..bce3fbeaa 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/BatchSequences.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/BatchSequences.java @@ -30,7 +30,7 @@ import java.util.concurrent.atomic.AtomicLong; @Slf4j public class BatchSequences { - private int batches; + private final int batches; List> buffer = new ArrayList<>(); @@ -56,7 +56,7 @@ public class BatchSequences { public List> get(int chunkNo) { List> retVal = new ArrayList<>(); - for (int i = 0 + chunkNo * batches; (i < batches + chunkNo * batches) && (i < buffer.size()); ++i) { + for (int i = chunkNo * batches; (i < batches + chunkNo * batches) && (i < buffer.size()); ++i) { BatchItem value = buffer.get(i); retVal.add(value); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java index 80c5357ad..bb1ffd4a4 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/CBOW.java @@ -103,7 +103,7 @@ public class CBOW implements ElementsLearningAlgorith logger.info("Initializing syn1Neg..."); ((InMemoryLookupTable) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax()); ((InMemoryLookupTable) lookupTable).setNegative(configuration.getNegative()); - ((InMemoryLookupTable) lookupTable).resetWeights(false); + lookupTable.resetWeights(false); } } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/SkipGram.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/SkipGram.java index 4fe3320a7..912cb29a9 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/SkipGram.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/elements/SkipGram.java @@ -118,7 +118,7 @@ public class SkipGram implements ElementsLearningAlgo log.info("Initializing syn1Neg..."); ((InMemoryLookupTable) lookupTable).setUseHS(configuration.isUseHierarchicSoftmax()); ((InMemoryLookupTable) lookupTable).setNegative(configuration.getNegative()); - ((InMemoryLookupTable) lookupTable).resetWeights(false); + lookupTable.resetWeights(false); } } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/sequence/DM.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/sequence/DM.java index 42ad78579..64dffe9de 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/sequence/DM.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/learning/impl/sequence/DM.java @@ -58,7 +58,7 @@ public class DM implements SequenceLearningAlgorithm< protected INDArray syn0, syn1, syn1Neg, table; - private CBOW cbow = new CBOW<>(); + private final CBOW cbow = new CBOW<>(); @Override public ElementsLearningAlgorithm getElementsLearningAlgorithm() { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/VectorsConfiguration.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/VectorsConfiguration.java index c613671db..42273db58 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/VectorsConfiguration.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/VectorsConfiguration.java @@ -29,6 +29,7 @@ import com.fasterxml.jackson.databind.SerializationFeature; import java.io.IOException; import java.io.Serializable; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; @@ -118,7 +119,7 @@ public class VectorsConfiguration implements Serializable { public String toEncodedJson() { Base64 base64 = new Base64(Integer.MAX_VALUE); try { - return base64.encodeAsString(this.toJson().getBytes("UTF-8")); + return base64.encodeAsString(this.toJson().getBytes(StandardCharsets.UTF_8)); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java index a2e3fb8c6..e521e2be7 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/loader/WordVectorSerializer.java @@ -338,7 +338,7 @@ public class WordVectorSerializer { if (i < vec.length() - 1) builder.append(" "); } - writer.println(builder.toString()); + writer.println(builder); } } } @@ -530,11 +530,11 @@ public class WordVectorSerializer { try (PrintWriter writer = new PrintWriter(new FileWriter(tempFileFreqs))) { for (int i = 0; i < vectors.getVocab().numWords(); i++) { VocabWord word = vectors.getVocab().elementAtIndex(i); - StringBuilder builder = new StringBuilder(ReadHelper.encodeB64(word.getLabel())).append(" ") - .append(word.getElementFrequency()).append(" ") - .append(vectors.getVocab().docAppearedIn(word.getLabel())); + String builder = ReadHelper.encodeB64(word.getLabel()) + " " + + word.getElementFrequency() + " " + + vectors.getVocab().docAppearedIn(word.getLabel()); - writer.println(builder.toString().trim()); + writer.println(builder.trim()); } } @@ -830,7 +830,7 @@ public class WordVectorSerializer { List rows = new ArrayList<>(); while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - double array[] = new double[split.length]; + double[] array = new double[split.length]; for (int i = 0; i < split.length; i++) { array[i] = Double.parseDouble(split[i]); } @@ -904,7 +904,7 @@ public class WordVectorSerializer { List rows = new ArrayList<>(); while ((line = reader.readLine()) != null) { String[] split = line.split(" "); - double array[] = new double[split.length]; + double[] array = new double[split.length]; for (int i = 0; i < split.length; i++) { array[i] = Double.parseDouble(split[i]); } @@ -1055,7 +1055,7 @@ public class WordVectorSerializer { InMemoryLookupTable lookupTable = - (InMemoryLookupTable) new InMemoryLookupTable.Builder() + new InMemoryLookupTable.Builder() .vectorLength(arrays.get(0).columns()).useAdaGrad(false).cache(vocabCache) .build(); Nd4j.clearNans(syn); @@ -1177,7 +1177,7 @@ public class WordVectorSerializer { PrintWriter printWriter = null; try { - printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(path), "UTF-8")); + printWriter = new PrintWriter(new OutputStreamWriter(new FileOutputStream(path), StandardCharsets.UTF_8)); } catch (Exception e) { throw new RuntimeException(e); } @@ -1265,7 +1265,7 @@ public class WordVectorSerializer { INDArray gradient = word.getHistoricalGradient(); if (gradient == null) gradient = Nd4j.zeros(word.getCodes().size()); - double ada[] = new double[gradient.columns()]; + double[] ada = new double[gradient.columns()]; for (int x = 0; x < gradient.columns(); x++) { ada[x] = gradient.getDouble(x); } @@ -1356,7 +1356,7 @@ public class WordVectorSerializer { // now, it's time to transfer syn0/syn1/syn1 neg values InMemoryLookupTable lookupTable = - (InMemoryLookupTable) new InMemoryLookupTable.Builder().negative(configuration.getNegative()) + new InMemoryLookupTable.Builder().negative(configuration.getNegative()) .useAdaGrad(configuration.isUseAdaGrad()).lr(configuration.getLearningRate()) .cache(vocabCache).vectorLength(configuration.getLayersSize()).build(); @@ -1409,7 +1409,7 @@ public class WordVectorSerializer { @Deprecated public static void writeWordVectors(@NonNull Word2Vec vec, @NonNull String path) throws IOException { BufferedWriter write = new BufferedWriter( - new OutputStreamWriter(new FileOutputStream(new File(path), false), "UTF-8")); + new OutputStreamWriter(new FileOutputStream(new File(path), false), StandardCharsets.UTF_8)); writeWordVectors(vec, write); @@ -1647,7 +1647,7 @@ public class WordVectorSerializer { lookupTable.setSyn0(syn); - return new Pair<>((InMemoryLookupTable) lookupTable, (VocabCache) cache); + return new Pair<>(lookupTable, cache); } catch (IOException readeTextStreamException) { throw new RuntimeException(readeTextStreamException); } finally { @@ -1741,7 +1741,7 @@ public class WordVectorSerializer { } InMemoryLookupTable lookupTable = - (InMemoryLookupTable) new InMemoryLookupTable.Builder() + new InMemoryLookupTable.Builder() .vectorLength(arrays.get(0).columns()).cache(cache).build(); INDArray syn = Nd4j.vstack(arrays); @@ -1749,7 +1749,7 @@ public class WordVectorSerializer { Nd4j.clearNans(syn); lookupTable.setSyn0(syn); - return fromPair(Pair.makePair((InMemoryLookupTable) lookupTable, (VocabCache) cache)); + return fromPair(Pair.makePair(lookupTable, cache)); } /** @@ -1925,11 +1925,11 @@ public class WordVectorSerializer { VectorsConfiguration configuration = vectors.getConfiguration(); String json = configuration.toJson().trim(); - zipfile.write(json.getBytes("UTF-8")); + zipfile.write(json.getBytes(StandardCharsets.UTF_8)); ZipEntry vocab = new ZipEntry(VOCAB_ENTRY); zipfile.putNextEntry(vocab); - zipfile.write(vocabCache.toJson().getBytes("UTF-8")); + zipfile.write(vocabCache.toJson().getBytes(StandardCharsets.UTF_8)); INDArray syn0Data = lookupTable.getSyn0(); ZipEntry syn0 = new ZipEntry(SYN0_ENTRY); @@ -2013,11 +2013,11 @@ public class WordVectorSerializer { byte[] bytes = IOUtils.toByteArray(zipfile); if (name.equals(CONFIG_ENTRY)) { - String content = new String(bytes, "UTF-8"); + String content = new String(bytes, StandardCharsets.UTF_8); configuration = VectorsConfiguration.fromJson(content); continue; } else if (name.equals(VOCAB_ENTRY)) { - String content = new String(bytes, "UTF-8"); + String content = new String(bytes, StandardCharsets.UTF_8); vocabCache = AbstractCache.fromJson(content); continue; } @@ -2068,12 +2068,12 @@ public class WordVectorSerializer { */ public static SequenceVectors readSequenceVectors( @NonNull SequenceElementFactory factory, @NonNull InputStream stream) throws IOException { - BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8")); + BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8)); // at first we load vectors configuration String line = reader.readLine(); VectorsConfiguration configuration = - VectorsConfiguration.fromJson(new String(Base64.decodeBase64(line), "UTF-8")); + VectorsConfiguration.fromJson(new String(Base64.decodeBase64(line), StandardCharsets.UTF_8)); AbstractCache vocabCache = new AbstractCache.Builder().build(); @@ -2092,7 +2092,7 @@ public class WordVectorSerializer { reader.close(); - InMemoryLookupTable lookupTable = (InMemoryLookupTable) new InMemoryLookupTable.Builder() + InMemoryLookupTable lookupTable = new InMemoryLookupTable.Builder() .vectorLength(rows.get(0).columns()).cache(vocabCache).build(); // fix: add vocab cache /* @@ -2225,7 +2225,7 @@ public class WordVectorSerializer { Base64 base64 = new Base64(Integer.MAX_VALUE); try { String json = mapper.writeValueAsString(this); - String output = base64.encodeAsString(json.getBytes("UTF-8")); + String output = base64.encodeAsString(json.getBytes(StandardCharsets.UTF_8)); return output; } catch (Exception e) { throw new RuntimeException(e); @@ -2241,7 +2241,7 @@ public class WordVectorSerializer { protected static ElementPair fromEncodedJson(String encoded) { ObjectMapper mapper = SequenceElement.mapper(); try { - String decoded = new String(Base64.decodeBase64(encoded), "UTF-8"); + String decoded = new String(Base64.decodeBase64(encoded), StandardCharsets.UTF_8); return mapper.readValue(decoded, ElementPair.class); } catch (IOException e) { throw new RuntimeException(e); @@ -2850,8 +2850,8 @@ public class WordVectorSerializer { } protected static class CSVReader implements Reader { - private BufferedReader reader; - private AtomicInteger idxCounter = new AtomicInteger(0); + private final BufferedReader reader; + private final AtomicInteger idxCounter = new AtomicInteger(0); private String nextLine; protected CSVReader(@NonNull File file) { @@ -3202,12 +3202,12 @@ public class WordVectorSerializer { bytes[i] = b; b = dis.readByte(); if (i == 49) { - sb.append(new String(bytes, "UTF-8")); + sb.append(new String(bytes, StandardCharsets.UTF_8)); i = -1; bytes = new byte[MAX_SIZE]; } } - sb.append(new String(bytes, 0, i + 1, "UTF-8")); + sb.append(new String(bytes, 0, i + 1, StandardCharsets.UTF_8)); return sb.toString(); } @@ -3221,7 +3221,7 @@ public class WordVectorSerializer { */ public static String encodeB64(String word) { try { - return B64 + Base64.encodeBase64String(word.getBytes("UTF-8")).replaceAll("(\r|\n)", ""); + return B64 + Base64.encodeBase64String(word.getBytes(StandardCharsets.UTF_8)).replaceAll("(\r|\n)", ""); } catch (Exception e) { throw new RuntimeException(e); } @@ -3238,7 +3238,7 @@ public class WordVectorSerializer { if (word.startsWith(B64)) { String arp = word.replaceFirst(B64, ""); try { - return new String(Base64.decodeBase64(arp), "UTF-8"); + return new String(Base64.decodeBase64(arp), StandardCharsets.UTF_8); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java index 39f8a3df0..9f27cea83 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/BasicModelUtils.java @@ -110,9 +110,8 @@ public class BasicModelUtils implements ModelUtils @Override public Collection wordsNearest(String label, int n) { - List collection = new ArrayList<>(wordsNearest(Arrays.asList(label), new ArrayList(), n + 1)); - if (collection.contains(label)) - collection.remove(label); + List collection = new ArrayList<>(wordsNearest(Collections.singletonList(label), new ArrayList(), n + 1)); + collection.remove(label); while (collection.size() > n) collection.remove(collection.size() - 1); @@ -147,7 +146,7 @@ public class BasicModelUtils implements ModelUtils } else { String[] split = s.split(" "); List positive = Arrays.asList(split[1], split[2]); - List negative = Arrays.asList(split[0]); + List negative = Collections.singletonList(split[0]); String predicted = split[3]; String w = wordsNearest(positive, negative, 1).iterator().next(); if (predicted.equals(w)) diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.java index 4725042e9..53261ced2 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/FlatModelUtils.java @@ -45,8 +45,7 @@ public class FlatModelUtils extends BasicModelUtils wordsNearest(String label, int n) { Collection collection = wordsNearest(lookupTable.vector(label), n); - if (collection.contains(label)) - collection.remove(label); + collection.remove(label); return collection; } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java index e2707ff0b..e0d354775 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/reader/impl/TreeModelUtils.java @@ -66,9 +66,8 @@ public class TreeModelUtils extends BasicModelUtils(); - Collection collection = wordsNearest(Arrays.asList(label), new ArrayList(), n + 1); - if (collection.contains(label)) - collection.remove(label); + Collection collection = wordsNearest(Collections.singletonList(label), new ArrayList(), n + 1); + collection.remove(label); return collection; } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java index fb5156441..a4a9917fb 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/embeddings/wordvectors/WordVectorsImpl.java @@ -227,7 +227,7 @@ public class WordVectorsImpl implements WordVectors { */ @Override public INDArray getWordVectors(@NonNull Collection labels) { - int indexes[] = new int[labels.size()]; + int[] indexes = new int[labels.size()]; int cnt = 0; boolean useIndexUnknown = useUnknown && vocab.containsWord(getUNK()); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java index 6e753860f..6a847bbef 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/fasttext/FastText.java @@ -124,7 +124,7 @@ public class FastText implements WordVectors, Serializable { private static class ArgsFactory { - private List args = new ArrayList<>(); + private final List args = new ArrayList<>(); private void add(String label, String value) { args.add(label); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/node2vec/Node2Vec.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/node2vec/Node2Vec.java index f5777ebb0..1dce091e3 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/node2vec/Node2Vec.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/node2vec/Node2Vec.java @@ -51,7 +51,7 @@ public class Node2Vec extends Seque } public static class Builder extends SequenceVectors.Builder { - private GraphWalker walker; + private final GraphWalker walker; public Builder(@NonNull GraphWalker walker, @NonNull VectorsConfiguration configuration) { this.walker = walker; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java index c4558c331..bcd1983ec 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectors.java @@ -981,14 +981,14 @@ public class ParagraphVectors extends Word2Vec { if (docIter instanceof LabelAwareDocumentIterator) this.labelAwareIterator = - new DocumentIteratorConverter((LabelAwareDocumentIterator) docIter, labelsSource); + new DocumentIteratorConverter(docIter, labelsSource); else this.labelAwareIterator = new DocumentIteratorConverter(docIter, labelsSource); } else if (sentenceIterator != null) { // we have SentenceIterator. Mechanics will be the same, as above if (sentenceIterator instanceof LabelAwareSentenceIterator) this.labelAwareIterator = new SentenceIteratorConverter( - (LabelAwareSentenceIterator) sentenceIterator, labelsSource); + sentenceIterator, labelsSource); else this.labelAwareIterator = new SentenceIteratorConverter(sentenceIterator, labelsSource); } else if (labelAwareIterator != null) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java index c6752ae75..cbb20dc05 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/SequenceVectors.java @@ -1081,9 +1081,9 @@ public class SequenceVectors extends WordVectorsImpl< // private final AtomicLong linesCounter; private final int limitUpper; private final int limitLower; - private AtomicBoolean isRunning = new AtomicBoolean(true); - private AtomicLong nextRandom; - private Collection stopList; + private final AtomicBoolean isRunning = new AtomicBoolean(true); + private final AtomicLong nextRandom; + private final Collection stopList; private static final int DEFAULT_BUFFER_SIZE = 512; @@ -1220,7 +1220,7 @@ public class SequenceVectors extends WordVectorsImpl< .cyclesBeforeInitialization(3) .initialSize(25L * 1024L * 1024L) .build(); - val workspace_id = "sequence_vectors_training_" + java.util.UUID.randomUUID().toString(); + val workspace_id = "sequence_vectors_training_" + UUID.randomUUID(); Nd4j.getAffinityManager().getDeviceForCurrentThread(); while (digitizer.hasMoreLines()) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/huffman/GraphHuffman.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/huffman/GraphHuffman.java index 7eaea0f30..e95a2cb88 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/huffman/GraphHuffman.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/huffman/GraphHuffman.java @@ -117,7 +117,7 @@ public class GraphHuffman implements BinaryTree { if (value) return (in | 1L << bitNum); //Bit mask |: 00010000 else - return (in & ~(1 << bitNum)); //Bit mask &: 11101111 + return (in & ~(1L << bitNum)); //Bit mask &: 11101111 } private static boolean getBit(long in, int bitNum) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/primitives/IGraph.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/primitives/IGraph.java index 29f8e3e08..8f683218a 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/primitives/IGraph.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/primitives/IGraph.java @@ -30,33 +30,33 @@ import java.util.Random; public interface IGraph { /** Number of vertices in the graph */ - public int numVertices(); + int numVertices(); /**Get a vertex in the graph for a given index * @param idx integer index of the vertex to get. must be in range 0 to numVertices() * @return vertex */ - public Vertex getVertex(int idx); + Vertex getVertex(int idx); /** Get multiple vertices in the graph * @param indexes the indexes of the vertices to retrieve * @return list of vertices */ - public List> getVertices(int[] indexes); + List> getVertices(int[] indexes); /** Get multiple vertices in the graph, with secified indices * @param from first vertex to get, inclusive * @param to last vertex to get, inclusive * @return list of vertices */ - public List> getVertices(int from, int to); + List> getVertices(int from, int to); /** Add an edge to the graph. */ - public void addEdge(Edge edge); + void addEdge(Edge edge); /** Convenience method for adding an edge (directed or undirected) to graph */ - public void addEdge(int from, int to, E value, boolean directed); + void addEdge(int from, int to, E value, boolean directed); /** Returns a list of edges for a vertex with a given index * For undirected graphs, returns all edges incident on the vertex @@ -64,7 +64,7 @@ public interface IGraph { * @param vertex index of the vertex to * @return list of edges for this vertex */ - public List> getEdgesOut(int vertex); + List> getEdgesOut(int vertex); /** Returns the degree of the vertex.
* For undirected graphs, this is just the degree.
@@ -72,7 +72,7 @@ public interface IGraph { * @param vertex vertex to get degree for * @return vertex degree */ - public int getVertexDegree(int vertex); + int getVertexDegree(int vertex); /** Randomly sample a vertex connected to a given vertex. Sampling is done uniformly at random. * Specifically, returns a random X such that either a directed edge (vertex -> X) exists, @@ -84,7 +84,7 @@ public interface IGraph { * @throws NoEdgesException thrown if the specified vertex has no edges, or no outgoing edges (in the case * of a directed graph). */ - public Vertex getRandomConnectedVertex(int vertex, Random rng) throws NoEdgesException; + Vertex getRandomConnectedVertex(int vertex, Random rng) throws NoEdgesException; /**Get a list of all of the vertices that the specified vertex is connected to
* Specifically, for undirected graphs return list of all X such that (vertex -- X) exists
@@ -92,7 +92,7 @@ public interface IGraph { * @param vertex Index of the vertex * @return list of vertices that the specified vertex is connected to */ - public List> getConnectedVertices(int vertex); + List> getConnectedVertices(int vertex); /**Return an array of indexes of vertices that the specified vertex is connected to.
* Specifically, for undirected graphs return int[] of all X.vertexID() such that (vertex -- X) exists
@@ -101,5 +101,5 @@ public interface IGraph { * @return list of vertices that the specified vertex is connected to * @see #getConnectedVertices(int) */ - public int[] getConnectedVertexIndices(int vertex); + int[] getConnectedVertexIndices(int vertex); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java index e6f245d47..28f6058e6 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/NearestVertexWalker.java @@ -46,7 +46,7 @@ public class NearestVertexWalker implements GraphWalk protected Random rng; protected int depth; - private AtomicInteger position = new AtomicInteger(0); + private final AtomicInteger position = new AtomicInteger(0); protected NearestVertexWalker() { @@ -259,7 +259,7 @@ public class NearestVertexWalker implements GraphWalk } protected class VertexComparator implements Comparator> { - private IGraph graph; + private final IGraph graph; public VertexComparator(@NonNull IGraph graph) { this.graph = graph; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/PopularityWalker.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/PopularityWalker.java index 433805d90..58f3e13cb 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/PopularityWalker.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/PopularityWalker.java @@ -181,7 +181,7 @@ public class PopularityWalker extends RandomWalker } break; case PROPORTIONAL: { - double norm[] = MathArrays.normalizeArray(weights, 1); + double[] norm = MathArrays.normalizeArray(weights, 1); double prob = rng.nextDouble(); double floor = 0.0; for (int b = 0; b < weights.length; b++) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/iterators/AbstractSequenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/iterators/AbstractSequenceIterator.java index ad93ca7aa..b5b48bfc0 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/iterators/AbstractSequenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/iterators/AbstractSequenceIterator.java @@ -30,7 +30,7 @@ import java.util.concurrent.atomic.AtomicInteger; public class AbstractSequenceIterator implements SequenceIterator { - private Iterable> underlyingIterable; + private final Iterable> underlyingIterable; private Iterator> currentIterator; // used to tag each sequence with own Id @@ -71,7 +71,7 @@ public class AbstractSequenceIterator implements Sequ } public static class Builder { - private Iterable> underlyingIterable; + private final Iterable> underlyingIterable; /** * Builds AbstractSequenceIterator on top of Iterable object diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/ScoreListener.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/ScoreListener.java index 29b2340cc..3b209dc05 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/ScoreListener.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/ScoreListener.java @@ -44,10 +44,7 @@ public class ScoreListener implements VectorsListener @Override public boolean validateEvent(ListenerEvent event, long argument) { - if (event == targetEvent) - return true; - - return false; + return event == targetEvent; } @Override diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/SerializingListener.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/SerializingListener.java index d6b711ce2..9b031ae76 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/SerializingListener.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/listeners/SerializingListener.java @@ -41,7 +41,7 @@ public class SerializingListener implements VectorsLi private ListenerEvent targetEvent = ListenerEvent.EPOCH; private int targetFrequency = 100000; - private Semaphore locker = new Semaphore(1); + private final Semaphore locker = new Semaphore(1); protected SerializingListener() {} @@ -60,10 +60,7 @@ public class SerializingListener implements VectorsLi */ locker.acquire(); - if (event == targetEvent && argument % targetFrequency == 0) { - return true; - } else - return false; + return event == targetEvent && argument % targetFrequency == 0; } catch (Exception e) { throw new RuntimeException(e); } finally { @@ -85,9 +82,7 @@ public class SerializingListener implements VectorsLi SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS"); - StringBuilder builder = new StringBuilder(targetFolder.getAbsolutePath()); - builder.append("/").append(modelPrefix).append("_").append(sdf.format(new Date())).append(".seqvec"); - File targetFile = new File(builder.toString()); + File targetFile = new File(targetFolder.getAbsolutePath() + "/" + modelPrefix + "_" + sdf.format(new Date()) + ".seqvec"); if (useBinarySerialization) { SerializationUtils.saveObject(sequenceVectors, targetFile); @@ -104,7 +99,7 @@ public class SerializingListener implements VectorsLi public static class Builder { private File targetFolder = new File("./"); - private String modelPrefix = "Model_"; + private final String modelPrefix = "Model_"; private boolean useBinarySerialization = true; private ListenerEvent targetEvent = ListenerEvent.EPOCH; private int targetFrequency = 100000; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/Sequence.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/Sequence.java index 71de077b1..e80db82db 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/Sequence.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/sequence/Sequence.java @@ -204,7 +204,7 @@ public class Sequence implements Serializable { Sequence sequence = (Sequence) o; - return elements != null ? elements.equals(sequence.elements) : sequence.elements == null; + return Objects.equals(elements, sequence.elements); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformer.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformer.java index 8e83448d8..04b64fa17 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformer.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/GraphTransformer.java @@ -86,7 +86,7 @@ public class GraphTransformer implements Iterable>() { - private GraphWalker walker = GraphTransformer.this.walker; + private final GraphWalker walker = GraphTransformer.this.walker; @Override public void remove() { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.java index d68fc6095..2f0064038 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIterator.java @@ -107,8 +107,8 @@ public class ParallelTransformerIterator extends BasicTransformerIterator { private static class CallableTransformer implements Callable> { - private LabelledDocument document; - private SentenceTransformer transformer; + private final LabelledDocument document; + private final SentenceTransformer transformer; public CallableTransformer(LabelledDocument document, SentenceTransformer transformer) { this.transformer = transformer; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/Huffman.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/Huffman.java index 53a0b34f1..c773722d8 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/Huffman.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/Huffman.java @@ -38,7 +38,7 @@ public class Huffman { public final int MAX_CODE_LENGTH; private volatile boolean buildTrigger = false; - private Logger logger = LoggerFactory.getLogger(Huffman.class); + private final Logger logger = LoggerFactory.getLogger(Huffman.class); public Huffman(Collection words) { this(words, 40); @@ -63,7 +63,7 @@ public class Huffman { }); } - private List words; + private final List words; public void build() { buildTrigger = true; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/StaticWord2Vec.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/StaticWord2Vec.java index 10c71e58b..46841f364 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/StaticWord2Vec.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/StaticWord2Vec.java @@ -38,7 +38,7 @@ import java.util.concurrent.ConcurrentHashMap; @Slf4j public class StaticWord2Vec implements WordVectors { - private List> cacheWrtDevice = new ArrayList<>(); + private final List> cacheWrtDevice = new ArrayList<>(); private AbstractStorage storage; private long cachePerDevice = 0L; private VocabCache vocabCache; @@ -380,9 +380,9 @@ public class StaticWord2Vec implements WordVectors { public static class Builder { - private AbstractStorage storage; + private final AbstractStorage storage; private long cachePerDevice = 0L; - private VocabCache vocabCache; + private final VocabCache vocabCache; /** * diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/StreamWork.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/StreamWork.java index 8ab9a7071..e6272597b 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/StreamWork.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/StreamWork.java @@ -25,7 +25,7 @@ import java.io.Serializable; import java.util.concurrent.atomic.AtomicInteger; public class StreamWork implements Serializable { - private InputStreamCreator is; + private final InputStreamCreator is; private AtomicInteger count = new AtomicInteger(0); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/VocabWork.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/VocabWork.java index ee8ed1ca2..aa4cc5556 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/VocabWork.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/VocabWork.java @@ -22,7 +22,9 @@ package org.deeplearning4j.models.word2vec; import java.io.Serializable; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; public class VocabWork implements Serializable { @@ -40,7 +42,7 @@ public class VocabWork implements Serializable { public VocabWork(AtomicInteger count, String work, boolean stem, String label) { - this(count, work, stem, Arrays.asList(label)); + this(count, work, stem, Collections.singletonList(label)); } public VocabWork(AtomicInteger count, String work, boolean stem, List label) { @@ -97,11 +99,11 @@ public class VocabWork implements Serializable { if (stem != vocabWork.stem) return false; - if (count != null ? !count.equals(vocabWork.count) : vocabWork.count != null) + if (!Objects.equals(count, vocabWork.count)) return false; - if (label != null ? !label.equals(vocabWork.label) : vocabWork.label != null) + if (!Objects.equals(label, vocabWork.label)) return false; - return !(work != null ? !work.equals(vocabWork.work) : vocabWork.work != null); + return !(!Objects.equals(work, vocabWork.work)); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataFetcher.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataFetcher.java index 71bd91a75..e7d740dff 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataFetcher.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataFetcher.java @@ -51,15 +51,15 @@ public class Word2VecDataFetcher implements DataSetFetcher { */ private static final long serialVersionUID = 3245955804749769475L; private transient Iterator files; - private Word2Vec vec; - private static Pattern begin = Pattern.compile("<[A-Z]+>"); - private static Pattern end = Pattern.compile(""); + private final Word2Vec vec; + private static final Pattern begin = Pattern.compile("<[A-Z]+>"); + private static final Pattern end = Pattern.compile(""); private List labels = new ArrayList<>(); private int batch; - private List cache = new ArrayList<>(); + private final List cache = new ArrayList<>(); private static final Logger log = LoggerFactory.getLogger(Word2VecDataFetcher.class); private int totalExamples; - private String path; + private final String path; public Word2VecDataFetcher(String path, Word2Vec vec, List labels) { if (vec == null || labels == null || labels.isEmpty()) diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIterator.java index d06f92e7b..99d90793f 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIterator.java @@ -42,10 +42,10 @@ import java.util.concurrent.CopyOnWriteArrayList; @Slf4j public class Word2VecDataSetIterator implements DataSetIterator { - private Word2Vec vec; - private LabelAwareSentenceIterator iter; - private List cachedWindow; - private List labels; + private final Word2Vec vec; + private final LabelAwareSentenceIterator iter; + private final List cachedWindow; + private final List labels; private int batch = 10; @Getter private DataSetPreProcessor preProcessor; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java index 9c3417878..7cbe51806 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructor.java @@ -48,7 +48,7 @@ public class VocabConstructor { private boolean useAdaGrad = false; private boolean fetchLabels = false; private int limit; - private AtomicLong seqCount = new AtomicLong(0); + private final AtomicLong seqCount = new AtomicLong(0); private InvertedIndex index; private boolean enableScavenger = false; private T unk; @@ -453,7 +453,7 @@ public class VocabConstructor { } public static class Builder { - private List> sources = new ArrayList<>(); + private final List> sources = new ArrayList<>(); private VocabCache cache; private Collection stopWords = new ArrayList<>(); private boolean useAdaGrad = false; @@ -608,7 +608,7 @@ public class VocabConstructor { private final Sequence document; private final AbstractCache targetVocab; private final AtomicLong loopCounter; - private AtomicBoolean done = new AtomicBoolean(false); + private final AtomicBoolean done = new AtomicBoolean(false); public VocabRunnable(@NonNull AbstractCache targetVocab, @NonNull Sequence sequence, @NonNull AtomicLong finalCounter, @NonNull AtomicLong loopCounter) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyHolder.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyHolder.java index 668305d1c..eb4fe89f5 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyHolder.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyHolder.java @@ -39,7 +39,7 @@ public class VocabularyHolder implements Serializable { private final Map vocabulary = new ConcurrentHashMap<>(); // idxMap marked as transient, since there's no real reason to save this data on serialization - private transient Map idxMap = new ConcurrentHashMap<>(); + private final transient Map idxMap = new ConcurrentHashMap<>(); private int minWordFrequency = 0; private boolean hugeModelExpected = false; private int retentionDelay = 3; @@ -52,11 +52,11 @@ public class VocabularyHolder implements Serializable { private long totalWordOccurrences = 0; // for scavenger mechanics we need to know the actual number of words being added - private transient AtomicLong hiddenWordsCounter = new AtomicLong(0); + private final transient AtomicLong hiddenWordsCounter = new AtomicLong(0); - private AtomicInteger totalWordCount = new AtomicInteger(0); + private final AtomicInteger totalWordCount = new AtomicInteger(0); - private Logger logger = LoggerFactory.getLogger(VocabularyHolder.class); + private final Logger logger = LoggerFactory.getLogger(VocabularyHolder.class); private static final int MAX_CODE_LENGTH = 40; @@ -285,7 +285,6 @@ public class VocabularyHolder implements Serializable { && hiddenWordsCounter.incrementAndGet() % scavengerThreshold == 0) activateScavenger(); - return; } } @@ -410,9 +409,9 @@ public class VocabularyHolder implements Serializable { int i; // get vocabulary as sorted list List vocab = this.words(); - int count[] = new int[vocab.size() * 2 + 1]; - int parent_node[] = new int[vocab.size() * 2 + 1]; - byte binary[] = new byte[vocab.size() * 2 + 1]; + int[] count = new int[vocab.size() * 2 + 1]; + int[] parent_node = new int[vocab.size() * 2 + 1]; + byte[] binary = new byte[vocab.size() * 2 + 1]; // at this point vocab is sorted, with descending order for (int a = 0; a < vocab.size(); a++) diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyWord.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyWord.java index 46c2103af..46d4b2f28 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyWord.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/VocabularyWord.java @@ -29,6 +29,7 @@ import com.fasterxml.jackson.databind.SerializationFeature; import java.io.IOException; import java.io.Serializable; +import java.util.Objects; @Data public class VocabularyWord implements Serializable { @@ -84,7 +85,7 @@ public class VocabularyWord implements Serializable { VocabularyWord word1 = (VocabularyWord) o; - return word != null ? word.equals(word1.word) : word1.word == null; + return Objects.equals(word, word1.word); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCache.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCache.java index 7c96a3ae6..641f32f9e 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCache.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCache.java @@ -70,14 +70,14 @@ public class AbstractCache implements VocabCache { // we're using for compatibility & failproof reasons: it's easier to store unique labels then abstract objects of unknown size // TODO: wtf this one is doing here? - private List stopWords = new ArrayList<>(); // stop words + private final List stopWords = new ArrayList<>(); // stop words // this variable defines how often scavenger will be activated private int scavengerThreshold = 3000000; // ser private int retentionDelay = 3; // ser // for scavenger mechanics we need to know the actual number of words being added - private transient AtomicLong hiddenWordsCounter = new AtomicLong(0); + private final transient AtomicLong hiddenWordsCounter = new AtomicLong(0); private final AtomicLong totalWordCount = new AtomicLong(0); // ser @@ -180,7 +180,7 @@ public class AbstractCache implements VocabCache { */ public boolean containsElement(T element) { // FIXME: lolwtf - return vocabulary.values().contains(element); + return vocabulary.containsValue(element); } /** diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/InMemoryLookupCache.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/InMemoryLookupCache.java index 349c827fb..96248536f 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/InMemoryLookupCache.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/InMemoryLookupCache.java @@ -32,6 +32,7 @@ import java.io.InputStream; import java.io.Serializable; import java.util.Collection; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; @@ -330,9 +331,7 @@ public class InMemoryLookupCache implements VocabCache, Serializable @Override public synchronized boolean addToken(VocabWord word) { - if (null == tokens.put(word.getLabel(), word)) - return true; - return false; + return null == tokens.put(word.getLabel(), word); } @Override @@ -448,11 +447,11 @@ public class InMemoryLookupCache implements VocabCache, Serializable if (numDocs != that.numDocs) return false; - if (wordIndex != null ? !wordIndex.equals(that.wordIndex) : that.wordIndex != null) + if (!Objects.equals(wordIndex, that.wordIndex)) return false; - if (wordFrequencies != null ? !wordFrequencies.equals(that.wordFrequencies) : that.wordFrequencies != null) + if (!Objects.equals(wordFrequencies, that.wordFrequencies)) return false; - if (docFrequencies != null ? !docFrequencies.equals(that.docFrequencies) : that.docFrequencies != null) + if (!Objects.equals(docFrequencies, that.docFrequencies)) return false; if (vocabWords().equals(that.vocabWords())) return true; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/FileDocumentIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/FileDocumentIterator.java index 990f5e629..2f9bc3fef 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/FileDocumentIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/FileDocumentIterator.java @@ -42,7 +42,7 @@ public class FileDocumentIterator implements DocumentIterator { private Iterator iter; private LineIterator lineIterator; - private File rootDir; + private final File rootDir; private static final Logger log = LoggerFactory.getLogger(FileDocumentIterator.class); public FileDocumentIterator(String path) { @@ -116,7 +116,7 @@ public class FileDocumentIterator implements DocumentIterator { if (rootDir.isDirectory()) iter = FileUtils.iterateFiles(rootDir, null, true); else - iter = Arrays.asList(rootDir).iterator(); + iter = Collections.singletonList(rootDir).iterator(); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/FilenamesLabelAwareIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/FilenamesLabelAwareIterator.java index 98eba8752..99553ae8c 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/FilenamesLabelAwareIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/FilenamesLabelAwareIterator.java @@ -114,8 +114,8 @@ public class FilenamesLabelAwareIterator implements LabelAwareIterator { public static class Builder { protected List foldersToScan = new ArrayList<>(); - private List fileList = new ArrayList<>(); - private List labels = new ArrayList<>(); + private final List fileList = new ArrayList<>(); + private final List labels = new ArrayList<>(); private boolean absPath = false; public Builder() { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/LabelsSource.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/LabelsSource.java index 8088eadc0..1dc09a26a 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/LabelsSource.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/documentiterator/LabelsSource.java @@ -32,13 +32,13 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; public class LabelsSource implements Serializable { - private AtomicLong counter = new AtomicLong(0); + private final AtomicLong counter = new AtomicLong(0); @Setter private String template; private boolean useFormatter = false; private List labels; private long maxCount = 0; - private Set uniq = Collections.newSetFromMap(new ConcurrentHashMap()); + private final Set uniq = Collections.newSetFromMap(new ConcurrentHashMap()); public LabelsSource() { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/inputsanitation/InputHomogenization.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/inputsanitation/InputHomogenization.java index 049f0c047..2d34687da 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/inputsanitation/InputHomogenization.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/inputsanitation/InputHomogenization.java @@ -25,7 +25,7 @@ import java.text.Normalizer.Form; import java.util.List; public class InputHomogenization { - private String input; + private final String input; private List ignoreCharactersContaining; private boolean preserveCase; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/ContextLabelRetriever.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/ContextLabelRetriever.java index 4c5ea7506..7772413b6 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/ContextLabelRetriever.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/ContextLabelRetriever.java @@ -33,8 +33,8 @@ import java.util.List; public class ContextLabelRetriever { - private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>"; - private static String END_LABEL = ""; + private static final String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>"; + private static final String END_LABEL = ""; private ContextLabelRetriever() {} @@ -66,7 +66,7 @@ public class ContextLabelRetriever { //no labels; add these as NONE and begin the new label if (!currTokens.isEmpty()) { - tokensWithSameLabel.add(new Pair<>("NONE", (List) new ArrayList<>(currTokens))); + tokensWithSameLabel.add(new Pair<>("NONE", new ArrayList<>(currTokens))); currTokens.clear(); } @@ -85,7 +85,7 @@ public class ContextLabelRetriever { Preconditions.checkState(!endLabel.isEmpty(), "End label is empty!"); Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s", currLabel, endLabel); - tokensWithSameLabel.add(new Pair<>(currLabel, (List) new ArrayList<>(currTokens))); + tokensWithSameLabel.add(new Pair<>(currLabel, new ArrayList<>(currTokens))); currTokens.clear(); //clear out the tokens @@ -96,7 +96,7 @@ public class ContextLabelRetriever { //no labels; add these as NONE and begin the new label if (!currTokens.isEmpty()) { - tokensWithSameLabel.add(new Pair<>("none", (List) new ArrayList<>(currTokens))); + tokensWithSameLabel.add(new Pair<>("none", new ArrayList<>(currTokens))); currTokens.clear(); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/Window.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/Window.java index 89244c7ba..739fa0907 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/Window.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/Window.java @@ -37,10 +37,10 @@ public class Window implements Serializable { private String label = "NONE"; private boolean beginLabel; private boolean endLabel; - private int windowSize; + private final int windowSize; private int median; - private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>"; - private static String END_LABEL = ""; + private static final String BEGIN_LABEL = "<([A-Z]+|\\d+)>"; + private static final String END_LABEL = ""; private int begin, end; /** diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/WordConverter.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/WordConverter.java index 5860dbf76..23f59131b 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/WordConverter.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/movingwindow/WordConverter.java @@ -32,7 +32,7 @@ import java.util.List; public class WordConverter { private List sentences = new ArrayList<>(); - private Word2Vec vec; + private final Word2Vec vec; private List windows; public WordConverter(List sentences, Word2Vec vec) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIterator.java index d140452dd..a88989050 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/AggregatingSentenceIterator.java @@ -28,9 +28,9 @@ import java.util.List; import java.util.concurrent.atomic.AtomicInteger; public class AggregatingSentenceIterator implements SentenceIterator { - private List backendIterators; + private final List backendIterators; private SentencePreProcessor preProcessor; - private AtomicInteger position = new AtomicInteger(0); + private final AtomicInteger position = new AtomicInteger(0); private AggregatingSentenceIterator(@NonNull List list) { this.backendIterators = list; @@ -82,7 +82,7 @@ public class AggregatingSentenceIterator implements SentenceIterator { } public static class Builder { - private List backendIterators = new ArrayList<>(); + private final List backendIterators = new ArrayList<>(); private SentencePreProcessor preProcessor; public Builder() { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/BasicLineIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/BasicLineIterator.java index 3da5acfd4..0f9f6b0c5 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/BasicLineIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/BasicLineIterator.java @@ -30,7 +30,7 @@ import java.util.Iterator; public class BasicLineIterator implements SentenceIterator, Iterable { private BufferedReader reader; - private InputStream backendStream; + private final InputStream backendStream; private SentencePreProcessor preProcessor; private boolean internal = false; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/BasicResultSetIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/BasicResultSetIterator.java index 0be1cf537..feaa357bb 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/BasicResultSetIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/BasicResultSetIterator.java @@ -25,8 +25,8 @@ import java.sql.SQLException; public class BasicResultSetIterator implements SentenceIterator { - private ResultSet rs; - private String columnName; + private final ResultSet rs; + private final String columnName; private SentencePreProcessor preProcessor; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/CollectionSentenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/CollectionSentenceIterator.java index be0a834a7..b3911a407 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/CollectionSentenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/CollectionSentenceIterator.java @@ -26,7 +26,7 @@ import java.util.Iterator; public class CollectionSentenceIterator extends BaseSentenceIterator { private Iterator iter; - private Collection coll; + private final Collection coll; public CollectionSentenceIterator(SentencePreProcessor preProcessor, Collection coll) { super(preProcessor); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/FileSentenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/FileSentenceIterator.java index 31c0dfc6c..98772f160 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/FileSentenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/FileSentenceIterator.java @@ -29,6 +29,7 @@ import java.io.File; import java.io.FileInputStream; import java.io.IOException; import java.util.Arrays; +import java.util.Collections; import java.util.Iterator; import java.util.Queue; import java.util.zip.GZIPInputStream; @@ -59,7 +60,7 @@ public class FileSentenceIterator extends BaseSentenceIterator { if (file.isDirectory()) fileIterator = FileUtils.iterateFiles(file, null, true); else - fileIterator = Arrays.asList(file).iterator(); + fileIterator = Collections.singletonList(file).iterator(); } public FileSentenceIterator(File dir) { @@ -141,7 +142,7 @@ public class FileSentenceIterator extends BaseSentenceIterator { @Override public void reset() { if (file.isFile()) - fileIterator = Arrays.asList(file).iterator(); + fileIterator = Collections.singletonList(file).iterator(); else fileIterator = FileUtils.iterateFiles(file, null, true); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/LineSentenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/LineSentenceIterator.java index 76ac3f37a..77ea21b1c 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/LineSentenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/LineSentenceIterator.java @@ -29,7 +29,7 @@ public class LineSentenceIterator extends BaseSentenceIterator { private InputStream file; private LineIterator iter; - private File f; + private final File f; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIterator.java index 576dd6140..f2eda2b4f 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/MutipleEpochsSentenceIterator.java @@ -25,9 +25,9 @@ import lombok.NonNull; import java.util.concurrent.atomic.AtomicInteger; public class MutipleEpochsSentenceIterator implements SentenceIterator { - private SentenceIterator iterator; - private int numEpochs; - private AtomicInteger counter = new AtomicInteger(0); + private final SentenceIterator iterator; + private final int numEpochs; + private final AtomicInteger counter = new AtomicInteger(0); public MutipleEpochsSentenceIterator(@NonNull SentenceIterator iterator, int numEpochs) { this.numEpochs = numEpochs; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIterator.java index ef83f32ba..f513d4eab 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/PrefetchingSentenceIterator.java @@ -60,7 +60,7 @@ public class PrefetchingSentenceIterator implements SentenceIterator { @Override public boolean hasNext() { - return (reader != null) ? reader.hasMoreLines() : false; + return reader != null && reader.hasMoreLines(); } @Override @@ -93,7 +93,7 @@ public class PrefetchingSentenceIterator implements SentenceIterator { } public static class Builder { - private SentenceIterator iterator; + private final SentenceIterator iterator; private int fetchSize = 10000; private SentencePreProcessor preProcessor; @@ -123,13 +123,13 @@ public class PrefetchingSentenceIterator implements SentenceIterator { } private class AsyncIteratorReader extends Thread implements Runnable { - private SentenceIterator iterator; - private int fetchSize; - private AtomicBoolean shouldTerminate = new AtomicBoolean(false); - private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); - private SentencePreProcessor preProcessor; - private AtomicBoolean isRunning = new AtomicBoolean(true); - private ArrayBlockingQueue buffer; + private final SentenceIterator iterator; + private final int fetchSize; + private final AtomicBoolean shouldTerminate = new AtomicBoolean(false); + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + private final SentencePreProcessor preProcessor; + private final AtomicBoolean isRunning = new AtomicBoolean(true); + private final ArrayBlockingQueue buffer; public AsyncIteratorReader(@NonNull SentenceIterator iterator, int fetchSize, SentencePreProcessor preProcessor) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/StreamLineIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/StreamLineIterator.java index 5bfcff057..6734d9747 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/StreamLineIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/StreamLineIterator.java @@ -35,7 +35,7 @@ import java.util.concurrent.atomic.AtomicBoolean; @Slf4j public class StreamLineIterator implements SentenceIterator { - private DocumentIterator iterator; + private final DocumentIterator iterator; private int linesToFetch; private final LinkedBlockingQueue buffer = new LinkedBlockingQueue<>(); private SentencePreProcessor preProcessor; @@ -118,14 +118,14 @@ public class StreamLineIterator implements SentenceIterator { } public static class Builder { - private DocumentIterator iterator; + private final DocumentIterator iterator; private int linesToFetch = 50; private SentencePreProcessor preProcessor; public Builder(@NonNull final InputStream stream) { this(new DocumentIterator() { private final InputStream onlyStream = stream; - private AtomicBoolean isConsumed = new AtomicBoolean(false); + private final AtomicBoolean isConsumed = new AtomicBoolean(false); @Override public boolean hasNext() { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/SynchronizedSentenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/SynchronizedSentenceIterator.java index ceaa3ad4a..aa80f8963 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/SynchronizedSentenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/SynchronizedSentenceIterator.java @@ -23,7 +23,7 @@ package org.deeplearning4j.text.sentenceiterator; import lombok.NonNull; public class SynchronizedSentenceIterator implements SentenceIterator { - private SentenceIterator underlyingIterator; + private final SentenceIterator underlyingIterator; public SynchronizedSentenceIterator(@NonNull SentenceIterator iterator) { this.underlyingIterator = iterator; diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/interoperability/SentenceIteratorConverter.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/interoperability/SentenceIteratorConverter.java index 74b7172e6..959edf763 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/interoperability/SentenceIteratorConverter.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/interoperability/SentenceIteratorConverter.java @@ -32,8 +32,8 @@ import org.slf4j.LoggerFactory; import java.util.List; public class SentenceIteratorConverter implements LabelAwareIterator { - private SentenceIterator backendIterator; - private LabelsSource generator; + private final SentenceIterator backendIterator; + private final LabelsSource generator; protected static final Logger log = LoggerFactory.getLogger(SentenceIteratorConverter.class); public SentenceIteratorConverter(@NonNull SentenceIterator iterator) { diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/labelaware/LabelAwareFileSentenceIterator.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/labelaware/LabelAwareFileSentenceIterator.java index 38f8124eb..cf329345f 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/labelaware/LabelAwareFileSentenceIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/sentenceiterator/labelaware/LabelAwareFileSentenceIterator.java @@ -25,6 +25,7 @@ import org.deeplearning4j.text.sentenceiterator.SentencePreProcessor; import java.io.File; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class LabelAwareFileSentenceIterator extends FileSentenceIterator implements LabelAwareSentenceIterator { @@ -49,6 +50,6 @@ public class LabelAwareFileSentenceIterator extends FileSentenceIterator impleme @Override public List currentLabels() { - return Arrays.asList(currentFile.getParentFile().getName()); + return Collections.singletonList(currentFile.getParentFile().getName()); } } diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/DefaultStreamTokenizer.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/DefaultStreamTokenizer.java index d7a8a2a1b..4aadef999 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/DefaultStreamTokenizer.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/DefaultStreamTokenizer.java @@ -36,10 +36,10 @@ import java.util.concurrent.atomic.AtomicInteger; */ public class DefaultStreamTokenizer implements Tokenizer { - private StreamTokenizer streamTokenizer; + private final StreamTokenizer streamTokenizer; private TokenPreProcess tokenPreProcess; - private List tokens = new ArrayList<>(); - private AtomicInteger position = new AtomicInteger(0); + private final List tokens = new ArrayList<>(); + private final AtomicInteger position = new AtomicInteger(0); protected static final Logger log = LoggerFactory.getLogger(DefaultStreamTokenizer.class); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/DefaultTokenizer.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/DefaultTokenizer.java index d504a2872..93d29ed4c 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/DefaultTokenizer.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/DefaultTokenizer.java @@ -34,7 +34,7 @@ public class DefaultTokenizer implements Tokenizer { tokenizer = new StringTokenizer(tokens); } - private StringTokenizer tokenizer; + private final StringTokenizer tokenizer; private TokenPreProcess tokenPreProcess; @Override diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/CompositePreProcessor.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/CompositePreProcessor.java index 0d861f851..8b3e3ac9d 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/CompositePreProcessor.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/CompositePreProcessor.java @@ -31,7 +31,7 @@ import java.util.List; public class CompositePreProcessor implements TokenPreProcess { - private List preProcessors; + private final List preProcessors; public CompositePreProcessor(@NonNull TokenPreProcess... preProcessors){ Preconditions.checkState(preProcessors.length > 0, "No preprocessors were specified (empty input)"); diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/StringCleaning.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/StringCleaning.java index f6db851b0..d7eaf8543 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/StringCleaning.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizer/preprocessor/StringCleaning.java @@ -24,7 +24,7 @@ import java.util.regex.Pattern; public class StringCleaning { - private static final Pattern punctPattern = Pattern.compile("[\\d\\.:,\"\'\\(\\)\\[\\]|/?!;]+"); + private static final Pattern punctPattern = Pattern.compile("[\\d\\.:,\"'\\(\\)\\[\\]|/?!;]+"); private StringCleaning() {} diff --git a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizerfactory/NGramTokenizerFactory.java b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizerfactory/NGramTokenizerFactory.java index 6203a7715..e42093108 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizerfactory/NGramTokenizerFactory.java +++ b/cavis-dnn/cavis-dnn-nlp/src/main/java/org/deeplearning4j/text/tokenization/tokenizerfactory/NGramTokenizerFactory.java @@ -33,7 +33,7 @@ public class NGramTokenizerFactory implements TokenizerFactory { private TokenPreProcess preProcess; private Integer minN = 1; private Integer maxN = 1; - private TokenizerFactory tokenizerFactory; + private final TokenizerFactory tokenizerFactory; public NGramTokenizerFactory(TokenizerFactory tokenizerFactory, Integer minN, Integer maxN) { this.tokenizerFactory = tokenizerFactory; diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java index f4303f28e..82e7493ab 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/iterator/TestBertIterator.java @@ -49,12 +49,12 @@ import static org.junit.jupiter.api.Assertions.*; @Timeout(200) public class TestBertIterator extends BaseDL4JTest { - private static File pathToVocab = Resources.asFile("other/vocab.txt"); - private static Charset c = StandardCharsets.UTF_8; - private static String shortSentence = "I saw a girl with a telescope."; - private static String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; - private static String sentenceA = "Goodnight noises everywhere"; - private static String sentenceB = "Goodnight moon"; + private static final File pathToVocab = Resources.asFile("other/vocab.txt"); + private static final Charset c = StandardCharsets.UTF_8; + private static final String shortSentence = "I saw a girl with a telescope."; + private static final String longSentence = "Donaudampfschifffahrts Kapitänsmützeninnenfuttersaum"; + private static final String sentenceA = "Goodnight noises everywhere"; + private static final String sentenceB = "Goodnight moon"; public TestBertIterator() throws IOException { } @@ -534,18 +534,18 @@ public class TestBertIterator extends BaseDL4JTest { @Getter private static class TestSentencePairsHelper { - private List sentencesLeft; - private List sentencesRight; - private List> sentencePairs; - private List> tokenizedSentencesLeft; - private List> tokenizedSentencesRight; - private List labels; + private final List sentencesLeft; + private final List sentencesRight; + private final List> sentencePairs; + private final List> tokenizedSentencesLeft; + private final List> tokenizedSentencesRight; + private final List labels; private int shortL; private int longL; private int sentenceALen; private int sentenceBLen; - private BertWordPieceTokenizerFactory tokenizer; - private CollectionLabeledPairSentenceProvider pairSentenceProvider; + private final BertWordPieceTokenizerFactory tokenizer; + private final CollectionLabeledPairSentenceProvider pairSentenceProvider; private TestSentencePairsHelper() throws IOException { this(3); @@ -596,13 +596,13 @@ public class TestBertIterator extends BaseDL4JTest { @Getter private static class TestSentenceHelper { - private List sentences; - private List> tokenizedSentences; - private List labels; + private final List sentences; + private final List> tokenizedSentences; + private final List labels; private int shortestL = 0; private int longestL = 0; - private BertWordPieceTokenizerFactory tokenizer; - private CollectionLabeledSentenceProvider sentenceProvider; + private final BertWordPieceTokenizerFactory tokenizer; + private final CollectionLabeledSentenceProvider sentenceProvider; private TestSentenceHelper() throws IOException { this(false, 2); diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java index 5faad62ad..30e633075 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/embeddings/inmemory/InMemoryLookupTableTest.java @@ -83,14 +83,14 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { assertEquals(244, cacheSource.numWords()); InMemoryLookupTable mem1 = - (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(100) - .cache(cacheSource).seed(17).build(); + new InMemoryLookupTable.Builder().vectorLength(100) + .cache(cacheSource).seed(17).build(); mem1.resetWeights(true); InMemoryLookupTable mem2 = - (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(100) - .cache(cacheSource).seed(15).build(); + new InMemoryLookupTable.Builder().vectorLength(100) + .cache(cacheSource).seed(15).build(); mem2.resetWeights(true); @@ -130,8 +130,8 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { assertEquals(244, cacheSource.numWords()); InMemoryLookupTable mem1 = - (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(100) - .cache(cacheSource).build(); + new InMemoryLookupTable.Builder().vectorLength(100) + .cache(cacheSource).build(); mem1.resetWeights(true); @@ -160,8 +160,8 @@ public class InMemoryLookupTableTest extends BaseDL4JTest { InMemoryLookupTable mem2 = - (InMemoryLookupTable) new InMemoryLookupTable.Builder().vectorLength(100) - .cache(cacheTarget).seed(18).build(); + new InMemoryLookupTable.Builder().vectorLength(100) + .cache(cacheTarget).seed(18).build(); mem2.resetWeights(true); diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java index 2879fdffe..8b604dfc0 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/fasttext/FastTextTest.java @@ -52,10 +52,10 @@ public class FastTextTest extends BaseDL4JTest { - private File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt"); - private File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin"); - private File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin"); - private File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec"); + private final File inputFile = Resources.asFile("models/fasttext/data/labeled_data.txt"); + private final File supModelFile = Resources.asFile("models/fasttext/supervised.model.bin"); + private final File cbowModelFile = Resources.asFile("models/fasttext/cbow.model.bin"); + private final File supervisedVectors = Resources.asFile("models/fasttext/supervised.model.vec"); @TempDir public File testDir; diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java index 84869467b..eba85a8fc 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java @@ -300,7 +300,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { SerializationUtils.saveObject(vec, tempFile); - ParagraphVectors vec2 = (ParagraphVectors) SerializationUtils.readObject(tempFile); + ParagraphVectors vec2 = SerializationUtils.readObject(tempFile); INDArray day2 = vec2.getWordVectorMatrix("day").dup(); List labelsBinary = vec2.labelsSource.getLabels(); @@ -985,8 +985,8 @@ public class ParagraphVectorsTest extends BaseDL4JTest { assertNotEquals(null, d2v.getLookupTable()); assertNotEquals(null, d2v.getVocab()); - assertTrue(d2v.getVocab() == w2v.getVocab()); - assertTrue(d2v.getLookupTable() == w2v.getLookupTable()); + assertSame(d2v.getVocab(), w2v.getVocab()); + assertSame(d2v.getLookupTable(), w2v.getLookupTable()); String textA = "Donald Trump referred to President Obama as \"your president\" during the first presidential debate on Monday, much to many people’s chagrin on social media. Trump, made the reference after saying that the greatest threat facing the world is nuclear weapons. He then turned to Hillary Clinton and said, \"Not global warming like you think and your President thinks,\" referring to Obama."; @@ -1156,7 +1156,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { Word2Vec unserialized = null; try { json = paragraphVectors.toJson(); - log.info("{}", json.toString()); + log.info("{}", json); unserialized = ParagraphVectors.fromJson(json); } catch (Exception e) { @@ -1164,12 +1164,12 @@ public class ParagraphVectorsTest extends BaseDL4JTest { fail(); } - assertEquals(cache.totalWordOccurrences(), ((ParagraphVectors) unserialized).getVocab().totalWordOccurrences()); - assertEquals(cache.totalNumberOfDocs(), ((ParagraphVectors) unserialized).getVocab().totalNumberOfDocs()); + assertEquals(cache.totalWordOccurrences(), unserialized.getVocab().totalWordOccurrences()); + assertEquals(cache.totalNumberOfDocs(), unserialized.getVocab().totalNumberOfDocs()); for (int i = 0; i < words.length; ++i) { val cached = cache.wordAtIndex(i); - val restored = ((ParagraphVectors) unserialized).getVocab().wordAtIndex(i); + val restored = unserialized.getVocab().wordAtIndex(i); assertNotNull(cached); assertEquals(cached, restored); } diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java index 02bfcf733..09c621c4d 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java @@ -472,7 +472,7 @@ public class SequenceVectorsTest extends BaseDL4JTest { @Override public String toString() { return "VocabWord{" + "wordFrequency=" + this.elementFrequency + ", index=" + index + ", codes=" + codes - + ", word='" + String.valueOf(id) + '\'' + ", points=" + points + ", codeLength=" + + ", word='" + id + '\'' + ", points=" + points + ", codeLength=" + codeLength + '}'; } } diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java index 7c150a610..2771d4ae4 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/graph/walkers/impl/RandomWalkerTest.java @@ -93,7 +93,7 @@ public class RandomWalkerTest extends BaseDL4JTest { for (int i = 0; i < 10; i++) { Vertex vertex = graph.getVertex(i); - assertEquals(null, vertex.getValue()); + assertNull(vertex.getValue()); assertEquals(i, vertex.vertexID()); } assertEquals(10, graph.numVertices()); @@ -101,7 +101,7 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom1() throws Exception { - RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graph) + RandomWalker walker = new RandomWalker.Builder<>(graph) .setNoEdgeHandling(NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED).setWalkLength(3).build(); int cnt = 0; @@ -123,7 +123,7 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom2() throws Exception { - RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graph) + RandomWalker walker = new RandomWalker.Builder<>(graph) .setSeed(12345) .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) .setWalkDirection(WalkDirection.FORWARD_UNIQUE) @@ -148,7 +148,7 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom3() throws Exception { - RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graph) + RandomWalker walker = new RandomWalker.Builder<>(graph) .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) .setWalkDirection(WalkDirection.FORWARD_UNIQUE) .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).build(); @@ -160,17 +160,17 @@ public class RandomWalkerTest extends BaseDL4JTest { } // if cycle passed without exception - something went bad - assertTrue(false); + fail(); } catch (NoEdgesException e) { // this cycle should throw exception } catch (Exception e) { - assertTrue(false); + fail(); } } @Test public void testGraphTraverseRandom4() throws Exception { - RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graphBig) + RandomWalker walker = new RandomWalker.Builder<>(graphBig) .setSeed(12345) .setNoEdgeHandling(NoEdgeHandling.EXCEPTION_ON_DISCONNECTED).setWalkLength(20) .setWalkDirection(WalkDirection.FORWARD_UNIQUE) @@ -187,7 +187,7 @@ public class RandomWalkerTest extends BaseDL4JTest { @Test public void testGraphTraverseRandom5() throws Exception { - RandomWalker walker = (RandomWalker) new RandomWalker.Builder<>(graphBig) + RandomWalker walker = new RandomWalker.Builder<>(graphBig) .setWalkLength(20).setWalkDirection(WalkDirection.FORWARD_UNIQUE) .setNoEdgeHandling(NoEdgeHandling.CUTOFF_ON_DISCONNECTED).build(); diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java index ceef572e1..d82844919 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/transformers/impl/iterables/ParallelTransformerIteratorTest.java @@ -48,7 +48,7 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals; @Slf4j @Timeout(300) public class ParallelTransformerIteratorTest extends BaseDL4JTest { - private TokenizerFactory factory = new DefaultTokenizerFactory(); + private final TokenizerFactory factory = new DefaultTokenizerFactory(); @BeforeEach public void setUp() throws Exception { @@ -165,7 +165,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { String testString = ""; String[] stringsArray = new String[100]; for (int i = 0; i < 100; ++i) { - testString += Integer.toString(i) + " "; + testString += i + " "; stringsArray[i] = Integer.toString(i); } InputStream inputStream = IOUtils.toInputStream(testString, "UTF-8"); @@ -196,7 +196,7 @@ public class ParallelTransformerIteratorTest extends BaseDL4JTest { String testStrings = ""; for (int i = 0; i < 1000; ++i) { stringsArray[i] = Integer.toString(i); - testStrings += Integer.toString(i) + "\n"; + testStrings += i + "\n"; } InputStream inputStream = IOUtils.toInputStream(testStrings, "UTF-8"); SentenceIterator iterator = new BasicLineIterator(inputStream); diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java index 7d806aafb..948a124b3 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/iterator/Word2VecDataSetIteratorTest.java @@ -93,7 +93,7 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest { iterator.reset(); return new LabelAwareSentenceIterator() { - private AtomicInteger cnt = new AtomicInteger(0); + private final AtomicInteger cnt = new AtomicInteger(0); @Override public String currentLabel() { diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java index 33b2715bf..f51337cba 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/VocabConstructorTest.java @@ -83,8 +83,7 @@ public class VocabConstructorTest extends BaseDL4JTest { continue; cnt++; - if (!set.contains(token)) - set.add(token); + set.add(token); } lines++; @@ -167,7 +166,7 @@ public class VocabConstructorTest extends BaseDL4JTest { public Iterator> iterator() { return new Iterator>() { - private AtomicBoolean switcher = new AtomicBoolean(true); + private final AtomicBoolean switcher = new AtomicBoolean(true); @Override public boolean hasNext() { @@ -216,7 +215,7 @@ public class VocabConstructorTest extends BaseDL4JTest { public Iterator> iterator() { return new Iterator>() { - private AtomicBoolean switcher = new AtomicBoolean(true); + private final AtomicBoolean switcher = new AtomicBoolean(true); @Override public boolean hasNext() { diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java index 9cdf38363..56e7e8bce 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/wordstore/inmemory/AbstractCacheTest.java @@ -136,7 +136,7 @@ public class AbstractCacheTest extends BaseDL4JTest { AbstractCache unserialized = null; try { json = cache.toJson(); - log.info("{}", json.toString()); + log.info("{}", json); unserialized = AbstractCache.fromJson(json); } @@ -159,7 +159,7 @@ public class AbstractCacheTest extends BaseDL4JTest { public void testUserClassSerialization() { AbstractCache cache = new AbstractCache.Builder().build(); - ExtVocabWord words[] = new ExtVocabWord[3]; + ExtVocabWord[] words = new ExtVocabWord[3]; words[0] = new ExtVocabWord("some", 1100, 1.0, "word"); words[1] = new ExtVocabWord("none", 23214, 2.0, "test"); words[2] = new ExtVocabWord("wwew", 13223, 3.0, "tester"); diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java index d36ab414c..665c0b180 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/BertWordPieceTokenizerTests.java @@ -47,8 +47,8 @@ import static org.junit.jupiter.api.Assertions.*; @Timeout(300) public class BertWordPieceTokenizerTests extends BaseDL4JTest { - private File pathToVocab = Resources.asFile("other/vocab.txt"); - private Charset c = StandardCharsets.UTF_8; + private final File pathToVocab = Resources.asFile("other/vocab.txt"); + private final Charset c = StandardCharsets.UTF_8; public BertWordPieceTokenizerTests() throws IOException { } diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java index 4e185df0e..d134f3000 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java @@ -55,7 +55,7 @@ public class NearestNeighborsClient { // Only one time Unirest.setObjectMapper(new ObjectMapper() { - private com.fasterxml.jackson.databind.ObjectMapper jacksonObjectMapper = + private final com.fasterxml.jackson.databind.ObjectMapper jacksonObjectMapper = new com.fasterxml.jackson.databind.ObjectMapper(); public T readValue(String value, Class valueType) { diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java index 6aff39dbb..c6c091837 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java @@ -25,8 +25,8 @@ public enum Distance { JACCARD("jaccard"), HAMMING("hamming"); - private String functionName; - private Distance(String name) { + private final String functionName; + Distance(String name) { functionName = name; } diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java index 25542dc8f..9a55db1da 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java @@ -33,7 +33,8 @@ public class CentersHolder { protected transient INDArray distances; protected transient INDArray argMin; - private long rows, cols; + private final long rows; + private final long cols; public CentersHolder(long rows, long cols) { this.rows = rows; @@ -46,7 +47,7 @@ public class CentersHolder { public synchronized void addCenter(INDArray pointView) { if (centers == null) - this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols}); + this.centers = Nd4j.create(pointView.dataType(), rows, cols); centers.putRow(index++, pointView); } @@ -56,7 +57,7 @@ public class CentersHolder { distances = Nd4j.create(centers.dataType(), centers.rows()); if (argMin == null) - argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); + argMin = Nd4j.createUninitialized(DataType.LONG); if (op == null) { op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); @@ -80,7 +81,7 @@ public class CentersHolder { distances = Nd4j.create(centers.dataType(), centers.rows()); if (argMin == null) - argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); + argMin = Nd4j.createUninitialized(DataType.LONG); if (op == null) { op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java index 1c57bc38a..c55834c1b 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java @@ -33,8 +33,8 @@ public class ClusterSetInfo implements Serializable { private Map clustersInfos = new HashMap<>(); private Table distancesBetweenClustersCenters = HashBasedTable.create(); private AtomicInteger pointLocationChange; - private boolean threadSafe; - private boolean inverse; + private final boolean threadSafe; + private final boolean inverse; public ClusterSetInfo(boolean inverse) { this(inverse, false); diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java index 013263629..bbb68bbec 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java @@ -29,10 +29,10 @@ import java.io.Serializable; public class HyperRect implements Serializable { //private List points; - private float[] lowerEnds; - private float[] higherEnds; - private INDArray lowerEndsIND; - private INDArray higherEndsIND; + private final float[] lowerEnds; + private final float[] higherEnds; + private final INDArray lowerEndsIND; + private final INDArray higherEndsIND; public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) { this.lowerEnds = new float[lowerEndsIn.length]; diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java index 68ccf6281..207ccb39b 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java @@ -128,7 +128,7 @@ public class KDTree implements Serializable { // Share this data for recursive calls of "knn" private float currentDistance; private INDArray currentPoint; - private INDArray minDistance = Nd4j.scalar(0.f); + private final INDArray minDistance = Nd4j.scalar(0.f); public List> knn(INDArray point, float distance) { diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java index 75e342e78..53307e517 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java @@ -63,14 +63,14 @@ public class RandomProjectionLSH implements LSH { return "cosinedistance"; } - @Getter private int hashLength; + @Getter private final int hashLength; - @Getter private int numTables; + @Getter private final int numTables; - @Getter private int inDimension; + @Getter private final int inDimension; - @Getter private double radius; + @Getter private final double radius; INDArray randomProjection; @@ -190,7 +190,7 @@ public class RandomProjectionLSH implements LSH { INDArray bucketData(INDArray query){ INDArray mask = bucket(query); int nRes = mask.sum(0).getInt(0); - INDArray res = Nd4j.create(new int[] {nRes, inDimension}); + INDArray res = Nd4j.create(nRes, inDimension); int j = 0; for (int i = 0; i < nRes; i++){ while (mask.getInt(j) == 0 && j < mask.length() - 1) { @@ -216,7 +216,7 @@ public class RandomProjectionLSH implements LSH { int accepted = 0; while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1; - INDArray res = Nd4j.create(new int[] {accepted, inDimension}); + INDArray res = Nd4j.create(accepted, inDimension); for(int i = 0; i < accepted; i++){ res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); } diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java index b26ffc636..0dae7e642 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java @@ -42,9 +42,9 @@ public class QuadTree implements Serializable { private Cell boundary; static final int QT_NO_DIMS = 2; static final int QT_NODE_CAPACITY = 1; - private INDArray buf = Nd4j.create(QT_NO_DIMS); + private final INDArray buf = Nd4j.create(QT_NO_DIMS); private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS); - private int[] index = new int[QT_NODE_CAPACITY]; + private final int[] index = new int[QT_NODE_CAPACITY]; /** diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java index 1360a5c92..6bf1d1faa 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java @@ -25,6 +25,7 @@ import org.nd4j.common.primitives.Pair; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.ExecutorService; @@ -111,15 +112,15 @@ public class RPTree { * @return a list of samples */ public List> queryWithDistances(INDArray query, int numResults) { - return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction); + return RPUtils.queryAllWithDistances(query,X, Collections.singletonList(this),numResults,similarityFunction); } public INDArray query(INDArray query,int numResults) { - return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction); + return RPUtils.queryAll(query,X, Collections.singletonList(this),numResults,similarityFunction); } public List getCandidates(INDArray target) { - return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction); + return RPUtils.getCandidates(target, Collections.singletonList(this),similarityFunction); } diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java index aecdae476..184264b9e 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java @@ -36,7 +36,7 @@ import java.util.*; public class RPUtils { - private static ThreadLocal> functionInstances = new ThreadLocal<>(); + private static final ThreadLocal> functionInstances = new ThreadLocal<>(); public static DifferentialFunction getOp(String name, INDArray x, diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java index 2781f2ce4..6526ff687 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java @@ -24,7 +24,7 @@ import java.io.Serializable; * @author Adam Gibson */ public class Cell implements Serializable { - private int dimension; + private final int dimension; private INDArray corner, width; public Cell(int dimension) { diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java index 1ef6dcaf6..df82c0c31 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java @@ -58,7 +58,7 @@ public class SpTree implements Serializable { private boolean isLeaf = true; private Collection indices; private SpTree[] children; - private static Logger log = LoggerFactory.getLogger(SpTree.class); + private static final Logger log = LoggerFactory.getLogger(SpTree.class); private String similarityFunction = Distance.EUCLIDEAN.toString(); diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java index 0f657569e..93fbb073a 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java @@ -144,7 +144,7 @@ public class MathUtils { * * @return the correlation coefficient or r */ - public static double correlation(double[] residuals, double targetAttribute[]) { + public static double correlation(double[] residuals, double[] targetAttribute) { double[] predictedValues = new double[residuals.length]; for (int i = 0; i < predictedValues.length; i++) { predictedValues[i] = targetAttribute[i] - residuals[i]; @@ -1011,7 +1011,7 @@ public class MathUtils { */ public static /*@pure@*/ double roundDouble(double value, int afterDecimalPoint) { - double mask = Math.pow(10.0, (double) afterDecimalPoint); + double mask = Math.pow(10.0, afterDecimalPoint); return (double) (Math.round(value * mask)) / mask; }//end roundDouble diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java index 5ca73a1ac..d3ab0536f 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java @@ -24,7 +24,7 @@ import java.util.concurrent.*; public class MultiThreadUtils { - private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class); + private static final Logger log = LoggerFactory.getLogger(MultiThreadUtils.class); private static ExecutorService instance; diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java index 9dbc75416..df4a08991 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java @@ -35,13 +35,13 @@ import java.util.List; * nearby points by k in a greedy fashion */ public class VPTreeFillSearch { - private VPTree vpTree; - private int k; + private final VPTree vpTree; + private final int k; @Getter private List results; @Getter private List distances; - private INDArray target; + private final INDArray target; public VPTreeFillSearch(VPTree vpTree, int k, INDArray target) { this.vpTree = vpTree; diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java index 00beb9e71..f0b864d84 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java @@ -72,11 +72,11 @@ public class KDTreeTest extends BaseDL4JTest { @Test public void testTree() { KDTree tree = new KDTree(2); - INDArray half = Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT); - INDArray one = Nd4j.create(new double[] {1, 1}, new long[]{1,2}).castTo(DataType.FLOAT); + INDArray half = Nd4j.create(new double[] {0.5, 0.5}, 1,2).castTo(DataType.FLOAT); + INDArray one = Nd4j.create(new double[] {1, 1}, 1,2).castTo(DataType.FLOAT); tree.insert(half); tree.insert(one); - Pair pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT)); + Pair pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, 1,2).castTo(DataType.FLOAT)); assertEquals(half, pair.getValue()); } diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java index 40683daa9..63c4bcf15 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java @@ -38,7 +38,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; @Timeout(120) public class KMeansTest extends BaseDL4JTest { - private boolean[] useKMeansPlusPlus = {true, false}; + private final boolean[] useKMeansPlusPlus = {true, false}; @Test public void testKMeans() { @@ -178,9 +178,9 @@ public class KMeansTest extends BaseDL4JTest { ClusterSet clusterSet = kMeansClustering.applyTo(points); - INDArray row0 = Nd4j.createFromArray(new double[]{16.6575, 18.4850}); - INDArray row1 = Nd4j.createFromArray(new double[]{32.6050, 31.1500}); - INDArray row2 = Nd4j.createFromArray(new double[]{75.9348, 74.1990}); + INDArray row0 = Nd4j.createFromArray(16.6575, 18.4850); + INDArray row1 = Nd4j.createFromArray(32.6050, 31.1500); + INDArray row2 = Nd4j.createFromArray(75.9348, 74.1990); /*List clusters = clusterSet.getClusters(); assertEquals(row0, clusters.get(0).getCenter().getArray()); @@ -211,9 +211,9 @@ public class KMeansTest extends BaseDL4JTest { int rows = 3, cols = 2; CentersHolder ch = new CentersHolder(rows, cols); - INDArray row0 = Nd4j.createFromArray(new double[]{16.4000, 17.1200}); - INDArray row1 = Nd4j.createFromArray(new double[]{45.8000, 54.2300}); - INDArray row2 = Nd4j.createFromArray(new double[]{95.9348, 94.1990}); + INDArray row0 = Nd4j.createFromArray(16.4000, 17.1200); + INDArray row1 = Nd4j.createFromArray(45.8000, 54.2300); + INDArray row2 = Nd4j.createFromArray(95.9348, 94.1990); ch.addCenter(row0); ch.addCenter(row1); diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java index 5973a1f5a..39cb8b10a 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java @@ -62,18 +62,15 @@ public class SPTreeTest extends BaseDL4JTest { 0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041}; INDArray data = Nd4j.createFromArray(aData).reshape(11,5); - INDArray rows = Nd4j.createFromArray(new int[]{ - 0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); - INDArray cols = Nd4j.createFromArray(new int[]{ - 4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); - INDArray vals = Nd4j.createFromArray(new double[] - { 0.6806, 0.1978, 0.1349, 0.0403, 0.0087, 0.0369, 0.0081, 0.0172, 0.0014, 0.0046, 0.0081, 0.3375, 0.2274, 0.0556, 0.0098, 0.0175, 0.0027, 0.0077, 0.0014, 0.0023, 0.0175, 0.6569, 0.1762, 0.0254, 0.0200, 0.0118, 0.0074, 0.0046, 0.0124, 0.0012, 0.1978, 0.0014, 0.0254, 0.7198, 0.0712, 0.0850, 0.0389, 0.0555, 0.0418, 0.0286, 0.6806, 0.3375, 0.0074, 0.0712, 0.2290, 0.0224, 0.0189, 0.0080, 0.0187, 0.0097, 0.0172, 0.0124, 0.0418, 0.7799, 0.0521, 0.0395, 0.0097, 0.0030, 0.0023, 1.706e-5, 0.0087, 0.0027, 0.6569, 0.0850, 0.0080, 0.5562, 0.0173, 0.0015, 1.706e-5, 0.0369, 0.0077, 0.0286, 0.0187, 0.7799, 0.0711, 0.0200, 0.0084, 0.0012, 0.0403, 0.0556, 0.1762, 0.0389, 0.0224, 0.0030, 0.5562, 0.0084, 0.0060, 0.0028, 0.0014, 0.2274, 0.0200, 0.0555, 0.0189, 0.0521, 0.0015, 0.0711, 0.0028, 0.3911, 0.1349, 0.0098, 0.0118, 0.7198, 0.2290, 0.0395, 0.0173, 0.0200, 0.0060, 0.3911}); + INDArray rows = Nd4j.createFromArray(0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99); + INDArray cols = Nd4j.createFromArray(4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1); + INDArray vals = Nd4j.createFromArray(0.6806, 0.1978, 0.1349, 0.0403, 0.0087, 0.0369, 0.0081, 0.0172, 0.0014, 0.0046, 0.0081, 0.3375, 0.2274, 0.0556, 0.0098, 0.0175, 0.0027, 0.0077, 0.0014, 0.0023, 0.0175, 0.6569, 0.1762, 0.0254, 0.0200, 0.0118, 0.0074, 0.0046, 0.0124, 0.0012, 0.1978, 0.0014, 0.0254, 0.7198, 0.0712, 0.0850, 0.0389, 0.0555, 0.0418, 0.0286, 0.6806, 0.3375, 0.0074, 0.0712, 0.2290, 0.0224, 0.0189, 0.0080, 0.0187, 0.0097, 0.0172, 0.0124, 0.0418, 0.7799, 0.0521, 0.0395, 0.0097, 0.0030, 0.0023, 1.706e-5, 0.0087, 0.0027, 0.6569, 0.0850, 0.0080, 0.5562, 0.0173, 0.0015, 1.706e-5, 0.0369, 0.0077, 0.0286, 0.0187, 0.7799, 0.0711, 0.0200, 0.0084, 0.0012, 0.0403, 0.0556, 0.1762, 0.0389, 0.0224, 0.0030, 0.5562, 0.0084, 0.0060, 0.0028, 0.0014, 0.2274, 0.0200, 0.0555, 0.0189, 0.0521, 0.0015, 0.0711, 0.0028, 0.3911, 0.1349, 0.0098, 0.0118, 0.7198, 0.2290, 0.0395, 0.0173, 0.0200, 0.0060, 0.3911); SpTree tree = new SpTree(data); INDArray posF = Nd4j.create(11, 5); /*try (MemoryWorkspace ws = tree.workspace().notifyScopeEntered())*/ { tree.computeEdgeForces(rows, cols, vals, 11, posF); } - INDArray expected = Nd4j.createFromArray(new double[]{ -0.08045664291717945, -0.1010737980370276, 0.01793326162563703, 0.16108447776416351, -0.20679423033936287, -0.15788549368713395, 0.02546624825966788, 0.062309466206907055, -0.165806093080134, 0.15266225270841186, 0.17508365896345726, 0.09588570563583201, 0.34124767300538084, 0.14606666020839956, -0.06786563815470595, -0.09326646571247202, -0.19896040730569928, -0.3618837364446506, 0.13946315445146712, -0.04570186310149667, -0.2473462951783839, -0.41362278505023914, -0.1094083777758208, 0.10705807646770374, 0.24462088260113946, 0.21722270026621748, -0.21799892431326567, -0.08205544003080587, -0.11170161709042685, -0.2674768703060442, 0.03617747284043274, 0.16430316252598698, 0.04552845070022399, 0.2593696744801452, 0.1439989190892037, -0.059339471967457376, 0.05460893792863096, -0.0595168036583193, -0.2527693197519917, -0.15850951859835274, -0.2945536856938165, 0.15434659331638875, -0.022910846947667776, 0.23598009757792854, -0.11149279745674007, 0.09670616593772939, 0.11125703954547914, -0.08519984596392606, -0.12779827002328714, 0.23025192887225998, 0.13741473964038722, -0.06193553503816597, -0.08349781586292176, 0.1622156410642145, 0.155975447743472}).reshape(11,5); + INDArray expected = Nd4j.createFromArray(-0.08045664291717945, -0.1010737980370276, 0.01793326162563703, 0.16108447776416351, -0.20679423033936287, -0.15788549368713395, 0.02546624825966788, 0.062309466206907055, -0.165806093080134, 0.15266225270841186, 0.17508365896345726, 0.09588570563583201, 0.34124767300538084, 0.14606666020839956, -0.06786563815470595, -0.09326646571247202, -0.19896040730569928, -0.3618837364446506, 0.13946315445146712, -0.04570186310149667, -0.2473462951783839, -0.41362278505023914, -0.1094083777758208, 0.10705807646770374, 0.24462088260113946, 0.21722270026621748, -0.21799892431326567, -0.08205544003080587, -0.11170161709042685, -0.2674768703060442, 0.03617747284043274, 0.16430316252598698, 0.04552845070022399, 0.2593696744801452, 0.1439989190892037, -0.059339471967457376, 0.05460893792863096, -0.0595168036583193, -0.2527693197519917, -0.15850951859835274, -0.2945536856938165, 0.15434659331638875, -0.022910846947667776, 0.23598009757792854, -0.11149279745674007, 0.09670616593772939, 0.11125703954547914, -0.08519984596392606, -0.12779827002328714, 0.23025192887225998, 0.13741473964038722, -0.06193553503816597, -0.08349781586292176, 0.1622156410642145, 0.155975447743472).reshape(11,5); for (int i = 0; i < 11; ++i) assertArrayEquals(expected.getRow(i).toDoubleVector(), posF.getRow(i).toDoubleVector(), 1e-2); diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java index c4146ebe2..9fa41f77f 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java @@ -40,7 +40,7 @@ public class VPTreeSerializationTests extends BaseDL4JTest { @Test public void testSerialization_1() throws Exception { - val points = Nd4j.rand(new int[] {10, 15}); + val points = Nd4j.rand(10, 15); val treeA = new VPTree(points, true, 2); try (val bos = new ByteArrayOutputStream()) { @@ -84,7 +84,7 @@ public class VPTreeSerializationTests extends BaseDL4JTest { @Test public void testNewConstructor_1() { - val points = Nd4j.rand(new int[] {10, 15}); + val points = Nd4j.rand(10, 15); val treeA = new VPTree(points, true, 2); val rows = Nd4j.tear(points, 1); diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java index 99acc67d7..d7f8e0a29 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java @@ -406,7 +406,7 @@ public class VpTreeNodeTest extends BaseDL4JTest { i = 0; for (DataPoint p : results) sortedResults.putRow(i++, p.getPoint()); - INDArray[] sortedWithIndices = Nd4j.sortWithIndices(sortedResults, dimensionToSort, true);; + INDArray[] sortedWithIndices = Nd4j.sortWithIndices(sortedResults, dimensionToSort, true); sortedResults = sortedWithIndices[1]; assertEquals(trueResults.sumNumber().doubleValue(), sortedResults.sumNumber().doubleValue(), 1e-5); } diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java index 93cc48675..661d1f46f 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborsServer.java @@ -63,15 +63,15 @@ public class NearestNeighborsServer extends AbstractVerticle { private static class RunArgs { @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true) - private String ndarrayPath = null; + private final String ndarrayPath = null; @Parameter(names = {"--labelsPath"}, arity = 1, required = false) - private String labelsPath = null; + private final String labelsPath = null; @Parameter(names = {"--nearestNeighborsPort"}, arity = 1) - private int port = 9000; + private final int port = 9000; @Parameter(names = {"--similarityFunction"}, arity = 1) - private String similarityFunction = "euclidean"; + private final String similarityFunction = "euclidean"; @Parameter(names = {"--invert"}, arity = 1) - private boolean invert = false; + private final boolean invert = false; } private static RunArgs instanceArgs; @@ -93,7 +93,7 @@ public class NearestNeighborsServer extends AbstractVerticle { log.error("Error in NearestNeighboursServer parameters", e); StringBuilder sb = new StringBuilder(); jcmdr.usage(sb); - log.error("Usage: {}", sb.toString()); + log.error("Usage: {}", sb); //User provides invalid input -> print the usage info jcmdr.usage(); @@ -211,12 +211,10 @@ public class NearestNeighborsServer extends AbstractVerticle { rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) .putHeader("content-type", "application/json") .end(JsonMappers.getMapper().writeValueAsString(results)); - return; } catch (Throwable e) { log.error("Error in POST /knn",e); rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) .end("Error parsing request - " + e.getMessage()); - return; } }); @@ -270,7 +268,6 @@ public class NearestNeighborsServer extends AbstractVerticle { log.error("Error in POST /knnnew",e); rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) .end("Error parsing request - " + e.getMessage()); - return; } }); } diff --git a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java index 22e67ac59..87b1c8c72 100644 --- a/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java +++ b/cavis-dnn/cavis-dnn-nn-parent/cavis-dnn-nn-server/src/test/java/org/deeplearning4j/nearestneighbor/server/NearestNeighborTest.java @@ -93,7 +93,7 @@ public class NearestNeighborTest extends BaseDL4JTest { @Test public void vpTreeTest() throws Exception { - INDArray matrix = Nd4j.rand(new int[] {400,10}); + INDArray matrix = Nd4j.rand(400,10); INDArray rowVector = matrix.getRow(70); INDArray resultArr = Nd4j.zeros(400,1); Executor executor = Executors.newSingleThreadExecutor(); @@ -144,7 +144,7 @@ public class NearestNeighborTest extends BaseDL4JTest { int numNeighbors = 42; INDArray points = Nd4j.rand(numRows, numCols); VPTree tree = new VPTree(points); - INDArray query = Nd4j.rand(new int[] {1, numCols}); + INDArray query = Nd4j.rand(1, numCols); VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query); fillSearch.search(); List results = fillSearch.getResults(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java index d0acb2e06..8f55745ed 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java @@ -93,7 +93,7 @@ public class EarlyStoppingConfiguration implements Serializable private EarlyStoppingModelSaver modelSaver = new InMemoryModelSaver<>(); private List epochTerminationConditions = new ArrayList<>(); - private List iterationTerminationConditions = new ArrayList<>(); + private final List iterationTerminationConditions = new ArrayList<>(); private boolean saveLastModel = false; private int evaluateEveryNEpochs = 1; private ScoreCalculator scoreCalculator; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/saver/LocalFileGraphSaver.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/saver/LocalFileGraphSaver.java index 314747866..4b08e401f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/saver/LocalFileGraphSaver.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/saver/LocalFileGraphSaver.java @@ -34,8 +34,8 @@ public class LocalFileGraphSaver implements EarlyStoppingModelSaver { public enum ROCType {ROC, BINARY, MULTICLASS} - public enum Metric {AUC, AUPRC}; + public enum Metric {AUC, AUPRC} protected final ROCType type; protected final Metric metric; @@ -80,7 +80,7 @@ public class ROCScoreCalculator extends BaseIEvaluationScoreCalculator implements IEarlyStoppingTrainer { - private static Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class); + private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class); protected T model; @@ -294,7 +292,7 @@ public abstract class BaseEarlyStoppingTrainer implements IEarl } if (epochTerminate) { log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, - termReason.toString()); + termReason); T bestModel; try { bestModel = esConfig.getModelSaver().getBestModel(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingGraphTrainer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingGraphTrainer.java index d78967448..e0011f535 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingGraphTrainer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingGraphTrainer.java @@ -32,7 +32,7 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; public class EarlyStoppingGraphTrainer extends BaseEarlyStoppingTrainer { //implements IEarlyStoppingTrainer { - private ComputationGraph net; + private final ComputationGraph net; /** * @param esConfig Configuration diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingTrainer.java index c96ad86f9..f4df7a3d4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingTrainer.java @@ -34,8 +34,8 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; public class EarlyStoppingTrainer extends BaseEarlyStoppingTrainer { - private MultiLayerNetwork net; - private boolean isMultiEpoch = false; + private final MultiLayerNetwork net; + private final boolean isMultiEpoch = false; public EarlyStoppingTrainer(EarlyStoppingConfiguration earlyStoppingConfiguration, diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/BaseEvaluation.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/BaseEvaluation.java index e7318364d..1c784f9b1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/BaseEvaluation.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/eval/BaseEvaluation.java @@ -42,9 +42,9 @@ import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; public abstract class BaseEvaluation extends org.nd4j.evaluation.BaseEvaluation { @Getter - private static ObjectMapper objectMapper = configureMapper(new ObjectMapper()); + private static final ObjectMapper objectMapper = configureMapper(new ObjectMapper()); @Getter - private static ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory())); + private static final ObjectMapper yamlMapper = configureMapper(new ObjectMapper(new YAMLFactory())); private static ObjectMapper configureMapper(ObjectMapper ret) { ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index efe5b0f60..8fe4b99a3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -929,13 +929,9 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { public GraphBuilder removeVertex(String vertexName, boolean removeConnections) { vertices.remove(vertexName); vertexInputs.remove(vertexName); - if (networkInputs.contains(vertexName)) { - networkInputs.remove(vertexName); - } + networkInputs.remove(vertexName); if (removeConnections) { - if (networkOutputs.contains(vertexName)) { - networkOutputs.remove(vertexName); - } + networkOutputs.remove(vertexName); Map> newVertexInputs = new LinkedHashMap<>(); for (Map.Entry> entry : this.vertexInputs.entrySet()) { List inputs = entry.getValue(); @@ -954,9 +950,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { } this.vertexInputs = newVertexInputs; - if (inputPreProcessors.containsKey(vertexName)) { - inputPreProcessors.remove(vertexName); - } + inputPreProcessors.remove(vertexName); } return this; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index 5ceb3ea63..69ff898e2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -144,8 +144,8 @@ public class NeuralNetConfiguration implements Serializable, Cloneable { */ public static class ListBuilder extends MultiLayerConfiguration.Builder { private int layerCounter = -1; //Used only for .layer(Layer) method - private Map layerwise; - private Builder globalConfig; + private final Map layerwise; + private final Builder globalConfig; // Constructor public ListBuilder(Builder globalConfig, Map layerMap) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MaxNormConstraint.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MaxNormConstraint.java index 8f2994cf5..43fdc4254 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MaxNormConstraint.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MaxNormConstraint.java @@ -62,7 +62,7 @@ public class MaxNormConstraint extends BaseConstraint { */ public MaxNormConstraint(double maxNorm, int... dimensions) { - this(maxNorm, Collections.emptySet(), dimensions); + this(maxNorm, Collections.emptySet(), dimensions); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.java index 895072c39..6449a9abd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.java @@ -70,7 +70,7 @@ public class MinMaxNormConstraint extends BaseConstraint { * parameters which have order [depthOut, depthIn, kH, kW] */ public MinMaxNormConstraint(double min, double max, double rate, int... dimensions){ - this(min, max, rate, Collections.emptySet(), dimensions); + this(min, max, rate, Collections.emptySet(), dimensions); } /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/UnitNormConstraint.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/UnitNormConstraint.java index 8e06315be..a082056a7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/UnitNormConstraint.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/UnitNormConstraint.java @@ -44,7 +44,7 @@ public class UnitNormConstraint extends BaseConstraint { * parameters which have order [depthOut, depthIn, kH, kW] */ public UnitNormConstraint(int... dimensions){ - this(Collections.emptySet(), dimensions); + this(Collections.emptySet(), dimensions); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/BinomialDistribution.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/BinomialDistribution.java index 883b027eb..14c6d368a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/BinomialDistribution.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/BinomialDistribution.java @@ -77,9 +77,7 @@ public class BinomialDistribution extends Distribution { BinomialDistribution other = (BinomialDistribution) obj; if (numberOfTrials != other.numberOfTrials) return false; - if (Double.doubleToLongBits(probabilityOfSuccess) != Double.doubleToLongBits(other.probabilityOfSuccess)) - return false; - return true; + return Double.doubleToLongBits(probabilityOfSuccess) == Double.doubleToLongBits(other.probabilityOfSuccess); } public String toString() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/NormalDistribution.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/NormalDistribution.java index 566c58f66..1c867a6ff 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/NormalDistribution.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/NormalDistribution.java @@ -87,9 +87,7 @@ public class NormalDistribution extends Distribution { NormalDistribution other = (NormalDistribution) obj; if (Double.doubleToLongBits(mean) != Double.doubleToLongBits(other.mean)) return false; - if (Double.doubleToLongBits(std) != Double.doubleToLongBits(other.std)) - return false; - return true; + return Double.doubleToLongBits(std) == Double.doubleToLongBits(other.std); } public String toString() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/serde/LegacyDistributionDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/serde/LegacyDistributionDeserializer.java index 88415f1cf..ecf9fee12 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/serde/LegacyDistributionDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/distribution/serde/LegacyDistributionDeserializer.java @@ -33,7 +33,7 @@ import java.io.IOException; public class LegacyDistributionDeserializer extends JsonDeserializer { @Override public Distribution deserialize(JsonParser jp, DeserializationContext deserializationContext) - throws IOException, JsonProcessingException { + throws IOException { //Manually parse old format JsonNode node = jp.getCodec().readTree(jp); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/SubsetVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/SubsetVertex.java index 9c98c8b3b..52f0e059d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/SubsetVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/SubsetVertex.java @@ -63,7 +63,7 @@ public class SubsetVertex extends GraphVertex { @Override public int hashCode() { - return new Integer(from).hashCode() ^ new Integer(to).hashCode(); + return Integer.valueOf(from).hashCode() ^ Integer.valueOf(to).hashCode(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java index 94cfe157f..a974a7f91 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AbstractLSTM.java @@ -84,7 +84,7 @@ public abstract class AbstractLSTM extends BaseRecurrentLayer { * @param gateActivationFn Activation function for the LSTM gates */ public T gateActivationFunction(String gateActivationFn) { - return (T) gateActivationFunction(Activation.fromString(gateActivationFn)); + return gateActivationFunction(Activation.fromString(gateActivationFn)); } /** @@ -94,7 +94,7 @@ public abstract class AbstractLSTM extends BaseRecurrentLayer { * @param gateActivationFn Activation function for the LSTM gates */ public T gateActivationFunction(Activation gateActivationFn) { - return (T) gateActivationFunction(gateActivationFn.getActivationFunction()); + return gateActivationFunction(gateActivationFn.getActivationFunction()); } /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java index 548883015..c6f31faf3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java @@ -150,7 +150,7 @@ public class CapsuleLayer extends SameDiffLayer { public void defineParameters(SDLayerParams params) { params.clear(); params.addWeightParam(WEIGHT_PARAM, - 1, inputCapsules, capsules * capsuleDimensions, inputCapsuleDimensions, 1); + 1, inputCapsules, (long) capsules * capsuleDimensions, inputCapsuleDimensions, 1); if(hasBias){ params.addBiasParam(BIAS_PARAM, @@ -168,7 +168,7 @@ public class CapsuleLayer extends SameDiffLayer { WeightInitUtil.initWeights( inputCapsules * inputCapsuleDimensions, capsules * capsuleDimensions, - new long[]{1, inputCapsules, capsules * capsuleDimensions, + new long[]{1, inputCapsules, (long) capsules * capsuleDimensions, inputCapsuleDimensions, 1}, this.weightInit, null, diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java index 43cc2e9b0..820d73d5d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java @@ -81,12 +81,10 @@ public class CenterLossOutputLayer extends BaseOutputLayer { @Override public IUpdater getUpdaterByParam(String paramName) { // center loss utilizes alpha directly for this so any updater can be used for other layers - switch (paramName) { - case CenterLossParamInitializer.CENTER_KEY: - return new NoOp(); - default: - return iUpdater; + if (CenterLossParamInitializer.CENTER_KEY.equals(paramName)) { + return new NoOp(); } + return iUpdater; } public double getAlpha() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 9276408a9..ae26e62f0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -49,7 +49,7 @@ public class ConvolutionLayer extends FeedForwardLayer { protected boolean hasBias = true; protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate; //Default to truncate here - default for 0.6.0 and earlier networks on JSON deserialization - protected int dilation[] = new int[] {1, 1}; + protected int[] dilation = new int[] {1, 1}; protected int[] kernelSize; // Square filter protected int[] stride; // Default is 2. Down-sample by a factor of 2 protected int[] padding; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java index 102c0c008..76a943509 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java @@ -176,7 +176,7 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer { */ public Builder helperAllowFallback(boolean allowFallback) { this.setHelperAllowFallback(allowFallback); - return (Builder) this; + return this; } @SuppressWarnings("unchecked") diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java index 417edd8ce..d6015e022 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/InputTypeUtil.java @@ -76,8 +76,8 @@ public class InputTypeUtil { return InputType.convolutional(hOut, wOut, outputDepth, i.getFormat()); } - long hOut = sH * (hIn - 1) + kH - 2 * padH; - long wOut = sW * (wIn - 1) + kW - 2 * padW; + long hOut = sH * (hIn - 1) + kH - 2L * padH; + long wOut = sW * (wIn - 1) + kW - 2L * padW; return InputType.convolutional(hOut, wOut, outputDepth, i.getFormat()); } @@ -126,9 +126,9 @@ public class InputTypeUtil { return InputType.convolutional3D(dataFormat, dOut, hOut, wOut, outputDepth); } - long hOut = sH * (hIn - 1) + kH - 2 * padH; - long wOut = sW * (wIn - 1) + kW - 2 * padW; - long dOut = sD * (dIn - 1) + kD - 2 * padD; + long hOut = sH * (hIn - 1) + kH - 2L * padH; + long wOut = sW * (wIn - 1) + kW - 2L * padW; + long dOut = sD * (dIn - 1) + kD - 2L * padD; return InputType.convolutional3D(dataFormat, dOut, hOut, wOut, outputDepth); } @@ -179,20 +179,20 @@ public class InputTypeUtil { stride, padding, outputChannels, convolutionMode)); } - if (kH <= 0 || (padH > 0 && kH > inHeight + 2 * padH)) { + if (kH <= 0 || (padH > 0 && kH > inHeight + 2L * padH)) { throw new DL4JInvalidConfigException(getConfigErrorCommonLine(layerIdx, layerName, layerClass, true) + " Invalid input configuration for kernel height. Require 0 < kH <= inHeight + 2*padH; got (kH=" + kH + ", inHeight=" + inHeight + ", padH=" + padH + ")\n" + getConfigErrorCommonLastLine( inputType, kernelSize, stride, padding, outputChannels, convolutionMode)); } - if (kW <= 0 || (padW > 0 && kW > inWidth + 2 * padW)) { + if (kW <= 0 || (padW > 0 && kW > inWidth + 2L * padW)) { throw new DL4JInvalidConfigException(getConfigErrorCommonLine(layerIdx, layerName, layerClass, false) + " Invalid input configuration for kernel width. Require 0 < kW <= inWidth + 2*padW; got (kW=" + kW + ", inWidth=" + inWidth + ", padW=" + padW + ")\n" + getConfigErrorCommonLastLine( inputType, kernelSize, stride, padding, outputChannels, convolutionMode)); } - if (kD <= 0 || (padD > 0 && kD > inDepth + 2 * padD)) { + if (kD <= 0 || (padD > 0 && kD > inDepth + 2L * padD)) { throw new DL4JInvalidConfigException(getConfigErrorCommonLine(layerIdx, layerName, layerClass, false) + " Invalid input configuration for kernel channels. Require 0 < kD <= inDepth + 2*padD; got (kD=" + kD + ", inDepth=" + inDepth + ", padD=" + padD + ")\n" + getConfigErrorCommonLastLine( @@ -200,7 +200,7 @@ public class InputTypeUtil { } //Strict mode: require exactly the right size... if (convolutionMode == ConvolutionMode.Strict) { - if ((inHeight - kH + 2 * padH) % sH != 0) { + if ((inHeight - kH + 2L * padH) % sH != 0) { double d = (inHeight - kH + 2 * padH) / ((double) sH) + 1.0; String str = String.format("%.2f", d); int truncated = (int) d; @@ -218,7 +218,7 @@ public class InputTypeUtil { convolutionMode)); } - if ((inWidth - kW + 2 * padW) % sW != 0) { + if ((inWidth - kW + 2L * padW) % sW != 0) { double d = (inWidth - kW + 2 * padW) / ((double) sW) + 1.0; String str = String.format("%.2f", d); int truncated = (int) d; @@ -236,7 +236,7 @@ public class InputTypeUtil { convolutionMode)); } - if ((inDepth - kD + 2 * padD) % sD != 0) { + if ((inDepth - kD + 2L * padD) % sD != 0) { double d = (inDepth - kD + 2 * padD) / ((double) sD) + 1.0; String str = String.format("%.2f", d); int truncated = (int) d; @@ -262,9 +262,9 @@ public class InputTypeUtil { return InputType.convolutional3D(dataFormat, outD, outH, outW, outputChannels); } - long dOut = (inDepth - kD + 2 * padD) / sD + 1; - long hOut = (inHeight - kH + 2 * padH) / sH + 1; - long wOut = (inWidth - kW + 2 * padW) / sW + 1; + long dOut = (inDepth - kD + 2L * padD) / sD + 1; + long hOut = (inHeight - kH + 2L * padH) / sH + 1; + long wOut = (inWidth - kW + 2L * padW) / sW + 1; return InputType.convolutional3D(dOut, hOut, wOut, outputChannels); } @@ -396,7 +396,7 @@ public class InputTypeUtil { convolutionMode)); } //note the padding check > 0 here. This validation fails for padding == 0. Verified on resnet50 - if (kH <= 0 || padH > 0 && (padH > 0 && kH > inHeight + 2 * padH)) { + if (kH <= 0 || padH > 0 && (padH > 0 && kH > inHeight + 2L * padH)) { throw new DL4JInvalidConfigException(getConfigErrorCommonLine(layerIdx, layerName, layerClass, true) + " Invalid input configuration for kernel height. Require 0 < kH <= inHeight + 2*padH; got (kH=" + kH + ", inHeight=" + inHeight + ", padH=" + padH + ")\n" + getConfigErrorCommonLastLine( @@ -404,7 +404,7 @@ public class InputTypeUtil { } //note the padding check > 0 here. This validation fails for padding == 0. Verified on resnet50 - if (kW <= 0 || padW > 0 && (padW > 0 && kW > inWidth + 2 * padW)) { + if (kW <= 0 || padW > 0 && (padW > 0 && kW > inWidth + 2L * padW)) { throw new DL4JInvalidConfigException(getConfigErrorCommonLine(layerIdx, layerName, layerClass, false) + " Invalid input configuration for kernel width. Require 0 < kW <= inWidth + 2*padW; got (kW=" + kW + ", inWidth=" + inWidth + ", padW=" + padW + ")\n" + getConfigErrorCommonLastLine( @@ -413,7 +413,7 @@ public class InputTypeUtil { //Strict mode: require exactly the right size... if (convolutionMode == ConvolutionMode.Strict) { - if ((inHeight - kH + 2 * padH) % sH != 0) { + if ((inHeight - kH + 2L * padH) % sH != 0) { double d = (inHeight - kH + 2 * padH) / ((double) sH) + 1.0; String str = String.format("%.2f", d); int truncated = (int) d; @@ -431,7 +431,7 @@ public class InputTypeUtil { } - if ((inWidth - kW + 2 * padW) % sW != 0) { + if ((inWidth - kW + 2L * padW) % sW != 0) { double d = (inWidth - kW + 2 * padW) / ((double) sW) + 1.0; String str = String.format("%.2f", d); int truncated = (int) d; @@ -455,8 +455,8 @@ public class InputTypeUtil { - long hOut = (inHeight - kH + 2 * padH) / sH + 1; - long wOut = (inWidth - kW + 2 * padW) / sW + 1; + long hOut = (inHeight - kH + 2L * padH) / sH + 1; + long wOut = (inWidth - kW + 2L * padW) / sW + 1; return InputType.convolutional(hOut, wOut, outputDepth, format); } @@ -596,15 +596,13 @@ public class InputTypeUtil { case FF: for(int i = 0; i < vertexInputs.length; i++) { if(vertexInputs[i].getType() != maxType) { - switch(vertexInputs[i].getType()) { - case RNN: - InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) vertexInputs[i]; - if(recurrent.getTimeSeriesLength() == 1) { - vertexInputs[i] = InputType.feedForward(recurrent.getSize()); - } - break; - default: - throw new IllegalArgumentException("Attempted conversion of types and was unable to"); + if (vertexInputs[i].getType() == InputType.Type.RNN) { + InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) vertexInputs[i]; + if (recurrent.getTimeSeriesLength() == 1) { + vertexInputs[i] = InputType.feedForward(recurrent.getSize()); + } + } else { + throw new IllegalArgumentException("Attempted conversion of types and was unable to"); } } } @@ -621,14 +619,11 @@ public class InputTypeUtil { } for(int i = 0; i < vertexInputs.length; i++) { if(vertexInputs[i].getType() != maxType) { - switch(vertexInputs[i].getType()) { - case FF: - InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) vertexInputs[i]; - vertexInputs[i] = InputType.recurrent(ff.getSize(),rnnFormat); - break; - default: - throw new IllegalArgumentException("Attempted conversion of types and was unable to"); - + if (vertexInputs[i].getType() == InputType.Type.FF) { + InputType.InputTypeFeedForward ff = (InputType.InputTypeFeedForward) vertexInputs[i]; + vertexInputs[i] = InputType.recurrent(ff.getSize(), rnnFormat); + } else { + throw new IllegalArgumentException("Attempted conversion of types and was unable to"); } } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 724a0c22d..b44055332 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -146,7 +146,7 @@ public class LocallyConnected2D extends SameDiffLayer { @Override public void defineParameters(SDLayerParams params) { params.clear(); - val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut}; + val weightsShape = new long[] {(long) outputSize[0] * outputSize[1], featureDim, nOut}; params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); if (hasBias) { val biasShape = new long[] {nOut}; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java index 033b96470..2107bdede 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java @@ -129,10 +129,10 @@ public class PrimaryCapsules extends SameDiffLayer { public void defineParameters(SDLayerParams params) { params.clear(); params.addWeightParam(WEIGHT_PARAM, - kernelSize[0], kernelSize[1], inputChannels, capsuleDimensions * channels); + kernelSize[0], kernelSize[1], inputChannels, (long) capsuleDimensions * channels); if(hasBias){ - params.addBiasParam(BIAS_PARAM, capsuleDimensions * channels); + params.addBiasParam(BIAS_PARAM, (long) capsuleDimensions * channels); } } @@ -165,7 +165,7 @@ public class PrimaryCapsules extends SameDiffLayer { InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil .getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, - capsuleDimensions * channels, -1, getLayerName(), PrimaryCapsules.class); + (long) capsuleDimensions * channels, -1, getLayerName(), PrimaryCapsules.class); return InputType.recurrent((int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions), capsuleDimensions); @@ -187,7 +187,7 @@ public class PrimaryCapsules extends SameDiffLayer { InputTypeConvolutional out = (InputTypeConvolutional) InputTypeUtil .getOutputTypeCnnLayers(inputType, kernelSize, stride, padding, dilation, convolutionMode, - capsuleDimensions * channels, -1, getLayerName(), PrimaryCapsules.class); + (long) capsuleDimensions * channels, -1, getLayerName(), PrimaryCapsules.class); this.capsules = (int) (out.getChannels() * out.getHeight() * out.getWidth() / capsuleDimensions); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index 2175c58ab..cb643cd7b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -455,7 +455,7 @@ public class Subsampling3DLayer extends NoParamLayer { } public T dilation(int dDepth, int dHeight, int dWidth) { - this.setDilation(new int[] {dDepth, dHeight, dWidth}); + this.setDilation(dDepth, dHeight, dWidth); return (T) this; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java index 3af2e9d55..6a012ed15 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java @@ -128,7 +128,7 @@ public class Upsampling1D extends BaseUpsamplingLayer { */ public Builder size(int size) { - this.setSize(new int[] {size}); + this.setSize(size); return this; } @@ -153,7 +153,7 @@ public class Upsampling1D extends BaseUpsamplingLayer { if(size.length == 2){ if(size[0] == size[1]) { - setSize(new int[]{size[0]}); + setSize(size[0]); return; } else { Preconditions.checkArgument(false, diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java index 984926ec1..ef5d832b4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java @@ -157,7 +157,7 @@ public class Upsampling3D extends BaseUpsamplingLayer { */ public Builder size(int size) { - this.setSize(new int[] {size, size, size}); + this.setSize(size, size, size); return this; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java index ae71c3811..fd2546019 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java @@ -150,7 +150,7 @@ public class Cropping1D extends NoParamLayer { * @param cropBottom Amount of cropping to apply to the bottom of the input activations */ public Builder(int cropTop, int cropBottom) { - this.setCropping(new int[]{cropTop, cropBottom}); + this.setCropping(cropTop, cropBottom); } public Cropping1D build() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java index 8ea2ea18e..29aad71bd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java @@ -185,7 +185,7 @@ public class Cropping2D extends NoParamLayer { * @param cropRight Amount of cropping to apply to the right of the input activations */ public Builder(int cropTop, int cropBottom, int cropLeft, int cropRight) { - this.setCropping(new int[] {cropTop, cropBottom, cropLeft, cropRight}); + this.setCropping(cropTop, cropBottom, cropLeft, cropRight); } public Cropping2D build() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java index df3137629..1ab34b17b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java @@ -167,7 +167,7 @@ public class Cropping3D extends NoParamLayer { * @param cropRightW Amount of cropping to apply to the right of the width dimension */ public Builder(int cropLeftD, int cropRightD, int cropLeftH, int cropRightH, int cropLeftW, int cropRightW) { - this.setCropping(new int[] {cropLeftD, cropRightD, cropLeftH, cropRightH, cropLeftW, cropRightW}); + this.setCropping(cropLeftD, cropRightD, cropLeftH, cropRightH, cropLeftW, cropRightW); } public Cropping3D build() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java index 747a95320..8dac21edc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java @@ -33,7 +33,7 @@ import java.io.IOException; public class BoundingBoxesDeserializer extends JsonDeserializer { @Override - public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { JsonNode node = jp.getCodec().readTree(jp); if(node.has("dataBuffer")){ //Must be legacy format serialization diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java index 0e6d6ebb6..d3c10ec2f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java @@ -91,8 +91,8 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex { public class VertexInputs { - private SameDiff sameDiff; - private Map map = new LinkedHashMap<>(); + private final SameDiff sameDiff; + private final Map map = new LinkedHashMap<>(); protected VertexInputs(SameDiff sd) { this.sameDiff = sd; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java index 47cfffd43..ca96fb46e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/CompositeReconstructionDistribution.java @@ -227,7 +227,7 @@ public class CompositeReconstructionDistribution implements ReconstructionDistri private INDArray randomSample(INDArray preOutDistributionParams, boolean isMean) { int inputSoFar = 0; int paramsSoFar = 0; - INDArray out = Nd4j.createUninitialized(preOutDistributionParams.dataType(), new long[] {preOutDistributionParams.size(0), totalSize}); + INDArray out = Nd4j.createUninitialized(preOutDistributionParams.dataType(), preOutDistributionParams.size(0), totalSize); for (int i = 0; i < distributionSizes.length; i++) { int thisDataSize = distributionSizes[i]; int thisParamsSize = reconstructionDistributions[i].distributionInputSize(thisDataSize); @@ -254,8 +254,8 @@ public class CompositeReconstructionDistribution implements ReconstructionDistri public static class Builder { - private List distributionSizes = new ArrayList<>(); - private List reconstructionDistributions = new ArrayList<>(); + private final List distributionSizes = new ArrayList<>(); + private final List reconstructionDistributions = new ArrayList<>(); /** * Add another distribution to the composite distribution. This will add the distribution for the next 'distributionSize' diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java index 3cc996c4a..ca1f10bd0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java @@ -93,10 +93,7 @@ public class VariationalAutoencoder extends BasePretrainNetwork { if (paramName.startsWith(VariationalAutoencoderParamInitializer.PZX_LOGSTD2_PREFIX)) { return true; } - if (paramName.startsWith(VariationalAutoencoderParamInitializer.PXZ_PREFIX)) { - return true; - } - return false; + return paramName.startsWith(VariationalAutoencoderParamInitializer.PXZ_PREFIX); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/LayerMemoryReport.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/LayerMemoryReport.java index 28725679b..771df513c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/LayerMemoryReport.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/LayerMemoryReport.java @@ -180,10 +180,10 @@ public class LayerMemoryReport extends MemoryReport { public static class Builder { - private String layerName; - private Class layerType; - private InputType inputType; - private InputType outputType; + private final String layerName; + private final Class layerType; + private final InputType inputType; + private final InputType outputType; //Standard memory (in terms of total ND4J array length) private long parameterSize; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java index 02e8a1544..9d667cc07 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/preprocessor/FeedForwardToCnn3DPreProcessor.java @@ -106,7 +106,7 @@ public class FeedForwardToCnn3DPreProcessor implements InputPreProcessor { epsilons = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, epsilons, 'c'); if (shape == null || ArrayUtil.prod(shape) != epsilons.length()) { - INDArray ret = epsilons.reshape('c', epsilons.size(0),inputDepth * inputHeight * inputWidth * numChannels); + INDArray ret = epsilons.reshape('c', epsilons.size(0), (long) inputDepth * inputHeight * inputWidth * numChannels); return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, ret); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java index 513b42aa8..abd52c0c3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java @@ -64,7 +64,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im @Override public abstract T deserialize(JsonParser jp, DeserializationContext ctxt) - throws IOException, JsonProcessingException; + throws IOException; protected boolean requiresIUpdaterFromLegacy(Layer[] layers){ for(Layer l : layers){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java index c5a8fe912..8097111d6 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java @@ -32,8 +32,8 @@ import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; @Slf4j public class JsonMappers { - private static ObjectMapper jsonMapper = new ObjectMapper(); - private static ObjectMapper yamlMapper = new ObjectMapper(new YAMLFactory()); + private static final ObjectMapper jsonMapper = new ObjectMapper(); + private static final ObjectMapper yamlMapper = new ObjectMapper(new YAMLFactory()); private static ObjectMapper legacyMapper; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java index d7c3d636a..e9397126a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/format/DataFormatDeserializer.java @@ -32,7 +32,7 @@ import java.io.IOException; public class DataFormatDeserializer extends JsonDeserializer { @Override - public DataFormat deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + public DataFormat deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { JsonNode node = jp.getCodec().readTree(jp); String text = node.textValue(); switch (text){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java index 3bbb5f8f6..804655669 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyIntArrayDeserializer.java @@ -31,7 +31,7 @@ import java.io.IOException; public class LegacyIntArrayDeserializer extends JsonDeserializer { @Override - public int[] deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + public int[] deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { JsonNode n = jp.getCodec().readTree(jp); if(n.isArray()){ ArrayNode an = (ArrayNode)n; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/DefaultStepFunction.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/DefaultStepFunction.java index dc32f1232..d3a2c9518 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/DefaultStepFunction.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/DefaultStepFunction.java @@ -35,9 +35,7 @@ public class DefaultStepFunction extends StepFunction { return true; if (obj == null) return false; - if (getClass() != obj.getClass()) - return false; - return true; + return getClass() == obj.getClass(); } public String toString() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/GradientStepFunction.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/GradientStepFunction.java index 2a727535f..4b18a4aeb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/GradientStepFunction.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/GradientStepFunction.java @@ -35,9 +35,7 @@ public class GradientStepFunction extends StepFunction { return true; if (obj == null) return false; - if (getClass() != obj.getClass()) - return false; - return true; + return getClass() == obj.getClass(); } public String toString() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/NegativeDefaultStepFunction.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/NegativeDefaultStepFunction.java index 867ed7b28..7bd42d9e7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/NegativeDefaultStepFunction.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/NegativeDefaultStepFunction.java @@ -35,9 +35,7 @@ public class NegativeDefaultStepFunction extends StepFunction { return true; if (obj == null) return false; - if (getClass() != obj.getClass()) - return false; - return true; + return getClass() == obj.getClass(); } public String toString() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/NegativeGradientStepFunction.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/NegativeGradientStepFunction.java index a7c1d3648..943aed06f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/NegativeGradientStepFunction.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/stepfunctions/NegativeGradientStepFunction.java @@ -35,9 +35,7 @@ public class NegativeGradientStepFunction extends StepFunction { return true; if (obj == null) return false; - if (getClass() != obj.getClass()) - return false; - return true; + return getClass() == obj.getClass(); } public String toString() { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/gradient/DefaultGradient.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/gradient/DefaultGradient.java index 9f8557ccc..23d1651f5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/gradient/DefaultGradient.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/gradient/DefaultGradient.java @@ -31,7 +31,7 @@ import java.util.Map; public class DefaultGradient implements Gradient { public static final char DEFAULT_FLATTENING_ORDER = 'f'; - private Map gradients = new LinkedHashMap<>(); + private final Map gradients = new LinkedHashMap<>(); private Map flatteningOrders; @Setter private INDArray flattenedGradient; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 44a838df0..ac8a05be4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -191,12 +191,12 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * The number of input arrays to the network. Many networks only have 1 input; however, a ComputationGraph may * have an arbitrary number (>=1) separate input arrays */ - private int numInputArrays; + private final int numInputArrays; /** * The number of output arrays to the network. Many networks only have 1 output; however, a ComputationGraph may * have an arbitrary number (>=1) separate output arrays */ - private int numOutputArrays; + private final int numOutputArrays; //Current inputs, labels, input mask arrays and label mask arrays private transient INDArray[] inputs; @@ -2605,7 +2605,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { List outputLayers = configuration.getNetworkOutputs(); for(String s : outputLayers ){ GraphVertex gv = getVertex(s); - if(gv instanceof LayerVertex && ((LayerVertex)gv).getLayer() instanceof IOutputLayer){ + if(gv instanceof LayerVertex && gv.getLayer() instanceof IOutputLayer){ throw new IllegalStateException("Cannot perform backprop with external errors in conjunction with an output layer:" + " output layers cannot use external errors for backprop. Layer name: " + s); } @@ -3923,7 +3923,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * @return Evaluation object; results of evaluation on all examples in the data set */ public T evaluate(DataSetIterator iterator) { - return (T)evaluate(iterator, (List)null); + return evaluate(iterator, (List)null); } /** @@ -4185,7 +4185,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { @SuppressWarnings("unchecked") @SafeVarargs private final T[] doEvaluationHelper(MultiDataSetIterator iterator, T... evaluations) { - Map map = Collections.singletonMap(0, (IEvaluation[])evaluations); + Map map = Collections.singletonMap(0, evaluations); return (T[])doEvaluationHelper(iterator, map).get(0); } @@ -4311,7 +4311,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { configuration.setTrainingWorkspaceMode(cMode); - return (Map) evaluations; + return evaluations; } /** @@ -4385,7 +4385,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { connections = configuration.getVertexInputs().get(currentVertexName).toString(); List inputTypeList = new ArrayList<>(); if (currentVertex.hasLayer()) { - Layer currentLayer = ((LayerVertex) currentVertex).getLayer(); + Layer currentLayer = currentVertex.getLayer(); classNameArr = currentLayer.getClass().getName().split("\\."); className = classNameArr[classNameArr.length - 1]; paramCount = String.format("%,d", currentLayer.numParams()); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java index 5d28feb9b..f678fb782 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ElementWiseVertex.java @@ -51,7 +51,7 @@ public class ElementWiseVertex extends BaseGraphVertex { Add, Subtract, Product, Average, Max } - private Op op; + private final Op op; private int nInForwardPass; public ElementWiseVertex(ComputationGraph graph, String name, int vertexIndex, Op op, DataType dataType) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java index 28f83c16f..0931bdb98 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2NormalizeVertex.java @@ -43,8 +43,8 @@ public class L2NormalizeVertex extends BaseGraphVertex { private static final int[] DEFAULT_RANK3_DIMS = new int[] {1, 2}; private static final int[] DEFAULT_RANK4_DIMS = new int[] {1, 2, 3}; - private int[] dimension; - private double eps; + private final int[] dimension; + private final double eps; public L2NormalizeVertex(ComputationGraph graph, String name, int vertexIndex, int[] dimension, double eps, DataType dataType) { this(graph, name, vertexIndex, null, null, dimension, eps, dataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java index b7db002c5..d839b9872 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/L2Vertex.java @@ -38,7 +38,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; public class L2Vertex extends BaseGraphVertex { - private double eps; + private final double eps; public L2Vertex(ComputationGraph graph, String name, int vertexIndex, double eps, DataType dataType) { this(graph, name, vertexIndex, null, null, eps, dataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java index a57b24eee..fdd05c390 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java @@ -199,11 +199,10 @@ public class LayerVertex extends BaseGraphVertex { @Override public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("LayerVertex(id=").append(vertexIndex).append(",name=\"").append(vertexName).append("\",inputs=") - .append(Arrays.toString(inputVertices)).append(",outputs=") - .append(Arrays.toString(outputVertices)).append(")"); - return sb.toString(); + String sb = "LayerVertex(id=" + vertexIndex + ",name=\"" + vertexName + "\",inputs=" + + Arrays.toString(inputVertices) + ",outputs=" + + Arrays.toString(outputVertices) + ")"; + return sb; } @Override @@ -229,9 +228,7 @@ public class LayerVertex extends BaseGraphVertex { } if (!(resolvedLayer instanceof IOutputLayer)) { - if (epsilon == null) { - return false; - } + return epsilon != null; } return true; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java index f1e4a4f8b..1187bbabb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/MergeVertex.java @@ -44,7 +44,7 @@ public class MergeVertex extends BaseGraphVertex { private long[][] forwardPassShapes; private int fwdPassRank; - private int mergeAxis; + private final int mergeAxis; public MergeVertex(ComputationGraph graph, String name, int vertexIndex, DataType dataType, int mergeAxis) { this(graph, name, vertexIndex, null, null, dataType, mergeAxis); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java index b8bedadbb..4586dd3d8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/PreprocessorVertex.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; public class PreprocessorVertex extends BaseGraphVertex { @Getter - private InputPreProcessor preProcessor; + private final InputPreProcessor preProcessor; public PreprocessorVertex(ComputationGraph graph, String name, int vertexIndex, InputPreProcessor preProcessor, DataType dataType) { this(graph, name, vertexIndex, null, null, preProcessor, dataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java index 4c8bbfc16..5ccc81132 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ReshapeVertex.java @@ -34,9 +34,9 @@ import org.nd4j.common.primitives.Pair; public class ReshapeVertex extends BaseGraphVertex { - private char order; - private int[] newShape; - private int[] maskShape; + private final char order; + private final int[] newShape; + private final int[] maskShape; public ReshapeVertex(ComputationGraph graph, String name, int vertexIndex, char order, int[] newShape, int[] maskShape, DataType dataType) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java index f62a9278d..16863434a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ScaleVertex.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; public class ScaleVertex extends BaseGraphVertex { - private double scaleFactor; + private final double scaleFactor; public ScaleVertex(ComputationGraph graph, String name, int vertexIndex, double scaleFactor, DataType dataType) { this(graph, name, vertexIndex, null, null, scaleFactor, dataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java index 82c2e1155..d289c4e75 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/ShiftVertex.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; public class ShiftVertex extends BaseGraphVertex { - private double shiftFactor; + private final double shiftFactor; public ShiftVertex(ComputationGraph graph, String name, int vertexIndex, double shiftFactor, DataType dataType) { this(graph, name, vertexIndex, null, null, shiftFactor, dataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java index 34d7b63e6..6889a0f39 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/StackVertex.java @@ -72,9 +72,7 @@ public class StackVertex extends BaseGraphVertex { // create the new shape outShape[0] = nStack * inShape[0]; - for (int i = 1; i < inShape.length; i++) { - outShape[i] = inShape[i]; - } + System.arraycopy(inShape, 1, outShape, 1, inShape.length - 1); boolean variableLengthTS = false; if (inShape.length == 3) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java index 50fcf1699..d3271849c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/SubsetVertex.java @@ -37,8 +37,8 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.Arrays; public class SubsetVertex extends BaseGraphVertex { - private int from; - private int to; //inclusive + private final int from; + private final int to; //inclusive private long[] forwardShape; public SubsetVertex(ComputationGraph graph, String name, int vertexIndex, int from, int to, DataType dataType) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java index a9c70c27a..c31cd1ae1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/UnstackVertex.java @@ -37,9 +37,9 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.util.Arrays; public class UnstackVertex extends BaseGraphVertex { - private long from; - private int stackSize; - private long forwardShape[]; + private final long from; + private final int stackSize; + private long[] forwardShape; private long step; public UnstackVertex(ComputationGraph graph, String name, int vertexIndex, int from, int stackSize, DataType dataType) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java index 85dc8b06b..2bfc6ee97 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java @@ -37,8 +37,8 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; public class DuplicateToTimeSeriesVertex extends BaseGraphVertex { - private String inputName; - private int inputVertexIndex; + private final String inputName; + private final int inputVertexIndex; public DuplicateToTimeSeriesVertex(ComputationGraph graph, String name, int vertexIndex, String inputVertexName, DataType dataType) { this(graph, name, vertexIndex, null, null, inputVertexName, dataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java index 4eab20e41..0475936d0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java @@ -38,8 +38,8 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; public class LastTimeStepVertex extends BaseGraphVertex { - private String inputName; - private int inputIdx; + private final String inputName; + private final int inputIdx; /** Shape of the forward pass activations */ private long[] fwdPassShape; /** Indexes of the time steps that were extracted, for each example */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayer.java index 06f3f53b3..1e6c60add 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayer.java @@ -40,7 +40,7 @@ public class FrozenLayer extends BaseWrapperLayer { private boolean logFit = false; private boolean logTestMode = false; private boolean logGradient = false; - private Gradient zeroGradient; + private final Gradient zeroGradient; private transient DummyConfig config; public FrozenLayer(Layer insideLayer) { @@ -176,7 +176,6 @@ public class FrozenLayer extends BaseWrapperLayer { if (!training) return; if (logTestMode) { - return; } else { OneTimeLogger.info(log, "Frozen layer instance found! Frozen layers are treated as always in test mode. Warning will only be issued once per instance"); @@ -188,7 +187,6 @@ public class FrozenLayer extends BaseWrapperLayer { if (training.equals(TrainingMode.TEST)) return; if (logTestMode) { - return; } else { OneTimeLogger.info(log, "Frozen layer instance found! Frozen layers are treated as always in test mode. Warning will only be issued once per instance"); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java index a5bb54857..918a21a4a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java @@ -38,7 +38,7 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { private boolean logTestMode = false; private boolean logGradient = false; - private Gradient zeroGradient; + private final Gradient zeroGradient; public FrozenLayerWithBackprop(final Layer insideLayer) { super(insideLayer); @@ -144,7 +144,6 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { if (!training) return; if (logTestMode) { - return; } else { OneTimeLogger.info(log, "Frozen layer instance found! Frozen layers are treated as always in test mode. Warning will only be issued once per instance"); @@ -156,7 +155,6 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { if (training.equals(TrainingMode.TEST)) return; if (logTestMode) { - return; } else { OneTimeLogger.info(log, "Frozen layer instance found! Frozen layers are treated as always in test mode. Warning will only be issued once per instance"); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java index eb59a2c5f..dfff491e4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java @@ -64,7 +64,7 @@ public class HelperUtils { if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) { if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) { log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName); - helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( cudnnHelperClassName, (Class) layerHelperSuperClass, new Object[]{arguments}); @@ -76,7 +76,7 @@ public class HelperUtils { ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader(); DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass); try { - helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( cudnnHelperClassName, (Class) layerHelperSuperClass, arguments); @@ -99,7 +99,7 @@ public class HelperUtils { } } else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) { - helperRet = DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( oneDnnClassName, arguments); log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java index 442808357..84dd1fd1f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java @@ -98,13 +98,13 @@ public class RepeatVector extends AbstractLayer { - private int[] cropping; //[padTop, padBottom] + private final int[] cropping; //[padTop, padBottom] public Cropping1DLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); @@ -67,7 +67,7 @@ public class Cropping1DLayer extends AbstractLayer { INDArray epsNext = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, dataType, inShape, 'c'); INDArray epsNextSubset = epsNext.get(all(), all(), interval(cropping[0], epsNext.size(2)-cropping[1])); epsNextSubset.assign(epsilon); - return new Pair<>((Gradient) new DefaultGradient(), epsNext); + return new Pair<>(new DefaultGradient(), epsNext); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java index 8e40fc652..3d6beac05 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java @@ -38,7 +38,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval; public class Cropping2DLayer extends AbstractLayer { - private int[] cropping; //[padTop, padBottom, padLeft, padRight] + private final int[] cropping; //[padTop, padBottom, padLeft, padRight] public Cropping2DLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); @@ -66,7 +66,7 @@ public class Cropping2DLayer extends AbstractLayer((Gradient) new DefaultGradient(), epsNext); + return new Pair<>(new DefaultGradient(), epsNext); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping3DLayer.java index ea2c5a20a..4dc09217a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping3DLayer.java @@ -37,7 +37,7 @@ import static org.nd4j.linalg.indexing.NDArrayIndex.interval; public class Cropping3DLayer extends AbstractLayer { - private int[] cropping; //[cropLeftD, cropRightD, cropLeftH, cropRightH, cropLeftW, cropRightW] + private final int[] cropping; //[cropLeftD, cropRightD, cropLeftH, cropRightH, cropLeftW, cropRightW] public Cropping3DLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); @@ -65,7 +65,7 @@ public class Cropping3DLayer extends AbstractLayer((Gradient) new DefaultGradient(), epsNext); + return new Pair<>(new DefaultGradient(), epsNext); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java index 6c293c6ab..386c312e6 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java @@ -36,7 +36,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; public class ZeroPadding1DLayer extends AbstractLayer { - private int[] padding; // [padLeft, padRight] + private final int[] padding; // [padLeft, padRight] public ZeroPadding1DLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); @@ -66,7 +66,7 @@ public class ZeroPadding1DLayer extends AbstractLayer((Gradient) new DefaultGradient(), workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsNext)); + return new Pair<>(new DefaultGradient(), workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsNext)); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding3DLayer.java index e39d6886b..bffd04288 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding3DLayer.java @@ -36,7 +36,7 @@ import org.nd4j.common.primitives.Pair; public class ZeroPadding3DLayer extends AbstractLayer { - private int[] padding; // [padLeft1, padRight1, padLeft2, padRight2, padLeft3, padRight3] + private final int[] padding; // [padLeft1, padRight1, padLeft2, padRight2, padLeft3, padRight3] public ZeroPadding3DLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); @@ -69,7 +69,7 @@ public class ZeroPadding3DLayer extends AbstractLayer((Gradient) new DefaultGradient(), epsNext); + return new Pair<>(new DefaultGradient(), epsNext); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPaddingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPaddingLayer.java index d467474e3..c46167bee 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPaddingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPaddingLayer.java @@ -80,7 +80,7 @@ public class ZeroPaddingLayer extends AbstractLayer((Gradient) new DefaultGradient(), epsNext); + return new Pair<>(new DefaultGradient(), epsNext); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/autoencoder/recursive/Tree.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/autoencoder/recursive/Tree.java index 48c84c2ea..ef0accc18 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/autoencoder/recursive/Tree.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/autoencoder/recursive/Tree.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import java.io.Serializable; import java.util.ArrayList; import java.util.List; +import java.util.Objects; public class Tree implements Serializable { @@ -446,23 +447,23 @@ public class Tree implements Serializable { return false; if (goldLabel != tree.goldLabel) return false; - if (headWord != null ? !headWord.equals(tree.headWord) : tree.headWord != null) + if (!Objects.equals(headWord, tree.headWord)) return false; - if (label != null ? !label.equals(tree.label) : tree.label != null) + if (!Objects.equals(label, tree.label)) return false; - if (parse != null ? !parse.equals(tree.parse) : tree.parse != null) + if (!Objects.equals(parse, tree.parse)) return false; - if (prediction != null ? !prediction.equals(tree.prediction) : tree.prediction != null) + if (!Objects.equals(prediction, tree.prediction)) return false; - if (tags != null ? !tags.equals(tree.tags) : tree.tags != null) + if (!Objects.equals(tags, tree.tags)) return false; - if (tokens != null ? !tokens.equals(tree.tokens) : tree.tokens != null) + if (!Objects.equals(tokens, tree.tokens)) return false; - if (type != null ? !type.equals(tree.type) : tree.type != null) + if (!Objects.equals(type, tree.type)) return false; - if (value != null ? !value.equals(tree.value) : tree.value != null) + if (!Objects.equals(value, tree.value)) return false; - return !(vector != null ? !vector.equals(tree.vector) : tree.vector != null); + return !(!Objects.equals(vector, tree.vector)); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java index 51a988f13..762407264 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java @@ -155,7 +155,7 @@ public class EmbeddingSequenceLayer extends BaseLayer gradientViews; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java index a99244c13..3ad4f8b0a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java @@ -178,7 +178,7 @@ public class LSTMHelpers { //initialize prevOutputActivations to zeroes if (prevOutputActivations == null) { - prevOutputActivations = Nd4j.zeros(input.dataType(), new long[] {miniBatchSize, hiddenLayerSize}); + prevOutputActivations = Nd4j.zeros(input.dataType(), miniBatchSize, hiddenLayerSize); } if (helper != null && (layer.helperCountFail == 0 || !isHelperAllowFallback)) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java index 5241d9b41..c591cd18d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayer.java @@ -37,7 +37,7 @@ import static org.deeplearning4j.nn.conf.RNNFormat.NWC; public class MaskZeroLayer extends BaseWrapperLayer { private static final long serialVersionUID = -7369482676002469854L; - private double maskingValue; + private final double maskingValue; public MaskZeroLayer(@NonNull Layer underlying, double maskingValue){ super(underlying); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java index 2655739c9..0176ce720 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java @@ -267,7 +267,7 @@ public class SimpleRnn extends BaseRecurrentLayer 0 || prevStepOut != null) { if(hasLayerNorm()){ - INDArray currRecPreNorm = forBackprop ? recPreNorm.get(all(), all(), point(i)) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f');; + INDArray currRecPreNorm = forBackprop ? recPreNorm.get(all(), all(), point(i)) : workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f'); Nd4j.gemm(prevStepOut, rw, currRecPreNorm, false, false, 1.0, 0.0); INDArray recNorm = workspaceMgr.createUninitialized(ArrayType.FF_WORKING_MEM, currOut.dataType(), currOut.shape(), 'f'); Nd4j.getExecutioner().exec(new LayerNorm(currRecPreNorm, gr, recNorm, true, 1)); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java index 9f9d6cb43..9a97f6a4a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/TimeDistributedLayer.java @@ -34,7 +34,7 @@ import org.nd4j.common.util.ArrayUtil; public class TimeDistributedLayer extends BaseWrapperLayer { - private RNNFormat rnnDataFormat; + private final RNNFormat rnnDataFormat; public TimeDistributedLayer(Layer underlying, RNNFormat rnnDataFormat) { super(underlying); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/util/MaskLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/util/MaskLayer.java index f5d1b24cf..984fe67ee 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/util/MaskLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/util/MaskLayer.java @@ -35,7 +35,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; import java.util.Arrays; public class MaskLayer extends AbstractLayer { - private Gradient emptyGradient = new DefaultGradient(); + private final Gradient emptyGradient = new DefaultGradient(); public MaskLayer(NeuralNetConfiguration conf, DataType dataType) { super(conf, dataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java index 5936168be..75df1dfad 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java @@ -796,7 +796,7 @@ public class VariationalAutoencoder implements Layer { @Override public void setListeners(TrainingListener... listeners) { - setListeners(Arrays.asList(listeners)); + setListeners(Arrays.asList(listeners)); } @Override @@ -828,8 +828,7 @@ public class VariationalAutoencoder implements Layer { return; } - for (TrainingListener listener : listeners) - trainingListeners.add(listener); + Collections.addAll(trainingListeners, listeners); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 4b4a97c2d..f590a1caa 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -99,8 +99,6 @@ import org.nd4j.common.util.OneTimeLogger; import java.io.*; import java.util.*; -; - @Slf4j public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork { @@ -1997,7 +1995,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura for (Map.Entry entry : currPair.getFirst().gradientForVariable().entrySet()) { String origName = entry.getKey(); - multiGradientKey = String.valueOf(i) + "_" + origName; + multiGradientKey = i + "_" + origName; gradientList.addLast(new Triple<>(multiGradientKey, entry.getValue(), currPair.getFirst().flatteningOrderForVariable(origName))); } @@ -2109,7 +2107,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura rnnClearPreviousState(); for (int i = 0; i < nSubsets; i++) { - long startTimeIdx = i * fwdLen; + long startTimeIdx = (long) i * fwdLen; long endTimeIdx = startTimeIdx + fwdLen; if (endTimeIdx > timeSeriesLength) endTimeIdx = timeSeriesLength; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java index af5ce819d..b9f682818 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java @@ -91,7 +91,7 @@ public class DepthwiseConvolutionParamInitializer implements ParamInitializer { @Override public List weightKeys(Layer layer) { - return Arrays.asList(WEIGHT_KEY); + return Collections.singletonList(WEIGHT_KEY); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java index 05e723eed..d0a93e368 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java @@ -39,8 +39,8 @@ import java.util.Map; public class PReLUParamInitializer implements ParamInitializer { public final static String WEIGHT_KEY = "W"; - private long[] weightShape; - private long[] sharedAxes; + private final long[] weightShape; + private final long[] sharedAxes; public PReLUParamInitializer(long[] shape, long[] sharedAxes) { this.weightShape = shape; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java index ec64afe2c..52ae7c891 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java @@ -51,28 +51,28 @@ import java.util.*; public class TransferLearning { public static class Builder { - private MultiLayerConfiguration origConf; - private MultiLayerNetwork origModel; + private final MultiLayerConfiguration origConf; + private final MultiLayerNetwork origModel; private MultiLayerNetwork editedModel; private FineTuneConfiguration finetuneConfiguration; private int frozenTill = -1; private int popN = 0; private boolean prepDone = false; - private Set editedLayers = new HashSet<>(); - private Map> editedLayersMap = + private final Set editedLayers = new HashSet<>(); + private final Map> editedLayersMap = new HashMap<>(); - private Map> nInEditedMap = new HashMap<>(); - private List editedParams = new ArrayList<>(); - private List editedConfs = new ArrayList<>(); - private List appendParams = new ArrayList<>(); //these could be new arrays, and views from origParams - private List appendConfs = new ArrayList<>(); + private final Map> nInEditedMap = new HashMap<>(); + private final List editedParams = new ArrayList<>(); + private final List editedConfs = new ArrayList<>(); + private final List appendParams = new ArrayList<>(); //these could be new arrays, and views from origParams + private final List appendConfs = new ArrayList<>(); private Map inputPreProcessors = new HashMap<>(); private InputType inputType; private Boolean validateOutputLayerConfig; - private DataType dataType; + private final DataType dataType; /** * Multilayer Network to tweak for transfer learning @@ -430,9 +430,7 @@ public class TransferLearning { int i = 0; while (i < popN) { Integer layerNum = origModel.getnLayers() - i; - if (inputPreProcessors.containsKey(layerNum)) { - inputPreProcessors.remove(layerNum); - } + inputPreProcessors.remove(layerNum); editedConfs.remove(editedConfs.size() - 1); editedParams.remove(editedParams.size() - 1); i++; @@ -543,7 +541,7 @@ public class TransferLearning { MultiLayerConfiguration conf = new MultiLayerConfiguration.Builder().inputPreProcessors(inputPreProcessors) .setInputType(this.inputType).confs(allConfs) - .validateOutputLayerConfig(validateOutputLayerConfig == null ? true : validateOutputLayerConfig) + .validateOutputLayerConfig(validateOutputLayerConfig == null || validateOutputLayerConfig) .dataType(origConf.getDataType()) .build(); if (finetuneConfiguration != null) { @@ -554,19 +552,19 @@ public class TransferLearning { } public static class GraphBuilder { - private ComputationGraph origGraph; - private ComputationGraphConfiguration origConfig; + private final ComputationGraph origGraph; + private final ComputationGraphConfiguration origConfig; private FineTuneConfiguration fineTuneConfiguration; private ComputationGraphConfiguration.GraphBuilder editedConfigBuilder; private String[] frozenOutputAt; private boolean hasFrozen = false; - private Set editedVertices = new HashSet<>(); + private final Set editedVertices = new HashSet<>(); private WorkspaceMode workspaceMode; private Boolean validateOutputLayerConfig = null; - private Map nInFromNewConfig = new HashMap<>(); + private final Map nInFromNewConfig = new HashMap<>(); /** * Computation Graph to tweak for transfer learning @@ -960,7 +958,7 @@ public class TransferLearning { initBuilderIfReq(); ComputationGraphConfiguration newConfig = editedConfigBuilder - .validateOutputLayerConfig(validateOutputLayerConfig == null ? true : validateOutputLayerConfig).build(); + .validateOutputLayerConfig(validateOutputLayerConfig == null || validateOutputLayerConfig).build(); if (this.workspaceMode != null) newConfig.setTrainingWorkspaceMode(workspaceMode); ComputationGraph newGraph = new ComputationGraph(newConfig); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java index b2d12a620..4f4d1690f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java @@ -462,8 +462,7 @@ public abstract class BaseMultiLayerUpdater implements Updater return false; BaseMultiLayerUpdater that = (BaseMultiLayerUpdater) o; - return updaterStateViewArray != null ? updaterStateViewArray.equals(that.updaterStateViewArray) - : that.updaterStateViewArray == null; + return Objects.equals(updaterStateViewArray, that.updaterStateViewArray); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java index 87c791c54..dea50edd9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java @@ -47,7 +47,7 @@ public class LayerUpdater extends BaseMultiLayerUpdater { @Override protected Trainable[] getOrderedLayers() { - return new Trainable[] {(Trainable)network}; + return new Trainable[] {network}; } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/weights/embeddings/WeightInitEmbedding.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/weights/embeddings/WeightInitEmbedding.java index 6e92b2187..4ca7f2635 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/weights/embeddings/WeightInitEmbedding.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/weights/embeddings/WeightInitEmbedding.java @@ -32,7 +32,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; @EqualsAndHashCode public class WeightInitEmbedding implements IWeightInit { - private EmbeddingInitializer serializableInit; + private final EmbeddingInitializer serializableInit; private EmbeddingInitializer nonSerializableInit; public WeightInitEmbedding(@NonNull EmbeddingInitializer embeddingInitializer){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/workspace/LayerWorkspaceMgr.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/workspace/LayerWorkspaceMgr.java index bf8126ed5..a7b972a2f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/workspace/LayerWorkspaceMgr.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/workspace/LayerWorkspaceMgr.java @@ -35,12 +35,12 @@ import java.util.*; public class LayerWorkspaceMgr extends BaseWorkspaceMgr { public static String CUDNN_WORKSPACE_KEY = "CUDNN_WORKSPACE"; - private static LayerWorkspaceMgr NO_WS_IMMUTABLE; + private static final LayerWorkspaceMgr NO_WS_IMMUTABLE; static{ Set all = new HashSet<>(); Collections.addAll(all, ArrayType.values()); NO_WS_IMMUTABLE = new LayerWorkspaceMgr( - all, Collections.emptyMap(), Collections.emptyMap()); + all, Collections.emptyMap(), Collections.emptyMap()); } protected Set noLeverageOverride; @@ -136,7 +136,7 @@ public class LayerWorkspaceMgr extends BaseWorkspaceMgr { public static class Builder { - private LayerWorkspaceMgr mgr; + private final LayerWorkspaceMgr mgr; public Builder(){ mgr = new LayerWorkspaceMgr(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/Solver.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/Solver.java index 9f1a64de9..4cb638c5d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/Solver.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/Solver.java @@ -91,7 +91,7 @@ public class Solver { public static class Builder { private NeuralNetConfiguration conf; private Model model; - private List listeners = new ArrayList<>(); + private final List listeners = new ArrayList<>(); public Builder configure(NeuralNetConfiguration conf) { this.conf = conf; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java index 5871b99a0..4ebf2e050 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java @@ -39,26 +39,27 @@ import java.util.concurrent.TimeUnit; @Slf4j public class CheckpointListener extends BaseTrainingListener implements Serializable { - private enum KeepMode {ALL, LAST, LAST_AND_EVERY}; + private enum KeepMode {ALL, LAST, LAST_AND_EVERY} + private static final String[] MODEL_TYPES = new String[]{"MultiLayerNetwork", "ComputationGraph", "Model"}; - private File rootDir; - private KeepMode keepMode; - private int keepLast; - private int keepEvery; - private boolean logSaving; - private boolean deleteExisting; + private final File rootDir; + private final KeepMode keepMode; + private final int keepLast; + private final int keepEvery; + private final boolean logSaving; + private final boolean deleteExisting; - private Integer saveEveryNEpochs; - private Integer saveEveryNIterations; - private boolean saveEveryNIterSinceLast; - private Long saveEveryAmount; - private TimeUnit saveEveryUnit; + private final Integer saveEveryNEpochs; + private final Integer saveEveryNIterations; + private final boolean saveEveryNIterSinceLast; + private final Long saveEveryAmount; + private final TimeUnit saveEveryUnit; private Long saveEveryMs; - private boolean saveEverySinceLast; + private final boolean saveEverySinceLast; private int lastCheckpointNum = -1; - private File checkpointRecordFile; + private final File checkpointRecordFile; private Checkpoint lastCheckpoint; private long startTime = -1; @@ -151,7 +152,6 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ long lastSaveTime = (lastCheckpoint != null ? lastCheckpoint.getTimestamp() : startTime); if((time - lastSaveTime) >= saveEveryMs){ saveCheckpoint(model); - return; } } else { //Save periodically, regardless of when last model was saved @@ -159,7 +159,6 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ if((time - lastSave) > saveEveryMs){ saveCheckpoint(model); lastSaveEveryMsNoSinceLast = time; - return; } } } @@ -197,7 +196,6 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ //Finally: determine if we should delete some old models... if(keepMode == null || keepMode == KeepMode.ALL){ - return; } else if(keepMode == KeepMode.LAST){ List checkpoints = availableCheckpoints(); Iterator iter = checkpoints.iterator(); @@ -490,7 +488,7 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ public static class Builder { - private File rootDir; + private final File rootDir; private KeepMode keepMode; private int keepLast; private int keepEvery; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java index 3112e559c..51f798e26 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java @@ -27,13 +27,14 @@ import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.List; public class CollectScoresIterationListener extends BaseTrainingListener { - private int frequency; + private final int frequency; private int iterationCount = 0; //private List> scoreVsIter = new ArrayList<>(); @@ -42,8 +43,8 @@ public class CollectScoresIterationListener extends BaseTrainingListener { private int position = 0; private int bucketNumber = 1; - private List indexes; - private List scores; + private final List indexes; + private final List scores; public ScoreStat() { indexes = new ArrayList<>(1); @@ -170,7 +171,7 @@ public class CollectScoresIterationListener extends BaseTrainingListener { sb.append("\n").append(indexes[i]).append(delimiter).append(scores[i]); } } - outputStream.write(sb.toString().getBytes("UTF-8")); + outputStream.write(sb.toString().getBytes(StandardCharsets.UTF_8)); } /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java index 23d0e81fe..68402f40e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java @@ -50,7 +50,7 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali private transient ThreadLocal> lastGcMs = new ThreadLocal<>(); private transient List gcBeans = null; - private boolean reportScore; + private final boolean reportScore; private boolean reportGC; private boolean reportSample = true; private boolean reportBatch = true; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/TimeIterationListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/TimeIterationListener.java index 0dc166a3f..cc48c216b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/TimeIterationListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/TimeIterationListener.java @@ -32,9 +32,9 @@ import java.util.concurrent.atomic.AtomicLong; @Slf4j public class TimeIterationListener extends BaseTrainingListener implements Serializable { - private long start; - private int iterationCount; - private AtomicLong iterationCounter = new AtomicLong(0); + private final long start; + private final int iterationCount; + private final AtomicLong iterationCounter = new AtomicLong(0); /** * Constructor @@ -52,7 +52,7 @@ public class TimeIterationListener extends BaseTrainingListener implements Seria long remaining = (iterationCount - currentIteration) * elapsed / currentIteration; long minutes = remaining / (1000 * 60); Date date = new Date(start + elapsed + remaining); - log.info("Remaining time : " + minutes + "mn - End expected : " + date.toString()); + log.info("Remaining time : " + minutes + "mn - End expected : " + date); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BackTrackLineSearch.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BackTrackLineSearch.java index 39c9f8da2..18e64c081 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BackTrackLineSearch.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BackTrackLineSearch.java @@ -44,9 +44,9 @@ import static org.nd4j.linalg.ops.transforms.Transforms.abs; public class BackTrackLineSearch implements LineOptimizer { private static final Logger log = LoggerFactory.getLogger(BackTrackLineSearch.class); - private Model layer; - private StepFunction stepFunction; - private ConvexOptimizer optimizer; + private final Model layer; + private final StepFunction stepFunction; + private final ConvexOptimizer optimizer; private int maxIterations; double stepMax = 100; private boolean minObjectiveFunction = true; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java index 320b4293a..5760ee337 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java @@ -39,7 +39,7 @@ import java.util.LinkedList; */ public class LBFGS extends BaseOptimizer { private static final long serialVersionUID = 9148732140255034888L; - private int m = 4; + private final int m = 4; public LBFGS(NeuralNetConfiguration conf, StepFunction stepFunction, Collection trainingListeners, Model model) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java index 41a73e577..32c40bdfc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java @@ -141,7 +141,7 @@ public class Convolution1DUtils { if (convolutionMode == ConvolutionMode.Same || convolutionMode == ConvolutionMode.Causal) { return (int) Math.ceil(inH / ((double) strides)); } - return (inH - eKernel + 2 * padding) / strides + 1; + return (inH - eKernel + 2L * padding) / strides + 1; } /** @@ -204,25 +204,24 @@ public class Convolution1DUtils { int truncated = (int) d; int sameSize = (int) Math.ceil(inH / ((double) strides)); - StringBuilder sb = new StringBuilder(); - sb.append("Invalid input data or configuration: Combination of kernel size, " + + String sb = "Invalid input data or configuration: Combination of kernel size, " + "stride and padding are not " + - "valid for given input height, using ConvolutionMode.Strict\n") - .append("ConvolutionMode.Strict requires: output height = (input height - kernelSize + " + - "2*padding)/stride + 1 to be an integer. Got: (") - .append(inH).append(" - ").append(eKernel).append(" + 2*").append(padding).append(")/") - .append(strides).append(" + 1 = ") - .append(str).append("\n").append("See \"Constraints on strides\" at http://cs231n.github." + - "io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n") - .append("To truncate/crop the input, such that output height = floor(") - .append(str).append(") = ") - .append(truncated).append(", use ConvolutionType.Truncate.\n") - .append("Alternatively use ConvolutionType.Same, which will use padding to give an " + - "output height of ceil(") - .append(inH).append("/").append(strides).append(")=").append(sameSize) - .append(getCommonErrorMsg(inputData, eKernel, strides, padding, dilation)); + "valid for given input height, using ConvolutionMode.Strict\n" + + "ConvolutionMode.Strict requires: output height = (input height - kernelSize + " + + "2*padding)/stride + 1 to be an integer. Got: (" + + inH + " - " + eKernel + " + 2*" + padding + ")/" + + strides + " + 1 = " + + str + "\n" + "See \"Constraints on strides\" at http://cs231n.github." + + "io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n" + + "To truncate/crop the input, such that output height = floor(" + + str + ") = " + + truncated + ", use ConvolutionType.Truncate.\n" + + "Alternatively use ConvolutionType.Same, which will use padding to give an " + + "output height of ceil(" + + inH + "/" + strides + ")=" + sameSize + + getCommonErrorMsg(inputData, eKernel, strides, padding, dilation); - throw new DL4JInvalidConfigException(sb.toString()); + throw new DL4JInvalidConfigException(sb); } } @@ -254,8 +253,7 @@ public class Convolution1DUtils { */ public static void validateConvolutionModePadding(ConvolutionMode mode, int padding) { if (mode == ConvolutionMode.Same) { - boolean nullPadding = true; - if (padding != 0) nullPadding = false; + boolean nullPadding = padding == 0; if (!nullPadding) throw new IllegalArgumentException("Padding cannot be used when using the `same' convolution mode"); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution3DUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution3DUtils.java index e7101ad75..28cafe388 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution3DUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution3DUtils.java @@ -89,7 +89,7 @@ public class Convolution3DUtils { if (convolutionMode != ConvolutionMode.Same) { for (int i = 0; i < 3; i++) { - if ((eKernel[i] <= 0 || eKernel[i] > inShape[i] + 2 * padding[i])) { + if ((eKernel[i] <= 0 || eKernel[i] > inShape[i] + 2L * padding[i])) { StringBuilder sb = new StringBuilder(); sb.append("Invalid input data or configuration: "); if (atrous) sb.append("effective "); @@ -102,7 +102,7 @@ public class Convolution3DUtils { sb.append("kernel = ").append(eKernel[i]).append(", input ").append(dims[i]).append(" = ") .append(inShape[i]).append(" and padding ").append(dims[i]).append(" = ") .append(padding[i]).append(" which do not satisfy 0 < ") - .append(eKernel[i]).append(" <= ").append(inShape[i] + 2 * padding[i]) + .append(eKernel[i]).append(" <= ").append(inShape[i] + 2L * padding[i]) .append(getCommonErrorMsg(inputDataShape, eKernel, strides, padding, dilation)); throw new DL4JInvalidInputException(sb.toString()); @@ -111,30 +111,29 @@ public class Convolution3DUtils { } if (convolutionMode == ConvolutionMode.Strict) { for (int j = 0; j < 3; j++) { - if ((inShape[j] - eKernel[0] + 2 * padding[0]) % strides[0] != 0) { + if ((inShape[j] - eKernel[0] + 2L * padding[0]) % strides[0] != 0) { double d = (inShape[j] - eKernel[0] + 2 * padding[0]) / ((double) strides[0]) + 1.0; String str = String.format("%.2f", d); int truncated = (int) d; int sameSize = (int) Math.ceil(inShape[j] / ((double) strides[0])); - StringBuilder sb = new StringBuilder(); - sb.append("Invalid input data or configuration: Combination of kernel size, stride and padding ") - .append("are not valid for given input height, using ConvolutionMode.Strict\n") - .append("ConvolutionMode.Strict requires: output height = (input height - kernelSize + ") - .append( "2*padding)/stride + 1 to be an integer. Got: (") - .append(inShape[j]).append(" - ").append(eKernel[0]).append(" + 2*") - .append(padding[0]).append(")/").append(strides[0]).append(" + 1 = ") - .append(str).append("\n") - .append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ ") - .append("and ConvolutionType enumeration Javadoc.\n") - .append("To truncate/crop the input, such that output height = floor(").append(str) - .append(") = ").append(truncated).append(", use ConvolutionType.Truncate.\n") - .append("Alternatively use ConvolutionType.Same, which will use padding to give ") - .append("an output height of ceil(") - .append(inShape[j]).append("/").append(strides[0]).append(")=").append(sameSize) - .append(getCommonErrorMsg(inputDataShape, eKernel, strides, padding, dilation)); + String sb = "Invalid input data or configuration: Combination of kernel size, stride and padding " + + "are not valid for given input height, using ConvolutionMode.Strict\n" + + "ConvolutionMode.Strict requires: output height = (input height - kernelSize + " + + "2*padding)/stride + 1 to be an integer. Got: (" + + inShape[j] + " - " + eKernel[0] + " + 2*" + + padding[0] + ")/" + strides[0] + " + 1 = " + + str + "\n" + + "See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ " + + "and ConvolutionType enumeration Javadoc.\n" + + "To truncate/crop the input, such that output height = floor(" + str + + ") = " + truncated + ", use ConvolutionType.Truncate.\n" + + "Alternatively use ConvolutionType.Same, which will use padding to give " + + "an output height of ceil(" + + inShape[j] + "/" + strides[0] + ")=" + sameSize + + getCommonErrorMsg(inputDataShape, eKernel, strides, padding, dilation); - throw new DL4JInvalidConfigException(sb.toString()); + throw new DL4JInvalidConfigException(sb); } } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 8737c974e..616f1c620 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -143,9 +143,9 @@ public class ConvolutionUtils { return new long[]{hOut, wOut, dOut}; } - long hOut = strides[0] * (hIn - 1) + eKernel[0] - 2 * padding[0]; - long wOut = strides[1] * (wIn - 1) + eKernel[1] - 2 * padding[1]; - long dOut = strides[2] * (dIn - 1) + eKernel[2] - 2 * padding[2]; + long hOut = strides[0] * (hIn - 1) + eKernel[0] - 2L * padding[0]; + long wOut = strides[1] * (wIn - 1) + eKernel[1] - 2L * padding[1]; + long dOut = strides[2] * (dIn - 1) + eKernel[2] - 2L * padding[2]; return new long[]{hOut, wOut, dOut}; } @@ -376,17 +376,16 @@ public class ConvolutionUtils { int truncated = (int) d; int sameSize = (int) Math.ceil(inH / ((double) strides[0])); - StringBuilder sb = new StringBuilder(); - sb.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input height, using ConvolutionMode.Strict\n") - .append("ConvolutionMode.Strict requires: output height = (input height - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (") - .append(inH).append(" - ").append(eKernel[0]).append(" + 2*").append(padding[0]).append(")/").append(strides[0]).append(" + 1 = ") - .append(str).append("\n").append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n") - .append("To truncate/crop the input, such that output height = floor(").append(str).append(") = ") - .append(truncated).append(", use ConvolutionType.Truncate.\n") - .append("Alternatively use ConvolutionType.Same, which will use padding to give an output height of ceil(") - .append(inH).append("/").append(strides[0]).append(")=").append(sameSize).append(getCommonErrorMsg(inputData, eKernel, strides, padding, dilation)); + String sb = "Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input height, using ConvolutionMode.Strict\n" + + "ConvolutionMode.Strict requires: output height = (input height - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (" + + inH + " - " + eKernel[0] + " + 2*" + padding[0] + ")/" + strides[0] + " + 1 = " + + str + "\n" + "See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n" + + "To truncate/crop the input, such that output height = floor(" + str + ") = " + + truncated + ", use ConvolutionType.Truncate.\n" + + "Alternatively use ConvolutionType.Same, which will use padding to give an output height of ceil(" + + inH + "/" + strides[0] + ")=" + sameSize + getCommonErrorMsg(inputData, eKernel, strides, padding, dilation); - throw new DL4JInvalidConfigException(sb.toString()); + throw new DL4JInvalidConfigException(sb); } if ((inW - eKernel[1] + 2 * padding[1]) % strides[1] != 0) { @@ -394,19 +393,18 @@ public class ConvolutionUtils { String str = String.format("%.2f", d); int truncated = (int) d; int sameSize = (int) Math.ceil(inW / ((double) strides[1])); - StringBuilder sb = new StringBuilder(); - sb.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\n") - .append("ConvolutionMode.Strict requires: output width = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (") - .append(inW).append(" - ").append(eKernel[1]).append(" + 2*").append(padding[1]) - .append(")/").append(strides[1]).append(" + 1 = ").append(str).append("\n") - .append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n") - .append("To truncate/crop the input, such that output width = floor(").append(str).append(") = ") - .append(truncated).append(", use ConvolutionType.Truncate.\n") - .append("Alternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(") - .append(inW).append("/").append(strides[1]).append(")=").append(sameSize) - .append(getCommonErrorMsg(inputData, eKernel, strides, padding, dilation)); + String sb = "Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\n" + + "ConvolutionMode.Strict requires: output width = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (" + + inW + " - " + eKernel[1] + " + 2*" + padding[1] + + ")/" + strides[1] + " + 1 = " + str + "\n" + + "See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n" + + "To truncate/crop the input, such that output width = floor(" + str + ") = " + + truncated + ", use ConvolutionType.Truncate.\n" + + "Alternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(" + + inW + "/" + strides[1] + ")=" + sameSize + + getCommonErrorMsg(inputData, eKernel, strides, padding, dilation); throw new DL4JInvalidConfigException( - sb.toString()); + sb); } if (eKernel.length == 3 && (inShape[2] - eKernel[2] + 2 * padding[2]) % strides[2] != 0) { @@ -415,19 +413,18 @@ public class ConvolutionUtils { String str = String.format("%.2f", d); int truncated = (int) d; int sameSize = (int) Math.ceil(inD / ((double) strides[2])); - StringBuilder sb = new StringBuilder(); - sb.append("Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\n") - .append("ConvolutionMode.Strict requires: output channels = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (") - .append(inD).append(" - ").append(eKernel[2]).append(" + 2*").append(padding[2]) - .append(")/").append(strides[1]).append(" + 1 = ").append(str).append("\n") - .append("See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n") - .append("To truncate/crop the input, such that output width = floor(").append(str).append(") = ") - .append(truncated).append(", use ConvolutionType.Truncate.\n") - .append("Alternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(") - .append(inW).append("/").append(strides[2]).append(")=").append(sameSize) - .append(getCommonErrorMsg(inputData, eKernel, strides, padding, dilation)); + String sb = "Invalid input data or configuration: Combination of kernel size, stride and padding are not valid for given input width, using ConvolutionMode.Strict\n" + + "ConvolutionMode.Strict requires: output channels = (input - kernelSize + 2*padding)/stride + 1 to be an integer. Got: (" + + inD + " - " + eKernel[2] + " + 2*" + padding[2] + + ")/" + strides[1] + " + 1 = " + str + "\n" + + "See \"Constraints on strides\" at http://cs231n.github.io/convolutional-networks/ and ConvolutionType enumeration Javadoc.\n" + + "To truncate/crop the input, such that output width = floor(" + str + ") = " + + truncated + ", use ConvolutionType.Truncate.\n" + + "Alternatively use ConvolutionType.Same, which will use padding to give an output width of ceil(" + + inW + "/" + strides[2] + ")=" + sameSize + + getCommonErrorMsg(inputData, eKernel, strides, padding, dilation); throw new DL4JInvalidConfigException( - sb.toString()); + sb); } } @@ -574,7 +571,10 @@ public class ConvolutionUtils { if (mode == ConvolutionMode.Same) { boolean nullPadding = true; for (int i : padding) { - if (i != 0) nullPadding = false; + if (i != 0) { + nullPadding = false; + break; + } } if (!nullPadding) throw new IllegalArgumentException("Padding cannot be used when using the `same' convolution mode"); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java index 9fd95b22d..ac28ced80 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java @@ -234,7 +234,7 @@ public class CrashReportingUtil { sb.append(String.format(wsFormat, ws.getId(), (ws.isScopeActive() ? "OPEN" : "CLOSED"), fBytes(ws.getCurrentSize()), - String.valueOf(numCycles))).append("\n"); + numCycles)).append("\n"); } } sb.append(fBytes("Workspaces total size", totalWsSize)); @@ -471,7 +471,7 @@ public class CrashReportingUtil { for(Layer layer : layers){ long numParams = layer.numParams(); sb.append(String.format(format, layer.getIndex(), layer.conf().getLayer().getLayerName(), - layer.getClass().getSimpleName(), String.valueOf(numParams), fBytes(numParams * bytesPerElement))).append("\n"); + layer.getClass().getSimpleName(), numParams, fBytes(numParams * bytesPerElement))).append("\n"); } } @@ -515,7 +515,7 @@ public class CrashReportingUtil { } sb.append(String.format(format, idx, layerName, l.getClass().getSimpleName(), h.getClass().getSimpleName(), - fBytes(layerTotal), mem.toString())).append("\n"); + fBytes(layerTotal), mem)).append("\n"); totalHelperMem += layerTotal; } @@ -567,7 +567,7 @@ public class CrashReportingUtil { bytes = 0; } totalActivationBytes += bytes; - sb.append(String.format(format, String.valueOf(i), layers[i].conf().getLayer().getLayerName(), layers[i].getClass().getSimpleName(), + sb.append(String.format(format, i, layers[i].conf().getLayer().getLayerName(), layers[i].getClass().getSimpleName(), inputTypes.get(i), Arrays.toString(shape), (numElements < 0 ? "" : String.valueOf(numElements)), fBytes(bytes))).append("\n"); last = bytes; } @@ -630,7 +630,7 @@ public class CrashReportingUtil { className = gv.getClass().getSimpleName(); } - sb.append(String.format(format, String.valueOf(i), layerName, className, it, + sb.append(String.format(format, i, layerName, className, it, Arrays.toString(shape), (numElements < 0 ? "" : String.valueOf(numElements)), fBytes(bytes))).append("\n"); if(!net.getConfiguration().getNetworkOutputs().contains(layerName)){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java index 8f1ba93e4..08a3d086a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java @@ -151,12 +151,9 @@ public class OutputLayerUtil { public static boolean activationExceedsZeroOneRange(IActivation activation, boolean isLossLayer){ if(OUTSIDE_ZERO_ONE_RANGE.contains(activation.getClass())){ - if(isLossLayer && activation instanceof ActivationIdentity){ - //Note: we're intentionally excluding identity here, for situations like dense(softmax) -> loss(identity) - //However, we might miss a few invalid configs like dense(relu) -> loss(identity) - return false; - } - return true; + //Note: we're intentionally excluding identity here, for situations like dense(softmax) -> loss(identity) + //However, we might miss a few invalid configs like dense(relu) -> loss(identity) + return !isLossLayer || !(activation instanceof ActivationIdentity); } return false; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ValidationUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ValidationUtils.java index f2dd9b5d2..0e9952989 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ValidationUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ValidationUtils.java @@ -70,8 +70,9 @@ public class ValidationUtils { boolean nonnegative = true; for(int value : data){ - if(value < 0) { + if (value < 0) { nonnegative = false; + break; } } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainerContext.java b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainerContext.java index 9e509888c..47d04d303 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainerContext.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainerContext.java @@ -35,10 +35,10 @@ public class ParameterServerTrainerContext implements TrainerContext { private ParameterServerNode parameterServerNode; private MediaDriver mediaDriver; private MediaDriver.Context mediaDriverContext; - private int statusServerPort = 33000; - private int numUpdatesPerEpoch = 1; + private final int statusServerPort = 33000; + private final int numUpdatesPerEpoch = 1; private String[] parameterServerArgs; - private int numWorkers = 1; + private final int numWorkers = 1; /** * Initialize the context @@ -52,7 +52,7 @@ public class ParameterServerTrainerContext implements TrainerContext { mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext); parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers); if (parameterServerArgs == null) - parameterServerArgs = new String[] {"-m", "true", "-s", "1," + String.valueOf(model.numParams()), "-p", + parameterServerArgs = new String[] {"-m", "true", "-s", "1," + model.numParams(), "-p", "40323", "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sh", "localhost", "-sp", String.valueOf(statusServerPort), "-u", String.valueOf(numUpdatesPerEpoch)}; diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java index 683db198a..e1f8b9273 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java @@ -58,9 +58,9 @@ public class EarlyStoppingParallelTrainer implements IEarlyStop private ParallelWrapper wrapper; private double bestModelScore = Double.MAX_VALUE; private int bestModelEpoch = -1; - private AtomicDouble latestScore = new AtomicDouble(0.0); - private AtomicBoolean terminate = new AtomicBoolean(false); - private AtomicInteger iterCount = new AtomicInteger(0); + private final AtomicDouble latestScore = new AtomicDouble(0.0); + private final AtomicBoolean terminate = new AtomicBoolean(false); + private final AtomicInteger iterCount = new AtomicInteger(0); protected volatile IterationTerminationCondition terminationReason = null; public EarlyStoppingParallelTrainer(EarlyStoppingConfiguration earlyStoppingConfiguration, T model, @@ -262,7 +262,7 @@ public class EarlyStoppingParallelTrainer implements IEarlyStop } if (epochTerminate) { log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, - termReason.toString()); + termReason); T bestModel; try { bestModel = esConfig.getModelSaver().getBestModel(); @@ -316,8 +316,8 @@ public class EarlyStoppingParallelTrainer implements IEarlyStop */ private class AveragingTrainingListener extends BaseTrainingListener { private final Logger log = LoggerFactory.getLogger(AveragingTrainingListener.class); - private IterationTerminationCondition terminationReason = null; - private EarlyStoppingParallelTrainer trainer; + private final IterationTerminationCondition terminationReason = null; + private final EarlyStoppingParallelTrainer trainer; /** Default constructor printing every 10 iterations */ public AveragingTrainingListener(EarlyStoppingParallelTrainer trainer) { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java index 67a693dc0..52a28606e 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java @@ -284,7 +284,7 @@ public class ParallelInference { public static class Builder { - private Model model; + private final Model model; private int workers = DEFAULT_NUM_WORKERS; private int batchLimit = DEFAULT_BATCH_LIMIT; private InferenceMode inferenceMode = DEFAULT_INFERENCE_MODE; @@ -413,16 +413,16 @@ public class ParallelInference { * */ private class InferenceWorker extends Thread implements Runnable { - private BlockingQueue inputQueue; - private AtomicBoolean shouldWork = new AtomicBoolean(true); - private AtomicBoolean isStopped = new AtomicBoolean(false); + private final BlockingQueue inputQueue; + private final AtomicBoolean shouldWork = new AtomicBoolean(true); + private final AtomicBoolean isStopped = new AtomicBoolean(false); private Model protoModel; private Model replicatedModel; - private AtomicLong counter = new AtomicLong(0); - private boolean rootDevice; - private int deviceId; + private final AtomicLong counter = new AtomicLong(0); + private final boolean rootDevice; + private final int deviceId; - private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock(); private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) { this.inputQueue = inputQueue; @@ -571,9 +571,9 @@ public class ParallelInference { protected static class ObservablesProvider { - private BlockingQueue targetQueue; - private long nanos; - private int batchLimit; + private final BlockingQueue targetQueue; + private final long nanos; + private final int batchLimit; private volatile BatchedInferenceObservable currentObservable; private final Object locker = new Object(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java index 46390ee1d..8da3b5262 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java @@ -26,7 +26,7 @@ import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler; import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.AdaptiveThresholdAlgorithm; -import org.nd4j.linalg.dataset.AsyncDataSetIterator;; +import org.nd4j.linalg.dataset.AsyncDataSetIterator; import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.DummyBlockDataSetIterator; import org.deeplearning4j.datasets.iterator.DummyBlockMultiDataSetIterator; @@ -932,7 +932,7 @@ public class ParallelWrapper implements AutoCloseable { // memory sie in number of bytes long memorySize = encoderMemory == null || encoderMemory < 0 - ? maxUpdate * 4 * (workers + 3) + ? (long) maxUpdate * 4 * (workers + 3) : encoderMemory; this.accumulator = new EncodedGradientsAccumulator(workers, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, maxUpdate, false), memorySize, workers + 2, Integer.MAX_VALUE, false); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObserver.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObserver.java index cf87082b1..559a82f88 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObserver.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BasicInferenceObserver.java @@ -29,7 +29,7 @@ import java.util.concurrent.locks.LockSupport; @Slf4j public class BasicInferenceObserver implements Observer { - private AtomicBoolean finished; + private final AtomicBoolean finished; public BasicInferenceObserver() { finished = new AtomicBoolean(false); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservable.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservable.java index 1ae8995d8..5ae162931 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservable.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/inference/observers/BatchedInferenceObservable.java @@ -39,18 +39,18 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; @Slf4j public class BatchedInferenceObservable extends BasicInferenceObservable implements InferenceObservable { - private List inputs = new ArrayList<>(); - private List inputMasks = new ArrayList<>(); - private List outputs = new ArrayList<>(); - private AtomicInteger counter = new AtomicInteger(0); - private ThreadLocal position = new ThreadLocal<>(); - private List outputBatchInputArrays = new ArrayList<>(); + private final List inputs = new ArrayList<>(); + private final List inputMasks = new ArrayList<>(); + private final List outputs = new ArrayList<>(); + private final AtomicInteger counter = new AtomicInteger(0); + private final ThreadLocal position = new ThreadLocal<>(); + private final List outputBatchInputArrays = new ArrayList<>(); private final Object locker = new Object(); - private ReentrantReadWriteLock realLocker = new ReentrantReadWriteLock(); - private AtomicBoolean isLocked = new AtomicBoolean(false); - private AtomicBoolean isReadLocked = new AtomicBoolean(false); + private final ReentrantReadWriteLock realLocker = new ReentrantReadWriteLock(); + private final AtomicBoolean isLocked = new AtomicBoolean(false); + private final AtomicBoolean isReadLocked = new AtomicBoolean(false); public BatchedInferenceObservable() { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java index a8db019b4..ecb28ef9b 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java @@ -120,8 +120,8 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { try { - val result0 = pi.output(new INDArray[]{Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, new long[]{1, 5})}, null)[0]; - val result1 = pi.output(new INDArray[]{Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, new long[]{1, 5})}, null)[0]; + val result0 = pi.output(new INDArray[]{Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, 1, 5)}, null)[0]; + val result1 = pi.output(new INDArray[]{Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, 1, 5)}, null)[0]; assertNotNull(result0); assertEquals(result0, result1); @@ -153,8 +153,8 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { try { - val result0 = pi.output(new INDArray[]{Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, new long[]{1, 5})}, null)[0]; - val result1 = pi.output(new INDArray[]{Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, new long[]{1, 5})}, null)[0]; + val result0 = pi.output(new INDArray[]{Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, 1, 5)}, null)[0]; + val result1 = pi.output(new INDArray[]{Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0, 5.0}, 1, 5)}, null)[0]; assertNotNull(result0); assertEquals(result0, result1); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java index 6f694286d..3919bfbc7 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java @@ -233,7 +233,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { assertNotEquals(null, observable1); - assertTrue(observable1 == observable2); + assertSame(observable1, observable2); } @Test @@ -248,7 +248,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { assertNotEquals(null, observable1); - assertTrue(observable1 == observable2); + assertSame(observable1, observable2); List> l = observable1.getInputBatches(); assertEquals(1, l.size()); @@ -276,8 +276,8 @@ public class ParallelInferenceTest extends BaseDL4JTest { assertNotEquals(null, observable1); assertNotEquals(null, observable3); - assertTrue(observable1 == observable2); - assertTrue(observable1 != observable3); + assertSame(observable1, observable2); + assertNotSame(observable1, observable3); List> l = observable1.getInputBatches(); assertEquals(1, l.size()); @@ -439,7 +439,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { List arrs = new ArrayList<>(); List exp = new ArrayList<>(); for (int l : tsLengths) { - INDArray in = Nd4j.rand(new int[]{1, nIn, l}); + INDArray in = Nd4j.rand(1, nIn, l); arrs.add(in); INDArray out = net.output(in); exp.add(out); @@ -724,7 +724,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { for (int i = 0; i < nRuns; i++) { int currTSLength = (randomTSLength ? 1 + r.nextInt(tsLength) : tsLength); int currNumEx = 1 + r.nextInt(3); - INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn, currTSLength}); + INDArray inArr = Nd4j.rand(currNumEx, nIn, currTSLength); in.add(inArr); INDArray inMask = null; @@ -857,7 +857,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { int runs = isIntegrationTests() ? 100 : 20; for (int i = 0; i < 100; i++) { int currNumEx = 1 + r.nextInt(3); - INDArray inArr = Nd4j.rand(new int[]{currNumEx, nIn}); + INDArray inArr = Nd4j.rand(currNumEx, nIn); in.add(new INDArray[]{inArr}); INDArray[] out = net.output(inArr); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java index bf525ac67..315788855 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java @@ -95,9 +95,9 @@ public class ParallelWrapperMainTest extends BaseDL4JTest { tmp.deleteOnExit(); ParallelWrapperMain parallelWrapperMain = new ParallelWrapperMain(); try { - parallelWrapperMain.runMain(new String[]{"--modelPath", tempModel.getAbsolutePath(), + parallelWrapperMain.runMain("--modelPath", tempModel.getAbsolutePath(), "--dataSetIteratorFactoryClazz", MnistDataSetIteratorProviderFactory.class.getName(), - "--modelOutputPath", tmp.getAbsolutePath(), "--uiUrl", "localhost:" + uiPort}); + "--modelOutputPath", tmp.getAbsolutePath(), "--uiUrl", "localhost:" + uiPort); } finally { parallelWrapperMain.stop(); } diff --git a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java index 4a4ac3aaa..4eb8846a8 100644 --- a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java +++ b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java @@ -29,8 +29,8 @@ import java.util.concurrent.atomic.AtomicBoolean; public class PythonContextManager { - private static Set contexts = new HashSet<>(); - private static AtomicBoolean init = new AtomicBoolean(false); + private static final Set contexts = new HashSet<>(); + private static final AtomicBoolean init = new AtomicBoolean(false); private static String currentContext; private static final String MAIN_CONTEXT = "main"; private static final String COLLAPSED_KEY = "__collapsed__"; diff --git a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java index 40131a237..c05735def 100644 --- a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java +++ b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java @@ -41,7 +41,7 @@ import static org.bytedeco.cpython.helper.python.Py_SetPath; public class PythonExecutioner { private final static String PYTHON_EXCEPTION_KEY = "__python_exception__"; - private static AtomicBoolean init = new AtomicBoolean(false); + private static final AtomicBoolean init = new AtomicBoolean(false); public final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path"; public final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append"; public final static String DEFAULT_APPEND_TYPE = "before"; diff --git a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java index bd0893a72..59ae4b224 100644 --- a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java +++ b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java @@ -91,12 +91,12 @@ public class PythonObject { public PythonObject callWithKwargs(PythonObject kwargs) { if (!Python.callable(this)) { - throw new PythonException("Object is not callable: " + toString()); + throw new PythonException("Object is not callable: " + this); } PyObject tuple = PyTuple_New(0); PyObject dict = kwargs.nativePythonObject; if (PyObject_IsInstance(dict, new PyObject(PyDict_Type())) != 1) { - throw new PythonException("Expected kwargs to be dict. Received: " + kwargs.toString()); + throw new PythonException("Expected kwargs to be dict. Received: " + kwargs); } PythonObject ret = new PythonObject(PyObject_Call(nativePythonObject, tuple, dict)); Py_DecRef(tuple); @@ -109,7 +109,7 @@ public class PythonObject { boolean ownsTuple = false; try { if (!Python.callable(this)) { - throw new PythonException("Object is not callable: " + toString()); + throw new PythonException("Object is not callable: " + this); } if (PyObject_IsInstance(args.nativePythonObject, new PyObject(PyTuple_Type())) == 1) { @@ -118,10 +118,10 @@ public class PythonObject { tuple = PyList_AsTuple(args.nativePythonObject); ownsTuple = true; } else { - throw new PythonException("Expected args to be tuple or list. Received: " + args.toString()); + throw new PythonException("Expected args to be tuple or list. Received: " + args); } if (kwargs != null && PyObject_IsInstance(kwargs.nativePythonObject, new PyObject(PyDict_Type())) != 1) { - throw new PythonException("Expected kwargs to be dict. Received: " + kwargs.toString()); + throw new PythonException("Expected kwargs to be dict. Received: " + kwargs); } return new PythonObject(PyObject_Call(nativePythonObject, tuple, kwargs == null ? null : kwargs.nativePythonObject)); } finally { @@ -147,7 +147,7 @@ public class PythonObject { PythonGIL.assertThreadSafe(); try (PythonGC pgc = PythonGC.watch()) { if (!Python.callable(this)) { - throw new PythonException("Object is not callable: " + toString()); + throw new PythonException("Object is not callable: " + this); } PythonObject pyArgs; PythonObject pyKwargs; diff --git a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java index 21f22eaf6..d02234f69 100644 --- a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java +++ b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java @@ -28,12 +28,10 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; public class PythonProcess { - private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + private static final String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ String[] allArgs = new String[arguments.length + 1]; - for (int i = 0; i < arguments.length; i++){ - allArgs[i + 1] = arguments[i]; - } + System.arraycopy(arguments, 0, allArgs, 1, arguments.length); allArgs[0] = pythonExecutable; ProcessBuilder pb = new ProcessBuilder(allArgs); Process process = pb.start(); @@ -45,9 +43,7 @@ public class PythonProcess { public static void run(String... arguments)throws IOException, InterruptedException{ String[] allArgs = new String[arguments.length + 1]; - for (int i = 0; i < arguments.length; i++){ - allArgs[i + 1] = arguments[i]; - } + System.arraycopy(arguments, 0, allArgs, 1, arguments.length); allArgs[0] = pythonExecutable; ProcessBuilder pb = new ProcessBuilder(allArgs); pb.inheritIO().start().waitFor(); diff --git a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java index 9120c82d4..77eb71a25 100644 --- a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java +++ b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java @@ -31,11 +31,11 @@ public class PythonTypes { private static List getPrimitiveTypes() { - return Arrays.asList(STR, INT, FLOAT, BOOL, BYTES); + return Arrays.asList(STR, INT, FLOAT, BOOL, BYTES); } private static List getCollectionTypes() { - return Arrays.asList(LIST, DICT); + return Arrays.asList(LIST, DICT); } private static List getExternalTypes() { @@ -149,7 +149,7 @@ public class PythonTypes { PythonGIL.assertThreadSafe(); long val = PyLong_AsLong(pythonObject.getNativePythonObject()); if (val == -1 && PyErr_Occurred() != null) { - throw new PythonException("Could not convert value to int: " + pythonObject.toString()); + throw new PythonException("Could not convert value to int: " + pythonObject); } return val; } @@ -180,7 +180,7 @@ public class PythonTypes { PythonGIL.assertThreadSafe(); double val = PyFloat_AsDouble(pythonObject.getNativePythonObject()); if (val == -1 && PyErr_Occurred() != null) { - throw new PythonException("Could not convert value to float: " + pythonObject.toString()); + throw new PythonException("Could not convert value to float: " + pythonObject); } return val; } @@ -344,7 +344,7 @@ public class PythonTypes { HashMap ret = new HashMap(); PyObject dictType = new PyObject(PyDict_Type()); if (PyObject_IsInstance(pythonObject.getNativePythonObject(), dictType) != 1) { - throw new PythonException("Expected dict, received: " + pythonObject.toString()); + throw new PythonException("Expected dict, received: " + pythonObject); } PyObject keys = PyDict_Keys(pythonObject.getNativePythonObject()); diff --git a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyBasicTest.java b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyBasicTest.java index 2d9851977..17f1b246b 100644 --- a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyBasicTest.java +++ b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyBasicTest.java @@ -37,8 +37,8 @@ import java.util.List; @NotThreadSafe ////@RunWith(Parameterized.class) public class PythonNumpyBasicTest { - private DataType dataType; - private long[] shape; + private final DataType dataType; + private final long[] shape; public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) { this.dataType = dataType; diff --git a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java index 58c466d13..e3c4fb311 100644 --- a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java +++ b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java @@ -35,7 +35,7 @@ import java.util.*; @NotThreadSafe ////@RunWith(Parameterized.class) public class PythonNumpyCollectionsTest { - private DataType dataType; + private final DataType dataType; public PythonNumpyCollectionsTest(DataType dataType){ this.dataType = dataType; diff --git a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java index c18d0a925..49bf7fd61 100644 --- a/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java +++ b/cavis-dnn/cavis-dnn-python4j/cavis-python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java @@ -35,7 +35,7 @@ import java.util.List; @NotThreadSafe public class PythonNumpyMultiThreadTest { - private DataType dataType; + private final DataType dataType; public PythonNumpyMultiThreadTest(DataType dataType) { this.dataType = dataType; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java index 7e769cbb5..5a9735a8e 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java @@ -39,10 +39,10 @@ public class StatsCalculationHelper { private long lastDataSetBefore; private long lastProcessBefore; private long totalExampleCount; - private List dataSetGetTimes = new ArrayList<>(); - private List processMiniBatchTimes = new ArrayList<>(); + private final List dataSetGetTimes = new ArrayList<>(); + private final List processMiniBatchTimes = new ArrayList<>(); - private TimeSource timeSource = TimeSourceProvider.getInstance(); + private final TimeSource timeSource = TimeSourceProvider.getInstance(); public void logMethodStartTime() { methodStartTime = timeSource.currentTimeMillis(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java index b012f3a0d..1515e4ee3 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java @@ -36,7 +36,7 @@ import java.util.List; public class ExecuteWorkerPathMDSFlatMap implements FlatMapFunction, R> { private final FlatMapFunction, R> workerFlatMap; - private MultiDataSetLoader loader; + private final MultiDataSetLoader loader; private final int maxDataSetObjects; private final Broadcast hadoopConfig; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java index ac9a0a256..6a42c2259 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java @@ -103,7 +103,7 @@ public class BatchAndExportDataSetsFunction implements Function2(countBefore, Collections.emptyList()); + return new Pair<>(countBefore, Collections.emptyList()); } List exportPaths = new ArrayList<>(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java index b7e30b351..a5f607ee6 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java @@ -105,7 +105,7 @@ public class BatchAndExportMultiDataSetsFunction if (tempList.isEmpty() || (numExamples < minibatchSize && !finalExport)) { //No op - return new Pair<>(countBefore, Collections.emptyList()); + return new Pair<>(countBefore, Collections.emptyList()); } List exportPaths = new ArrayList<>(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java index f6b12a1eb..b72484095 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java @@ -34,7 +34,7 @@ import java.util.Random; public class SplitDataSetExamplesPairFlatMapFunction implements PairFlatMapFunction { private transient Random r; - private int maxKeyIndex; + private final int maxKeyIndex; public SplitDataSetExamplesPairFlatMapFunction(int maxKeyIndex) { this.maxKeyIndex = maxKeyIndex; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java index f8413037b..10ad4847d 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java @@ -44,12 +44,12 @@ import java.util.List; public class DataVecByteDataSetFunction implements PairFunction, Double, DataSet> { private int labelIndex = 0; - private int numPossibleLabels; - private int byteFileLen; - private int batchSize; + private final int numPossibleLabels; + private final int byteFileLen; + private final int batchSize; private int numExamples; private boolean regression = false; - private DataSetPreProcessor preProcessor; + private final DataSetPreProcessor preProcessor; public DataVecByteDataSetFunction(int labelIndex, int numPossibleLabels, int batchSize, int byteFileLen) { this(labelIndex, numPossibleLabels, batchSize, byteFileLen, false, null); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java index 4c0da6832..926051ba1 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java @@ -31,8 +31,8 @@ import java.util.Iterator; import java.util.List; public class RDDMiniBatches implements Serializable { - private int miniBatches; - private JavaRDD toSplitJava; + private final int miniBatches; + private final JavaRDD toSplitJava; public RDDMiniBatches(int miniBatches, JavaRDD toSplit) { this.miniBatches = miniBatches; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java index 8d24bba6a..48ff6d0b0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java @@ -34,10 +34,10 @@ import java.util.ArrayList; import java.util.List; public class RecordReaderFunction implements Function { - private RecordReader recordReader; + private final RecordReader recordReader; private int labelIndex = -1; private int numPossibleLabels = -1; - private WritableConverter converter; + private final WritableConverter converter; public RecordReaderFunction(RecordReader recordReader, int labelIndex, int numPossibleLabels, WritableConverter converter) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java index 5f1029131..5ed1848b7 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java @@ -41,11 +41,11 @@ import java.util.Map; public abstract class BaseSparkEarlyStoppingTrainer implements IEarlyStoppingTrainer { - private static Logger log = LoggerFactory.getLogger(BaseSparkEarlyStoppingTrainer.class); + private static final Logger log = LoggerFactory.getLogger(BaseSparkEarlyStoppingTrainer.class); - private JavaSparkContext sc; + private final JavaSparkContext sc; private final EarlyStoppingConfiguration esConfig; - private T net; + private final T net; private final JavaRDD train; private final JavaRDD trainMulti; private EarlyStoppingListener listener; @@ -206,7 +206,7 @@ public abstract class BaseSparkEarlyStoppingTrainer implements } if (epochTerminate) { log.info("Hit epoch termination condition at epoch {}. Details: {}", epochCount, - termReason.toString()); + termReason); T bestModel; try { bestModel = esConfig.getModelSaver().getBestModel(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java index be71c408c..fd5590e7f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java @@ -30,9 +30,9 @@ import org.nd4j.linalg.dataset.DataSet; public class SparkDataSetLossCalculator implements ScoreCalculator { - private JavaRDD data; - private boolean average; - private SparkContext sc; + private final JavaRDD data; + private final boolean average; + private final SparkContext sc; /**Calculate the score (loss function value) on a given data set (usually a test set) * diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java index efdab70aa..fb052d008 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java @@ -35,7 +35,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; public class SparkEarlyStoppingGraphTrainer extends BaseSparkEarlyStoppingTrainer { - private SparkComputationGraph sparkNet; + private final SparkComputationGraph sparkNet; public SparkEarlyStoppingGraphTrainer(SparkContext sc, TrainingMaster trainingMaster, EarlyStoppingConfiguration esConfig, ComputationGraph net, diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java index 3e61bd7cd..cab795894 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java @@ -34,7 +34,7 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; public class SparkEarlyStoppingTrainer extends BaseSparkEarlyStoppingTrainer { - private SparkDl4jMultiLayer sparkNet; + private final SparkDl4jMultiLayer sparkNet; public SparkEarlyStoppingTrainer(SparkContext sc, TrainingMaster trainingMaster, EarlyStoppingConfiguration esConfig, MultiLayerNetwork net, diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java index be03c85af..5227aeef7 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java @@ -30,9 +30,9 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; public class SparkLossCalculatorComputationGraph implements ScoreCalculator { - private JavaRDD data; - private boolean average; - private SparkContext sc; + private final JavaRDD data; + private final boolean average; + private final SparkContext sc; /** * Calculate the score (loss function value) on a given data set (usually a test set) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java index 36011825d..6762b7486 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java @@ -39,7 +39,7 @@ import java.util.List; public class SparkListenable { protected TrainingMaster trainingMaster; - private List listeners = new ArrayList<>(); + private final List listeners = new ArrayList<>(); /** diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java index e2f5814bd..b0f532a54 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java @@ -42,11 +42,11 @@ public class HashingBalancedPartitioner extends Partitioner { // avg # red elems per partition : 2.33 // avg # blue elems per partition : 3.33 // partitionWeightsByClass = [[1.714, .429, .857], [0.9, 0.6, 1.5]] - private List> partitionWeightsByClass; + private final List> partitionWeightsByClass; // The cumulative distribution of jump probabilities of extra elements by partition, by class // 0 for partitions that already have enough elements - private List> jumpTable; + private final List> jumpTable; private Random r; public HashingBalancedPartitioner(List> partitionWeightsByClass) { @@ -63,7 +63,7 @@ public class HashingBalancedPartitioner extends Partitioner { } this.partitionWeightsByClass = partitionWeightsByClass; // p_(j, i) - List> jumpsByClass = new ArrayList<>();; + List> jumpsByClass = new ArrayList<>(); for (int j = 0; j < numClasses; j++) { Double totalImbalance = 0D; // i_j = sum(max(1 - p_(j, i), 0) , i = 1..numPartitions) for (int i = 0; i < numPartitions; i++) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java index 8550c6e3c..a38322234 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java @@ -53,14 +53,14 @@ public class EvaluationRunner { } private final AtomicInteger workerCount = new AtomicInteger(0); - private Queue queue = new ConcurrentLinkedQueue<>(); + private final Queue queue = new ConcurrentLinkedQueue<>(); //parameters map for device local parameters for a given broadcast //Note: byte[] doesn't override Object.equals hence this is effectively an *identity* weak hash map, which is what we want here //i.e., DeviceLocal can be GC'd once the Broadcast is no longer referenced anywhere //This approach relies on the fact that a single Broadcast object's *content* will be shared by all of Spark's threads, // even though the Broadcast object itself mayb not be //Also by storing params as a byte[] (i.e., in serialized form), we sidestep a lot of the thread locality issues - private Map paramsMap = new WeakHashMap<>(); + private final Map paramsMap = new WeakHashMap<>(); private EvaluationRunner(){ } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java index 14d08dc99..67b120ddf 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java @@ -78,13 +78,13 @@ public class SparkComputationGraph extends SparkListenable { public static final int DEFAULT_ROC_THRESHOLD_STEPS = 32; public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 64; public static final int DEFAULT_EVAL_WORKERS = 4; - private transient JavaSparkContext sc; - private ComputationGraphConfiguration conf; + private final transient JavaSparkContext sc; + private final ComputationGraphConfiguration conf; private ComputationGraph network; private double lastScore; private int defaultEvaluationWorkers = DEFAULT_EVAL_WORKERS; - private transient AtomicInteger iterationsCount = new AtomicInteger(0); + private final transient AtomicInteger iterationsCount = new AtomicInteger(0); /** * Instantiate a ComputationGraph instance with the given context, network and training master. diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java index d8aadc3f1..3fa3312d7 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java @@ -44,9 +44,9 @@ public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWith @Override public VariationalAutoencoder getVaeLayer() { ComputationGraph network = - new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue())); + new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue())); network.init(); - INDArray val = ((INDArray) params.value()).unsafeDuplication(); + INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcasted set parameters"); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java index 57c568239..a71912367 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java @@ -46,9 +46,9 @@ public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstruc @Override public VariationalAutoencoder getVaeLayer() { ComputationGraph network = - new ComputationGraph(ComputationGraphConfiguration.fromJson((String) jsonConfig.getValue())); + new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue())); network.init(); - INDArray val = ((INDArray) params.value()).unsafeDuplication(); + INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcasted set parameters"); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java index 7acae9d8f..578165fc7 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java @@ -41,9 +41,9 @@ import java.util.List; public class ScoreFlatMapFunctionCGDataSet implements FlatMapFunction, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGDataSet.class); - private String json; - private Broadcast params; - private int minibatchSize; + private final String json; + private final Broadcast params; + private final int minibatchSize; public ScoreFlatMapFunctionCGDataSet(String json, Broadcast params, int minibatchSize) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java index 60ba08857..5ea855fbd 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java @@ -41,9 +41,9 @@ import java.util.List; public class ScoreFlatMapFunctionCGMultiDataSet implements FlatMapFunction, Tuple2> { private static final Logger log = LoggerFactory.getLogger(ScoreFlatMapFunctionCGMultiDataSet.class); - private String json; - private Broadcast params; - private int minibatchSize; + private final String json; + private final Broadcast params; + private final int minibatchSize; public ScoreFlatMapFunctionCGMultiDataSet(String json, Broadcast params, int minibatchSize) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java index be7780f2f..d8e1c1437 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java @@ -79,8 +79,8 @@ public class SparkDl4jMultiLayer extends SparkListenable { public static final int DEFAULT_EVAL_SCORE_BATCH_SIZE = 64; public static final int DEFAULT_ROC_THRESHOLD_STEPS = 32; public static final int DEFAULT_EVAL_WORKERS = 4; - private transient JavaSparkContext sc; - private MultiLayerConfiguration conf; + private final transient JavaSparkContext sc; + private final MultiLayerConfiguration conf; private MultiLayerNetwork network; private double lastScore; private int defaultEvaluationWorkers = DEFAULT_EVAL_WORKERS; @@ -157,7 +157,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { } /** - * Set the network that underlies this SparkDl4jMultiLayer instacne + * Set the network that underlies this SparkDl4jMultiLayer instance * * @param network network to set */ diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java index 3f7c5ba6c..e1c2f760d 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java @@ -47,9 +47,9 @@ public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKe @Override public VariationalAutoencoder getVaeLayer() { MultiLayerNetwork network = - new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue())); + new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); network.init(); - INDArray val = ((INDArray) params.value()).unsafeDuplication(); + INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java index d9dd8a155..12fbbbeb6 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java @@ -47,9 +47,9 @@ public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructi @Override public VariationalAutoencoder getVaeLayer() { MultiLayerNetwork network = - new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue())); + new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); network.init(); - INDArray val = ((INDArray) params.value()).unsafeDuplication(); + INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) throw new IllegalStateException( "Network did not have same number of parameters as the broadcast set parameters"); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java index 5030a21b6..87374a584 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java @@ -59,8 +59,8 @@ public class ParameterAveragingTrainingWorker extends BaseTrainingWorker trainingHooks; private final WorkerConfiguration configuration; private ParameterAveragingTrainingWorkerStats.ParameterAveragingTrainingWorkerStatsHelper stats = null; - private Collection trainingListeners; - private StatsStorageRouterProvider listenerRouterProvider; + private final Collection trainingListeners; + private final StatsStorageRouterProvider listenerRouterProvider; public ParameterAveragingTrainingWorker(Broadcast broadcast, boolean saveUpdater, WorkerConfiguration configuration, Collection trainingHooks, @@ -172,9 +172,9 @@ public class ParameterAveragingTrainingWorker extends BaseTrainingWorker exportTimes = new ArrayList<>(); //Starts for exporting data - private List countTimes = new ArrayList<>(); - private List broadcastTimes = new ArrayList<>(); - private List repartitionTimes = new ArrayList<>(); - private List fitTimes = new ArrayList<>(); - private List splitTimes = new ArrayList<>(); - private List mapPartitions = new ArrayList<>(); - private List aggregateTimes = new ArrayList<>(); - private List processParamsUpdaterTimes = new ArrayList<>(); + private final List exportTimes = new ArrayList<>(); //Starts for exporting data + private final List countTimes = new ArrayList<>(); + private final List broadcastTimes = new ArrayList<>(); + private final List repartitionTimes = new ArrayList<>(); + private final List fitTimes = new ArrayList<>(); + private final List splitTimes = new ArrayList<>(); + private final List mapPartitions = new ArrayList<>(); + private final List aggregateTimes = new ArrayList<>(); + private final List processParamsUpdaterTimes = new ArrayList<>(); private final TimeSource timeSource = TimeSourceProvider.getInstance(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java index fce3ec751..35e6ba9f0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java @@ -173,7 +173,7 @@ public class ParameterAveragingTrainingWorkerStats implements SparkTrainingStats private long initEndTime; private long lastFitStartTime; //TODO replace with fast int collection (no boxing) - private List fitTimes = new ArrayList<>(); + private final List fitTimes = new ArrayList<>(); private final TimeSource timeSource = TimeSourceProvider.getInstance(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java index 2e7c6bad5..a4d06bfba 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java @@ -38,8 +38,8 @@ public class PathSparkDataSetIterator extends BaseDataSetIterator { public static final int BUFFER_SIZE = 4194304; //4 MB private FileSystem fileSystem; - private DataSetLoader dataSetLoader; - private Broadcast hadoopConfig; + private final DataSetLoader dataSetLoader; + private final Broadcast hadoopConfig; public PathSparkDataSetIterator(Iterator iter, DataSetLoader dataSetLoader, Broadcast hadoopConfig) { this.dataSetStreams = null; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java index 09ed9973c..f462352d0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java @@ -90,7 +90,7 @@ public class SparkADSI extends AsyncDataSetIterator { this.buffer = queue; this.prefetchSize = queueSize; this.backedIterator = iterator; - this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID().toString(); + this.workspaceId = "SADSI_ITER-" + java.util.UUID.randomUUID(); if (iterator.resetSupported()) this.backedIterator.reset(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java index 128db97a7..ab5a3ee20 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java @@ -91,7 +91,7 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator { this.backedIterator = iterator; this.useWorkspaces = useWorkspace; this.prefetchSize = queueSize; - this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID().toString(); + this.workspaceId = "SAMDSI_ITER-" + java.util.UUID.randomUUID(); this.deviceId = deviceId; if (iterator.resetSupported()) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java index 867d89795..0e083f31a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java @@ -44,6 +44,7 @@ import java.awt.*; import java.io.BufferedOutputStream; import java.io.IOException; import java.io.OutputStream; +import java.nio.charset.StandardCharsets; import java.util.*; import java.util.List; @@ -238,7 +239,7 @@ public class StatsUtils { } String html = StaticPageUtil.renderHTML(components); - outputStream.write(html.getBytes("UTF-8")); + outputStream.write(html.getBytes(StandardCharsets.UTF_8)); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java index 8b6332ba4..b55f560da 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java @@ -49,7 +49,7 @@ public class NTPTimeSource implements TimeSource { public static final String DEFAULT_NTP_SERVER = "0.pool.ntp.org"; - private static Logger log = LoggerFactory.getLogger(NTPTimeSource.class); + private static final Logger log = LoggerFactory.getLogger(NTPTimeSource.class); private static NTPTimeSource instance; public static synchronized TimeSource getInstance() { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java index dbde9f862..576cda013 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java @@ -428,7 +428,7 @@ public class MLLibUtil { // FIXMEL int cast double[] fArr = features.toArray(); - return new DataSet(Nd4j.create(fArr, new long[]{1,fArr.length}), + return new DataSet(Nd4j.create(fArr, 1,fArr.length), FeatureUtil.toOutcomeVector((int) label, (int) numPossibleLabels)); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java index 6e88fbfa8..c6d935912 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java @@ -61,6 +61,7 @@ import java.io.*; import java.lang.reflect.Array; import java.net.URI; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.*; @Slf4j @@ -113,7 +114,7 @@ public class SparkUtils { boolean equals; INDArray deserialized; try { - deserialized = (INDArray) si.deserialize(bb, null); + deserialized = si.deserialize(bb, null); //Equals method may fail on malformed INDArrays, hence should be within the try-catch equals = Nd4j.linspace(1, 5, 5).equals(deserialized); } catch (Exception e) { @@ -153,7 +154,7 @@ public class SparkUtils { public static void writeStringToFile(String path, String toWrite, SparkContext sc) throws IOException { FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration()); try (BufferedOutputStream bos = new BufferedOutputStream(fileSystem.create(new Path(path)))) { - bos.write(toWrite.getBytes("UTF-8")); + bos.write(toWrite.getBytes(StandardCharsets.UTF_8)); } } @@ -177,7 +178,7 @@ public class SparkUtils { FileSystem fileSystem = FileSystem.get(sc.hadoopConfiguration()); try (BufferedInputStream bis = new BufferedInputStream(fileSystem.open(new Path(path)))) { byte[] asBytes = IOUtils.toByteArray(bis); - return new String(asBytes, "UTF-8"); + return new String(asBytes, StandardCharsets.UTF_8); } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java index cc9490a9a..f5831a6ef 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java @@ -32,7 +32,7 @@ import java.io.IOException; public class StorageLevelDeserializer extends JsonDeserializer { @Override public StorageLevel deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) - throws IOException, JsonProcessingException { + throws IOException { JsonNode node = jsonParser.getCodec().readTree(jsonParser); String value = node.textValue(); if (value == null || "null".equals(value)) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java index db02ea278..2b9257df5 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java @@ -53,7 +53,7 @@ public class StorageLevelSerializer extends JsonSerializer { @Override public void serialize(StorageLevel storageLevel, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) - throws IOException, JsonProcessingException { + throws IOException { //This is a little ugly, but Spark doesn't provide many options here... String s = null; if (storageLevel != null) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java index ed8de3623..f4e9f674e 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java @@ -284,7 +284,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { private static class LoggingEarlyStoppingListener implements EarlyStoppingListener { - private static Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); + private static final Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); private int onStartCallCount = 0; private int onEpochCallCount = 0; private int onCompletionCallCount = 0; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java index 3de17a742..39618055e 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java @@ -290,7 +290,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { private static class LoggingEarlyStoppingListener implements EarlyStoppingListener { - private static Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); + private static final Logger log = LoggerFactory.getLogger(LoggingEarlyStoppingListener.class); private int onStartCallCount = 0; private int onEpochCallCount = 0; private int onCompletionCallCount = 0; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestKryo.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestKryo.java index 48212f814..da5d7822a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestKryo.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestKryo.java @@ -53,10 +53,10 @@ public class TestKryo extends BaseSparkKryoTest { private void testSerialization(T in, SerializerInstance si) { ByteBuffer bb = si.serialize(in, null); - T deserialized = (T)si.deserialize(bb, null); + T deserialized = si.deserialize(bb, null); boolean equals = in.equals(deserialized); - assertTrue(equals, in.getClass() + "\t" + in.toString()); + assertTrue(equals, in.getClass() + "\t" + in); } @Test @@ -105,7 +105,7 @@ public class TestKryo extends BaseSparkKryoTest { GraphVertex[] vertices = new GraphVertex[] {new ElementWiseVertex(ElementWiseVertex.Op.Add), new L2NormalizeVertex(), new LayerVertex(null, null), new MergeVertex(), new PoolHelperVertex(), new PreprocessorVertex(new CnnToFeedForwardPreProcessor(28, 28, 1)), - new ReshapeVertex(new int[] {1, 1}), new ScaleVertex(1.0), new ShiftVertex(1.0), + new ReshapeVertex(1, 1), new ScaleVertex(1.0), new ShiftVertex(1.0), new SubsetVertex(1, 1), new UnstackVertex(0, 2), new DuplicateToTimeSeriesVertex("in1"), new LastTimeStepVertex("in1")}; @@ -118,26 +118,26 @@ public class TestKryo extends BaseSparkKryoTest { public void testSerializationEvaluation() { Evaluation e = new Evaluation(); - e.eval(Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3}), Nd4j.create(new double[] {0.2, 0.5, 0.3}, new long[]{1, 3})); + e.eval(Nd4j.create(new double[] {1, 0, 0}, 1, 3), Nd4j.create(new double[] {0.2, 0.5, 0.3}, 1, 3)); EvaluationBinary eb = new EvaluationBinary(); - eb.eval(Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3}), Nd4j.create(new double[] {0.2, 0.6, 0.3}, new long[]{1, 3})); + eb.eval(Nd4j.create(new double[] {1, 0, 0}, 1, 3), Nd4j.create(new double[] {0.2, 0.6, 0.3}, 1, 3)); ROC roc = new ROC(30); - roc.eval(Nd4j.create(new double[] {1}, new long[]{1, 1}), Nd4j.create(new double[] {0.2}, new long[]{1, 1})); + roc.eval(Nd4j.create(new double[] {1}, 1, 1), Nd4j.create(new double[] {0.2}, 1, 1)); ROC roc2 = new ROC(); - roc2.eval(Nd4j.create(new double[] {1}, new long[]{1, 1}), Nd4j.create(new double[] {0.2}, new long[]{1, 1})); + roc2.eval(Nd4j.create(new double[] {1}, 1, 1), Nd4j.create(new double[] {0.2}, 1, 1)); ROCMultiClass rocM = new ROCMultiClass(30); - rocM.eval(Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3}), Nd4j.create(new double[] {0.2, 0.5, 0.3}, new long[]{1, 3})); + rocM.eval(Nd4j.create(new double[] {1, 0, 0}, 1, 3), Nd4j.create(new double[] {0.2, 0.5, 0.3}, 1, 3)); ROCMultiClass rocM2 = new ROCMultiClass(); - rocM2.eval(Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3}), Nd4j.create(new double[] {0.2, 0.5, 0.3}, new long[]{1, 3})); + rocM2.eval(Nd4j.create(new double[] {1, 0, 0}, 1, 3), Nd4j.create(new double[] {0.2, 0.5, 0.3}, 1, 3)); ROCBinary rocB = new ROCBinary(30); - rocB.eval(Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3}), Nd4j.create(new double[] {0.2, 0.6, 0.3}, new long[]{1, 3})); + rocB.eval(Nd4j.create(new double[] {1, 0, 0}, 1, 3), Nd4j.create(new double[] {0.2, 0.6, 0.3}, 1, 3)); ROCBinary rocB2 = new ROCBinary(); - rocB2.eval(Nd4j.create(new double[] {1, 0, 0}, new long[]{1, 3}), Nd4j.create(new double[] {0.2, 0.6, 0.3}, new long[]{1, 3})); + rocB2.eval(Nd4j.create(new double[] {1, 0, 0}, 1, 3), Nd4j.create(new double[] {0.2, 0.6, 0.3}, 1, 3)); RegressionEvaluation re = new RegressionEvaluation(); re.eval(Nd4j.rand(1, 5), Nd4j.rand(1, 5)); @@ -184,16 +184,16 @@ public class TestKryo extends BaseSparkKryoTest { testSerialization(new ConcurrentHashMap<>(m), si); testSerialization(Collections.unmodifiableMap(m), si); - testSerialization(Arrays.asList("s"), si); + testSerialization(Collections.singletonList("s"), si); testSerialization(Collections.singleton("s"), si); - testSerialization(Collections.synchronizedList(Arrays.asList("s")), si); + testSerialization(Collections.synchronizedList(Collections.singletonList("s")), si); testSerialization(Collections.emptyList(), si); - testSerialization(new CopyOnWriteArrayList<>(Arrays.asList("s")), si); - testSerialization(Collections.unmodifiableList(Arrays.asList("s")), si); + testSerialization(new CopyOnWriteArrayList<>(Collections.singletonList("s")), si); + testSerialization(Collections.unmodifiableList(Collections.singletonList("s")), si); testSerialization(Collections.singleton("s"), si); - testSerialization(Collections.synchronizedSet(new HashSet<>(Arrays.asList("s"))), si); + testSerialization(Collections.synchronizedSet(new HashSet<>(Collections.singletonList("s"))), si); testSerialization(Collections.emptySet(), si); - testSerialization(Collections.unmodifiableSet(new HashSet<>(Arrays.asList("s"))), si); + testSerialization(Collections.unmodifiableSet(new HashSet<>(Collections.singletonList("s"))), si); } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java index 43c50fdeb..79e6da95c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java @@ -74,8 +74,8 @@ public class MiniBatchTests extends BaseSparkTest { @Override public Object call(DataSet dataSet) throws Exception { - assertTrue(dataSet.getFeatures().columns() == 150); - assertTrue(dataSet.numExamples() == 30); + assertEquals(150, dataSet.getFeatures().columns()); + assertEquals(30, dataSet.numExamples()); return null; } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java index fad1b4092..bd2e0f389 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java @@ -263,7 +263,7 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { Path p = new File(testDir,"dl4j_testSeqPairFn").toPath(); p.toFile().deleteOnExit(); - String outPath = p.toString() + "/out"; + String outPath = p + "/out"; new File(outPath).deleteOnExit(); toWrite.saveAsNewAPIHadoopFile(outPath, Text.class, BytesPairWritable.class, SequenceFileOutputFormat.class); @@ -540,11 +540,7 @@ public class TestDataVecDataSetFunctions extends BaseSparkTest { if (m1 != null && !m1.equals(Nd4j.ones(m1.shape()))) { return false; } - if (m2 != null && !m2.equals(Nd4j.ones(m2.shape()))) { - return false; - } - - return true; + return m2 == null || m2.equals(Nd4j.ones(m2.shape())); } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java index cc6e5f9ec..579effe1a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java @@ -119,7 +119,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { TrainingMaster tm = new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 10, 1, 0); SparkComputationGraph scg = new SparkComputationGraph(sc, cg, tm); - scg.setListeners(Collections.singleton((TrainingListener) new ScoreIterationListener(5))); + scg.setListeners(Collections.singleton(new ScoreIterationListener(5))); JavaRDD rdd = sc.parallelize(list); scg.fitMultiDataSet(rdd); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java index 550ccc9b2..8b5a8b46c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java @@ -120,9 +120,9 @@ public class TestMiscFunctions extends BaseSparkTest { net.init(); List ds = Arrays.asList( - new org.nd4j.linalg.dataset.DataSet(Nd4j.rand(new int[]{1, 4, 5}), Nd4j.create(new double[]{1,1,1,0,0})), - new org.nd4j.linalg.dataset.DataSet(Nd4j.rand(new int[]{1, 4, 5}), Nd4j.create(new double[]{1,1,1,1,0})), - new org.nd4j.linalg.dataset.DataSet(Nd4j.rand(new int[]{1, 4, 5}), Nd4j.create(new double[]{1,1,1,1,1})) + new org.nd4j.linalg.dataset.DataSet(Nd4j.rand(1, 4, 5), Nd4j.create(new double[]{1,1,1,0,0})), + new org.nd4j.linalg.dataset.DataSet(Nd4j.rand(1, 4, 5), Nd4j.create(new double[]{1,1,1,1,0})), + new org.nd4j.linalg.dataset.DataSet(Nd4j.rand(1, 4, 5), Nd4j.create(new double[]{1,1,1,1,1})) ); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java index 9c7f783e0..d2c0d66bc 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java @@ -47,8 +47,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.ArrayList; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; @Slf4j public class TestSparkDl4jMultiLayer extends BaseSparkTest { @@ -136,7 +135,7 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest { tm.deleteTempFiles(sc); assertEquals(10000, evaluation.getNumRowCounter()); //10k test set - assertTrue(!Double.isNaN(evaluation.accuracy())); + assertFalse(Double.isNaN(evaluation.accuracy())); assertTrue(evaluation.accuracy() >= 0.10); assertTrue(evaluation.precision() >= 0.10); assertTrue(evaluation.recall() >= 0.10); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java index cbe7247bd..050e6279c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java @@ -157,7 +157,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { Nd4j.getRandom().setSeed(seed); List list = new ArrayList<>(); for (int i = 0; i < totalExamples; i++) { - INDArray f = Nd4j.rand(new int[] {1, 3, 10, 10}); + INDArray f = Nd4j.rand(1, 3, 10, 10); INDArray l = Nd4j.rand(1, 10); DataSet ds = new DataSet(f, l); list.add(ds); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java index f4939e369..5d33e82c6 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java @@ -47,6 +47,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.io.ByteArrayOutputStream; import java.lang.reflect.Field; +import java.nio.charset.StandardCharsets; import java.util.*; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -255,8 +256,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest { ByteArrayOutputStream baos = new ByteArrayOutputStream(); StatsUtils.exportStatsAsHTML(stats, baos); baos.close(); - byte[] bytes = baos.toByteArray(); - String str = new String(bytes, "UTF-8"); + String str = baos.toString(StandardCharsets.UTF_8); // System.out.println(str); } finally { sc.stop(); @@ -294,8 +294,8 @@ public class TestTrainingStatsCollection extends BaseSparkTest { jvmIDs.add(e.getJvmID()); threadIDs.add(e.getThreadID()); } - assertTrue(machineIDs.size() == expNMachineIDs); - assertTrue(jvmIDs.size() == expNumJvmIds); - assertTrue(threadIDs.size() == expNumThreadIds); + assertEquals(machineIDs.size(), expNMachineIDs); + assertEquals(jvmIDs.size(), expNumJvmIds); + assertEquals(threadIDs.size(), expNumThreadIds); } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java index 6f79d7595..aadf69cdd 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java @@ -119,7 +119,7 @@ public class TestListeners extends BaseSparkTest { String widSubstring = wid.substring(0, wid.length() - 1); assertEquals(firstWorkerSubstring, widSubstring); - String counterVal = wid.substring(wid.length() - 1, wid.length()); + String counterVal = wid.substring(wid.length() - 1); int cv = Integer.parseInt(counterVal); assertTrue(0 <= cv && cv < numExecutors()); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java index 77fdff58e..d7f705e55 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java @@ -39,9 +39,7 @@ import java.util.Arrays; import java.util.List; import java.util.Random; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; @Timeout(300) public class TestRepartitioning extends BaseSparkTest { @@ -61,7 +59,7 @@ public class TestRepartitioning extends BaseSparkTest { rdd = rdd.repartition(200); JavaRDD rdd2 = SparkUtils.repartitionBalanceIfRequired(rdd, Repartition.Always, 100, 10); - assertFalse(rdd == rdd2); //Should be different objects due to repartitioning + assertNotSame(rdd, rdd2); //Should be different objects due to repartitioning assertEquals(10, rdd2.partitions().size()); for (int i = 0; i < 10; i++) { @@ -255,7 +253,7 @@ public class TestRepartitioning extends BaseSparkTest { rdd = rdd.repartition(200); JavaRDD rdd2 = SparkUtils.repartitionApproximateBalance(rdd, Repartition.Always, 10); - assertFalse(rdd == rdd2); //Should be different objects due to repartitioning + assertNotSame(rdd, rdd2); //Should be different objects due to repartitioning assertEquals(10, rdd2.partitions().size()); @@ -277,7 +275,7 @@ public class TestRepartitioning extends BaseSparkTest { JavaRDD rdd = sc.parallelize(list); JavaRDD rdd2 = SparkUtils.repartitionApproximateBalance(rdd, Repartition.Always, 100); - assertFalse(rdd == rdd2); //Should be different objects due to repartitioning + assertNotSame(rdd, rdd2); //Should be different objects due to repartitioning assertEquals(100, rdd2.partitions().size()); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestValidation.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestValidation.java index 21ba9fc23..fe2968ed8 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestValidation.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/util/TestValidation.java @@ -36,6 +36,7 @@ import org.nd4j.linalg.factory.Nd4j; import java.io.File; import java.util.Arrays; +import java.util.Collections; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -169,8 +170,8 @@ public class TestValidation extends BaseSparkTest { //Add MultiDataSet with incorrect labels shape: new MultiDataSet(Nd4j.create(1,10), Nd4j.create(1,20)).save(f3); - r = SparkDataValidation.validateMultiDataSets(sc, f.toURI().toString(), Arrays.asList(new int[]{-1,10}), - Arrays.asList(new int[]{-1,10})); + r = SparkDataValidation.validateMultiDataSets(sc, f.toURI().toString(), Collections.singletonList(new int[]{-1, 10}), + Collections.singletonList(new int[]{-1, 10})); exp = ValidationResult.builder() .countTotal(4) .countTotalValid(3) @@ -183,8 +184,8 @@ public class TestValidation extends BaseSparkTest { //Add a MultiDataSet with incorrect number of feature arrays: new MultiDataSet(new INDArray[]{Nd4j.create(1,10), Nd4j.create(1,10)}, new INDArray[]{Nd4j.create(1,10)}).save(f3); - r = SparkDataValidation.validateMultiDataSets(sc, f.toURI().toString(), Arrays.asList(new int[]{-1,10}), - Arrays.asList(new int[]{-1,10})); + r = SparkDataValidation.validateMultiDataSets(sc, f.toURI().toString(), Collections.singletonList(new int[]{-1, 10}), + Collections.singletonList(new int[]{-1, 10})); exp = ValidationResult.builder() .countTotal(4) .countTotalValid(3) @@ -194,8 +195,8 @@ public class TestValidation extends BaseSparkTest { assertEquals(exp, r); - r = SparkDataValidation.deleteInvalidMultiDataSets(sc, f.toURI().toString(), Arrays.asList(new int[]{-1,10}), - Arrays.asList(new int[]{-1,10})); + r = SparkDataValidation.deleteInvalidMultiDataSets(sc, f.toURI().toString(), Collections.singletonList(new int[]{-1, 10}), + Collections.singletonList(new int[]{-1, 10})); exp.setCountInvalidDeleted(1); assertEquals(exp, r); assertFalse(f3.exists()); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java index d8e6f235d..aafb35fcf 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java @@ -36,27 +36,27 @@ import java.util.concurrent.atomic.AtomicLong; public class FirstIterationFunction implements FlatMapFunction, Long>>, Entry> { - private int ithIteration = 1; - private int vectorLength; - private boolean useAdaGrad; + private final int ithIteration = 1; + private final int vectorLength; + private final boolean useAdaGrad; private int batchSize = 0; - private double negative; - private int window; - private double alpha; - private double minAlpha; - private long totalWordCount; - private long seed; - private int maxExp; - private double[] expTable; - private int iterations; - private Map indexSyn0VecMap; - private Map pointSyn1VecMap; - private AtomicLong nextRandom = new AtomicLong(5); + private final double negative; + private final int window; + private final double alpha; + private final double minAlpha; + private final long totalWordCount; + private final long seed; + private final int maxExp; + private final double[] expTable; + private final int iterations; + private final Map indexSyn0VecMap; + private final Map pointSyn1VecMap; + private final AtomicLong nextRandom = new AtomicLong(5); - private volatile VocabCache vocab; + private final VocabCache vocab; private volatile NegativeHolder negativeHolder; - private AtomicLong cid = new AtomicLong(0); - private AtomicLong aff = new AtomicLong(0); + private final AtomicLong cid = new AtomicLong(0); + private final AtomicLong aff = new AtomicLong(0); @@ -123,7 +123,7 @@ public class FirstIterationFunction implements for (int ithWordInSentence = 0; ithWordInSentence < vocabWordsList.size(); ithWordInSentence++) { // Random value ranging from 0 to window size nextRandom.set(Math.abs(nextRandom.get() * 25214903917L + 11)); - int b = (int) (long) this.nextRandom.get() % window; + int b = (int) this.nextRandom.get() % window; VocabWord currentWord = vocabWordsList.get(ithWordInSentence); if (currentWord != null) { skipGram(ithWordInSentence, vocabWordsList, b, currentSentenceAlpha); @@ -164,7 +164,7 @@ public class FirstIterationFunction implements if (indexSyn0VecMap.containsKey(vocab.elementAtIndex(currentWordIndex))) { l1 = indexSyn0VecMap.get(vocab.elementAtIndex(currentWordIndex)); } else { - l1 = getRandomSyn0Vec(vectorLength, (long) currentWordIndex); + l1 = getRandomSyn0Vec(vectorLength, currentWordIndex); } // diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java index 5b788562b..bade562b0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java @@ -32,7 +32,7 @@ import java.io.Serializable; import java.util.concurrent.atomic.AtomicBoolean; public class NegativeHolder implements Serializable { - private static NegativeHolder ourInstance = new NegativeHolder(); + private static final NegativeHolder ourInstance = new NegativeHolder(); public static NegativeHolder getInstance() { return ourInstance; @@ -43,7 +43,7 @@ public class NegativeHolder implements Serializable { @Getter private volatile INDArray table; - private transient AtomicBoolean wasInit = new AtomicBoolean(false); + private final transient AtomicBoolean wasInit = new AtomicBoolean(false); private transient VocabCache vocab; private NegativeHolder() { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java index 205d54ae0..7907821cb 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java @@ -38,27 +38,27 @@ import java.util.concurrent.atomic.AtomicLong; public class SecondIterationFunction implements FlatMapFunction, Long>>, Entry> { - private int ithIteration = 1; - private int vectorLength; - private boolean useAdaGrad; + private final int ithIteration = 1; + private final int vectorLength; + private final boolean useAdaGrad; private int batchSize = 0; - private double negative; - private int window; - private double alpha; - private double minAlpha; - private long totalWordCount; - private long seed; - private int maxExp; - private double[] expTable; - private int iterations; + private final double negative; + private final int window; + private final double alpha; + private final double minAlpha; + private final long totalWordCount; + private final long seed; + private final int maxExp; + private final double[] expTable; + private final int iterations; - private AtomicLong nextRandom = new AtomicLong(5); + private final AtomicLong nextRandom = new AtomicLong(5); - private volatile VocabCache vocab; + private final VocabCache vocab; private transient volatile NegativeHolder negativeHolder; private transient volatile VocabHolder vocabHolder; - private AtomicLong cid = new AtomicLong(0); - private AtomicLong aff = new AtomicLong(0); + private final AtomicLong cid = new AtomicLong(0); + private final AtomicLong aff = new AtomicLong(0); @@ -133,7 +133,7 @@ public class SecondIterationFunction implements FlatMapFunction { - private AtomicLong nextRandom = new AtomicLong(5); + private final AtomicLong nextRandom = new AtomicLong(5); // private static Logger log = LoggerFactory.getLogger(SentenceBatch.class); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java index 1a983c68d..480215b2c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java @@ -35,14 +35,14 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; public class VocabHolder implements Serializable { - private static VocabHolder ourInstance = new VocabHolder(); + private static final VocabHolder ourInstance = new VocabHolder(); - private Map indexSyn0VecMap = new ConcurrentHashMap<>(); - private Map pointSyn1VecMap = new ConcurrentHashMap<>(); - private HashSet workers = new LinkedHashSet<>(); + private final Map indexSyn0VecMap = new ConcurrentHashMap<>(); + private final Map pointSyn1VecMap = new ConcurrentHashMap<>(); + private final HashSet workers = new LinkedHashSet<>(); - private AtomicLong seed = new AtomicLong(0); - private AtomicInteger vectorLength = new AtomicInteger(0); + private final AtomicLong seed = new AtomicLong(0); + private final AtomicInteger vectorLength = new AtomicInteger(0); public static VocabHolder getInstance() { return ourInstance; @@ -56,8 +56,7 @@ public class VocabHolder implements Serializable { } public INDArray getSyn0Vector(Integer wordIndex, VocabCache vocabCache) { - if (!workers.contains(Thread.currentThread().getId())) - workers.add(Thread.currentThread().getId()); + workers.add(Thread.currentThread().getId()); VocabWord word = vocabCache.elementAtIndex(wordIndex); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java index b5146f74d..f52676adc 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java @@ -57,10 +57,10 @@ import java.util.concurrent.atomic.AtomicLong; public class Word2Vec extends WordVectorsImpl implements Serializable { private INDArray trainedSyn1; - private static Logger log = LoggerFactory.getLogger(Word2Vec.class); - private int MAX_EXP = 6; + private static final Logger log = LoggerFactory.getLogger(Word2Vec.class); + private final int MAX_EXP = 6; @Getter - private double[] expTable; + private final double[] expTable; @Getter protected VectorsConfiguration configuration; @@ -68,8 +68,8 @@ public class Word2Vec extends WordVectorsImpl implements Serializable private int nGrams = 1; private String tokenizer = "org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFactory"; private String tokenPreprocessor = "org.deeplearning4j.text.tokenization.tokenizer.preprocessor.CommonPreprocessor"; - private boolean removeStop = false; - private long seed = 42L; + private final boolean removeStop = false; + private final long seed = 42L; private boolean useUnknown = false; // Constructor to take InMemoryLookupCache table from an already trained model diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java index 5ce201f0c..eebdb1eb2 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java @@ -33,7 +33,7 @@ import java.util.*; */ @Deprecated public class Word2VecChange implements Serializable { - private Map> changes = new HashMap<>(); + private final Map> changes = new HashMap<>(); public Word2VecChange(List> counterMap, Word2VecParam param) { Iterator> iter = counterMap.iterator(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java index 1e7f81133..68c464377 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java @@ -44,7 +44,7 @@ public class Word2VecParam implements Serializable { private double alpha = 0.025; private double minAlpha = 1e-2; private int totalWords = 1; - private static transient final Logger log = LoggerFactory.getLogger(Word2VecPerformer.class); + private static final Logger log = LoggerFactory.getLogger(Word2VecPerformer.class); private int lastChecked = 0; private Broadcast wordCount; private InMemoryLookupTable weights; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java index 3b65a353d..6bd786052 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java @@ -41,21 +41,21 @@ import java.util.concurrent.atomic.AtomicLong; @Deprecated public class Word2VecPerformer implements VoidFunction, AtomicLong>> { - private static double MAX_EXP = 6; + private static final double MAX_EXP = 6; private boolean useAdaGrad = false; private double negative = 5; private int numWords = 1; private INDArray table; private int window = 5; - private AtomicLong nextRandom = new AtomicLong(5); + private final AtomicLong nextRandom = new AtomicLong(5); private double alpha = 0.025; private double minAlpha = 1e-2; private int totalWords = 1; - private static transient final Logger log = LoggerFactory.getLogger(Word2VecPerformer.class); + private static final Logger log = LoggerFactory.getLogger(Word2VecPerformer.class); private int lastChecked = 0; - private Broadcast wordCount; - private InMemoryLookupTable weights; - private double[] expTable = new double[1000]; + private final Broadcast wordCount; + private final InMemoryLookupTable weights; + private final double[] expTable = new double[1000]; private int vectorLength; @@ -239,7 +239,7 @@ public class Word2VecPerformer implements VoidFunction, Ato double numWordsSoFar = wordCount.getValue().doubleValue(); List sentence = pair.getFirst(); - double alpha2 = Math.max(minAlpha, alpha * (1 - (1.0 * numWordsSoFar / (double) totalWords))); + double alpha2 = Math.max(minAlpha, alpha * (1 - (numWordsSoFar / (double) totalWords))); int totalNewWords = 0; trainSentence(sentence, alpha2); totalNewWords += sentence.size(); @@ -253,7 +253,7 @@ public class Word2VecPerformer implements VoidFunction, Ato log.info("Words so far " + newWords + " out of " + totalWords); } - pair.getSecond().getAndAdd((long) totalNewWords); + pair.getSecond().getAndAdd(totalNewWords); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java index 7bb7c44d8..b8330b064 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java @@ -64,7 +64,7 @@ public class Word2VecPerformerVoid implements VoidFunction, private double minAlpha = 1e-2; private int totalWords = 1; private int iterations = 5; - private static transient final Logger log = LoggerFactory.getLogger(Word2VecPerformerVoid.class); + private static final Logger log = LoggerFactory.getLogger(Word2VecPerformerVoid.class); private int lastChecked = 0; private Broadcast wordCount; private InMemoryLookupTable weights; @@ -389,7 +389,7 @@ public class Word2VecPerformerVoid implements VoidFunction, double numWordsSoFar = wordCount.getValue().doubleValue(); List sentence = pair.getFirst(); - double alpha2 = Math.max(minAlpha, alpha * (1 - (1.0 * numWordsSoFar / (double) totalWords))); + double alpha2 = Math.max(minAlpha, alpha * (1 - (numWordsSoFar / (double) totalWords))); int totalNewWords = 0; trainSentence(sentence, alpha2); totalNewWords += sentence.size(); @@ -403,7 +403,7 @@ public class Word2VecPerformerVoid implements VoidFunction, log.info("Words so far " + newWords + " out of " + totalWords); } - pair.getSecond().getAndAdd((long) totalNewWords); + pair.getSecond().getAndAdd(totalNewWords); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java index 677fb3738..db0e7d5ef 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java @@ -29,7 +29,7 @@ import java.util.List; @Deprecated public class Word2VecSetup implements Function, Long>, Word2VecFuncCall> { - private Broadcast param; + private final Broadcast param; public Word2VecSetup(Broadcast param) { this.param = param; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java index 6adcc7d1f..21908df41 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java @@ -77,19 +77,19 @@ public class Word2VecVariables { public static T assignVar(String variableName, SparkConf conf, Class clazz) throws Exception { Object ret; if (clazz.equals(Integer.class)) { - ret = conf.getInt(variableName, (Integer) getDefault(variableName)); + ret = conf.getInt(variableName, getDefault(variableName)); } else if (clazz.equals(Double.class)) { - ret = conf.getDouble(variableName, (Double) getDefault(variableName)); + ret = conf.getDouble(variableName, getDefault(variableName)); } else if (clazz.equals(Boolean.class)) { - ret = conf.getBoolean(variableName, (Boolean) getDefault(variableName)); + ret = conf.getBoolean(variableName, getDefault(variableName)); } else if (clazz.equals(String.class)) { - ret = conf.get(variableName, (String) getDefault(variableName)); + ret = conf.get(variableName, getDefault(variableName)); } else if (clazz.equals(Long.class)) { - ret = conf.getLong(variableName, (Long) getDefault(variableName)); + ret = conf.getLong(variableName, getDefault(variableName)); } else { throw new Exception("Variable Type not supported. Only boolean, int, double and String supported."); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java index 4b757ec5f..dcd600e18 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java @@ -36,8 +36,8 @@ import java.util.concurrent.atomic.AtomicLong; public class CountCumSum { // Starting variables - private JavaSparkContext sc; - private JavaRDD sentenceCountRDD; + private final JavaSparkContext sc; + private final JavaRDD sentenceCountRDD; // Variables to fill in as we go private JavaRDD foldWithinPartitionRDD; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java index f332c1f92..a4d54ba9c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java @@ -33,7 +33,7 @@ import java.util.concurrent.atomic.AtomicLong; * @author jeffreytang */ public class FoldBetweenPartitionFunction implements Function2, Iterator> { - private Broadcast> broadcastedMaxPerPartitionCounter; + private final Broadcast> broadcastedMaxPerPartitionCounter; public FoldBetweenPartitionFunction(Broadcast> broadcastedMaxPerPartitionCounter) { this.broadcastedMaxPerPartitionCounter = broadcastedMaxPerPartitionCounter; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java index 38910c623..7e4b3e7f8 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java @@ -38,7 +38,7 @@ public class FoldWithinPartitionFunction implements Function2> maxPerPartitionAcc; + private final CollectionAccumulator> maxPerPartitionAcc; @Override diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java index 5fb7b0fbc..49329436b 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java @@ -55,7 +55,7 @@ public class TextPipeline { private Broadcast> stopWordBroadCast; // Return values private JavaRDD, AtomicLong>> sentenceWordsCountRDD; - private VocabCache vocabCache = new AbstractCache<>(); + private final VocabCache vocabCache = new AbstractCache<>(); private Broadcast> vocabCacheBroadcast; private JavaRDD> vocabWordListRDD; private JavaRDD sentenceCountRDD; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java index 75b855695..335294d4f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java @@ -34,8 +34,8 @@ import java.util.List; @SuppressWarnings("unchecked") @Slf4j public class TokenizerFunction implements Function> { - private String tokenizerFactoryClazz; - private String tokenizerPreprocessorClazz; + private final String tokenizerFactoryClazz; + private final String tokenizerPreprocessorClazz; private transient TokenizerFactory tokenizerFactory; private int nGrams = 1; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java index 312677c98..e66067be5 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java @@ -34,8 +34,8 @@ import java.util.concurrent.atomic.AtomicLong; */ public class UpdateWordFreqAccumulatorFunction implements Function, Pair, AtomicLong>> { - private Broadcast> stopWords; - private CollectionAccumulator> wordFreqAcc; + private final Broadcast> stopWords; + private final CollectionAccumulator> wordFreqAcc; public UpdateWordFreqAccumulatorFunction(Broadcast> stopWords, CollectionAccumulator> wordFreqAcc) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java index 4859b91a6..983855af9 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java @@ -45,6 +45,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import java.io.File; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import static org.junit.jupiter.api.Assertions.*; @@ -85,7 +86,7 @@ public class Word2VecTest { // .setRemoveStop(false) .tokenizerFactory(t).seed(42L).negative(10).useAdaGrad(false).layerSize(150).windowSize(5) .learningRate(0.025).minLearningRate(0.0001).iterations(1).batchSize(100).minWordFrequency(5) - .stopWords(Arrays.asList("three")).useUnknown(true).build(); + .stopWords(Collections.singletonList("three")).useUnknown(true).build(); word2Vec.train(corpus); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java index 618bf0ac7..12f52f26c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java @@ -37,7 +37,7 @@ public class TestFunction implements Function { return a; } - private List lst; + private final List lst; private int a; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java index 7e4a4944e..e570bce58 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java @@ -48,8 +48,7 @@ import scala.Tuple2; import java.util.*; import java.util.concurrent.atomic.AtomicLong; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.*; /** * @author Jeffrey Tang @@ -206,7 +205,7 @@ public class TextPipelineTest extends BaseSparkTest { pipeline.filterMinWordAddVocab(wordFreqCounter); VocabCache vocabCache = pipeline.getVocabCache(); - assertTrue(vocabCache != null); + assertNotNull(vocabCache); VocabWord redVocab = vocabCache.tokenFor("red"); VocabWord flowerVocab = vocabCache.tokenFor("flowers"); @@ -239,7 +238,7 @@ public class TextPipelineTest extends BaseSparkTest { pipeline.buildVocabCache(); VocabCache vocabCache = pipeline.getVocabCache(); - assertTrue(vocabCache != null); + assertNotNull(vocabCache); log.info("VocabWords: " + vocabCache.words()); assertEquals(5, vocabCache.numWords()); @@ -349,8 +348,8 @@ public class TextPipelineTest extends BaseSparkTest { CountCumSum countCumSum = new CountCumSum(sentenceCountRDD); JavaRDD sentenceCountCumSumRDD = countCumSum.buildCumSum(); List sentenceCountCumSumList = sentenceCountCumSumRDD.collect(); - assertTrue(sentenceCountCumSumList.get(0) == 6L); - assertTrue(sentenceCountCumSumList.get(1) == 9L); + assertEquals(6L, (long) sentenceCountCumSumList.get(0)); + assertEquals(9L, (long) sentenceCountCumSumList.get(1)); sc.stop(); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java index 64d83910f..1b9ba7539 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java @@ -139,7 +139,7 @@ public class SilentTrainingDriver implements TrainingDriver maxIter / 2 || i >= stopLyingIteration)) { diff --git a/cavis-dnn/cavis-dnn-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java b/cavis-dnn/cavis-dnn-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java index de88c6851..280da3878 100644 --- a/cavis-dnn/cavis-dnn-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java +++ b/cavis-dnn/cavis-dnn-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java @@ -52,7 +52,7 @@ public class Test6058 extends BaseDL4JTest { .build(); System.out.println("fit"); - INDArray weights = Nd4j.rand(new int[]{nWords, 100}); + INDArray weights = Nd4j.rand(nWords, 100); weights.getRow(1).assign(0); try { tsne.fit(weights); diff --git a/cavis-native/cavis-native-blas/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java b/cavis-native/cavis-native-blas/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java index 2a544489a..3044db667 100644 --- a/cavis-native/cavis-native-blas/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java +++ b/cavis-native/cavis-native-blas/src/main/java/org/nd4j/nativeblas/BaseNativeNDArrayFactory.java @@ -43,6 +43,7 @@ import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.HashMap; import java.util.Map; @@ -114,8 +115,8 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { DataBuffer data = null; Pointer shapeBufferPointer = nativeOps.shapeBufferForNumpy(pointer); int length = nativeOps.lengthForShapeBufferPointer(shapeBufferPointer); - shapeBufferPointer.capacity(8 * length); - shapeBufferPointer.limit(8 * length); + shapeBufferPointer.capacity(8L * length); + shapeBufferPointer.limit(8L * length); shapeBufferPointer.position(0); @@ -307,8 +308,8 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { DataBuffer data = null; Pointer shapeBufferPointer = nativeOps.shapeBufferForNumpyHeader(pointer); int length = nativeOps.lengthForShapeBufferPointer(shapeBufferPointer); - shapeBufferPointer.capacity(8 * length); - shapeBufferPointer.limit(8 * length); + shapeBufferPointer.capacity(8L * length); + shapeBufferPointer.limit(8L * length); shapeBufferPointer.position(0); @@ -488,7 +489,7 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { */ @Override public INDArray createFromNpyFile(File file) { - byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8")); + byte[] pathBytes = file.getAbsolutePath().getBytes(StandardCharsets.UTF_8); ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder()); directBuffer.put(pathBytes); ((Buffer) directBuffer).rewind(); @@ -668,7 +669,7 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { public Map _createFromNpzFile(File file) throws Exception{ // TODO: Fix libnd4j implementation - byte[] pathBytes = file.getAbsolutePath().getBytes(Charset.forName("UTF-8")); + byte[] pathBytes = file.getAbsolutePath().getBytes(StandardCharsets.UTF_8); ByteBuffer directBuffer = ByteBuffer.allocateDirect(pathBytes.length).order(ByteOrder.nativeOrder()); directBuffer.put(pathBytes); ((Buffer) directBuffer).rewind(); @@ -735,7 +736,7 @@ public abstract class BaseNativeNDArrayFactory extends BaseNDArrayFactory { } else{ - throw new Exception("Unsupported data type: " + String.valueOf(elemSize)); + throw new Exception("Unsupported data type: " + elemSize); } diff --git a/cavis-native/cavis-native-common/src/main/java/org/nd4j/compression/impl/NoOp.java b/cavis-native/cavis-native-common/src/main/java/org/nd4j/compression/impl/NoOp.java index bd21a9a5e..9be6daf7c 100644 --- a/cavis-native/cavis-native-common/src/main/java/org/nd4j/compression/impl/NoOp.java +++ b/cavis-native/cavis-native-common/src/main/java/org/nd4j/compression/impl/NoOp.java @@ -84,20 +84,20 @@ public class NoOp extends AbstractCompressor { CompressionDescriptor descriptor = new CompressionDescriptor(); descriptor.setCompressionType(getCompressionType()); - descriptor.setOriginalLength(length * elementSize); + descriptor.setOriginalLength((long) length * elementSize); descriptor.setCompressionAlgorithm(getDescriptor()); descriptor.setOriginalElementSize(elementSize); - descriptor.setCompressedLength(length * elementSize); + descriptor.setCompressedLength((long) length * elementSize); descriptor.setNumberOfElements(length); - BytePointer ptr = new BytePointer(length * elementSize); + BytePointer ptr = new BytePointer((long) length * elementSize); val perfD = PerformanceTracker.getInstance().helperStartTransaction(); // this Pointer.memcpy is used intentionally. This method operates on host memory ALWAYS - Pointer.memcpy(ptr, srcPointer, length * elementSize); + Pointer.memcpy(ptr, srcPointer, (long) length * elementSize); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, length * elementSize, MemcpyDirection.HOST_TO_HOST); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, (long) length * elementSize, MemcpyDirection.HOST_TO_HOST); CompressedDataBuffer buffer = new CompressedDataBuffer(ptr, descriptor); diff --git a/cavis-native/cavis-native-common/src/main/java/org/nd4j/rng/deallocator/GarbageStateReference.java b/cavis-native/cavis-native-common/src/main/java/org/nd4j/rng/deallocator/GarbageStateReference.java index bbf81f2e7..9ed8a5608 100644 --- a/cavis-native/cavis-native-common/src/main/java/org/nd4j/rng/deallocator/GarbageStateReference.java +++ b/cavis-native/cavis-native-common/src/main/java/org/nd4j/rng/deallocator/GarbageStateReference.java @@ -28,7 +28,7 @@ import java.lang.ref.WeakReference; public class GarbageStateReference extends WeakReference { @Getter - private Pointer statePointer; + private final Pointer statePointer; public GarbageStateReference(NativePack referent, ReferenceQueue queue) { super(referent, queue); diff --git a/cavis-native/cavis-native-common/src/main/java/org/nd4j/rng/deallocator/NativeRandomDeallocator.java b/cavis-native/cavis-native-common/src/main/java/org/nd4j/rng/deallocator/NativeRandomDeallocator.java index 362989583..091fa590e 100644 --- a/cavis-native/cavis-native-common/src/main/java/org/nd4j/rng/deallocator/NativeRandomDeallocator.java +++ b/cavis-native/cavis-native-common/src/main/java/org/nd4j/rng/deallocator/NativeRandomDeallocator.java @@ -38,7 +38,7 @@ public class NativeRandomDeallocator { // we don't really need concurrency here, so 1 queue will be just fine private final ReferenceQueue queue; private final Map referenceMap; - private List deallocatorThreads = new ArrayList<>(); + private final List deallocatorThreads = new ArrayList<>(); private NativeRandomDeallocator() { this.queue = new ReferenceQueue<>(); diff --git a/cavis-native/cavis-native-common/src/main/java/org/nd4j/storage/CompressedRamStorage.java b/cavis-native/cavis-native-common/src/main/java/org/nd4j/storage/CompressedRamStorage.java index 73b3f2c66..8c6cd0500 100644 --- a/cavis-native/cavis-native-common/src/main/java/org/nd4j/storage/CompressedRamStorage.java +++ b/cavis-native/cavis-native-common/src/main/java/org/nd4j/storage/CompressedRamStorage.java @@ -35,9 +35,9 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; public class CompressedRamStorage implements AbstractStorage { private NDArrayCompressor compressor = new NoOp(); - private Map compressedEntries = new ConcurrentHashMap<>(); + private final Map compressedEntries = new ConcurrentHashMap<>(); private boolean useInplaceCompression = false; - private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); private boolean emulateIsAbsent = false; private CompressedRamStorage() { diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java index 668b4e25b..d216cca3c 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuNDArrayFactory.java @@ -272,12 +272,12 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { @Override public INDArray create(double[] data, long[] shape, char ordering) { - return create(data, shape, (Character) ordering); + return create(data, shape, ordering); } @Override public INDArray create(float[] data, long[] shape, char ordering) { - return create(data, shape, (Character) ordering); + return create(data, shape, ordering); } @Override @@ -682,9 +682,9 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { val tadManager = Nd4j.getExecutioner().getTADManager(); - val tadBuffers = tadManager.getTADOnlyShapeInfo(source, new int[] {sourceDimension}); + val tadBuffers = tadManager.getTADOnlyShapeInfo(source, sourceDimension); - val zTadBuffers = tadManager.getTADOnlyShapeInfo(ret, new int[] {sourceDimension}); + val zTadBuffers = tadManager.getTADOnlyShapeInfo(ret, sourceDimension); val hostTadShapeInfo = tadBuffers.getFirst().addressPointer(); @@ -970,10 +970,7 @@ public class CpuNDArrayFactory extends BaseNativeNDArrayFactory { source.setData(buffer); - if (buffer instanceof CompressedDataBuffer) - source.markAsCompressed(true); - else - source.markAsCompressed(false); + source.markAsCompressed(buffer instanceof CompressedDataBuffer); return source; } diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java index f3179af14..adb7438cc 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/CpuTADManager.java @@ -41,8 +41,8 @@ public class CpuTADManager implements TADManager { private Map> cache = new ConcurrentHashMap<>(); private NativeOps nativeOps; private ConstantHandler constantHandler; - private AtomicLong bytes = new AtomicLong(0); - private AtomicInteger counter = new AtomicInteger(0); + private final AtomicLong bytes = new AtomicLong(0); + private final AtomicInteger counter = new AtomicInteger(0); private static final int MAX_ENTRIES = 100; public CpuTADManager() { diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.java index 3e2cc6619..538caaa7e 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/DirectShapeInfoProvider.java @@ -40,8 +40,8 @@ public class DirectShapeInfoProvider extends BaseShapeInfoProvider { // TODO: to be removed private Map> shapeCache = new ConcurrentHashMap<>(); - private Map> longCache = new ConcurrentHashMap<>(); - private AtomicInteger counter = new AtomicInteger(0); + private final Map> longCache = new ConcurrentHashMap<>(); + private final AtomicInteger counter = new AtomicInteger(0); private static final int MAX_ENTRIES = 1000; public Pair createShapeInformation(long[] shape, long[] stride, long elementWiseStride, char order, DataType dataType) { diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java index bf23aec07..16841b11c 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/blas/CpuLapack.java @@ -44,7 +44,7 @@ public class CpuLapack extends BaseLapack { if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - return A.ordering() == 'f' ? (int) A.rows() : (int) A.columns(); + return A.ordering() == 'f' ? A.rows() : A.columns(); } //========================= // L U DECOMP @@ -86,7 +86,7 @@ public class CpuLapack extends BaseLapack { // Copy R ( upper part of Q ) into result if( R != null ) { R.assign( A.get( NDArrayIndex.interval( 0, A.columns() ), NDArrayIndex.all() ) ) ; - INDArrayIndex ix[] = new INDArrayIndex[ 2 ] ; + INDArrayIndex[] ix = new INDArrayIndex[ 2 ] ; for( int i=1 ; i strings) { // header size first - long size = (strings.size() + 1) * 8; + long size = (strings.size() + 1) * 8L; for (val s:strings) size += s.length(); diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.java index dd61c6d83..8d00cc5cc 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/cache/ConstantBuffersCache.java @@ -35,8 +35,8 @@ import java.util.concurrent.atomic.AtomicLong; public class ConstantBuffersCache extends BasicConstantHandler { protected Map buffersCache = new ConcurrentHashMap<>(); - private AtomicInteger counter = new AtomicInteger(0); - private AtomicLong bytes = new AtomicLong(0); + private final AtomicInteger counter = new AtomicInteger(0); + private final AtomicLong bytes = new AtomicLong(0); private static final int MAX_ENTRIES = 1000; /** @@ -58,8 +58,8 @@ public class ConstantBuffersCache extends BasicConstantHandler { counter.incrementAndGet(); buffersCache.put(descriptor, buffer); - bytes.addAndGet(array.length * Nd4j.sizeOfDataType(dataType)); - AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); + bytes.addAndGet((long) array.length * Nd4j.sizeOfDataType(dataType)); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, (long) array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } @@ -78,8 +78,8 @@ public class ConstantBuffersCache extends BasicConstantHandler { counter.incrementAndGet(); buffersCache.put(descriptor, buffer); - bytes.addAndGet(array.length * Nd4j.sizeOfDataType(dataType)); - AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); + bytes.addAndGet((long) array.length * Nd4j.sizeOfDataType(dataType)); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, (long) array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } @@ -98,8 +98,8 @@ public class ConstantBuffersCache extends BasicConstantHandler { counter.incrementAndGet(); buffersCache.put(descriptor, buffer); - bytes.addAndGet(array.length * Nd4j.sizeOfDataType(dataType)); - AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); + bytes.addAndGet((long) array.length * Nd4j.sizeOfDataType(dataType)); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, (long) array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } @@ -118,8 +118,8 @@ public class ConstantBuffersCache extends BasicConstantHandler { counter.incrementAndGet(); buffersCache.put(descriptor, buffer); - bytes.addAndGet(array.length * Nd4j.sizeOfDataType(dataType)); - AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); + bytes.addAndGet((long) array.length * Nd4j.sizeOfDataType(dataType)); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, (long) array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } @@ -138,8 +138,8 @@ public class ConstantBuffersCache extends BasicConstantHandler { counter.incrementAndGet(); buffersCache.put(descriptor, buffer); - bytes.addAndGet(array.length * Nd4j.sizeOfDataType(dataType)); - AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, array.length * Nd4j.sizeOfDataType(dataType)); + bytes.addAndGet((long) array.length * Nd4j.sizeOfDataType(dataType)); + AllocationsTracker.getInstance().markAllocated(AllocationKind.CONSTANT, 0, (long) array.length * Nd4j.sizeOfDataType(dataType)); } return buffer; } diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/compression/CpuFlexibleThreshold.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/compression/CpuFlexibleThreshold.java index e13355cfe..9b43a1414 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/compression/CpuFlexibleThreshold.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/compression/CpuFlexibleThreshold.java @@ -77,7 +77,7 @@ public class CpuFlexibleThreshold extends CpuThreshold { pointer.put(3, 0); CompressionDescriptor descriptor = new CompressionDescriptor(); - descriptor.setCompressedLength(compressedLength * 4); // sizeOf(INT) + descriptor.setCompressedLength(compressedLength * 4L); // sizeOf(INT) descriptor.setOriginalLength(originalLength); descriptor.setOriginalElementSize(Nd4j.sizeOfDataType(buffer.dataType())); descriptor.setNumberOfElements(buffer.length()); diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/compression/CpuThreshold.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/compression/CpuThreshold.java index 209747157..96d0ff617 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/compression/CpuThreshold.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/compression/CpuThreshold.java @@ -124,7 +124,7 @@ public class CpuThreshold extends AbstractCompressor { pointer.put(3, 0); CompressionDescriptor descriptor = new CompressionDescriptor(); - descriptor.setCompressedLength(compressedLength * 4); // sizeOf(INT) + descriptor.setCompressedLength(compressedLength * 4L); // sizeOf(INT) descriptor.setOriginalLength(originalLength); descriptor.setOriginalElementSize(Nd4j.sizeOfDataType(buffer.dataType())); descriptor.setNumberOfElements(buffer.length()); diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 0a922c704..d6ddf49de 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -40,8 +40,8 @@ import org.nd4j.nativeblas.OpaqueRandomGenerator; public class CpuOpContext extends BaseOpContext implements OpContext, Deallocatable { // we might want to have configurable - private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - private OpaqueContext context = nativeOps.createGraphContext(1); + private final NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private final OpaqueContext context = nativeOps.createGraphContext(1); private final transient long id = Nd4j.getDeallocatorService().nextValue(); public CpuOpContext() { @@ -74,7 +74,7 @@ public class CpuOpContext extends BaseOpContext implements OpContext, Deallocata if (arguments.length > 0) { super.setTArguments(arguments); nativeOps.setGraphContextTArguments(context, new DoublePointer(arguments), arguments.length); - }; + } } @Override @@ -86,7 +86,7 @@ public class CpuOpContext extends BaseOpContext implements OpContext, Deallocata args[e] = arguments[e].toInt(); nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length); - }; + } } @Override @@ -150,7 +150,7 @@ public class CpuOpContext extends BaseOpContext implements OpContext, Deallocata @Override public String getUniqueId() { - return new String("CTX_" + id); + return "CTX_" + id; } @Override diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index 52ab50235..53fc39d7e 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -76,20 +76,20 @@ import java.util.*; @Slf4j public class NativeOpExecutioner extends DefaultOpExecutioner { private final NativeOpsHolder holder = NativeOpsHolder.getInstance(); - private NativeOps loop = holder.getDeviceNativeOps(); - private ConstantHandler constantHandler = Nd4j.getConstantHandler(); - private CpuTADManager tadManager = new CpuTADManager(); + private final NativeOps loop = holder.getDeviceNativeOps(); + private final ConstantHandler constantHandler = Nd4j.getConstantHandler(); + private final CpuTADManager tadManager = new CpuTADManager(); //thread locals for custom op inputs and outputs to prevent allocations //every time exec(CustomOp) is called - private ThreadLocal> inputShapes = new ThreadLocal<>(); - private ThreadLocal> inputBuffers = new ThreadLocal<>(); - private ThreadLocal> outputShapes = new ThreadLocal<>(); - private ThreadLocal> outputBuffers = new ThreadLocal<>(); - private ThreadLocal> iArgsPointer = new ThreadLocal<>(); - private ThreadLocal> tArgsPointer = new ThreadLocal<>(); - private ThreadLocal> bArgsPointer = new ThreadLocal<>(); - private ThreadLocal> halfArgsPointer = new ThreadLocal<>(); + private final ThreadLocal> inputShapes = new ThreadLocal<>(); + private final ThreadLocal> inputBuffers = new ThreadLocal<>(); + private final ThreadLocal> outputShapes = new ThreadLocal<>(); + private final ThreadLocal> outputBuffers = new ThreadLocal<>(); + private final ThreadLocal> iArgsPointer = new ThreadLocal<>(); + private final ThreadLocal> tArgsPointer = new ThreadLocal<>(); + private final ThreadLocal> bArgsPointer = new ThreadLocal<>(); + private final ThreadLocal> halfArgsPointer = new ThreadLocal<>(); protected Map customOps = null; @@ -103,8 +103,8 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * Instead of allocating new memory chunks for each batch invocation, we reuse them on thread/opNum basis * Since for NativeOpExecutioner all executions are synchronous */ - private ThreadLocal> batchPointers = new ThreadLocal<>(); - private ThreadLocal> memoryBlocks = new ThreadLocal<>(); + private final ThreadLocal> batchPointers = new ThreadLocal<>(); + private final ThreadLocal> memoryBlocks = new ThreadLocal<>(); public NativeOpExecutioner() { tadManager.init(loop, constantHandler); @@ -120,7 +120,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } else { val split = env.toLowerCase().split(","); for (val name:split) { - mklOverrides.put(name, new Boolean(true)); + mklOverrides.put(name, Boolean.TRUE); } } } @@ -298,7 +298,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { long xT = x.tensorsAlongDimension(dimension); long yT = y.tensorsAlongDimension(dimension); - ret = Nd4j.create(op.resultType(), new long[]{xT, yT}); + ret = Nd4j.create(op.resultType(), xT, yT); } else { if (y != null) { @@ -354,7 +354,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { * and the associated offsets for each {@link INDArray#tensorAlongDimension(int, int...)} * The first item is the shape information. The second one is the offsets. */ - Pair tadBuffers = x.isEmpty() ? Pair.makePair(x.data(), null): tadManager.getTADOnlyShapeInfo(x, dimension); + Pair tadBuffers = x.isEmpty() ? Pair.makePair(x.data(), null): tadManager.getTADOnlyShapeInfo(x, dimension); Pair yTadBuffers = null; /** * Note that we use addresses in libnd4j. @@ -1699,7 +1699,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { int nOut = opContext != null ? opContext.numOutputArguments() : op.numOutputArguments(); log.error("Failed to calculate output shapes for op {}. Attempted to execute with {} inputs, {} outputs, " + "{} targs, {} iargs, {} bargs and {} dargs. {} - Please see above message (printed out from c++) for a possible cause of error.", - op.opName(), nIn, nOut, nTArgs, nIArgs, nBArgs, nDArgs, sb.toString()); + op.opName(), nIn, nOut, nTArgs, nIArgs, nBArgs, nDArgs, sb); throw t; } @@ -1898,7 +1898,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { if (Nd4jCpu.Environment.getInstance().isUseMKLDNN()) { val opName = op.opName(); val state = mklOverrides.get(op); - if (state != null && state == true) { + if (state != null && state) { mklOverride = true; Nd4jCpu.Environment.getInstance().setUseMKLDNN(true); } @@ -1972,7 +1972,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { nT + " targs," + nB + " bargs and " + nI + " iargs. " + - sb.toString() + + sb + " - Please see above message (printed out from c++) for a possible cause of error."); throw e; } finally { diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java index 91f29101d..671c39a93 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/workspace/CpuWorkspaceDeallocator.java @@ -36,10 +36,10 @@ import java.util.Queue; @Slf4j public class CpuWorkspaceDeallocator implements Deallocator { - private PointersPair pointersPair; - private Queue pinnedPointers; - private List externalPointers; - private LocationPolicy location; + private final PointersPair pointersPair; + private final Queue pinnedPointers; + private final List externalPointers; + private final LocationPolicy location; private Pair mmapInfo; public CpuWorkspaceDeallocator(@NonNull CpuWorkspace workspace) { diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/concurrency/DeviceAllocationsTracker.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/concurrency/DeviceAllocationsTracker.java index 907fd8103..870989777 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/concurrency/DeviceAllocationsTracker.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/concurrency/DeviceAllocationsTracker.java @@ -38,7 +38,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; */ @Slf4j public class DeviceAllocationsTracker { - private Configuration configuration; + private final Configuration configuration; private final ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/concurrency/RRWLock.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/concurrency/RRWLock.java index 35aa72207..0c298bc83 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/concurrency/RRWLock.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/concurrency/RRWLock.java @@ -32,10 +32,10 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; * @author raver119@gmail.com */ public class RRWLock implements Lock { - private ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock(); - private ReentrantReadWriteLock externalsLock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock externalsLock = new ReentrantReadWriteLock(); - private Map objectLocks = new ConcurrentHashMap<>(); + private final Map objectLocks = new ConcurrentHashMap<>(); /** diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java index e07bbf544..97378d94f 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AllocationPoint.java @@ -54,7 +54,7 @@ import java.util.concurrent.atomic.AtomicBoolean; public class AllocationPoint { @Getter - private OpaqueDataBuffer ptrDataBuffer; + private final OpaqueDataBuffer ptrDataBuffer; @Getter @Setter @@ -75,14 +75,14 @@ public class AllocationPoint { // thread safety is guaranteed by allocLock private AllocationStatus allocationStatus = AllocationStatus.UNDEFINED; - private transient TimeProvider timeProvider = new OperativeProvider(); + private final transient TimeProvider timeProvider = new OperativeProvider(); // corresponding access times in TimeProvider quants - private long accessHostRead = 0L; + private final long accessHostRead = 0L; private long accessDeviceRead = 0L; - private long accessHostWrite = 0L; - private long accessDeviceWrite = 0L; + private final long accessHostWrite = 0L; + private final long accessDeviceWrite = 0L; protected static final NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); /* @@ -103,7 +103,7 @@ public class AllocationPoint { */ private volatile int deviceId; - private long bytes; + private final long bytes; public AllocationPoint(@NonNull OpaqueDataBuffer opaqueDataBuffer, long bytes) { ptrDataBuffer = opaqueDataBuffer; @@ -124,7 +124,7 @@ public class AllocationPoint { NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetDeviceId(ptrDataBuffer, deviceId); } - private AtomicBoolean enqueued = new AtomicBoolean(false); + private final AtomicBoolean enqueued = new AtomicBoolean(false); @Getter @Setter diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index 21ba561f8..e0102fcac 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -89,18 +89,18 @@ public class AtomicAllocator implements Allocator { @Getter private transient MemoryHandler memoryHandler; - private AtomicLong allocationsCounter = new AtomicLong(0); + private final AtomicLong allocationsCounter = new AtomicLong(0); - private AtomicLong objectsTracker = new AtomicLong(0); + private final AtomicLong objectsTracker = new AtomicLong(0); // we have single tracking point for allocation points, since we're not going to cycle through it any time soon - private Map allocationsMap = new ConcurrentHashMap<>(); + private final Map allocationsMap = new ConcurrentHashMap<>(); /* locks for internal resources */ - private ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock(); - private ReentrantReadWriteLock externalsLock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock globalLock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock externalsLock = new ReentrantReadWriteLock(); /* here we have handles for garbage collector threads @@ -118,8 +118,8 @@ public class AtomicAllocator implements Allocator { private final Ring zeroLong = new LockedRing(30); private final Ring zeroShort = new LockedRing(30); - private ConstantHandler constantHandler = Nd4j.getConstantHandler(); - private AtomicLong useTracker = new AtomicLong(System.currentTimeMillis()); + private final ConstantHandler constantHandler = Nd4j.getConstantHandler(); + private final AtomicLong useTracker = new AtomicLong(System.currentTimeMillis()); public static AtomicAllocator getInstance() { if (INSTANCE == null) @@ -139,7 +139,7 @@ public class AtomicAllocator implements Allocator { /*initDeviceCollectors(); initHostCollectors();*/ - this.protector = ConstantProtector.getInstance(); + protector = ConstantProtector.getInstance(); } @@ -607,7 +607,7 @@ public class AtomicAllocator implements Allocator { //elementsDropped.incrementAndGet(); //continue; throw new UnsupportedOperationException("Pew-pew"); - } ; + } } else { elementsSurvived.incrementAndGet(); } @@ -777,7 +777,7 @@ public class AtomicAllocator implements Allocator { if (memoryHandler.getAllocatedHostMemory() < (configuration.getMaximumZeroAllocation() * 0.25) && (memoryHandler.getAllocatedHostObjects(bucketId) < 5000) && lastCheck > System.currentTimeMillis() - 30000) { - ; // i don't want deallocation to be fired on lower thresholds. just no sense locking stuff + // i don't want deallocation to be fired on lower thresholds. just no sense locking stuff //log.debug("Skipping zero GC round: ["+zeroUseCounter.get()+"/" +zeroAllocations.get(threadId).size() + "]"); } else { seekUnusedZero(bucketId, aggressiveness); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java index 72848face..81a411f85 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/CudaDeallocator.java @@ -28,7 +28,7 @@ import org.nd4j.nativeblas.OpaqueDataBuffer; @Slf4j public class CudaDeallocator implements Deallocator { - private OpaqueDataBuffer opaqueDataBuffer; + private final OpaqueDataBuffer opaqueDataBuffer; public CudaDeallocator(@NonNull BaseCudaDataBuffer buffer) { opaqueDataBuffer = buffer.getOpaqueDataBuffer(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java index 4db458499..bacf78bbc 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/MemoryTracker.java @@ -30,13 +30,13 @@ import java.util.concurrent.atomic.AtomicLong; @Slf4j public class MemoryTracker { - private List allocatedPerDevice = new ArrayList<>(); - private List cachedPerDevice = new ArrayList<>(); - private List totalPerDevice = new ArrayList<>(); - private List freePerDevice = new ArrayList<>(); - private List workspacesPerDevice = new ArrayList<>(); - private AtomicLong cachedHost = new AtomicLong(0); - private AtomicLong allocatedHost = new AtomicLong(0); + private final List allocatedPerDevice = new ArrayList<>(); + private final List cachedPerDevice = new ArrayList<>(); + private final List totalPerDevice = new ArrayList<>(); + private final List freePerDevice = new ArrayList<>(); + private final List workspacesPerDevice = new ArrayList<>(); + private final AtomicLong cachedHost = new AtomicLong(0); + private final AtomicLong allocatedHost = new AtomicLong(0); private final static MemoryTracker INSTANCE = new MemoryTracker(); public MemoryTracker() { diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/NestedPoint.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/NestedPoint.java index 118c3628b..208e87ae9 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/NestedPoint.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/NestedPoint.java @@ -29,6 +29,7 @@ import org.nd4j.jita.allocator.enums.AllocationStatus; import org.nd4j.jita.allocator.time.RateTimer; import org.nd4j.jita.allocator.time.impl.BinaryTimer; +import java.util.Objects; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -47,9 +48,9 @@ public class NestedPoint { private AtomicState accessState; private AtomicLong accessTime; @Getter - private RateTimer timerShort = new BinaryTimer(10, TimeUnit.SECONDS); + private final RateTimer timerShort = new BinaryTimer(10, TimeUnit.SECONDS); @Getter - private RateTimer timerLong = new BinaryTimer(60, TimeUnit.SECONDS); + private final RateTimer timerLong = new BinaryTimer(60, TimeUnit.SECONDS); // by default memory is UNDEFINED, and depends on parent memory chunk for now @@ -57,7 +58,7 @@ public class NestedPoint { @Setter private AllocationStatus nestedStatus = AllocationStatus.UNDEFINED; - private AtomicLong counter = new AtomicLong(0); + private final AtomicLong counter = new AtomicLong(0); public NestedPoint(@NonNull AllocationShape shape) { this.shape = shape; @@ -94,7 +95,7 @@ public class NestedPoint { NestedPoint that = (NestedPoint) o; - return shape != null ? shape.equals(that.shape) : that.shape == null; + return Objects.equals(shape, that.shape); } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/tad/DeviceTADManager.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/tad/DeviceTADManager.java index 133a2044e..cd0efcad8 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/tad/DeviceTADManager.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/tad/DeviceTADManager.java @@ -41,7 +41,7 @@ import java.util.concurrent.Semaphore; @Slf4j public class DeviceTADManager extends BasicTADManager { protected List>> tadCache = new ArrayList<>(); - private Semaphore lock = new Semaphore(1); + private final Semaphore lock = new Semaphore(1); public DeviceTADManager() { int numDevices = Nd4j.getAffinityManager().getNumberOfDevices(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/impl/BinaryTimer.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/impl/BinaryTimer.java index f71b46bf8..281135cb7 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/impl/BinaryTimer.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/impl/BinaryTimer.java @@ -32,8 +32,8 @@ import java.util.concurrent.atomic.AtomicLong; * @author raver119@gmail.com */ public class BinaryTimer implements RateTimer { - private AtomicLong timer; - private long timeframeMilliseconds; + private final AtomicLong timer; + private final long timeframeMilliseconds; public BinaryTimer(long timeframe, TimeUnit timeUnit) { timer = new AtomicLong(System.currentTimeMillis()); @@ -80,10 +80,6 @@ public class BinaryTimer implements RateTimer { protected boolean isAlive() { long currentTime = System.currentTimeMillis(); - if (currentTime - timer.get() > timeframeMilliseconds) { - return false; - } - - return true; + return currentTime - timer.get() <= timeframeMilliseconds; } } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/impl/SimpleTimer.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/impl/SimpleTimer.java index 1089ffdff..b89aec51d 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/impl/SimpleTimer.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/impl/SimpleTimer.java @@ -91,7 +91,6 @@ public class SimpleTimer implements RateTimer { buckets[x] = 0; } else { // do nothing here probably - ; } } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/providers/OperativeProvider.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/providers/OperativeProvider.java index 001796fae..e6dacf8c7 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/providers/OperativeProvider.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/providers/OperativeProvider.java @@ -28,7 +28,7 @@ import java.util.concurrent.atomic.AtomicLong; * @author raver119@gmail.com */ public class OperativeProvider implements TimeProvider { - private AtomicLong time = new AtomicLong(0); + private final AtomicLong time = new AtomicLong(0); /** diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/rings/LockedRing.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/rings/LockedRing.java index 86e6195f9..bf0f5efa5 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/rings/LockedRing.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/time/rings/LockedRing.java @@ -33,7 +33,7 @@ public class LockedRing implements Ring { private final float[] ring; private final AtomicInteger position = new AtomicInteger(0); - private ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(); /** * Builds new BasicRing with specified number of elements stored diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/utils/AllocationUtils.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/utils/AllocationUtils.java index e8f137506..afbb06e46 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/utils/AllocationUtils.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/utils/AllocationUtils.java @@ -92,7 +92,7 @@ public class AllocationUtils { public static DataBuffer getPointersBuffer(long[] pointers) { CudaDoubleDataBuffer tempX = new CudaDoubleDataBuffer(pointers.length); - AtomicAllocator.getInstance().memcpyBlocking(tempX, new LongPointer(pointers), pointers.length * 8, 0); + AtomicAllocator.getInstance().memcpyBlocking(tempX, new LongPointer(pointers), pointers.length * 8L, 0); return tempX; } } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 5e90a12a6..5e8377e84 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -47,11 +47,11 @@ import java.util.concurrent.atomic.AtomicInteger; */ @Slf4j public class CudaAffinityManager extends BasicAffinityManager { - private Map affinityMap = new ConcurrentHashMap<>(); - private AtomicInteger devPtr = new AtomicInteger(0); - private ThreadLocal affiliated = new ThreadLocal<>(); + private final Map affinityMap = new ConcurrentHashMap<>(); + private final AtomicInteger devPtr = new AtomicInteger(0); + private final ThreadLocal affiliated = new ThreadLocal<>(); - private AtomicInteger numberOfDevices = new AtomicInteger(-1); + private final AtomicInteger numberOfDevices = new AtomicInteger(-1); public CudaAffinityManager() { super(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java index 4e5ee96a8..d1354765f 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java @@ -35,9 +35,9 @@ import java.util.concurrent.atomic.AtomicLong; */ @Deprecated public class EventsProvider { - private List> queue = new ArrayList<>(); - private AtomicLong newCounter = new AtomicLong(0); - private AtomicLong cacheCounter = new AtomicLong(0); + private final List> queue = new ArrayList<>(); + private final AtomicLong newCounter = new AtomicLong(0); + private final AtomicLong cacheCounter = new AtomicLong(0); public EventsProvider() { int numDev = Nd4j.getAffinityManager().getNumberOfDevices(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/conf/Configuration.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/conf/Configuration.java index f25c90698..b20c23d4c 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/conf/Configuration.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/conf/Configuration.java @@ -89,13 +89,13 @@ public class Configuration implements Serializable { * Minimal number of activations for relocation threshold */ @Getter - private int minimumRelocationThreshold = 5; + private final int minimumRelocationThreshold = 5; /** * Minimal guaranteed TTL for memory chunk */ @Getter - private long minimumTTLMilliseconds = 10 * 1000L; + private final long minimumTTLMilliseconds = 10 * 1000L; /** * Number of buckets/garbage collectors for host memory @@ -108,18 +108,18 @@ public class Configuration implements Serializable { */ @Deprecated @Getter - private Aggressiveness hostDeallocAggressiveness = Aggressiveness.REASONABLE; + private final Aggressiveness hostDeallocAggressiveness = Aggressiveness.REASONABLE; @Deprecated @Getter - private Aggressiveness gpuDeallocAggressiveness = Aggressiveness.REASONABLE; + private final Aggressiveness gpuDeallocAggressiveness = Aggressiveness.REASONABLE; /** * Allocation aggressiveness */ @Deprecated @Getter - private Aggressiveness gpuAllocAggressiveness = Aggressiveness.REASONABLE; + private final Aggressiveness gpuAllocAggressiveness = Aggressiveness.REASONABLE; /** @@ -157,10 +157,10 @@ public class Configuration implements Serializable { private long maximumSingleDeviceAllocation = 1024 * 1024 * 1024L; @Getter - private List availableDevices = new ArrayList<>(); + private final List availableDevices = new ArrayList<>(); @Getter - private List bannedDevices = new ArrayList<>(); + private final List bannedDevices = new ArrayList<>(); @Getter private int maximumGridSize = 4096; diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/conf/CudaEnvironment.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/conf/CudaEnvironment.java index 69f600cd5..802e631cb 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/conf/CudaEnvironment.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/conf/CudaEnvironment.java @@ -37,7 +37,7 @@ import java.util.concurrent.ConcurrentHashMap; public class CudaEnvironment { private static final CudaEnvironment INSTANCE = new CudaEnvironment(); private static volatile Configuration configuration; - private static Map arch = new ConcurrentHashMap<>(); + private static final Map arch = new ConcurrentHashMap<>(); private CudaEnvironment() { configuration = new Configuration(); @@ -67,7 +67,7 @@ public class CudaEnvironment { if (!arch.containsKey(deviceId)) { int major = NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMajor(deviceId); int minor = NativeOpsHolder.getInstance().getDeviceNativeOps().getDeviceMinor(deviceId); - Integer cc = Integer.parseInt(new String("" + major + minor)); + Integer cc = Integer.parseInt("" + major + minor); arch.put(deviceId, cc); return cc; } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ConstantProtector.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ConstantProtector.java index 635b4d3dd..2a36e3f6d 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ConstantProtector.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ConstantProtector.java @@ -38,13 +38,13 @@ import java.util.concurrent.CopyOnWriteArrayList; * @author raver119@gmail.com */ public class ConstantProtector { - private static ConstantProtector ourInstance = new ConstantProtector(); + private static final ConstantProtector ourInstance = new ConstantProtector(); public static ConstantProtector getInstance() { return ourInstance; } - private List protectorLegacy = new CopyOnWriteArrayList<>(); + private final List protectorLegacy = new CopyOnWriteArrayList<>(); private List> protector = new CopyOnWriteArrayList<>(); private List>> deviceCache = new ArrayList<>(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java index 239fa6a8e..c43171f3d 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ProtectedCudaConstantHandler.java @@ -55,7 +55,7 @@ import java.util.concurrent.atomic.AtomicLong; */ @Slf4j public class ProtectedCudaConstantHandler implements ConstantHandler { - private static ProtectedCudaConstantHandler ourInstance = new ProtectedCudaConstantHandler(); + private static final ProtectedCudaConstantHandler ourInstance = new ProtectedCudaConstantHandler(); protected Map constantOffsets = new HashMap<>(); protected Map deviceLocks = new ConcurrentHashMap<>(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ProtectedCudaShapeInfoProvider.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ProtectedCudaShapeInfoProvider.java index e225e68c8..a1eac1e5f 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ProtectedCudaShapeInfoProvider.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/constant/ProtectedCudaShapeInfoProvider.java @@ -41,14 +41,14 @@ public class ProtectedCudaShapeInfoProvider extends BaseShapeInfoProvider { private AtomicAllocator allocator; - private AtomicLong cacheHit = new AtomicLong(1); - private AtomicLong cacheMiss = new AtomicLong(1); + private final AtomicLong cacheHit = new AtomicLong(1); + private final AtomicLong cacheMiss = new AtomicLong(1); - private Semaphore lock = new Semaphore(1); + private final Semaphore lock = new Semaphore(1); protected static final ConstantProtector protector = ConstantProtector.getInstance(); - private static ProtectedCudaShapeInfoProvider ourInstance = new ProtectedCudaShapeInfoProvider(); + private static final ProtectedCudaShapeInfoProvider ourInstance = new ProtectedCudaShapeInfoProvider(); private ProtectedCudaShapeInfoProvider() { diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index abc3aa5f0..a36f1d4c1 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -92,9 +92,9 @@ public class CudaZeroHandler implements MemoryHandler { // tracker for thread->device affinity protected Map devicesAffinity = new ConcurrentHashMap<>(); - private ReentrantReadWriteLock deviceLock = new ReentrantReadWriteLock(); + private final ReentrantReadWriteLock deviceLock = new ReentrantReadWriteLock(); - private AtomicInteger devPtr = new AtomicInteger(0); + private final AtomicInteger devPtr = new AtomicInteger(0); private final AtomicBoolean wasInitialised = new AtomicBoolean(false); @@ -127,7 +127,7 @@ public class CudaZeroHandler implements MemoryHandler { private final Map> zeroAllocations = new ConcurrentHashMap<>(); - private AtomicLong zeroCounter = new AtomicLong(0); + private final AtomicLong zeroCounter = new AtomicLong(0); protected NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); @@ -137,13 +137,10 @@ public class CudaZeroHandler implements MemoryHandler { this.INITIAL_LOCATION = configuration.getFirstMemory(); - switch (configuration.getExecutionModel()) { - case SEQUENTIAL: { - this.flowController = new GridFlowController(); - } - break; - default: - throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]"); + if (configuration.getExecutionModel() == Configuration.ExecutionModel.SEQUENTIAL) { + this.flowController = new GridFlowController(); + } else { + throw new RuntimeException("Unknown ExecutionModel: [" + configuration.getExecutionModel() + "]"); } int numDevices = NativeOpsHolder.getInstance().getDeviceNativeOps().getAvailableDevices(); @@ -165,9 +162,9 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void init(@NonNull Configuration configuration, @NonNull Allocator allocator) { - this.configuration = configuration; + CudaZeroHandler.configuration = configuration; - this.deviceMemoryTracker = new DeviceAllocationsTracker(this.configuration); + this.deviceMemoryTracker = new DeviceAllocationsTracker(CudaZeroHandler.configuration); this.flowController.init(allocator); } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java index 62a02fd12..39dad7bd4 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java @@ -391,8 +391,8 @@ public class CudaWorkspace extends Nd4jWorkspace { @Override protected void resetWorkspace() { - if (currentSize.get() < 1) - return; + if (currentSize.get() < 1) { + } /* diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspaceDeallocator.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspaceDeallocator.java index 4c7b15450..806986fc7 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspaceDeallocator.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspaceDeallocator.java @@ -36,9 +36,9 @@ import java.util.Queue; */ @Slf4j public class CudaWorkspaceDeallocator implements Deallocator { - private PointersPair pointersPair; - private Queue pinnedPointers; - private List externalPointers; + private final PointersPair pointersPair; + private final Queue pinnedPointers; + private final List externalPointers; public CudaWorkspaceDeallocator(@NonNull CudaWorkspace workspace) { this.pointersPair = workspace.workspace(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/CublasPointer.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/CublasPointer.java index 911348239..289e79416 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/CublasPointer.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/CublasPointer.java @@ -44,9 +44,9 @@ public class CublasPointer implements AutoCloseable { private Pointer devicePointer; private Pointer hostPointer; @Getter - private boolean closed = false; + private final boolean closed = false; private INDArray arr; - private CudaContext cudaContext; + private final CudaContext cudaContext; private boolean resultPointer = false; @@ -161,9 +161,7 @@ public class CublasPointer implements AutoCloseable { @Override public String toString() { - StringBuffer sb = new StringBuffer(); - sb.append("NativePointer: [" + devicePointer.address() + "]"); - return sb.toString(); + return "NativePointer: [" + devicePointer.address() + "]"; } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java index 88c08ceaa..95339aa30 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/JCublasNDArrayFactory.java @@ -523,14 +523,14 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { context.getOldStream(), allocator.getDeviceIdPointer()); val tempIndexes = new CudaLongDataBuffer(indexes.length); - AtomicAllocator.getInstance().memcpyBlocking(tempIndexes, new LongPointer(ArrayUtil.toLongArray(indexes)), indexes.length * 8, 0); + AtomicAllocator.getInstance().memcpyBlocking(tempIndexes, new LongPointer(ArrayUtil.toLongArray(indexes)), indexes.length * 8L, 0); Pointer pIndex = AtomicAllocator.getInstance().getPointer(tempIndexes, context); TADManager tadManager = Nd4j.getExecutioner().getTADManager(); - Pair tadBuffers = tadManager.getTADOnlyShapeInfo(source, new int[]{sourceDimension}); - Pair zTadBuffers = tadManager.getTADOnlyShapeInfo(ret, new int[]{sourceDimension}); + Pair tadBuffers = tadManager.getTADOnlyShapeInfo(source, sourceDimension); + Pair zTadBuffers = tadManager.getTADOnlyShapeInfo(ret, sourceDimension); Pointer tadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); Pointer zTadShapeInfo = AtomicAllocator.getInstance().getPointer(zTadBuffers.getFirst(), context); @@ -598,7 +598,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { CudaDoubleDataBuffer tempX = new CudaDoubleDataBuffer(arrays.length); - allocator.memcpyBlocking(tempX, new LongPointer(xPointers), xPointers.length * 8, 0); + allocator.memcpyBlocking(tempX, new LongPointer(xPointers), xPointers.length * 8L, 0); PointerPointer x = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)); @@ -707,7 +707,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { CudaDoubleDataBuffer tempX = new CudaDoubleDataBuffer(arrays.length); - allocator.memcpyBlocking(tempX, new LongPointer(xPointers), xPointers.length * 8, 0); + allocator.memcpyBlocking(tempX, new LongPointer(xPointers), xPointers.length * 8L, 0); PointerPointer x = new PointerPointer(AtomicAllocator.getInstance().getPointer(tempX, context)); @@ -930,10 +930,10 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { val tempTAD = new CudaDoubleDataBuffer(arrays.size()); val tempOffsets = new CudaDoubleDataBuffer(arrays.size()); - AtomicAllocator.getInstance().memcpyBlocking(tempX, new LongPointer(xPointers), xPointers.length * 8, 0); - AtomicAllocator.getInstance().memcpyBlocking(tempShapes, new LongPointer(xShapes), xPointers.length * 8, 0); - AtomicAllocator.getInstance().memcpyBlocking(tempTAD, new LongPointer(tadShapes), xPointers.length * 8, 0); - AtomicAllocator.getInstance().memcpyBlocking(tempOffsets, new LongPointer(tadOffsets), xPointers.length * 8, 0); + AtomicAllocator.getInstance().memcpyBlocking(tempX, new LongPointer(xPointers), xPointers.length * 8L, 0); + AtomicAllocator.getInstance().memcpyBlocking(tempShapes, new LongPointer(xShapes), xPointers.length * 8L, 0); + AtomicAllocator.getInstance().memcpyBlocking(tempTAD, new LongPointer(tadShapes), xPointers.length * 8L, 0); + AtomicAllocator.getInstance().memcpyBlocking(tempOffsets, new LongPointer(tadOffsets), xPointers.length * 8L, 0); nativeOps.shuffle(extras, null, @@ -1078,10 +1078,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { DataBuffer buffer = convertDataEx(typeSrc, source.data(), typeDst); source.setData(buffer); - if (buffer instanceof CompressedDataBuffer) - source.markAsCompressed(true); - else - source.markAsCompressed(false); + source.markAsCompressed(buffer instanceof CompressedDataBuffer); return source; } @@ -1307,7 +1304,7 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory { CudaDoubleDataBuffer tempX = new CudaDoubleDataBuffer(numTads); - AtomicAllocator.getInstance().memcpyBlocking(tempX, new LongPointer(xPointers), xPointers.length * 8, 0); + AtomicAllocator.getInstance().memcpyBlocking(tempX, new LongPointer(xPointers), xPointers.length * 8L, 0); PointerPointer extraz = new PointerPointer(null, // not used context.getOldStream(), AtomicAllocator.getInstance().getDeviceIdPointer()); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java index 912d2c388..c581f5696 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLapack.java @@ -64,8 +64,8 @@ import static org.bytedeco.cuda.global.cusolver.*; @Slf4j public class JcublasLapack extends BaseLapack { - private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - private Allocator allocator = AtomicAllocator.getInstance(); + private final NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private final Allocator allocator = AtomicAllocator.getInstance(); @Override public void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO) { @@ -109,7 +109,7 @@ public class JcublasLapack extends BaseLapack { int worksize = worksizeBuffer.getInt(0); // Now allocate memory for the workspace, the permutation matrix and a return code - Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + Pointer workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); // Do the actual LU decomp stat = cusolverDnSgetrf(solverDn, M, N, (FloatPointer) xAPointer.getDevicePointer(), M, @@ -176,7 +176,7 @@ public class JcublasLapack extends BaseLapack { int worksize = worksizeBuffer.getInt(0); // Now allocate memory for the workspace, the permutation matrix and a return code - val workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + val workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); // Do the actual LU decomp stat = cusolverDnDgetrf(solverDn, M, N, (DoublePointer) xAPointer.getDevicePointer(), M, @@ -250,7 +250,7 @@ public class JcublasLapack extends BaseLapack { } int worksize = worksizeBuffer.getInt(0); // Now allocate memory for the workspace, the permutation matrix and a return code - Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + Pointer workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); // Do the actual QR decomp stat = cusolverDnSgeqrf(solverDn, M, N, @@ -275,7 +275,7 @@ public class JcublasLapack extends BaseLapack { if (r != null) { r.assign(a.get(NDArrayIndex.interval(0, a.columns()), NDArrayIndex.all())); - INDArrayIndex ix[] = new INDArrayIndex[2]; + INDArrayIndex[] ix = new INDArrayIndex[2]; for (int i = 1; i < Math.min(a.rows(), a.columns()); i++) { ix[0] = NDArrayIndex.point(i); ix[1] = NDArrayIndex.interval(0, i); @@ -289,7 +289,7 @@ public class JcublasLapack extends BaseLapack { (IntPointer) worksizeBuffer.addressPointer() ); worksize = worksizeBuffer.getInt(0); - workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); stat = cusolverDnSorgqr(solverDn, M, N, N, (FloatPointer) xAPointer.getDevicePointer(), M, @@ -365,7 +365,7 @@ public class JcublasLapack extends BaseLapack { } int worksize = worksizeBuffer.getInt(0); // Now allocate memory for the workspace, the permutation matrix and a return code - Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + Pointer workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); // Do the actual QR decomp stat = cusolverDnDgeqrf(solverDn, M, N, @@ -390,7 +390,7 @@ public class JcublasLapack extends BaseLapack { if (r != null) { r.assign(a.get(NDArrayIndex.interval(0, a.columns()), NDArrayIndex.all())); - INDArrayIndex ix[] = new INDArrayIndex[2]; + INDArrayIndex[] ix = new INDArrayIndex[2]; for (int i = 1; i < Math.min(a.rows(), a.columns()); i++) { ix[0] = NDArrayIndex.point(i); ix[1] = NDArrayIndex.interval(0, i); @@ -403,7 +403,7 @@ public class JcublasLapack extends BaseLapack { (IntPointer) worksizeBuffer.addressPointer() ); worksize = worksizeBuffer.getInt(0); - workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); stat = cusolverDnDorgqr(solverDn, M, N, N, (DoublePointer) xAPointer.getDevicePointer(), M, @@ -476,7 +476,7 @@ public class JcublasLapack extends BaseLapack { int worksize = worksizeBuffer.getInt(0); // Now allocate memory for the workspace, the permutation matrix and a return code - Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + Pointer workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); // Do the actual decomp stat = cusolverDnSpotrf(solverDn, uplo, N, @@ -498,14 +498,14 @@ public class JcublasLapack extends BaseLapack { if (uplo == cublas.CUBLAS_FILL_MODE_UPPER ) { A.assign(A.transpose()); - INDArrayIndex ix[] = new INDArrayIndex[2]; + INDArrayIndex[] ix = new INDArrayIndex[2]; for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) { ix[0] = NDArrayIndex.point(i); ix[1] = NDArrayIndex.interval(0, i); A.put(ix, 0); } } else { - INDArrayIndex ix[] = new INDArrayIndex[2]; + INDArrayIndex[] ix = new INDArrayIndex[2]; for (int i = 0; i < Math.min(A.rows(), A.columns() - 1); i++) { ix[0] = NDArrayIndex.point(i); ix[1] = NDArrayIndex.interval(i + 1, A.columns()); @@ -562,7 +562,7 @@ public class JcublasLapack extends BaseLapack { int worksize = worksizeBuffer.getInt(0); // Now allocate memory for the workspace, the permutation matrix and a return code - Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType(DataType.DOUBLE)); + Pointer workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType(DataType.DOUBLE)); // Do the actual decomp stat = cusolverDnDpotrf(solverDn, uplo, N, @@ -584,14 +584,14 @@ public class JcublasLapack extends BaseLapack { if (uplo == cublas.CUBLAS_FILL_MODE_UPPER ) { A.assign(A.transpose()); - INDArrayIndex ix[] = new INDArrayIndex[2]; + INDArrayIndex[] ix = new INDArrayIndex[2]; for (int i = 1; i < Math.min(A.rows(), A.columns()); i++) { ix[0] = NDArrayIndex.point(i); ix[1] = NDArrayIndex.interval(0, i); A.put(ix, 0); } } else { - INDArrayIndex ix[] = new INDArrayIndex[2]; + INDArrayIndex[] ix = new INDArrayIndex[2]; for (int i = 0; i < Math.min(A.rows(), A.columns() - 1); i++) { ix[0] = NDArrayIndex.point(i); ix[1] = NDArrayIndex.interval(i + 1, A.columns()); @@ -691,7 +691,7 @@ public class JcublasLapack extends BaseLapack { } int worksize = worksizeBuffer.getInt(0); - Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + Pointer workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); DataBuffer rwork = Nd4j.getDataBufferFactory().createFloat((M < N ? M : N) - 1); // Do the actual decomp @@ -803,7 +803,7 @@ public class JcublasLapack extends BaseLapack { int worksize = worksizeBuffer.getInt(0); // Now allocate memory for the workspace, the non-converging row buffer and a return code - Pointer workspace = new Workspace(worksize * Nd4j.sizeOfDataType()); + Pointer workspace = new Workspace((long) worksize * Nd4j.sizeOfDataType()); DataBuffer rwork = Nd4j.getDataBufferFactory().createDouble((M < N ? M : N) - 1); // Do the actual decomp @@ -858,7 +858,7 @@ public class JcublasLapack extends BaseLapack { if (A.rows() > Integer.MAX_VALUE) { throw new RuntimeException("Rows overflow"); } - int M = (int) A.rows(); + int M = A.rows(); if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); @@ -892,7 +892,7 @@ public class JcublasLapack extends BaseLapack { int worksize = worksizeBuffer.getInt(0); // allocate memory for the workspace, the non-converging row buffer and a return code - val workspace = new Workspace(worksize * 4); //4 = float width + val workspace = new Workspace(worksize * 4L); //4 = float width INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, A.dataType())); @@ -936,7 +936,7 @@ public class JcublasLapack extends BaseLapack { throw new RuntimeException("Rows overflow"); } - int M = (int) A.rows(); + int M = A.rows(); if (Nd4j.getExecutioner() instanceof GridExecutioner) ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); @@ -970,7 +970,7 @@ public class JcublasLapack extends BaseLapack { int worksize = worksizeBuffer.getInt(0); // allocate memory for the workspace, the non-converging row buffer and a return code - Pointer workspace = new Workspace(worksize * 8); //8 = double width + Pointer workspace = new Workspace(worksize * 8L); //8 = double width INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1), Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1, 1}, A.dataType())); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java index e20a9f1d4..5ace8e798 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel1.java @@ -59,9 +59,9 @@ import static org.bytedeco.cuda.global.cublas.*; */ @Slf4j public class JcublasLevel1 extends BaseLevel1 { - private Allocator allocator = AtomicAllocator.getInstance(); - private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); - private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private final Allocator allocator = AtomicAllocator.getInstance(); + private final Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); + private final NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); @Override protected float sdsdot(long N, float alpha, INDArray X, int incX, INDArray Y, int incY) { @@ -403,7 +403,7 @@ public class JcublasLevel1 extends BaseLevel1 { // cublasHandle_t handle = ctx.getCublasHandle(); - ((CudaExecutioner) Nd4j.getExecutioner()).exec(new Axpy(X, Y, Y, alpha)); + Nd4j.getExecutioner().exec(new Axpy(X, Y, Y, alpha)); OpExecutionerUtil.checkForAny(Y); } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java index ef6a5a567..a2fa256a6 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel2.java @@ -49,9 +49,9 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.convertTranspose; */ @Slf4j public class JcublasLevel2 extends BaseLevel2 { - private Allocator allocator = AtomicAllocator.getInstance(); - private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); - private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private final Allocator allocator = AtomicAllocator.getInstance(); + private final Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); + private final NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); @Override protected void sgemv(char order, char TransA, int M, int N, float alpha, INDArray A, int lda, INDArray X, int incX, diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java index 69338ed7b..1543bee72 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/blas/JcublasLevel3.java @@ -75,9 +75,9 @@ import static org.nd4j.linalg.jcublas.blas.CudaBlas.convertUplo; */ @Slf4j public class JcublasLevel3 extends BaseLevel3 { - private Allocator allocator = AtomicAllocator.getInstance(); - private Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); - private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private final Allocator allocator = AtomicAllocator.getInstance(); + private final Nd4jBlas nd4jBlas = (Nd4jBlas) Nd4j.factory().blas(); + private final NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); @Override protected void hgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda, @@ -114,9 +114,9 @@ public class JcublasLevel3 extends BaseLevel3 { // CUDA_R_16F == 2 for CUDA 8 // CUBLAS_DATA_HALF == 2 for CUDA 7.5 cublasSgemmEx(new cublasContext(handle), convertTranspose(TransA), convertTranspose(TransB), M, N, K, - new FloatPointer(alpha), (ShortPointer) cAPointer.getDevicePointer(), 2, lda, - (ShortPointer) cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta), - (ShortPointer) cCPointer.getDevicePointer(), 2, ldc); + new FloatPointer(alpha), cAPointer.getDevicePointer(), 2, lda, + cBPointer.getDevicePointer(), 2, ldb, new FloatPointer(beta), + cCPointer.getDevicePointer(), 2, ldc); } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 3f97c6818..6b4793704 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -83,7 +83,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Getter protected transient volatile AllocationPoint allocationPoint; - private static AtomicAllocator allocator = AtomicAllocator.getInstance(); + private static final AtomicAllocator allocator = AtomicAllocator.getInstance(); @@ -1366,10 +1366,7 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda public boolean equals(Object o) { if (o == null) return false; - if (this == o) - return true; - - return false; + return this == o; } @Override diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java index 096646312..f5908afba 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/CudaUtf8Buffer.java @@ -35,6 +35,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import java.io.UnsupportedEncodingException; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collection; @@ -190,11 +191,7 @@ public class CudaUtf8Buffer extends BaseCudaDataBuffer { bytes[e] = dataPointer.get(idx); } - try { - return new String(bytes, "UTF-8"); - } catch (UnsupportedEncodingException e) { - throw new RuntimeException(e); - } + return new String(bytes, StandardCharsets.UTF_8); } @Override @@ -219,7 +216,7 @@ public class CudaUtf8Buffer extends BaseCudaDataBuffer { private static long stringBufferRequiredLength(@NonNull Collection strings) { // header size first - long size = (strings.size() + 1) * 8; + long size = (strings.size() + 1) * 8L; for (val s:strings) size += s.length(); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java index 7be9e08c1..b7cb87275 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/context/CudaContext.java @@ -67,7 +67,7 @@ public class CudaContext { @Builder.Default private int deviceId = -1; - private transient final static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private final static NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); @Override public String toString() { diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 2cc5077e4..612dfdda8 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -477,7 +477,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val yT = op.y().tensorsAlongDimension(dimension); // we intentionally want to set it to 0.0 - ret = Nd4j.createUninitialized(dtype, new long[] {xT, yT}); + ret = Nd4j.createUninitialized(dtype, xT, yT); } else { if (op.y() != null) { //2 options here: either pairwise, equal sizes - OR every X TAD vs. entirety of Y @@ -823,7 +823,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); Pointer hostZShapeInfo = z == null ? null : AddressRetriever.retrieveHostPointer(z.shapeInfoDataBuffer()); - int fdimension[] = dimension; + int[] fdimension = dimension; if (fdimension == null) fdimension = new int[] {0}; @@ -940,7 +940,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); DataBuffer offsets = x.isEmpty() ? null : tadBuffers.getSecond(); - Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer((DataBuffer) offsets, context); + Pointer devTadOffsets = offsets == null ? null : AtomicAllocator.getInstance().getPointer(offsets, context); Pointer xShapeInfo = AtomicAllocator.getInstance().getPointer(x.shapeInfoDataBuffer(), context); @@ -1337,7 +1337,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { Pointer dimensionHostPointer = null; Pointer retPointer = null; Pointer retHostShape = null; - int dimension[] = null; + int[] dimension = null; Pointer hostXShapeInfo = x == null ? null : AddressRetriever.retrieveHostPointer(x.shapeInfoDataBuffer()); Pointer hostYShapeInfo = y == null ? null : AddressRetriever.retrieveHostPointer(y.shapeInfoDataBuffer()); @@ -1742,7 +1742,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { return Collections.emptyList(); } - val inputBuffers = new PointerPointer<>(nIn * 2); + val inputBuffers = new PointerPointer<>(nIn * 2L); val inputShapes = new PointerPointer<>(nIn); val inputArgs = opContext != null ? opContext.getInputArrays() : op.inputArguments(); @@ -1934,8 +1934,8 @@ public class CudaExecutioner extends DefaultOpExecutioner { Nd4j.getExecutioner().commit(); - val ptrBuffers = new PointerPointer(map.size() * 2); - val ptrShapes = new PointerPointer(map.size() * 2); + val ptrBuffers = new PointerPointer(map.size() * 2L); + val ptrShapes = new PointerPointer(map.size() * 2L); val ptrIndices = new IntPointer(map.size()); int cnt = 0; @@ -1980,7 +1980,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val order = Shape.order(jshape); val array = Nd4j.create(shapeOf, stridesOf, 0, order); - Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType()); + Pointer.memcpy(AtomicAllocator.getInstance().getHostPointer(array), buffer, (long) ArrayUtil.prod(shapeOf) * Nd4j.sizeOfDataType()); //AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite(); if (1 > 0) throw new UnsupportedOperationException("Pew-pew"); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java index 2787e7282..6f2343fcd 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaGridExecutioner.java @@ -78,20 +78,20 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio //private List> deviceQueues = new ArrayList<>(); // last op - private ThreadLocal lastOp = new ThreadLocal<>(); + private final ThreadLocal lastOp = new ThreadLocal<>(); // private ThreadLocal extraz = new ThreadLocal<>(); - private ThreadLocal> deviceQueues = new ThreadLocal<>(); + private final ThreadLocal> deviceQueues = new ThreadLocal<>(); - private ThreadLocal opCounter = new ThreadLocal<>(); + private final ThreadLocal opCounter = new ThreadLocal<>(); - private AtomicLong metaCounter = new AtomicLong(0); - private AtomicLong execCounter = new AtomicLong(0); + private final AtomicLong metaCounter = new AtomicLong(0); + private final AtomicLong execCounter = new AtomicLong(0); - private List watchdog = new CopyOnWriteArrayList<>(); + private final List watchdog = new CopyOnWriteArrayList<>(); - private List> aggregates = new ArrayList<>(); + private final List> aggregates = new ArrayList<>(); - private AtomicBoolean experimental = new AtomicBoolean(false); + private final AtomicBoolean experimental = new AtomicBoolean(false); public CudaGridExecutioner() { // extraz.set(new PointerPointer(10)); @@ -125,9 +125,9 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio invokeWatchdog(op); if (op instanceof ReduceOp) { - exec((ReduceOp) op, new int[] {Integer.MAX_VALUE}); + exec((ReduceOp) op, Integer.MAX_VALUE); } else if (op instanceof IndexAccumulation) { - exec((IndexAccumulation) op, new int[] {Integer.MAX_VALUE}); + exec((IndexAccumulation) op, Integer.MAX_VALUE); } else if (op instanceof ScalarOp || op instanceof TransformOp) { // the only entry place for TADless ops processAsGridOp(op); @@ -188,12 +188,8 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio return true; } - if (opX == pointer.address()) { - //logger.error("op.X matched: {}", pointer.address()); - return true; - } - - return false; + //logger.error("op.X matched: {}", pointer.address()); + return opX == pointer.address(); } @@ -207,17 +203,11 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio long opY = op.y() == null ? 0 : AtomicAllocator.getInstance().getHostPointer(op.y()).address(); - if (opZ == pointer.address() || opY == pointer.address() || opX == pointer.address()) - return true; - - return false; + return opZ == pointer.address() || opY == pointer.address() || opX == pointer.address(); } protected boolean compareArrays(INDArray array, Op op) { - if (op.x() == array || op.y() == array || op.z() == array) - return true; - - return false; + return op.x() == array || op.y() == array || op.z() == array; } /** @@ -476,10 +466,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio * @return */ protected boolean isMatchingZX(Op opA, Op opB) { - if (opA.x() == opB.x() && opA.z() == opB.z() && opA.x() == opB.z()) - return true; - - return false; + return opA.x() == opB.x() && opA.z() == opB.z() && opA.x() == opB.z(); } /** @@ -490,10 +477,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio * @return */ protected boolean isMatchingZXY(Op opA, Op opB) { - if (opA.z() == opB.x() || opA.z() == opB.y()) - return true; - - return false; + return opA.z() == opB.x() || opA.z() == opB.y(); } protected GridPointers pointerizeOp(OpDescriptor descriptor) { @@ -694,7 +678,7 @@ public class CudaGridExecutioner extends CudaExecutioner implements GridExecutio // So, that's scalar. We'll have to flush queue flushQueue(); - buildZ(op, new int[] {Integer.MAX_VALUE}); + buildZ(op, Integer.MAX_VALUE); super.invoke(op, null, new int[] {Integer.MAX_VALUE}); } else { buildZ(op, dimension); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index c14b9c7eb..3ba143e36 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -46,8 +46,8 @@ import org.nd4j.nativeblas.OpaqueRandomGenerator; */ public class CudaOpContext extends BaseOpContext implements OpContext, Deallocatable { // we might want to have configurable - private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - private OpaqueContext context = nativeOps.createGraphContext(1); + private final NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + private final OpaqueContext context = nativeOps.createGraphContext(1); private final transient long id = Nd4j.getDeallocatorService().nextValue(); public CudaOpContext() { @@ -92,7 +92,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext, Deallocat args[e] = arguments[e].toInt(); nativeOps.setGraphContextDArguments(context, new IntPointer(args), arguments.length); - }; + } } @Override @@ -161,7 +161,7 @@ public class CudaOpContext extends BaseOpContext implements OpContext, Deallocat @Override public String getUniqueId() { - return new String("CTX_" + id); + return "CTX_" + id; } @Override diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 0a638ff15..1d083f0ce 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -212,7 +212,7 @@ tasks.withType(org.bytedeco.gradle.javacpp.BuildTask) { // Disable the standard javacpp generated tasks and use own // versions below. This allows to build for each variant [javacppBuildParser, javacppBuildCommand, javacppCompileJava, javacppBuildCompiler].each { - it.enabled false; + it.enabled false } chipList.each { thisChip -> diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArrayPublisher.java b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArrayPublisher.java index 9e8d1660f..8a1bfaaee 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArrayPublisher.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronNDArrayPublisher.java @@ -105,7 +105,7 @@ public class AeronNDArrayPublisher implements AutoCloseable { publication = aeron.addPublication(channel, streamId); log.info("Created publication on channel " + channel + " and stream " + streamId); } catch (DriverTimeoutException e) { - Thread.sleep(1000 * (connectionTries + 1)); + Thread.sleep(1000L * (connectionTries + 1)); log.warn("Failed to connect due to driver time out on channel " + channel + " and stream " + streamId + "...retrying in " + connectionTries + " seconds"); connectionTries++; diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronUtil.java b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronUtil.java index 97e83a5aa..be3a29781 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronUtil.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronUtil.java @@ -135,8 +135,8 @@ public class AeronUtil { final byte[] data = new byte[length]; buffer.getBytes(offset, data); - System.out.println(String.format("Message to stream %d from session %d (%d@%d) <<%s>>", streamId, - header.sessionId(), length, offset, new String(data))); + System.out.printf("Message to stream %d from session %d (%d@%d) <<%s>>%n", streamId, + header.sessionId(), length, offset, new String(data)); }; } @@ -165,8 +165,8 @@ public class AeronUtil { */ public static void printRate(final double messagesPerSec, final double bytesPerSec, final long totalMessages, final long totalBytes) { - System.out.println(String.format("%.02g msgs/sec, %.02g bytes/sec, totals %d messages %d MB", messagesPerSec, - bytesPerSec, totalMessages, totalBytes / (1024 * 1024))); + System.out.printf("%.02g msgs/sec, %.02g bytes/sec, totals %d messages %d MB%n", messagesPerSec, + bytesPerSec, totalMessages, totalBytes / (1024 * 1024)); } /** @@ -176,8 +176,8 @@ public class AeronUtil { */ public static void printAvailableImage(final Image image) { final Subscription subscription = image.subscription(); - System.out.println(String.format("Available image on %s streamId=%d sessionId=%d from %s", - subscription.channel(), subscription.streamId(), image.sessionId(), image.sourceIdentity())); + System.out.printf("Available image on %s streamId=%d sessionId=%d from %s%n", + subscription.channel(), subscription.streamId(), image.sessionId(), image.sourceIdentity()); } /** @@ -187,8 +187,8 @@ public class AeronUtil { */ public static void printUnavailableImage(final Image image) { final Subscription subscription = image.subscription(); - System.out.println(String.format("Unavailable image on %s streamId=%d sessionId=%d", subscription.channel(), - subscription.streamId(), image.sessionId())); + System.out.printf("Unavailable image on %s streamId=%d sessionId=%d%n", subscription.channel(), + subscription.streamId(), image.sessionId()); } private static final AtomicInteger conductorCount = new AtomicInteger(); diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayFragmentHandler.java b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayFragmentHandler.java index 59afadde0..0495e4877 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayFragmentHandler.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayFragmentHandler.java @@ -40,8 +40,8 @@ import java.nio.ByteOrder; */ @Slf4j public class NDArrayFragmentHandler implements FragmentHandler { - private NDArrayCallback ndArrayCallback; - private ChunkAccumulator chunkAccumulator = new InMemoryChunkAccumulator(); + private final NDArrayCallback ndArrayCallback; + private final ChunkAccumulator chunkAccumulator = new InMemoryChunkAccumulator(); public NDArrayFragmentHandler(NDArrayCallback ndArrayCallback) { this.ndArrayCallback = ndArrayCallback; diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java index 3658a8006..c73f9c3bb 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java @@ -256,7 +256,7 @@ public class NDArrayMessage implements Serializable { String messageId = UUID.randomUUID().toString(); for (int i = 0; i < ret.length; i++) { //data: only grab a chunk of the data - ByteBuffer view = (ByteBuffer) wholeBuffer.byteBuffer().asReadOnlyBuffer().position(i * chunkSize); + ByteBuffer view = wholeBuffer.byteBuffer().asReadOnlyBuffer().position(i * chunkSize); view.limit(Math.min(i * chunkSize + chunkSize, wholeBuffer.capacity())); view.order(ByteOrder.nativeOrder()); view = view.slice(); diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java index b725d3c04..963413423 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/chunk/InMemoryChunkAccumulator.java @@ -30,7 +30,7 @@ import java.util.Map; @Slf4j public class InMemoryChunkAccumulator implements ChunkAccumulator { - private Map> chunks = Maps.newConcurrentMap(); + private final Map> chunks = Maps.newConcurrentMap(); /** * Returns the number of chunks diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ndarrayholder/InMemoryNDArrayHolder.java b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ndarrayholder/InMemoryNDArrayHolder.java index 20b082819..21ea8a465 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ndarrayholder/InMemoryNDArrayHolder.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ndarrayholder/InMemoryNDArrayHolder.java @@ -36,8 +36,8 @@ import java.util.concurrent.atomic.AtomicReference; @NoArgsConstructor public class InMemoryNDArrayHolder implements NDArrayHolder { - private AtomicReference arr = new AtomicReference<>(); - private AtomicInteger totalUpdates = new AtomicInteger(0); + private final AtomicReference arr = new AtomicReference<>(); + private final AtomicInteger totalUpdates = new AtomicInteger(0); public InMemoryNDArrayHolder(int[] shape) { diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java b/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java index de0f74356..477dd1e1d 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java @@ -41,9 +41,9 @@ import static org.junit.jupiter.api.Assertions.assertFalse; public class LargeNdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; private Aeron.Context ctx; - private String channel = "aeron:udp?endpoint=localhost:" + (40123 + new java.util.Random().nextInt(130)); - private int streamId = 10; - private int length = (int) 1e7; + private final String channel = "aeron:udp?endpoint=localhost:" + (40123 + new java.util.Random().nextInt(130)); + private final int streamId = 10; + private final int length = (int) 1e7; @Override public long getTimeoutMilliseconds() { diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java b/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java index 999e11281..e0c680952 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java @@ -37,11 +37,11 @@ import java.util.concurrent.atomic.AtomicBoolean; @Timeout(120) public class NdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; - private static Logger log = LoggerFactory.getLogger(NdArrayIpcTest.class); + private static final Logger log = LoggerFactory.getLogger(NdArrayIpcTest.class); private Aeron.Context ctx; - private String channel = "aeron:udp?endpoint=localhost:" + (40132 + new java.util.Random().nextInt(3000)); - private int streamId = 10; - private int length = (int) 1e7; + private final String channel = "aeron:udp?endpoint=localhost:" + (40132 + new java.util.Random().nextInt(3000)); + private final int streamId = 10; + private final int length = (int) 1e7; @Override public long getTimeoutMilliseconds() { diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/base/Preconditions.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/base/Preconditions.java index c8bc3966d..c709d29b8 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/base/Preconditions.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/base/Preconditions.java @@ -687,12 +687,12 @@ public final class Preconditions { } else { if(nextCustom < 0 || (nextIdx > 0 && nextIdx < nextCustom)){ //%s tag - sb.append(message.substring(indexOfStart, nextIdx)) + sb.append(message, indexOfStart, nextIdx) .append(formatArg(args[i])); indexOfStart = nextIdx + 2; } else { //Custom tag - sb.append(message.substring(indexOfStart, nextCustom)); + sb.append(message, indexOfStart, nextCustom); String s = FORMATTERS.get(nextCustomTag).format(nextCustomTag, args[i]); sb.append(s); indexOfStart = nextCustom + nextCustomTag.length(); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/CompactHeapStringList.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/CompactHeapStringList.java index b7a25f248..dec86b34b 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/CompactHeapStringList.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/CompactHeapStringList.java @@ -286,7 +286,7 @@ public class CompactHeapStringList implements List { while (e1.hasNext() && e2.hasNext()) { String o1 = e1.next(); Object o2 = e2.next(); - if (!(o1 == null ? o2 == null : o1.equals(o2))) + if (!(Objects.equals(o1, o2))) return false; } return !(e1.hasNext() || e2.hasNext()); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/IntArrayKeyMap.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/IntArrayKeyMap.java index 2ed1b154d..84730c572 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/IntArrayKeyMap.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/IntArrayKeyMap.java @@ -28,7 +28,7 @@ import java.util.*; public class IntArrayKeyMap implements Map { - private Map map = new LinkedHashMap<>(); + private final Map map = new LinkedHashMap<>(); @Override public int size() { @@ -120,7 +120,7 @@ public class IntArrayKeyMap implements Map { public static class IntArray implements Comparable { @Getter - private int[] backingArray; + private final int[] backingArray; public IntArray(int[] backingArray) { Preconditions.checkNotNull(backingArray,"Backing array must not be null!"); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/IntArrayKeySet.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/IntArrayKeySet.java index 1a8893cda..b1db74f72 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/IntArrayKeySet.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/IntArrayKeySet.java @@ -23,7 +23,7 @@ package org.nd4j.common.collection; import java.util.*; public class IntArrayKeySet implements Set { - private Set set = new LinkedHashSet<>(); + private final Set set = new LinkedHashSet<>(); @Override public int size() { return set.size(); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java index a88871152..03ec92701 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java @@ -361,7 +361,7 @@ public class MultiDimensionalMap implements Serializable { MultiDimensionalMap that = (MultiDimensionalMap) o; - return !(backedMap != null ? !backedMap.equals(that.backedMap) : that.backedMap != null); + return !(!Objects.equals(backedMap, that.backedMap)); } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java index c5712d3eb..d16c190cb 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java @@ -28,7 +28,7 @@ import java.util.concurrent.ConcurrentSkipListSet; public class MultiDimensionalSet implements Set> { - private Set> backedSet; + private final Set> backedSet; public MultiDimensionalSet(Set> backedSet) { this.backedSet = backedSet; diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/holder/ObjectMapperHolder.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/holder/ObjectMapperHolder.java index 0cd5166a1..9df59f5f7 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/holder/ObjectMapperHolder.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/holder/ObjectMapperHolder.java @@ -26,7 +26,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; public class ObjectMapperHolder { - private static ObjectMapper objectMapper = getMapper(); + private static final ObjectMapper objectMapper = getMapper(); private ObjectMapperHolder() {} diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/AbstractFileResolvingResource.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/AbstractFileResolvingResource.java index 69728a1c3..d15d109d3 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/AbstractFileResolvingResource.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/AbstractFileResolvingResource.java @@ -124,7 +124,7 @@ public abstract class AbstractFileResolvingResource extends AbstractResource { ((HttpURLConnection) con).setRequestMethod("HEAD"); } - return (long) con.getContentLength(); + return con.getContentLength(); } } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/AbstractResource.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/AbstractResource.java index a6595a0e3..cf7ac3f38 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/AbstractResource.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/AbstractResource.java @@ -79,8 +79,7 @@ public abstract class AbstractResource implements Resource { long size = 0L; int read; - for (byte[] buf = new byte[255]; (read = is.read(buf)) != -1; size += (long) read) { - ; + for (byte[] buf = new byte[255]; (read = is.read(buf)) != -1; size += read) { } long var6 = size; @@ -89,7 +88,6 @@ public abstract class AbstractResource implements Resource { try { is.close(); } catch (IOException var14) { - ; } } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ClassPathResource.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ClassPathResource.java index cf3d45944..0ef4cde64 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ClassPathResource.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ClassPathResource.java @@ -365,7 +365,7 @@ public class ClassPathResource extends AbstractFileResolvingResource { private ZipFile zipFile; private ZipEntry entry; private InputStream stream; - private String resourceName; + private final String resourceName; public GetStreamFromZip(URL url, String resourceName) { this.url = url; diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/CollectionUtils.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/CollectionUtils.java index 268c6fac0..9d224b63f 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/CollectionUtils.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/CollectionUtils.java @@ -50,10 +50,7 @@ public abstract class CollectionUtils { Object[] arr$ = arr; int len$ = arr.length; - for (int i$ = 0; i$ < len$; ++i$) { - Object elem = arr$[i$]; - collection.add(elem); - } + collection.addAll(Arrays.asList(arr$).subList(0, len$)); } } @@ -157,7 +154,7 @@ public abstract class CollectionUtils { } public static T findValueOfType(Collection collection, Class type) { - if (isEmpty((Collection) collection)) { + if (isEmpty(collection)) { return null; } else { Object value = null; @@ -179,7 +176,7 @@ public abstract class CollectionUtils { } public static Object findValueOfType(Collection collection, Class[] types) { - if (!isEmpty((Collection) collection) && !ObjectUtils.isEmpty(types)) { + if (!isEmpty(collection) && !ObjectUtils.isEmpty(types)) { Class[] arr$ = types; int len$ = types.length; @@ -260,7 +257,7 @@ public abstract class CollectionUtils { } public static MultiValueMap unmodifiableMultiValueMap(MultiValueMap map) { - Assert.notNull(map, "\'map\' must not be null"); + Assert.notNull(map, "'map' must not be null"); LinkedHashMap result = new LinkedHashMap(map.size()); Iterator unmodifiableMap = map.entrySet().iterator(); @@ -278,7 +275,7 @@ public abstract class CollectionUtils { private final Map> map; public MultiValueMapAdapter(Map> map) { - Assert.notNull(map, "\'map\' must not be null"); + Assert.notNull(map, "'map' must not be null"); this.map = map; } @@ -374,7 +371,7 @@ public abstract class CollectionUtils { } public boolean equals(Object other) { - return this == other ? true : this.map.equals(other); + return this == other || this.map.equals(other); } public int hashCode() { @@ -387,7 +384,7 @@ public abstract class CollectionUtils { } private static class EnumerationIterator implements Iterator { - private Enumeration enumeration; + private final Enumeration enumeration; public EnumerationIterator(Enumeration enumeration) { this.enumeration = enumeration; diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ObjectUtils.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ObjectUtils.java index e1dcf32e9..43f6db46b 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ObjectUtils.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ObjectUtils.java @@ -122,7 +122,7 @@ public abstract class ObjectUtils { } throw new IllegalArgumentException(String.format("constant [%s] does not exist in enum opType %s", - new Object[] {constant, enumValues.getClass().getComponentType().getName()})); + constant, enumValues.getClass().getComponentType().getName())); } public static A[] addObjectToArray(A[] array, O obj) { @@ -479,7 +479,7 @@ public abstract class ObjectUtils { sb.append(", "); } - sb.append(String.valueOf(array[i])); + sb.append(array[i]); } sb.append("}"); @@ -557,7 +557,7 @@ public abstract class ObjectUtils { sb.append(", "); } - sb.append("\'").append(array[i]).append("\'"); + sb.append("'").append(array[i]).append("'"); } sb.append("}"); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java index 2a70e13d5..2332fcecc 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/ReflectionUtils.java @@ -289,7 +289,7 @@ public abstract class ReflectionUtils { mc.doWith(superIfc); } catch (IllegalAccessException var9) { throw new IllegalStateException( - "Shouldn\'t be illegal to access method \'" + superIfc.getName() + "\': " + var9); + "Shouldn't be illegal to access method '" + superIfc.getName() + "': " + var9); } } } @@ -374,7 +374,7 @@ public abstract class ReflectionUtils { fc.doWith(field); } catch (IllegalAccessException var10) { throw new IllegalStateException( - "Shouldn\'t be illegal to access field \'" + field.getName() + "\': " + var10); + "Shouldn't be illegal to access field '" + field.getName() + "': " + var10); } } } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/StringUtils.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/StringUtils.java index 9f4fecbec..264f76cf5 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/StringUtils.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/StringUtils.java @@ -242,7 +242,7 @@ public abstract class StringUtils { int index = inString.indexOf(oldPattern); for (int patLen = oldPattern.length(); index >= 0; index = inString.indexOf(oldPattern, pos)) { - sb.append(inString.substring(pos, index)); + sb.append(inString, pos, index); sb.append(newPattern); pos = index + patLen; } @@ -276,7 +276,7 @@ public abstract class StringUtils { } public static String quote(String str) { - return str != null ? "\'" + str + "\'" : null; + return str != null ? "'" + str + "'" : null; } public static Object quoteIfString(Object obj) { @@ -536,10 +536,7 @@ public abstract class StringUtils { String[] arr$ = array; int len$ = array.length; - for (int i$ = 0; i$ < len$; ++i$) { - String element = arr$[i$]; - set.add(element); - } + set.addAll(Arrays.asList(arr$).subList(0, len$)); return toStringArray(set); } @@ -656,10 +653,7 @@ public abstract class StringUtils { String[] arr$ = tokens; int len$ = tokens.length; - for (int i$ = 0; i$ < len$; ++i$) { - String token = arr$[i$]; - set.add(token); - } + set.addAll(Arrays.asList(arr$).subList(0, len$)); return set; } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/VfsUtils.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/VfsUtils.java index 2255c8176..502859a73 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/VfsUtils.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/io/VfsUtils.java @@ -43,14 +43,14 @@ public abstract class VfsUtils { private static Method VFS_METHOD_GET_ROOT_URL = null; private static Method VFS_METHOD_GET_ROOT_URI = null; private static Method VIRTUAL_FILE_METHOD_EXISTS = null; - private static Method VIRTUAL_FILE_METHOD_GET_INPUT_STREAM; - private static Method VIRTUAL_FILE_METHOD_GET_SIZE; - private static Method VIRTUAL_FILE_METHOD_GET_LAST_MODIFIED; - private static Method VIRTUAL_FILE_METHOD_TO_URL; - private static Method VIRTUAL_FILE_METHOD_TO_URI; - private static Method VIRTUAL_FILE_METHOD_GET_NAME; - private static Method VIRTUAL_FILE_METHOD_GET_PATH_NAME; - private static Method VIRTUAL_FILE_METHOD_GET_CHILD; + private static final Method VIRTUAL_FILE_METHOD_GET_INPUT_STREAM; + private static final Method VIRTUAL_FILE_METHOD_GET_SIZE; + private static final Method VIRTUAL_FILE_METHOD_GET_LAST_MODIFIED; + private static final Method VIRTUAL_FILE_METHOD_TO_URL; + private static final Method VIRTUAL_FILE_METHOD_TO_URI; + private static final Method VIRTUAL_FILE_METHOD_GET_NAME; + private static final Method VIRTUAL_FILE_METHOD_GET_PATH_NAME; + private static final Method VIRTUAL_FILE_METHOD_GET_CHILD; protected static Class VIRTUAL_FILE_VISITOR_INTERFACE; protected static Method VIRTUAL_FILE_METHOD_VISIT; private static Method VFS_UTILS_METHOD_IS_NESTED_FILE = null; @@ -122,11 +122,11 @@ public abstract class VfsUtils { } static Object getRelative(URL url) throws IOException { - return invokeVfsMethod(VFS_METHOD_GET_ROOT_URL, null, new Object[] {url}); + return invokeVfsMethod(VFS_METHOD_GET_ROOT_URL, null, url); } static Object getChild(Object vfsResource, String path) throws IOException { - return invokeVfsMethod(VIRTUAL_FILE_METHOD_GET_CHILD, vfsResource, new Object[] {path}); + return invokeVfsMethod(VIRTUAL_FILE_METHOD_GET_CHILD, vfsResource, path); } static File getFile(Object vfsResource) throws IOException { @@ -148,11 +148,11 @@ public abstract class VfsUtils { } static Object getRoot(URI url) throws IOException { - return invokeVfsMethod(VFS_METHOD_GET_ROOT_URI, null, new Object[] {url}); + return invokeVfsMethod(VFS_METHOD_GET_ROOT_URI, null, url); } protected static Object getRoot(URL url) throws IOException { - return invokeVfsMethod(VFS_METHOD_GET_ROOT_URL, null, new Object[] {url}); + return invokeVfsMethod(VFS_METHOD_GET_ROOT_URL, null, url); } protected static Object doGetVisitorAttribute() { @@ -195,8 +195,8 @@ public abstract class VfsUtils { try { String ex = VfsUtils.VFS_VER.V3.equals(version) ? "getChild" : "getRoot"; - VFS_METHOD_GET_ROOT_URL = ReflectionUtils.findMethod(vfsClass, ex, new Class[] {URL.class}); - VFS_METHOD_GET_ROOT_URI = ReflectionUtils.findMethod(vfsClass, ex, new Class[] {URI.class}); + VFS_METHOD_GET_ROOT_URL = ReflectionUtils.findMethod(vfsClass, ex, URL.class); + VFS_METHOD_GET_ROOT_URI = ReflectionUtils.findMethod(vfsClass, ex, URI.class); Class virtualFile = loader.loadClass(pkg + "VirtualFile"); VIRTUAL_FILE_METHOD_EXISTS = ReflectionUtils.findMethod(virtualFile, "exists"); VIRTUAL_FILE_METHOD_GET_INPUT_STREAM = ReflectionUtils.findMethod(virtualFile, "openStream"); @@ -208,15 +208,15 @@ public abstract class VfsUtils { VIRTUAL_FILE_METHOD_GET_PATH_NAME = ReflectionUtils.findMethod(virtualFile, "getPathName"); GET_PHYSICAL_FILE = ReflectionUtils.findMethod(virtualFile, "getPhysicalFile"); ex = VfsUtils.VFS_VER.V3.equals(version) ? "getChild" : "findChild"; - VIRTUAL_FILE_METHOD_GET_CHILD = ReflectionUtils.findMethod(virtualFile, ex, new Class[] {String.class}); + VIRTUAL_FILE_METHOD_GET_CHILD = ReflectionUtils.findMethod(virtualFile, ex, String.class); Class utilsClass = loader.loadClass(pkg + "VFSUtils"); VFS_UTILS_METHOD_GET_COMPATIBLE_URI = - ReflectionUtils.findMethod(utilsClass, "getCompatibleURI", new Class[] {virtualFile}); + ReflectionUtils.findMethod(utilsClass, "getCompatibleURI", virtualFile); VFS_UTILS_METHOD_IS_NESTED_FILE = - ReflectionUtils.findMethod(utilsClass, "isNestedFile", new Class[] {virtualFile}); + ReflectionUtils.findMethod(utilsClass, "isNestedFile", virtualFile); VIRTUAL_FILE_VISITOR_INTERFACE = loader.loadClass(pkg + "VirtualFileVisitor"); VIRTUAL_FILE_METHOD_VISIT = ReflectionUtils.findMethod(virtualFile, "visit", - new Class[] {VIRTUAL_FILE_VISITOR_INTERFACE}); + VIRTUAL_FILE_VISITOR_INTERFACE); Class visitorAttributesClass = loader.loadClass(pkg + "VisitorAttributes"); VISITOR_ATTRIBUTES_FIELD_RECURSE = ReflectionUtils.findField(visitorAttributesClass, "RECURSE"); } catch (ClassNotFoundException var7) { @@ -224,9 +224,9 @@ public abstract class VfsUtils { } } - private static enum VFS_VER { + private enum VFS_VER { V2, V3; - private VFS_VER() {} + VFS_VER() {} } } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/CounterMap.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/CounterMap.java index 1cc6758e6..597513300 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/CounterMap.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/CounterMap.java @@ -192,7 +192,7 @@ public class CounterMap implements Serializable{ public Iterator> getIterator() { return new Iterator>() { - Iterator outerIt; + final Iterator outerIt; Iterator innerIt; F curKey; diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicBoolean.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicBoolean.java index 6c807feea..13a9e523a 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicBoolean.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicBoolean.java @@ -31,7 +31,7 @@ import java.io.IOException; public class JsonDeserializerAtomicBoolean extends JsonDeserializer { @Override - public AtomicBoolean deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + public AtomicBoolean deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { JsonNode node = jsonParser.getCodec().readTree(jsonParser); boolean value = node.asBoolean(); return new AtomicBoolean(value); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicDouble.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicDouble.java index d777b0072..2b152e750 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicDouble.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonDeserializerAtomicDouble.java @@ -31,7 +31,7 @@ import java.io.IOException; public class JsonDeserializerAtomicDouble extends JsonDeserializer { @Override - public AtomicDouble deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + public AtomicDouble deserialize(JsonParser jsonParser, DeserializationContext deserializationContext) throws IOException { JsonNode node = jsonParser.getCodec().readTree(jsonParser); double value = node.asDouble(); return new AtomicDouble(value); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicBoolean.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicBoolean.java index c10f1bc95..e2d51b105 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicBoolean.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicBoolean.java @@ -30,7 +30,7 @@ import java.io.IOException; public class JsonSerializerAtomicBoolean extends JsonSerializer { @Override - public void serialize(AtomicBoolean atomicDouble, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException, JsonProcessingException { + public void serialize(AtomicBoolean atomicDouble, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { jsonGenerator.writeBoolean(atomicDouble.get()); } } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicDouble.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicDouble.java index 1f9041ccd..9e00819d4 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicDouble.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/primitives/serde/JsonSerializerAtomicDouble.java @@ -30,7 +30,7 @@ import java.io.IOException; public class JsonSerializerAtomicDouble extends JsonSerializer { @Override - public void serialize(AtomicDouble atomicDouble, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException, JsonProcessingException { + public void serialize(AtomicDouble atomicDouble, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) throws IOException { jsonGenerator.writeNumber(atomicDouble.doubleValue()); } } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/Resources.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/Resources.java index f8fa974f4..aec97ba3e 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/Resources.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/Resources.java @@ -31,7 +31,7 @@ import java.util.*; @Slf4j public class Resources { - private static Resources INSTANCE = new Resources(); + private static final Resources INSTANCE = new Resources(); protected final List resolvers; @@ -123,7 +123,7 @@ public class Resources { } throw new IllegalStateException("Cannot resolve resource (not found): none of " + resolvers.size() + - " resolvers can resolve resource \"" + resourcePath + "\" - available resolvers: " + resolvers.toString()); + " resolvers can resolve resource \"" + resourcePath + "\" - available resolvers: " + resolvers); } public InputStream getAsStream(String resourcePath) { @@ -135,7 +135,7 @@ public class Resources { } throw new IllegalStateException("Cannot resolve resource (not found): none of " + resolvers.size() + - " resolvers can resolve resource \"" + resourcePath + "\" - available resolvers: " + resolvers.toString()); + " resolvers can resolve resource \"" + resourcePath + "\" - available resolvers: " + resolvers); } public void copyDir(String directoryPath, File destinationDir) { diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/strumpf/ResourceFile.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/strumpf/ResourceFile.java index 0141be02f..8bdeae89c 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/strumpf/ResourceFile.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/strumpf/ResourceFile.java @@ -118,10 +118,7 @@ public class ResourceFile { Preconditions.checkState(expSha256 != null, "Expected JSON property %s was not found in resource reference file %s", sha256Property, filePath); String actualSha256 = sha256(file); - if (!expSha256.equals(actualSha256)) { - return false; - } - return true; + return expSha256.equals(actualSha256); } /** diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/strumpf/StrumpfResolver.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/strumpf/StrumpfResolver.java index 54ff89459..ba879f740 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/strumpf/StrumpfResolver.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/resources/strumpf/StrumpfResolver.java @@ -94,11 +94,7 @@ public class StrumpfResolver implements Resolver { } cpr = new ClassPathResource(resourcePath); - if (cpr.exists()) { - return true; - } - - return false; + return cpr.exists(); } @Override @@ -116,11 +112,7 @@ public class StrumpfResolver implements Resolver { //Second: Check classpath ClassPathResource cpr = new ClassPathResource(dirPath); - if (cpr.exists()) { - return true; - } - - return false; + return cpr.exists(); } @Override diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/tools/BTools.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/tools/BTools.java index 7e4d06b49..d22b22998 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/tools/BTools.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/tools/BTools.java @@ -272,10 +272,10 @@ public class BTools { // String FormatS = ""; if ( LeadingChar == '0' ) { - FormatS = "%" + LeadingChar + Integer.toString( CharsCount ) + "d"; + FormatS = "%" + LeadingChar + CharsCount + "d"; } else { - FormatS = "%" + Integer.toString( CharsCount ) + "d"; + FormatS = "%" + CharsCount + "d"; } // Result = String.format( FormatS, Value ); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/tools/SIS.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/tools/SIS.java index b10296fcc..a2ee4f925 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/tools/SIS.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/tools/SIS.java @@ -33,7 +33,7 @@ import java.time.format.DateTimeFormatter; public class SIS { // System Informations Saving // - private String baseModuleCode = "SIS"; + private final String baseModuleCode = "SIS"; private String moduleCode = "?"; // private PrintStream out; diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArchiveUtils.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArchiveUtils.java index 317c5a23d..cd682f3b2 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArchiveUtils.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArchiveUtils.java @@ -80,7 +80,7 @@ public class ArchiveUtils { new File(dest).mkdirs(); FileInputStream fin = new FileInputStream(target); int BUFFER = 2048; - byte data[] = new byte[BUFFER]; + byte[] data = new byte[BUFFER]; if (file.endsWith(".zip") || file.endsWith(".jar")) { try(ZipInputStream zis = new ZipInputStream(fin)) { @@ -152,7 +152,7 @@ public class ArchiveUtils { else { int count; try(FileOutputStream fos = new FileOutputStream(dest + File.separator + entry.getName()); - BufferedOutputStream destStream = new BufferedOutputStream(fos, BUFFER);) { + BufferedOutputStream destStream = new BufferedOutputStream(fos, BUFFER)) { while ((count = tarIn.read(data, 0, BUFFER)) != -1) { destStream.write(data, 0, count); } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java index 8a30f0e48..13780f3a6 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java @@ -295,7 +295,7 @@ public class ArrayUtil { public static long[] toLongs(byte[] data) { val ret = new long[data.length]; for (int i = 0; i < ret.length; i++) { - ret[i] = (long) data[i]; + ret[i] = data[i]; } return ret; } @@ -311,7 +311,7 @@ public class ArrayUtil { public static long[] toLongs(short[] data) { val ret = new long[data.length]; for (int i = 0; i < ret.length; i++) { - ret[i] = (long) data[i]; + ret[i] = data[i]; } return ret; } @@ -319,7 +319,7 @@ public class ArrayUtil { public static long[] toLongs(int[] data) { val ret = new long[data.length]; for (int i = 0; i < ret.length; i++) { - ret[i] = (long) data[i]; + ret[i] = data[i]; } return ret; } @@ -1105,7 +1105,7 @@ public class ArrayUtil { public static double[] toDoubles(int[] ints) { double[] ret = new double[ints.length]; for (int i = 0; i < ints.length; i++) - ret[i] = (double) ints[i]; + ret[i] = ints[i]; return ret; } @@ -1119,7 +1119,7 @@ public class ArrayUtil { public static double[] toDoubles(float[] ints) { double[] ret = new double[ints.length]; for (int i = 0; i < ints.length; i++) - ret[i] = (double) ints[i]; + ret[i] = ints[i]; return ret; } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/Index.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/Index.java index cc64e145d..ff91a9a4e 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/Index.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/Index.java @@ -23,14 +23,15 @@ package org.nd4j.common.util; import java.io.Serializable; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ConcurrentHashMap; @SuppressWarnings({"rawtypes", "unchecked"}) public class Index implements Serializable { private static final long serialVersionUID = 1160629777026141078L; - private Map objects = new ConcurrentHashMap<>(); - private Map indexes = new ConcurrentHashMap<>(); + private final Map objects = new ConcurrentHashMap<>(); + private final Map indexes = new ConcurrentHashMap<>(); public synchronized boolean add(Object o, int idx) { if (o instanceof String && o.toString().isEmpty()) { @@ -103,9 +104,9 @@ public class Index implements Serializable { Index index = (Index) o; - if (objects != null ? !objects.equals(index.objects) : index.objects != null) + if (!Objects.equals(objects, index.objects)) return false; - return !(indexes != null ? !indexes.equals(index.indexes) : index.indexes != null); + return !(!Objects.equals(indexes, index.indexes)); } diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/MathUtils.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/MathUtils.java index 58d72eace..6e249ffbd 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/MathUtils.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/MathUtils.java @@ -163,7 +163,7 @@ public class MathUtils { * @param targetAttribute target attribute vector * @return the correlation coefficient or r */ - public static double correlation(double[] residuals, double targetAttribute[]) { + public static double correlation(double[] residuals, double[] targetAttribute) { double[] predictedValues = new double[residuals.length]; for (int i = 0; i < predictedValues.length; i++) { predictedValues[i] = targetAttribute[i] - residuals[i]; @@ -1042,7 +1042,7 @@ public class MathUtils { */ public static /*@pure@*/ double roundDouble(double value, int afterDecimalPoint) { - double mask = Math.pow(10.0, (double) afterDecimalPoint); + double mask = Math.pow(10.0, afterDecimalPoint); return (double) (Math.round(value * mask)) / mask; }//end roundDouble diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/Rational.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/Rational.java index 404874016..e9914479c 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/Rational.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/Rational.java @@ -234,10 +234,10 @@ class Rational implements Cloneable { public Rational pow(BigInteger exponent) throws NumberFormatException { /* test for overflow */ if (exponent.compareTo(MAX_INT) == 1) { - throw new NumberFormatException("Exponent " + exponent.toString() + " too large."); + throw new NumberFormatException("Exponent " + exponent + " too large."); } if (exponent.compareTo(MIN_INT) == -1) { - throw new NumberFormatException("Exponent " + exponent.toString() + " too small."); + throw new NumberFormatException("Exponent " + exponent + " too small."); } /* promote to the simpler interface above */ return pow(exponent.intValue()); diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/SynchronizedTable.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/SynchronizedTable.java index ace0bf5f1..37c16114e 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/SynchronizedTable.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/SynchronizedTable.java @@ -27,7 +27,7 @@ import java.util.Map; import java.util.Set; public class SynchronizedTable implements Table { - private Table wrapped; + private final Table wrapped; public SynchronizedTable(Table wrapped) { this.wrapped = wrapped; diff --git a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java index b4c86b2e9..b08be4f36 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java +++ b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/function/FunctionalUtilsTest.java @@ -45,9 +45,9 @@ public class FunctionalUtilsTest { //[(fish,([],[alex])), (dog,([adam],[steve])), (cat,([adam],[alice]))] Map,List>> assertion = new HashMap<>(); - assertion.put("cat",Pair.of(Arrays.asList("adam"),Arrays.asList("alice"))); - assertion.put("dog",Pair.of(Arrays.asList("adam"),Arrays.asList("steve"))); - assertion.put("fish",Pair.of(Collections.emptyList(),Arrays.asList("alex"))); + assertion.put("cat",Pair.of(Collections.singletonList("adam"), Collections.singletonList("alice"))); + assertion.put("dog",Pair.of(Collections.singletonList("adam"), Collections.singletonList("steve"))); + assertion.put("fish",Pair.of(Collections.emptyList(), Collections.singletonList("alex"))); Map, List>> cogroup = FunctionalUtils.cogroup(leftMap, rightMap); assertEquals(assertion,cogroup); diff --git a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java index b3c924919..b215da12a 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java +++ b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/loader/TestFileBatch.java @@ -83,7 +83,7 @@ public class TestFileBatch { //Check that it is indeed a valid zip file: - File f = new File(FileUtils.getTempDirectoryPath()+"/"+UUID.randomUUID().toString()); + File f = new File(FileUtils.getTempDirectoryPath()+"/"+ UUID.randomUUID()); f.delete(); fb.writeAsZip(f); diff --git a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java index ee40a1089..992cff871 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java +++ b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/tools/InfoValuesTest.java @@ -27,9 +27,9 @@ import static org.junit.jupiter.api.Assertions.*; public class InfoValuesTest { // - private String[] t1_titleA = { "T0", "T1", "T2", "T3", "T4", "T5" }; + private final String[] t1_titleA = { "T0", "T1", "T2", "T3", "T4", "T5" }; // - private String[] t2_titleA = { "", "T1", "T2" }; + private final String[] t2_titleA = { "", "T1", "T2" }; // @Test diff --git a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/tools/SISTest.java b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/tools/SISTest.java index e89fdd324..95b625c32 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/tools/SISTest.java +++ b/cavis-nd4j/cavis-nd4j-common/src/test/java/org/nd4j/common/tools/SISTest.java @@ -50,7 +50,7 @@ public class SISTest { // assertEquals( 33, fFName.length() ); assertEquals( "Z", fFName.substring( 0, 1 ) ); - assertEquals( "_Test_ABC.txt", fFName.substring( fFName.length() - 13, fFName.length() ) ); + assertEquals( "_Test_ABC.txt", fFName.substring( fFName.length() - 13) ); // assertEquals( "", fFName ); // assertEquals( "", tmpFld.getRoot().getAbsolutePath() ); // diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/BackgroundDaemonStarter.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/BackgroundDaemonStarter.java index 7a582400a..67c79803b 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/BackgroundDaemonStarter.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/BackgroundDaemonStarter.java @@ -47,7 +47,7 @@ public class BackgroundDaemonStarter { * @throws InterruptedException */ public static int startSlave(int parameterLength, String masterUrl, String mediaDriverDirectory) throws Exception { - return exec(ParameterServerSubscriber.class, mediaDriverDirectory, "-s", "1," + String.valueOf(parameterLength), + return exec(ParameterServerSubscriber.class, mediaDriverDirectory, "-s", "1," + parameterLength, "-p", "40126", "-h", "localhost", "-id", "10", "-pm", masterUrl, "-sp", "9500", "--updatesPerEpoch", "1"); } @@ -96,7 +96,7 @@ public class BackgroundDaemonStarter { */ public static int startMaster(int parameterLength, String mediaDriverDirectory) throws Exception { return exec(ParameterServerSubscriber.class, mediaDriverDirectory, "-m", "true", "-s", - "1," + String.valueOf(parameterLength), "-p", "40123", "-h", "localhost", "-id", "11", "-sp", + "1," + parameterLength, "-p", "40123", "-h", "localhost", "-id", "11", "-sp", "9200", "--updatesPerEpoch", "1"); } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java index 005443fe3..8568637b4 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/background/RemoteParameterServerClientTests.java @@ -43,11 +43,11 @@ import static org.junit.jupiter.api.Assertions.assertEquals; @Slf4j public class RemoteParameterServerClientTests extends BaseND4JTest { - private int parameterLength = 1000; + private final int parameterLength = 1000; private Aeron.Context ctx; private MediaDriver mediaDriver; - private AtomicInteger masterStatus = new AtomicInteger(0); - private AtomicInteger slaveStatus = new AtomicInteger(0); + private final AtomicInteger masterStatus = new AtomicInteger(0); + private final AtomicInteger slaveStatus = new AtomicInteger(0); private Aeron aeron; @BeforeEach diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java index b57618211..2309aae84 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientPartialTest.java @@ -46,7 +46,7 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { private static MediaDriver mediaDriver; private static Aeron.Context ctx; private static ParameterServerSubscriber masterNode, slaveNode; - private int[] shape = {2, 2}; + private final int[] shape = {2, 2}; private static Aeron aeron; @BeforeAll @@ -74,7 +74,7 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { assertEquals("localhost", masterNode.getHost()); assertEquals(11, masterNode.getStreamId()); assertEquals(12, masterNode.getResponder().getStreamId()); - assertEquals(masterNode.getMasterArray(), Nd4j.create(new int[] {2, 2})); + assertEquals(masterNode.getMasterArray(), Nd4j.create(2, 2)); slaveNode = new ParameterServerSubscriber(mediaDriver); slaveNode.setAeron(aeron); @@ -127,7 +127,7 @@ public class ParameterServerClientPartialTest extends BaseND4JTest { Thread.sleep(30000); ParameterServerListener listener = (ParameterServerListener) masterNode.getCallback(); assertEquals(1, listener.getUpdater().numUpdates()); - INDArray assertion = Nd4j.create(new int[] {2, 2}); + INDArray assertion = Nd4j.create(2, 2); assertion.getColumn(0).addi(1.0); assertEquals(assertion, listener.getUpdater().ndArrayHolder().get()); INDArray arr = client.getArray(); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java index 985c77ec8..5492fb0f4 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-client/src/test/java/org/nd4j/parameterserver/client/ParameterServerClientTest.java @@ -41,10 +41,10 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class ParameterServerClientTest extends BaseND4JTest { private static MediaDriver mediaDriver; - private static Logger log = LoggerFactory.getLogger(ParameterServerClientTest.class); + private static final Logger log = LoggerFactory.getLogger(ParameterServerClientTest.class); private static Aeron aeron; private static ParameterServerSubscriber masterNode, slaveNode; - private static int parameterLength = 1000; + private static final int parameterLength = 1000; @BeforeAll public static void beforeClass() throws Exception { @@ -54,7 +54,7 @@ public class ParameterServerClientTest extends BaseND4JTest { masterNode = new ParameterServerSubscriber(mediaDriver); masterNode.setAeron(aeron); int masterPort = 40323 + new java.util.Random().nextInt(3000); - masterNode.run(new String[] {"-m", "true", "-s", "1," + String.valueOf(parameterLength), "-p", + masterNode.run(new String[] {"-m", "true", "-s", "1," + parameterLength, "-p", String.valueOf(masterPort), "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sp", "33000", "-u", String.valueOf(1)}); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/SoftSyncParameterUpdater.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/SoftSyncParameterUpdater.java index 59e19ab10..5ba28b6bc 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/SoftSyncParameterUpdater.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/SoftSyncParameterUpdater.java @@ -32,7 +32,7 @@ public class SoftSyncParameterUpdater extends BaseParameterUpdater { //s is the number of updates private int s; private int currentVersion; - private int accumulatedUpdates = 0; + private final int accumulatedUpdates = 0; private double scalingFactor; diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/SynchronousParameterUpdater.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/SynchronousParameterUpdater.java index 9ebf5bcbd..12635f433 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/SynchronousParameterUpdater.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/SynchronousParameterUpdater.java @@ -33,7 +33,7 @@ import java.util.Map; public class SynchronousParameterUpdater extends BaseParameterUpdater { private int workers = Runtime.getRuntime().availableProcessors(); - private static ObjectMapper objectMapper = new ObjectMapper(); + private static final ObjectMapper objectMapper = new ObjectMapper(); /** * Returns the number of required diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/storage/InMemoryUpdateStorage.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/storage/InMemoryUpdateStorage.java index 73202e0d2..7c4a81eb0 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/storage/InMemoryUpdateStorage.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/storage/InMemoryUpdateStorage.java @@ -28,7 +28,7 @@ import java.util.concurrent.CopyOnWriteArrayList; public class InMemoryUpdateStorage extends BaseUpdateStorage { - private List updates = new CopyOnWriteArrayList<>(); + private final List updates = new CopyOnWriteArrayList<>(); /** * Add an ndarray to the storage diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/storage/NoUpdateStorage.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/storage/NoUpdateStorage.java index a44db1a40..16ec168fc 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/storage/NoUpdateStorage.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-core/src/main/java/org/nd4j/parameterserver/updater/storage/NoUpdateStorage.java @@ -27,7 +27,7 @@ import java.util.concurrent.atomic.AtomicInteger; @Slf4j public class NoUpdateStorage extends BaseUpdateStorage { - private AtomicInteger updateCount = new AtomicInteger(0); + private final AtomicInteger updateCount = new AtomicInteger(0); /** * Add an ndarray to the storage diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/RetransmissionHandler.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/RetransmissionHandler.java index 59db0f670..af46e8a7e 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/RetransmissionHandler.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/RetransmissionHandler.java @@ -27,7 +27,7 @@ import org.nd4j.parameterserver.distributed.transport.Transport; @Deprecated public interface RetransmissionHandler { - public enum TransmissionStatus { + enum TransmissionStatus { MESSAGE_SENT, NOT_CONNECTED, BACKPRESSURE, ADMIN_ACTION, } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/completion/FrameCompletionHandler.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/completion/FrameCompletionHandler.java index 6f12029b3..9ba7014ae 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/completion/FrameCompletionHandler.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/completion/FrameCompletionHandler.java @@ -32,7 +32,7 @@ import java.util.concurrent.atomic.AtomicInteger; @Deprecated public class FrameCompletionHandler { - private Map frames = new ConcurrentHashMap<>(); + private final Map frames = new ConcurrentHashMap<>(); public boolean isTrackingFrame(RequestDescriptor descriptor) { return frames.containsKey(descriptor); @@ -104,12 +104,12 @@ public class FrameCompletionHandler { public static class FrameDescriptor { @Getter - private long frameOriginatorId; + private final long frameOriginatorId; // messageId within frame, and it's state - private Map states = new ConcurrentHashMap<>(); - private AtomicInteger messages = new AtomicInteger(0); - private AtomicInteger finished = new AtomicInteger(0); + private final Map states = new ConcurrentHashMap<>(); + private final AtomicInteger messages = new AtomicInteger(0); + private final AtomicInteger finished = new AtomicInteger(0); public FrameDescriptor(long frameOriginatorId) { diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/storage/BaseStorage.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/storage/BaseStorage.java index 846a689d1..25d74566f 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/storage/BaseStorage.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/logic/storage/BaseStorage.java @@ -31,7 +31,7 @@ import java.util.concurrent.ConcurrentHashMap; @Deprecated public abstract class BaseStorage implements Storage { - private ConcurrentHashMap storage = new ConcurrentHashMap<>(); + private final ConcurrentHashMap storage = new ConcurrentHashMap<>(); @Override diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/Frame.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/Frame.java index c6f4f8134..42313290b 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/Frame.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/Frame.java @@ -217,7 +217,7 @@ public class Frame implements Serializable, Iterable< //log.info("Firing message {}; originator: {}; frameId: {}; taskId: {}", message.getClass().getSimpleName(), message.getOriginatorId(), message.getFrameId(), message.getTaskId()); message.processMessage(); } - } ; + } } @Override diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedCbowDotMessage.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedCbowDotMessage.java index 258738a61..5df923c91 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedCbowDotMessage.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedCbowDotMessage.java @@ -87,7 +87,7 @@ public class DistributedCbowDotMessage extends BaseVoidMessage implements Distri CbowRequestMessage cbrm = new CbowRequestMessage(rowsA, rowsB, w1, codes, negSamples, alpha, 119); if (negSamples > 0) { // unfortunately we have to get copy of negSamples here - int negatives[] = Arrays.copyOfRange(rowsB, codes.length, rowsB.length); + int[] negatives = Arrays.copyOfRange(rowsB, codes.length, rowsB.length); cbrm.setNegatives(negatives); } cbrm.setFrameId(-119L); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedSgDotMessage.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedSgDotMessage.java index 003cb6d15..0d3ec17f7 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedSgDotMessage.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/messages/intercom/DistributedSgDotMessage.java @@ -84,7 +84,7 @@ public class DistributedSgDotMessage extends BaseVoidMessage implements Distribu SkipGramRequestMessage sgrm = new SkipGramRequestMessage(w1, w2, rowsB, codes, negSamples, alpha, 119); if (negSamples > 0) { // unfortunately we have to get copy of negSamples here - int negatives[] = Arrays.copyOfRange(rowsB, codes.length, rowsB.length); + int[] negatives = Arrays.copyOfRange(rowsB, codes.length, rowsB.length); sgrm.setNegatives(negatives); } sgrm.setTaskId(this.taskId); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/impl/CbowTrainer.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/impl/CbowTrainer.java index 78aa40bd3..b6eaf3e44 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/impl/CbowTrainer.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/impl/CbowTrainer.java @@ -57,11 +57,11 @@ public class CbowTrainer extends BaseTrainer { chains.put(RequestDescriptor.createDescriptor(message.getOriginatorId(), message.getTaskId()), chain); - int row_syn1[] = message.getSyn1rows(); + int[] row_syn1 = message.getSyn1rows(); if (message.getNegSamples() > 0) { - int rows = (int) storage.getArray(WordVectorStorage.SYN_0).rows(); - int tempArray[] = new int[message.getNegSamples() + 1]; + int rows = storage.getArray(WordVectorStorage.SYN_0).rows(); + int[] tempArray = new int[message.getNegSamples() + 1]; tempArray[0] = message.getW1(); for (int e = 1; e < message.getNegSamples() + 1; e++) { @@ -118,7 +118,7 @@ public class CbowTrainer extends BaseTrainer { + "]; taskId: [" + aggregation.getTaskId() + "]"); } - chain.addElement((DotAggregation) aggregation); + chain.addElement(aggregation); finishTraining(aggregation.getOriginatorId(), aggregation.getTaskId()); } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/impl/SkipGramTrainer.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/impl/SkipGramTrainer.java index add4c0867..8805ee65b 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/impl/SkipGramTrainer.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/training/impl/SkipGramTrainer.java @@ -70,13 +70,13 @@ public class SkipGramTrainer extends BaseTrainer { // we assume this is HS round //if (message.getPoints() != null && message.getPoints().length > 0) { - int row_syn0[] = new int[0]; //replicate(message.getW2(), message.getPoints().length); + int[] row_syn0 = new int[0]; //replicate(message.getW2(), message.getPoints().length); - int row_syn1[] = message.getPoints(); + int[] row_syn1 = message.getPoints(); if (message.getNegSamples() > 0) { - int rows = (int) storage.getArray(WordVectorStorage.SYN_0).rows(); - int tempArray[] = new int[message.getNegSamples() + 1]; + int rows = storage.getArray(WordVectorStorage.SYN_0).rows(); + int[] tempArray = new int[message.getNegSamples() + 1]; tempArray[0] = message.getW1(); for (int e = 1; e < message.getNegSamples() + 1; e++) { @@ -156,7 +156,7 @@ public class SkipGramTrainer extends BaseTrainer { + "]; taskId: [" + aggregation.getTaskId() + "]"); } - chain.addElement((DotAggregation) aggregation); + chain.addElement(aggregation); finishTraining(aggregation.getOriginatorId(), aggregation.getTaskId()); } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java index bad3a3fb4..222ff3546 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java @@ -279,7 +279,7 @@ public abstract class BaseTransport implements Transport { byte[] data = new byte[length]; buffer.getBytes(offset, data); - MeaningfulMessage message = (MeaningfulMessage) VoidMessage.fromBytes(data); + MeaningfulMessage message = VoidMessage.fromBytes(data); completed.put(message.getTaskId(), message); } @@ -412,7 +412,7 @@ public abstract class BaseTransport implements Transport { } break; default: - throw new IllegalStateException("Unknown thread model: [" + threading.toString() + "]"); + throw new IllegalStateException("Unknown thread model: [" + threading + "]"); } } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java index 7ee969015..8f49c7e7a 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java @@ -562,7 +562,7 @@ public class RoutedTransport extends BaseTransport { completed.put(message.getTaskId(), msg); } else if (message instanceof RequestMessage) { try { - messages.put((RequestMessage) message); + messages.put(message); } catch (InterruptedException e) { // do nothing } catch (Exception e) { @@ -570,7 +570,7 @@ public class RoutedTransport extends BaseTransport { } } else if (message instanceof DistributedMessage) { try { - messages.put((DistributedMessage) message); + messages.put(message); } catch (InterruptedException e) { // do nothing } catch (Exception e) { @@ -578,7 +578,7 @@ public class RoutedTransport extends BaseTransport { } } else if (message instanceof TrainingMessage) { try { - messages.put((TrainingMessage) message); + messages.put(message); } catch (InterruptedException e) { // do nothing } catch (Exception e) { @@ -586,7 +586,7 @@ public class RoutedTransport extends BaseTransport { } } else if (message instanceof VoidAggregation) { try { - messages.put((VoidAggregation) message); + messages.put(message); } catch (InterruptedException e) { // do nothing } catch (Exception e) { @@ -594,7 +594,7 @@ public class RoutedTransport extends BaseTransport { } } else if (message instanceof Frame) { try { - messages.put((Frame) message); + messages.put(message); } catch (InterruptedException e) { // do nothing } catch (Exception e) { @@ -664,8 +664,8 @@ public class RoutedTransport extends BaseTransport { public static class RemoteConnectionBuilder { - private Object locker = new Object(); - private AtomicBoolean activated = new AtomicBoolean(); + private final Object locker = new Object(); + private final AtomicBoolean activated = new AtomicBoolean(); } } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/util/NetworkInformation.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/util/NetworkInformation.java index 4937478eb..e55dc4b1c 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/util/NetworkInformation.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/util/NetworkInformation.java @@ -27,6 +27,7 @@ import lombok.NonNull; import java.io.Serializable; import java.util.ArrayList; import java.util.List; +import java.util.Objects; @NoArgsConstructor @Data @@ -49,7 +50,7 @@ public class NetworkInformation implements Serializable { NetworkInformation that = (NetworkInformation) o; - return ipAddresses != null ? ipAddresses.equals(that.ipAddresses) : that.ipAddresses == null; + return Objects.equals(ipAddresses, that.ipAddresses); } @Override diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServer.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServer.java index 0f1fc902c..fb6768334 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServer.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/ModelParameterServer.java @@ -62,7 +62,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; */ @Slf4j public final class ModelParameterServer { - protected static final ModelParameterServer INSTANCE = new ModelParameterServer(); + private static final ModelParameterServer INSTANCE = new ModelParameterServer(); @Getter private Transport transport; @@ -79,33 +79,33 @@ public final class ModelParameterServer { private final BlockingQueue updatesQueue = new LinkedBlockingQueue<>(4096); // subsribers that are connected to actual model - protected final List updatesSubscribers = new CopyOnWriteArrayList<>(); - protected final List> modelParamsSubsribers = new CopyOnWriteArrayList<>(); - protected final List> updaterParamsSubscribers = new CopyOnWriteArrayList<>(); + private final List updatesSubscribers = new CopyOnWriteArrayList<>(); + private final List> modelParamsSubsribers = new CopyOnWriteArrayList<>(); + private final List> updaterParamsSubscribers = new CopyOnWriteArrayList<>(); private boolean masterMode; - protected VoidConfiguration configuration; + private VoidConfiguration configuration; // this flag is true once mps is launched private final AtomicBoolean launchLock = new AtomicBoolean(false); private final AtomicBoolean stopLock = new AtomicBoolean(false); // this queue is used as temporary storage for updates received during restart event. - protected BlockingQueue updatesBacklog = new LinkedBlockingQueue<>(); + private BlockingQueue updatesBacklog = new LinkedBlockingQueue<>(); // these two fields only used at master node, to store latest updater copy - protected final Atomic updaterParameters = new Atomic<>(); - protected final ReentrantReadWriteLock updaterParamsLock = new ReentrantReadWriteLock(); - protected final AtomicBoolean gotFinalState = new AtomicBoolean(false); + private final Atomic updaterParameters = new Atomic<>(); + private final ReentrantReadWriteLock updaterParamsLock = new ReentrantReadWriteLock(); + private final AtomicBoolean gotFinalState = new AtomicBoolean(false); private Disposable disposable; - private AtomicInteger iterationNumber = new AtomicInteger(0); - private AtomicInteger epochNumber = new AtomicInteger(0); + private final AtomicInteger iterationNumber = new AtomicInteger(0); + private final AtomicInteger epochNumber = new AtomicInteger(0); - protected ModelParameterServer() { + private ModelParameterServer() { // } @@ -118,7 +118,7 @@ public final class ModelParameterServer { * * @param transport */ - protected ModelParameterServer(@NonNull Transport transport) { + private ModelParameterServer(@NonNull Transport transport) { this(transport, false); } @@ -128,7 +128,7 @@ public final class ModelParameterServer { * @param transport * @param isMasterNode */ - protected ModelParameterServer(@NonNull Transport transport, boolean isMasterNode) { + ModelParameterServer(@NonNull Transport transport, boolean isMasterNode) { this(VoidConfiguration.builder().portSupplier(new StaticPortSupplier(40123)).streamId(119).build(), transport, isMasterNode); } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTracker.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTracker.java index 18ff83fb3..07545a2d2 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTracker.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/chunks/impl/FileChunksTracker.java @@ -43,9 +43,9 @@ public class FileChunksTracker implements ChunksTracker map = new ConcurrentHashMap<>(); + private final Map map = new ConcurrentHashMap<>(); - private File holder; + private final File holder; private final long size; @@ -87,7 +87,7 @@ public class FileChunksTracker implements ChunksTracker implements ChunksTrack private final int numChunks; - private Map map = new ConcurrentHashMap<>(); + private final Map map = new ConcurrentHashMap<>(); private final byte[] buffer; diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/messages/pairs/params/UpdaterParametersMessage.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/messages/pairs/params/UpdaterParametersMessage.java index bb59958d2..82b92b810 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/messages/pairs/params/UpdaterParametersMessage.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/messages/pairs/params/UpdaterParametersMessage.java @@ -34,7 +34,7 @@ public final class UpdaterParametersMessage extends BaseINDArrayMessage implemen @Getter @Setter - protected boolean finalState = false; + private boolean finalState = false; public UpdaterParametersMessage(@NonNull String messageId, INDArray payload) { super(messageId, payload); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java index 3cc941171..81b01e915 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java @@ -332,7 +332,7 @@ public abstract class BaseTransport implements Transport { if (!isLoopedNode(n, originatorId, relayId)) { sendMessage(voidMessage, n.getId()); } - }; + } } } @@ -637,7 +637,7 @@ public abstract class BaseTransport implements Transport { * @param */ public static class MessageFlow implements Consumer, Publisher { - private List> subscribers = new CopyOnWriteArrayList<>(); + private final List> subscribers = new CopyOnWriteArrayList<>(); @Override public void accept(T voidMessage) throws Exception { diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransport.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransport.java index 4ed3fdb4b..508485244 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransport.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/DummyTransport.java @@ -141,8 +141,8 @@ public class DummyTransport extends BaseTransport { * This class is written to mimic network connectivity locally */ public static class Connector { - private Map transports = new ConcurrentHashMap<>(); - private ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), new ThreadFactory() { + private final Map transports = new ConcurrentHashMap<>(); + private final ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors(), new ThreadFactory() { @Override public Thread newThread(@NonNull Runnable r) { val t = Executors.defaultThreadFactory().newThread(r); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizer.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizer.java index 3adaec329..7b0108fe8 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizer.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizer.java @@ -51,10 +51,10 @@ public class MeshOrganizer implements Serializable { // just shortcut to the root node of the tree @Getter(AccessLevel.PUBLIC) - private Node rootNode = new Node(true); + private final Node rootNode = new Node(true); // SortedSet, with sort by number of downstreams - private transient List sortedNodes = new ArrayList<>(); + private final transient List sortedNodes = new ArrayList<>(); // flattened map of the tree, ID -> Node private transient Map nodeMap = new HashMap<>(); @@ -325,7 +325,7 @@ public class MeshOrganizer implements Serializable { * @return */ protected long flatSize() { - return (long) nodeMap.size(); + return nodeMap.size(); } /** @@ -476,7 +476,7 @@ public class MeshOrganizer implements Serializable { val distance = distanceFromRoot(); for (val d: downstream) - if (d.numberOfDescendants() < MeshOrganizer.MAX_DOWNSTREAMS * (MeshOrganizer.MAX_DEPTH - distance)) + if (d.numberOfDescendants() < (long) MeshOrganizer.MAX_DOWNSTREAMS * (MeshOrganizer.MAX_DEPTH - distance)) return d.pushDownstreamNode(node); return addDownstreamNode(node); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java index 8f95c5ff9..66bc47f81 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerStressTest.java @@ -279,7 +279,7 @@ public class VoidParameterServerStressTest extends BaseND4JTest { log.info("p50: {} us", newTimes.get(newTimes.size() / 2) / 1000); - parameterServer.shutdown();; + parameterServer.shutdown(); for (VoidParameterServer server : shards) { server.shutdown(); @@ -492,7 +492,7 @@ public class VoidParameterServerStressTest extends BaseND4JTest { @Test @Timeout(60) public void testPerformanceUnicast3() throws Exception { VoidConfiguration voidConfiguration = VoidConfiguration.builder().numberOfShards(1) - .shardAddresses(Arrays.asList("127.0.0.1:49823")).build(); + .shardAddresses(Collections.singletonList("127.0.0.1:49823")).build(); voidConfiguration.setUnicastControllerPort(49823); Transport transport = new RoutedTransport(); @@ -538,7 +538,7 @@ public class VoidParameterServerStressTest extends BaseND4JTest { @Test @Timeout(60) public void testPerformanceUnicast4() throws Exception { VoidConfiguration voidConfiguration = VoidConfiguration.builder().numberOfShards(1) - .shardAddresses(Arrays.asList("127.0.0.1:49823")).build(); + .shardAddresses(Collections.singletonList("127.0.0.1:49823")).build(); voidConfiguration.setUnicastControllerPort(49823); Transport transport = new RoutedTransport(); @@ -635,7 +635,7 @@ public class VoidParameterServerStressTest extends BaseND4JTest { protected static CbowRequestMessage getCRM() { int w1 = RandomUtils.nextInt(0, NUM_WORDS); - int syn0[] = new int[5]; + int[] syn0 = new int[5]; for (int e = 0; e < syn0.length; e++) { syn0[e] = RandomUtils.nextInt(0, NUM_WORDS); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java index c29f7d6e3..5b716db55 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/VoidParameterServerTest.java @@ -48,6 +48,7 @@ import org.nd4j.parameterserver.distributed.transport.Transport; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -67,7 +68,7 @@ public class VoidParameterServerTest extends BaseND4JTest { if (localIPs == null) { localIPs = new ArrayList<>(VoidParameterServer.getLocalAddresses()); - badIPs = Arrays.asList("127.0.0.1"); + badIPs = Collections.singletonList("127.0.0.1"); } } @@ -277,8 +278,8 @@ public class VoidParameterServerTest extends BaseND4JTest { * Now we're checking how data storage was initialized */ - assertEquals(null, shards[t].getNegTable()); - assertEquals(null, shards[t].getSyn1()); + assertNull(shards[t].getNegTable()); + assertNull(shards[t].getSyn1()); assertNotEquals(null, shards[t].getExpTable()); @@ -302,7 +303,7 @@ public class VoidParameterServerTest extends BaseND4JTest { // now we assign each row to something for (int t = 0; t < threads.length; t++) { - shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, (double) t)); + shards[t].handleMessage(new DistributedAssignMessage(WordVectorStorage.SYN_0, 1, t)); assertEquals(Nd4j.create(message.getColumnsPerShard()).assign((double) t), shards[t].getSyn0().getRow(1)); } @@ -342,8 +343,8 @@ public class VoidParameterServerTest extends BaseND4JTest { } // and at this moment, Shard_0 should contain aggregated vector for us - assertEquals(true, shards[0].clipboard.isTracking(0L, 1L)); - assertEquals(true, shards[0].clipboard.isReady(0L, 1L)); + assertTrue(shards[0].clipboard.isTracking(0L, 1L)); + assertTrue(shards[0].clipboard.isReady(0L, 1L)); INDArray jointVector = shards[0].clipboard.nextCandidate().getAccumulatedResult(); @@ -385,7 +386,7 @@ public class VoidParameterServerTest extends BaseND4JTest { // at this moment ot should be caclulated everywhere exp = Nd4j.create(new double[] {0.0, 30.0, 120.0}); for (int t = 0; t < threads.length; t++) { - assertEquals(true, shards[t].clipboard.isReady(0L, 2L)); + assertTrue(shards[t].clipboard.isReady(0L, 2L)); DotAggregation dot = (DotAggregation) shards[t].clipboard.unpin(0L, 2L); INDArray aggregated = dot.getAccumulatedResult(); assertEquals(exp, aggregated); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java index 48b024e38..73bb74a1e 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/ClipboardTest.java @@ -66,7 +66,7 @@ public class ClipboardTest extends BaseND4JTest { clipboard.pin(aggregation); } - assertEquals(false, clipboard.hasCandidates()); + assertFalse(clipboard.hasCandidates()); assertEquals(0, clipboard.getNumberOfCompleteStacks()); assertEquals(100, clipboard.getNumberOfPinnedStacks()); } @@ -98,7 +98,7 @@ public class ClipboardTest extends BaseND4JTest { assertEquals(0, aggregation.getMissingChunks()); - assertEquals(true, clipboard.hasCandidates()); + assertTrue(clipboard.hasCandidates()); assertEquals(1, clipboard.getNumberOfCompleteStacks()); } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java index 16638a32f..b398f61a6 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/FrameCompletionHandlerTest.java @@ -66,7 +66,7 @@ public class FrameCompletionHandlerTest extends BaseND4JTest { for (Long originator : originators) { for (Long frame : frames) { - assertEquals(true, handler.isCompleted(originator, frame)); + assertTrue(handler.isCompleted(originator, frame)); } } } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java index 2890ed7d2..c39fb6afd 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/logic/routing/InterleavedRouterTest.java @@ -86,7 +86,7 @@ public class InterleavedRouterTest extends BaseND4JTest { InterleavedRouter router = new InterleavedRouter(); router.init(configuration, transport); - int w1[] = new int[] {512, 345, 486, 212}; + int[] w1 = new int[] {512, 345, 486, 212}; for (int i = 0; i < w1.length; i++) { SkipGramRequestMessage message = new SkipGramRequestMessage(w1[i], 1, new int[] {1, 2, 3}, diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java index 23b94d76f..7d283148e 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/VoidMessageTest.java @@ -51,7 +51,7 @@ public class VoidMessageTest extends BaseND4JTest { byte[] bytes = message.asBytes(); - SkipGramRequestMessage restored = (SkipGramRequestMessage) VoidMessage.fromBytes(bytes); + SkipGramRequestMessage restored = VoidMessage.fromBytes(bytes); assertNotEquals(null, restored); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java index 4456c6d04..d3a5f56f6 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/messages/aggregations/VoidAggregationTest.java @@ -124,7 +124,7 @@ public class VoidAggregationTest extends BaseND4JTest { } INDArray result = aggregation.getAccumulatedResult(); - assertEquals(true, result.isScalar()); + assertTrue(result.isScalar()); assertEquals(exp, result.getDouble(0), 1e-5); } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java index f977912ec..51640d444 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/transport/RoutedTransportTest.java @@ -110,7 +110,7 @@ public class RoutedTransportTest extends BaseND4JTest { for (int t = 1; t < transports.length; t++) { message = transports[t].messages.poll(1, TimeUnit.SECONDS); - assertEquals(null, message); + assertNull(message); } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java index 24b8915e9..896754a41 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/util/NetworkOrganizerTest.java @@ -426,23 +426,21 @@ public class NetworkOrganizerTest extends BaseND4JTest { } protected String getRandomIp() { - StringBuilder builder = new StringBuilder(); - builder.append(RandomUtils.nextInt(1, 172)).append("."); - builder.append(RandomUtils.nextInt(0, 255)).append("."); - builder.append(RandomUtils.nextInt(0, 255)).append("."); - builder.append(RandomUtils.nextInt(1, 255)); + String builder = RandomUtils.nextInt(1, 172) + "." + + RandomUtils.nextInt(0, 255) + "." + + RandomUtils.nextInt(0, 255) + "." + + RandomUtils.nextInt(1, 255); - return builder.toString(); + return builder; } protected String getRandomAwsIp() { - StringBuilder builder = new StringBuilder("172."); - builder.append(RandomUtils.nextInt(16, 32)).append("."); - builder.append(RandomUtils.nextInt(0, 255)).append("."); - builder.append(RandomUtils.nextInt(1, 255)); + String builder = "172." + RandomUtils.nextInt(16, 32) + "." + + RandomUtils.nextInt(0, 255) + "." + + RandomUtils.nextInt(1, 255); - return builder.toString(); + return builder; } } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java index 92a6a668c..83fd4c324 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MeshOrganizerTest.java @@ -330,7 +330,7 @@ public class MeshOrganizerTest extends BaseND4JTest { mesh1.addNode(java.util.UUID.randomUUID().toString()); - try(val baos = new ByteArrayOutputStream();) { + try(val baos = new ByteArrayOutputStream()) { SerializationUtils.serialize(mesh1, baos); try(val bais = new ByteArrayInputStream(baos.toByteArray())) { diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java index 17e9dd7aa..08d763cf8 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/distributed/v2/util/MessageSplitterTest.java @@ -100,7 +100,7 @@ public class MessageSplitterTest extends BaseND4JTest { assertNotNull(ref.get()); assertEquals(array, ref.get().getPayload()); assertEquals(0, splitter.memoryUse.intValue()); - assertEquals(false, splitter.isTrackedMessage(message.getMessageId())); + assertFalse(splitter.isTrackedMessage(message.getMessageId())); assertEquals(0, splitter.trackers.size()); } } diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java index 7fee8e9c0..1626fc398 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/test/java/org/nd4j/parameterserver/node/ParameterServerNodeTest.java @@ -44,9 +44,9 @@ public class ParameterServerNodeTest extends BaseND4JTest { private static MediaDriver mediaDriver; private static Aeron aeron; private static ParameterServerNode parameterServerNode; - private static int parameterLength = 4; - private static int masterStatusPort = 40323 + new java.util.Random().nextInt(15999); - private static int statusPort = masterStatusPort - 1299; + private static final int parameterLength = 4; + private static final int masterStatusPort = 40323 + new java.util.Random().nextInt(15999); + private static final int statusPort = masterStatusPort - 1299; @BeforeAll public static void before() throws Exception { @@ -54,7 +54,7 @@ public class ParameterServerNodeTest extends BaseND4JTest { System.setProperty("play.server.dir", "/tmp"); aeron = Aeron.connect(getContext()); parameterServerNode = new ParameterServerNode(mediaDriver, statusPort); - parameterServerNode.runMain(new String[] {"-m", "true", "-s", "1," + String.valueOf(parameterLength), "-p", + parameterServerNode.runMain(new String[] {"-m", "true", "-s", "1," + parameterLength, "-p", String.valueOf(masterStatusPort), "-h", "localhost", "-id", "11", "-md", mediaDriver.aeronDirectoryName(), "-sp", String.valueOf(statusPort), "-sh", "localhost", "-u", String.valueOf(Runtime.getRuntime().availableProcessors())}); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java index 9d49e454f..dc281a329 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-status/src/test/java/org/nd4j/parameterserver/status/play/StorageTests.java @@ -42,7 +42,7 @@ public class StorageTests extends BaseND4JTest { assertEquals(noEmpty, mapDb.getState(1)); Thread.sleep(10000); - assertTrue(mapDb.numStates() == 0); + assertEquals(0, mapDb.numStates()); } @@ -57,7 +57,7 @@ public class StorageTests extends BaseND4JTest { assertEquals(noEmpty, statusStorage.getState(1)); Thread.sleep(10000); - assertTrue(statusStorage.numStates() == 0); + assertEquals(0, statusStorage.numStates()); } diff --git a/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/DummyDeAllocator.java b/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/DummyDeAllocator.java index 36fd694b7..94b284035 100644 --- a/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/DummyDeAllocator.java +++ b/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/DummyDeAllocator.java @@ -24,7 +24,7 @@ import org.bytedeco.javacpp.Pointer; import org.bytedeco.tensorflow.Deallocator_Pointer_long_Pointer; public class DummyDeAllocator extends Deallocator_Pointer_long_Pointer { - private static DummyDeAllocator INSTANCE = new DummyDeAllocator(); + private static final DummyDeAllocator INSTANCE = new DummyDeAllocator(); public static DummyDeAllocator getInstance() { return INSTANCE; diff --git a/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorDataType.java b/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorDataType.java index 08cb4610d..2d9ba91cc 100644 --- a/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorDataType.java +++ b/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorDataType.java @@ -113,22 +113,26 @@ public enum TensorDataType { public static TensorDataType fromNd4jType(INDArray array) { DataType dataType = array.dataType(); - switch(dataType) { - case COMPRESSED: - CompressedDataBuffer compressedData = (CompressedDataBuffer) array.data(); - CompressionDescriptor desc = compressedData.getCompressionDescriptor(); - String algo = desc.getCompressionAlgorithm(); - switch (algo) { - case "FLOAT16": return HALF; - case "INT8": return INT8; - case "UINT8": return UINT8; - case "INT16": return INT16; - case "UINT16": return UINT16; - default: throw new IllegalArgumentException("Unsupported compression algorithm: " + algo); - } - - default: return fromNd4jType(dataType); + if (dataType == DataType.COMPRESSED) { + CompressedDataBuffer compressedData = (CompressedDataBuffer) array.data(); + CompressionDescriptor desc = compressedData.getCompressionDescriptor(); + String algo = desc.getCompressionAlgorithm(); + switch (algo) { + case "FLOAT16": + return HALF; + case "INT8": + return INT8; + case "UINT8": + return UINT8; + case "INT16": + return INT16; + case "UINT16": + return UINT16; + default: + throw new IllegalArgumentException("Unsupported compression algorithm: " + algo); + } } + return fromNd4jType(dataType); } } diff --git a/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java b/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java index 7d9d9cb59..def412f06 100644 --- a/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java +++ b/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/TensorflowConversion.java @@ -177,7 +177,7 @@ public class TensorflowConversion { BytePointer tf_data = new BytePointer(TF_TensorData(tf_tensor)).capacity(TF_TensorByteSize(tf_tensor)); TF_Status status = TF_NewStatus(); for (int i = 0; i < length; i++) { - tf_data.position(8 * i).putLong(offset); + tf_data.position(8L * i).putLong(offset); offset += TF_StringEncode(strings[i], strings[i].capacity() - 1, tf_data.position(8 * length + offset), tf_data.capacity() - tf_data.position(), status); if (TF_GetCode(status) != TF_OK) { throw new IllegalStateException("ERROR: Unable to convert tensor " + TF_Message(status).getString()); @@ -233,8 +233,8 @@ public class TensorflowConversion { SizeTPointer size = new SizeTPointer(1); TF_Status status = TF_NewStatus(); for (int i = 0; i < length; i++) { - long offset = data.position(8 * i).getLong(); - TF_StringDecode(data.position(8 * length + offset), data.capacity() - data.position(), str, size, status); + long offset = data.position(8L * i).getLong(); + TF_StringDecode(data.position(8L * length + offset), data.capacity() - data.position(), str, size, status); if (TF_GetCode(status) != TF_OK) { throw new IllegalStateException("ERROR: Unable to convert tensor " + TF_Message(status).getString()); } diff --git a/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java b/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java index 11e76c519..fa309b96e 100644 --- a/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java +++ b/cavis-nd4j/cavis-nd4j-tensorflow/src/main/java/org/nd4j/tensorflow/conversion/graphrunner/GraphRunner.java @@ -57,7 +57,7 @@ public class GraphRunner implements Closeable { //the in memory representation parsed from protobuf private TF_Graph graph; //the conversion between nd4j and TensorFlow - private TensorflowConversion conversion = TensorflowConversion.getInstance(); + private final TensorflowConversion conversion = TensorflowConversion.getInstance(); //a persistent session to be used when running the graph private TF_Session session; //the options for the model @@ -74,7 +74,7 @@ public class GraphRunner implements Closeable { @Setter @Singular private Map inputDataTypes,outputDataTypes; - private static Map,GraphRunner> recastGraphDefs; + private static final Map,GraphRunner> recastGraphDefs; static { recastGraphDefs = new ConcurrentHashMap<>(); @@ -598,8 +598,8 @@ public class GraphRunner implements Closeable { byte[] graphForDataType = graphForDataType(from,to); GraphRunner graphRunner = GraphRunner.builder() .graphBytes(graphForDataType) - .inputNames(Arrays.asList("input")) - .outputNames(Arrays.asList("cast_output")) + .inputNames(Collections.singletonList("input")) + .outputNames(Collections.singletonList("cast_output")) .build(); recastGraphDefs.put(key,graphRunner); diff --git a/cavis-ui/cavis-ui-common/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java b/cavis-ui/cavis-ui-common/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java index d3680f43a..4770b2d76 100644 --- a/cavis-ui/cavis-ui-common/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java +++ b/cavis-ui/cavis-ui-common/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java @@ -63,11 +63,11 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { private static final Logger log = LoggerFactory.getLogger(ConvolutionalIterationListener.class); private int minibatchNum = 0; private boolean openBrowser = true; - private String path; - private boolean firstIteration = true; + private final String path; + private final boolean firstIteration = true; - private Color borderColor = new Color(140, 140, 140); - private Color bgColor = new Color(255, 255, 255); + private final Color borderColor = new Color(140, 140, 140); + private final Color bgColor = new Color(255, 255, 255); private final StatsStorageRouter ssr; private final String sessionID; @@ -217,7 +217,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { try { sourceImage = restoreRGBImage( - inputs.tensorAlongDimension(sampleDim, new int[] {3, 2, 1})); + inputs.tensorAlongDimension(sampleDim, 3, 2, 1)); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/Chart.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/Chart.java index d6fc0f165..458134b75 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/Chart.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/Chart.java @@ -70,8 +70,8 @@ public abstract class Chart extends Component { @SuppressWarnings("unchecked") public static abstract class Builder> { - private String title; - private StyleChart style; + private final String title; + private final StyleChart style; private Boolean suppressAxisHorizontal; private Boolean suppressAxisVertical; private boolean showLegend; diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartHistogram.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartHistogram.java index dee8b4f9f..bbc3c5fd1 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartHistogram.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartHistogram.java @@ -52,9 +52,9 @@ public class ChartHistogram extends Chart { public static class Builder extends Chart.Builder { - private List lowerBounds = new ArrayList<>(); - private List upperBounds = new ArrayList<>(); - private List yValues = new ArrayList<>(); + private final List lowerBounds = new ArrayList<>(); + private final List upperBounds = new ArrayList<>(); + private final List yValues = new ArrayList<>(); public Builder(String title, StyleChart style) { super(title, style); diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartHorizontalBar.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartHorizontalBar.java index bfc15f9e4..afd23f73b 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartHorizontalBar.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartHorizontalBar.java @@ -55,8 +55,8 @@ public class ChartHorizontalBar extends Chart { public static class Builder extends Chart.Builder { - private List labels = new ArrayList<>(); - private List values = new ArrayList<>(); + private final List labels = new ArrayList<>(); + private final List values = new ArrayList<>(); private Double xMin; private Double xMax; diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartLine.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartLine.java index d40b63682..e99cbb317 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartLine.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartLine.java @@ -53,10 +53,10 @@ public class ChartLine extends Chart { public static class Builder extends Chart.Builder { - private List x = new ArrayList<>(); - private List y = new ArrayList<>(); - private List seriesNames = new ArrayList<>(); - private boolean showLegend = true; + private final List x = new ArrayList<>(); + private final List y = new ArrayList<>(); + private final List seriesNames = new ArrayList<>(); + private final boolean showLegend = true; public Builder(String title, StyleChart style) { diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartScatter.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartScatter.java index ae79a8c77..d200b8906 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartScatter.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartScatter.java @@ -54,9 +54,9 @@ public class ChartScatter extends Chart { public static class Builder extends Chart.Builder { - private List x = new ArrayList<>(); - private List y = new ArrayList<>(); - private List seriesNames = new ArrayList<>(); + private final List x = new ArrayList<>(); + private final List y = new ArrayList<>(); + private final List seriesNames = new ArrayList<>(); public Builder(String title, StyleChart style) { super(title, style); diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartStackedArea.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartStackedArea.java index 199357c26..238d44f76 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartStackedArea.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartStackedArea.java @@ -54,8 +54,8 @@ public class ChartStackedArea extends Chart { public static class Builder extends Chart.Builder { private double[] x; - private List y = new ArrayList<>(); - private List seriesNames = new ArrayList<>(); + private final List y = new ArrayList<>(); + private final List seriesNames = new ArrayList<>(); public Builder(String title, StyleChart style) { diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartTimeline.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartTimeline.java index 57b2bf4f6..c5818f9a7 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartTimeline.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/chart/ChartTimeline.java @@ -55,8 +55,8 @@ public class ChartTimeline extends Chart { public static class Builder extends Chart.Builder { - private List laneNames = new ArrayList<>(); - private List> laneData = new ArrayList<>(); + private final List laneNames = new ArrayList<>(); + private final List> laneData = new ArrayList<>(); public Builder(String title, StyleChart style) { diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/decorator/DecoratorAccordion.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/decorator/DecoratorAccordion.java index 543282d74..1e11ee29d 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/decorator/DecoratorAccordion.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/decorator/DecoratorAccordion.java @@ -54,9 +54,9 @@ public class DecoratorAccordion extends Component { public static class Builder { - private StyleAccordion style; + private final StyleAccordion style; private String title; - private List innerComponents = new ArrayList<>(); + private final List innerComponents = new ArrayList<>(); private boolean defaultCollapsed; public Builder(StyleAccordion style) { diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/table/ComponentTable.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/table/ComponentTable.java index aebb3d0e6..80dedd965 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/table/ComponentTable.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/table/ComponentTable.java @@ -56,7 +56,7 @@ public class ComponentTable extends Component { public static class Builder { - private StyleTable style; + private final StyleTable style; private String[] header; private String[][] content; diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/text/ComponentText.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/text/ComponentText.java index fc5ee7d7a..19f9c87f3 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/text/ComponentText.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/components/text/ComponentText.java @@ -56,8 +56,8 @@ public class ComponentText extends Component { public static class Builder { - private StyleText style; - private String text; + private final StyleText style; + private final String text; public Builder(String text, StyleText style) { this.text = text; diff --git a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/standalone/StaticPageUtil.java b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/standalone/StaticPageUtil.java index 0ee0d2527..7bffdaa04 100644 --- a/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/standalone/StaticPageUtil.java +++ b/cavis-ui/cavis-ui-components/src/main/java/org/deeplearning4j/ui/standalone/StaticPageUtil.java @@ -37,6 +37,7 @@ import java.io.File; import java.io.IOException; import java.io.StringWriter; import java.io.Writer; +import java.nio.charset.StandardCharsets; import java.util.*; public class StaticPageUtil { @@ -87,7 +88,7 @@ public class StaticPageUtil { cfg.setTemplateExceptionHandler(TemplateExceptionHandler.RETHROW_HANDLER); ClassPathResource cpr = new ClassPathResource("assets/dl4j-ui.js"); - String scriptContents = IOUtils.toString(cpr.getInputStream(), "UTF-8"); + String scriptContents = IOUtils.toString(cpr.getInputStream(), StandardCharsets.UTF_8); Map pageElements = new HashMap<>(); List list = new ArrayList<>(); diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/nearestneighbors/word2vec/NearestNeighborsQuery.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/nearestneighbors/word2vec/NearestNeighborsQuery.java index f379dba6d..b912acda8 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/nearestneighbors/word2vec/NearestNeighborsQuery.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/nearestneighbors/word2vec/NearestNeighborsQuery.java @@ -21,6 +21,7 @@ package org.deeplearning4j.ui.model.nearestneighbors.word2vec; import java.io.Serializable; +import java.util.Objects; /** * @author Adam Gibson @@ -63,7 +64,7 @@ public class NearestNeighborsQuery implements Serializable { if (numWords != that.numWords) return false; - return !(word != null ? !word.equals(that.word) : that.word != null); + return !(!Objects.equals(word, that.word)); } diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java index e56ad1f67..3ecc58fd5 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java @@ -73,7 +73,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { private Map> gcStatsAtLastReport; //NOTE: may have multiple models, due to multiple pretrain layers all using the same StatsListener - private List modelInfos = new ArrayList<>(); + private final List modelInfos = new ArrayList<>(); private Map activationHistograms; private Map meanActivations; //TODO replace with Eclipse collections primitive maps... @@ -687,7 +687,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { router.putStaticInfo(initReport); //TODO error handling } - private Map devPointers = new HashMap<>(); + private final Map devPointers = new HashMap<>(); private synchronized Pointer getDevicePointer(int device) { if (devPointers.containsKey(device)) { diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/impl/SbeUtil.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/impl/SbeUtil.java index 0e38e7799..7f5bda056 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/impl/SbeUtil.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/impl/SbeUtil.java @@ -22,11 +22,12 @@ package org.deeplearning4j.ui.model.stats.impl; import java.io.*; import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.util.Map; public class SbeUtil { - public static final Charset UTF8 = Charset.forName("UTF-8"); + public static final Charset UTF8 = StandardCharsets.UTF_8; public static final byte[] EMPTY_BYTES = new byte[0]; //Also equivalent to "".getBytes(UTF8); private SbeUtil() {} diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/InMemoryStatsStorage.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/InMemoryStatsStorage.java index 76f00ef2b..f99a88155 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/InMemoryStatsStorage.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/InMemoryStatsStorage.java @@ -62,9 +62,7 @@ public class InMemoryStatsStorage extends BaseCollectionStatsStorage { @Override public void putStaticInfo(Persistable staticInfo) { List sses = checkStorageEvents(staticInfo); - if (!sessionIDs.contains(staticInfo.getSessionID())) { - sessionIDs.add(staticInfo.getSessionID()); - } + sessionIDs.add(staticInfo.getSessionID()); SessionTypeWorkerId id = new SessionTypeWorkerId(staticInfo.getSessionID(), staticInfo.getTypeID(), staticInfo.getWorkerID()); diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/mapdb/MapDBStatsStorage.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/mapdb/MapDBStatsStorage.java index b7a2fecf9..b05d43d67 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/mapdb/MapDBStatsStorage.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/mapdb/MapDBStatsStorage.java @@ -44,12 +44,12 @@ public class MapDBStatsStorage extends BaseCollectionStatsStorage { private static final String COMPOSITE_KEY_SEPARATOR = "@@@"; private boolean isClosed = false; - private DB db; - private Lock updateMapLock = new ReentrantLock(true); + private final DB db; + private final Lock updateMapLock = new ReentrantLock(true); - private Map classToInteger; //For storage - private Map integerToClass; //For storage - private Atomic.Integer classCounter; + private final Map classToInteger; //For storage + private final Map integerToClass; //For storage + private final Atomic.Integer classCounter; public MapDBStatsStorage() { this(new Builder()); @@ -147,9 +147,7 @@ public class MapDBStatsStorage extends BaseCollectionStatsStorage { @Override public void putStaticInfo(Persistable staticInfo) { List sses = checkStorageEvents(staticInfo); - if (!sessionIDs.contains(staticInfo.getSessionID())) { - sessionIDs.add(staticInfo.getSessionID()); - } + sessionIDs.add(staticInfo.getSessionID()); SessionTypeWorkerId id = new SessionTypeWorkerId(staticInfo.getSessionID(), staticInfo.getTypeID(), staticInfo.getWorkerID()); diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/sqlite/J7FileStatsStorage.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/sqlite/J7FileStatsStorage.java index 07dee8b04..1b370b13a 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/sqlite/J7FileStatsStorage.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/storage/sqlite/J7FileStatsStorage.java @@ -45,7 +45,7 @@ public class J7FileStatsStorage implements StatsStorage { private final File file; private final Connection connection; - private List listeners = new ArrayList<>(); + private final List listeners = new ArrayList<>(); /** * @param file Storage location for the stats @@ -445,7 +445,7 @@ public class J7FileStatsStorage implements StatsStorage { List out = new ArrayList<>(); while (rs.next()) { byte[] bytes = rs.getBytes(5); - out.add((Persistable) deserialize(bytes)); + out.add(deserialize(bytes)); } return out; } catch (SQLException e) { @@ -561,7 +561,7 @@ public class J7FileStatsStorage implements StatsStorage { List out = new ArrayList<>(); while (rs.next()) { byte[] bytes = rs.getBytes(6); - out.add((Persistable) deserialize(bytes)); + out.add(deserialize(bytes)); } return out; } catch (SQLException e) { @@ -621,7 +621,7 @@ public class J7FileStatsStorage implements StatsStorage { List out = new ArrayList<>(); while (rs.next()) { byte[] bytes = rs.getBytes(1); - out.add((Persistable) deserialize(bytes)); + out.add(deserialize(bytes)); } return out; } catch (SQLException e) { diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/HistogramBin.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/HistogramBin.java index a8f4dffc3..37fc4e87f 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/HistogramBin.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/HistogramBin.java @@ -29,6 +29,7 @@ import org.slf4j.LoggerFactory; import java.io.Serializable; import java.math.BigDecimal; +import java.math.RoundingMode; import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicInteger; @@ -87,7 +88,7 @@ public class HistogramBin implements Serializable { BigDecimal[] keys = new BigDecimal[numberOfBins]; for (int x = 0; x < numberOfBins; x++) { - BigDecimal pos = new BigDecimal((min + (x * binSize))).setScale(rounds, BigDecimal.ROUND_CEILING); + BigDecimal pos = BigDecimal.valueOf(min + (x * binSize)).setScale(rounds, RoundingMode.CEILING); data.put(pos, new AtomicInteger(0)); keys[x] = pos; } @@ -110,7 +111,7 @@ public class HistogramBin implements Serializable { } public static class Builder { - private INDArray source; + private final INDArray source; private int binCount; private int rounds = 2; diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/beans/CompactModelAndGradient.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/beans/CompactModelAndGradient.java index 5a52f0dce..19b7a552a 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/beans/CompactModelAndGradient.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/weights/beans/CompactModelAndGradient.java @@ -22,10 +22,7 @@ package org.deeplearning4j.ui.model.weights.beans; import java.io.Serializable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; /** * Slightly modified version of ModelAndGradient, with binned params/gradients, suitable for fast network transfers for HistogramIterationListener @@ -136,9 +133,9 @@ public class CompactModelAndGradient implements Serializable { if (Double.compare(that.score, score) != 0) return false; - if (parameters != null ? !parameters.equals(that.parameters) : that.parameters != null) + if (!Objects.equals(parameters, that.parameters)) return false; - return !(gradients != null ? !gradients.equals(that.gradients) : that.gradients != null); + return !(!Objects.equals(gradients, that.gradients)); } @Override diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java index 21b77bab2..08c87d92c 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/VertxUIServer.java @@ -83,7 +83,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { private static VertxUIServer instance; @Getter - private static AtomicBoolean multiSession = new AtomicBoolean(false); + private static final AtomicBoolean multiSession = new AtomicBoolean(false); @Getter @Setter private static Function statsStorageProvider; @@ -217,7 +217,7 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { } - private List uiModules = new CopyOnWriteArrayList<>(); + private final List uiModules = new CopyOnWriteArrayList<>(); private RemoteReceiverModule remoteReceiverModule; /** * Loader that attaches {@code StatsStorage} provided by {@code #statsStorageProvider} for the given session ID @@ -226,16 +226,16 @@ public class VertxUIServer extends AbstractVerticle implements UIServer { private Function statsStorageLoader; //typeIDModuleMap: Records which modules are registered for which type IDs - private Map> typeIDModuleMap = new ConcurrentHashMap<>(); + private final Map> typeIDModuleMap = new ConcurrentHashMap<>(); private HttpServer server; - private AtomicBoolean shutdown = new AtomicBoolean(false); - private long uiProcessingDelay = 500; //500ms. TODO make configurable + private final AtomicBoolean shutdown = new AtomicBoolean(false); + private final long uiProcessingDelay = 500; //500ms. TODO make configurable private final BlockingQueue eventQueue = new LinkedBlockingQueue<>(); - private List> listeners = new CopyOnWriteArrayList<>(); - private List statsStorageInstances = new CopyOnWriteArrayList<>(); + private final List> listeners = new CopyOnWriteArrayList<>(); + private final List statsStorageInstances = new CopyOnWriteArrayList<>(); private Thread uiEventRoutingThread; diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java index f0cb3a29c..a14f45dea 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/api/UIServer.java @@ -175,5 +175,5 @@ public interface UIServer { */ static Thread getShutdownHook() { return VertxUIServer.getShutdownHook(); - }; + } } diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java index 25af6684e..d1cc5a01f 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/i18n/DefaultI18N.java @@ -38,12 +38,12 @@ public class DefaultI18N implements I18N { public static final String FALLBACK_LANGUAGE = "en"; //use this if the specified language doesn't have the requested message private static DefaultI18N instance; - private static Map sessionInstances = Collections.synchronizedMap(new HashMap<>()); + private static final Map sessionInstances = Collections.synchronizedMap(new HashMap<>()); private static Throwable languageLoadingException = null; private String currentLanguage = DEFAULT_LANGUAGE; - private Map> messagesByLanguage = new HashMap<>(); + private final Map> messagesByLanguage = new HashMap<>(); /** * Get global instance (used in single-session mode) diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/i18n/I18NProvider.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/i18n/I18NProvider.java index 15ddbca99..10a744602 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/i18n/I18NProvider.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/i18n/I18NProvider.java @@ -28,7 +28,7 @@ public class I18NProvider { /** * Current I18N instance */ - private static I18N i18n = DefaultI18N.getInstance(); + private static final I18N i18n = DefaultI18N.getInstance(); /** * Get the current/global I18N instance (used in single-session mode) diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java index dd0ea2bf0..096b56113 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/remote/RemoteReceiverModule.java @@ -43,7 +43,7 @@ import java.util.concurrent.atomic.AtomicBoolean; @Slf4j public class RemoteReceiverModule implements UIModule { - private AtomicBoolean enabled = new AtomicBoolean(false); + private final AtomicBoolean enabled = new AtomicBoolean(false); private StatsStorageRouter statsStorage; public void setEnabled(boolean enabled) { diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 975d78a3f..858648018 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -80,19 +80,19 @@ public class TrainModule implements UIModule { public static final double NAN_REPLACEMENT_VALUE = 0.0; //UI front-end chokes on NaN in JSON public static final int DEFAULT_MAX_CHART_POINTS = 512; private static final DecimalFormat df2 = new DecimalFormat("#.00"); - private static DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); + private static final DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); private enum ModelType { MLN, CG, Layer } private final int maxChartPoints; //Technically, the way it's set up: won't exceed 2*maxChartPoints - private Map knownSessionIDs = Collections.synchronizedMap(new HashMap<>()); + private final Map knownSessionIDs = Collections.synchronizedMap(new HashMap<>()); private String currentSessionID; private int currentWorkerIdx; - private Map workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID - private Map> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID - private Map lastUpdateForSession = new ConcurrentHashMap<>(); + private final Map workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID + private final Map> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID + private final Map lastUpdateForSession = new ConcurrentHashMap<>(); private final Configuration configuration; @@ -795,7 +795,7 @@ public class TrainModule implements UIModule { {i18N.getMessage("train.overview.perftable.examplesPerSec"), ""}}; if (last != null) { - perfInfo[2][1] = String.valueOf(dateFormat.format(new Date(last.getTimeStamp()))); + perfInfo[2][1] = dateFormat.format(new Date(last.getTimeStamp())); perfInfo[3][1] = String.valueOf(last.getTotalMinibatches()); perfInfo[4][1] = String.valueOf(df2.format(last.getMinibatchesPerSecond())); perfInfo[5][1] = String.valueOf(df2.format(last.getExamplesPerSecond())); @@ -1334,7 +1334,7 @@ public class TrainModule implements UIModule { return new MeanMagnitudes(iterCounts, ratioValues, outParamMM, outUpdateMM); } - private static Triple EMPTY_TRIPLE = new Triple<>(new int[0], new float[0], new float[0]); + private static final Triple EMPTY_TRIPLE = new Triple<>(new int[0], new float[0], new float[0]); private static Triple getLayerActivations(int index, TrainModuleUtils.GraphInfo gi, List updates, List iterationCounts) { diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java index f90fb1f24..ff6f00901 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java @@ -245,7 +245,7 @@ public class TrainModuleUtils { if (layerName == null) layerName = "layer0"; vertexNames.add(layerName); - originalVertexName.add(String.valueOf("0")); + originalVertexName.add("0"); String layerType = config.getLayer().getClass().getSimpleName().replaceAll("Layer$", ""); layerTypes.add(layerType); diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java index b7602ecb4..7e0b79194 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/tsne/TsneModule.java @@ -41,7 +41,7 @@ import java.util.*; public class TsneModule implements UIModule { private static final String UPLOADED_FILE = "UploadedFile"; - private Map> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>()); + private final Map> knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>()); private List uploadedFileLines = null; public TsneModule() { diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ModelMetaData.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ModelMetaData.java index 397899c29..24e5fbc98 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ModelMetaData.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ModelMetaData.java @@ -37,6 +37,6 @@ public class ModelMetaData { * @return */ public boolean useMDS() { - return inputShape.length > 1 ? true : false; + return inputShape.length > 1; } } diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ZooModel.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ZooModel.java index 977de99ae..958edec33 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ZooModel.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ZooModel.java @@ -71,10 +71,10 @@ public abstract class ZooModel implements InstantiableModel { File cachedFile = new File(rootCacheDir, localFilename); if (!cachedFile.exists()) { - log.info("Downloading model to " + cachedFile.toString()); + log.info("Downloading model to " + cachedFile); FileUtils.copyURLToFile(new URL(remoteUrl), cachedFile,Integer.MAX_VALUE,Integer.MAX_VALUE); } else { - log.info("Using cached model at " + cachedFile.toString()); + log.info("Using cached model at " + cachedFile); } long expectedChecksum = pretrainedChecksum(pretrainedType); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/util/darknet/DarknetLabels.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/util/darknet/DarknetLabels.java index 5d9f91a96..c62e18b32 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/util/darknet/DarknetLabels.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/util/darknet/DarknetLabels.java @@ -31,8 +31,8 @@ import java.util.List; public class DarknetLabels extends BaseLabels { - private boolean shortNames; - private int numClasses; + private final boolean shortNames; + private final int numClasses; /** Calls {@code this(true)}. * Defaults to 1000 clasess diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java index 04d1f8fce..0e6fdfb38 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java @@ -129,7 +129,7 @@ public class TestImageNet extends BaseDL4JTest { labels = new VOCLabels(); for (DetectedObject obj : objs) { ClassPrediction classPrediction = labels.decodePredictions(obj.getClassPredictions(), 1).get(0).get(0); - log.info(obj.toString() + " " + classPrediction); + log.info(obj + " " + classPrediction); assertEquals("dog", classPrediction.getLabel()); } @@ -155,7 +155,7 @@ public class TestImageNet extends BaseDL4JTest { labels = new COCOLabels(); for (DetectedObject obj : objs) { ClassPrediction classPrediction = labels.decodePredictions(obj.getClassPredictions(), 1).get(0).get(0); - log.info(obj.toString() + " " + classPrediction); + log.info(obj + " " + classPrediction); assertEquals("dog", classPrediction.getLabel()); } diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java index 9abe0b848..f9e8b83a1 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java @@ -136,7 +136,7 @@ public class TestInstantiation extends BaseDL4JTest { assertTrue(model.pretrainedAvailable(PretrainedType.IMAGENET)); ComputationGraph initializedModel = (ComputationGraph) model.initPretrained(); - INDArray f = Nd4j.rand(new int[]{1, 3, 224, 224}); + INDArray f = Nd4j.rand(1, 3, 224, 224); INDArray[] result = initializedModel.output(f); assertArrayEquals(result[0].shape(), new long[]{1, 1000}); diff --git a/vsconfig.gradle b/vsconfig.gradle index 229600cf4..a247ceb10 100644 --- a/vsconfig.gradle +++ b/vsconfig.gradle @@ -51,7 +51,7 @@ def configureVisualStudio() { return } def vswhereOutput = "${vswherePath} -latest -format json".execute().text.trim() - def vswhereJson = new groovy.json.JsonSlurper().parseText(vswhereOutput); + def vswhereJson = new groovy.json.JsonSlurper().parseText(vswhereOutput) if (vswhereJson.isEmpty()) { println "Visual Studio not found!" return From 4dfc637305371a1762d34c8540d10c5bb957438d Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 21 Oct 2022 16:36:56 +0200 Subject: [PATCH 053/126] Fix javadoc and cleanup Signed-off-by: brian --- build.gradle | 19 ++ .../java/org/datavec/api/writable/Text.java | 2 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 11 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 60 ++-- .../evaluation/classification/Evaluation.java | 4 +- .../nd4j/linalg/api/blas/BlasException.java | 2 +- .../java/org/nd4j/linalg/api/blas/Lapack.java | 21 +- .../java/org/nd4j/linalg/api/blas/Level1.java | 8 +- .../linalg/api/buffer/BaseDataBuffer.java | 9 +- .../linalg/checkutil/NDArrayCreationUtil.java | 2 +- .../org/nd4j/linalg/dataset/api/DataSet.java | 2 +- .../nd4j/linalg/dataset/api/MultiDataSet.java | 2 +- .../org/nd4j/linalg/util/ND4JTestUtils.java | 256 +++++++++--------- .../iterator/impl/EmnistDataSetIterator.java | 4 +- .../iterator/impl/LFWDataSetIterator.java | 159 ++++++----- cavis-full/build.gradle | 2 +- .../org/nd4j/aeron/ipc/NDArrayMessage.java | 2 +- .../transport/RoutedTransport.java | 4 +- 18 files changed, 317 insertions(+), 252 deletions(-) diff --git a/build.gradle b/build.gradle index cd5911461..fc9167f30 100644 --- a/build.gradle +++ b/build.gradle @@ -56,6 +56,7 @@ configurations.all { } + allprojects { Project proj -> apply plugin: 'com.google.osdetector' @@ -162,3 +163,21 @@ allprojects { Project proj -> } } } + + +task aggregatedJavadocs(type: Javadoc, description: 'Generate javadocs from all child projects as if it was a single project', group: 'Documentation') { + subprojects.each { proj -> + proj.tasks.withType(Javadoc).each { javadocTask -> + logger.quiet("Adding javadoc for project " + proj.name) + source += javadocTask.source + classpath += javadocTask.classpath + excludes += javadocTask.excludes + includes += javadocTask.includes + } + } + destinationDir = file("$buildDir/docs/javadoc") + title = "$project.name $version API" + options.author true + options.links 'http://docs.oracle.com/javase/8/docs/api/' + options.addStringOption('Xdoclint:none', '-quiet') +} \ No newline at end of file diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java index b36452a0d..b80d491d2 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java @@ -106,7 +106,7 @@ public class Text extends BinaryComparable implements WritableComparable 0 + * @param numEpochs The number of epochs for training. Must be > 0 * @param validationIter The DataSetIterator to use for validation (null to skip validation) * @param validationFrequency The frequency with which to run validation. 1 is every epoch, 2 is every other, etc. * @param listeners Additional listeners to use during this operation @@ -1479,7 +1479,7 @@ public class SameDiff extends SDBaseOps { * A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with - * @param numEpochs The number of epochs for training. Must be > 0 + * @param numEpochs The number of epochs for training. Must be > 0 * @param listeners Additional listeners to use during this operation * @return a {@link History} object containing the history information for this training operation * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). @@ -1497,7 +1497,7 @@ public class SameDiff extends SDBaseOps { * A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with - * @param numEpochs The number of epochs for training. Must be > 0 + * @param numEpochs The number of epochs for training. Must be > 0 * @param validationIter The MultiDataSetIterator to use for validation (null to skip validation) * @param validationFrequency The frequency with which to run validation. 1 is every epoch, 2 is every other, etc. * @param listeners Additional listeners to use during this operation @@ -1514,7 +1514,7 @@ public class SameDiff extends SDBaseOps { * A special case of {@link #fit()}. * * @param iter The iterator to train the SameDiff instance with - * @param numEpochs The number of epochs for training. Must be > 0 + * @param numEpochs The number of epochs for training. Must be > 0 * @param listeners Additional listeners to use during this operation * @return a {@link History} object containing the history information for this training operation * (evaluations specified in the {@link TrainingConfig}, loss values, and timing information). @@ -3036,7 +3036,6 @@ public class SameDiff extends SDBaseOps { * See also: {@link VariableType} * * @param variables Variables to convert to constants - * @return The (now constant) SDVariables */ public void convertToConstants(List variables) { if (variables.size() == 0) @@ -3201,7 +3200,7 @@ public class SameDiff extends SDBaseOps { * For example, {@code z(float) = x(float)+y(float)}, changing both x and y to double results in {@code z(double) = x(double)+y(double)} * without doing anything to change z's datatype directly (z datatype is inferred from x + y + add op).
* ARRAY type SDVariables cannot be converted directly, as their datatypes are determined by the function + - * input datatypes. + * input datatypes.
* Note that this method should be used with caution: incorrect datatype modifications may leave your network * in an incorrect state. For example, {@code op(x(float),y(float)) -> op(x(double),y(float))} may not be * supported by all ops. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index a80836439..2886fca90 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -382,7 +382,7 @@ public class SDBaseOps { } /** - * Cast the array to a new datatype - for example, Integer -> Float
+ * Cast the array to a new datatype - for example, Integer -> Float
* * @param arg Input variable to cast (NDARRAY type) * @param datatype Datatype to cast to @@ -393,7 +393,7 @@ public class SDBaseOps { } /** - * Cast the array to a new datatype - for example, Integer -> Float
+ * Cast the array to a new datatype - for example, Integer -> Float
* * @param name name May be null. Name for the output variable * @param arg Input variable to cast (NDARRAY type) @@ -654,7 +654,7 @@ public class SDBaseOps { * * @param x Input variable (NUMERIC type) * @param partitions 1D input with values 0 to numPartitions-1 (INT type) - * @param numPartitions Number of partitions, >= 1 + * @param numPartitions Number of partitions, >= 1 */ public SDVariable[] dynamicPartition(SDVariable x, SDVariable partitions, int numPartitions) { SDValidation.validateNumerical("dynamicPartition", "x", x); @@ -676,7 +676,7 @@ public class SDBaseOps { * @param names names May be null. Arrays of names for the output variables. * @param x Input variable (NUMERIC type) * @param partitions 1D input with values 0 to numPartitions-1 (INT type) - * @param numPartitions Number of partitions, >= 1 + * @param numPartitions Number of partitions, >= 1 */ public SDVariable[] dynamicPartition(String[] names, SDVariable x, SDVariable partitions, int numPartitions) { @@ -689,7 +689,7 @@ public class SDBaseOps { /** * Dynamically merge the specified input arrays into a single array, using the specified indices
* - * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) * @param x Input variables. (NUMERIC type) * @return output Merged output variable (NUMERIC type) */ @@ -705,7 +705,7 @@ public class SDBaseOps { * Dynamically merge the specified input arrays into a single array, using the specified indices
* * @param name name May be null. Name for the output variable - * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) + * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) * @param x Input variables. (NUMERIC type) * @return output Merged output variable (NUMERIC type) */ @@ -943,7 +943,7 @@ public class SDBaseOps { } /** - * Greater than operation: elementwise x > y
+ * Greater than operation: elementwise x > y
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -957,7 +957,7 @@ public class SDBaseOps { } /** - * Greater than operation: elementwise x > y
+ * Greater than operation: elementwise x > y
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -973,7 +973,7 @@ public class SDBaseOps { } /** - * Greater than operation: elementwise x > y
+ * Greater than operation: elementwise x > y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
@@ -993,7 +993,7 @@ public class SDBaseOps { } /** - * Greater than operation: elementwise x > y
+ * Greater than operation: elementwise x > y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
@@ -1015,7 +1015,7 @@ public class SDBaseOps { } /** - * Greater than or equals operation: elementwise x >= y
+ * Greater than or equals operation: elementwise x >= y
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1029,7 +1029,7 @@ public class SDBaseOps { } /** - * Greater than or equals operation: elementwise x >= y
+ * Greater than or equals operation: elementwise x >= y
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1045,7 +1045,7 @@ public class SDBaseOps { } /** - * Greater than or equal to operation: elementwise x >= y
+ * Greater than or equal to operation: elementwise x >= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
@@ -1065,7 +1065,7 @@ public class SDBaseOps { } /** - * Greater than or equal to operation: elementwise x >= y
+ * Greater than or equal to operation: elementwise x >= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
@@ -1232,7 +1232,7 @@ public class SDBaseOps { } /** - * Less than operation: elementwise x < y
+ * Less than operation: elementwise x < y
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1246,7 +1246,7 @@ public class SDBaseOps { } /** - * Less than operation: elementwise x < y
+ * Less than operation: elementwise x < y
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1262,7 +1262,7 @@ public class SDBaseOps { } /** - * Less than operation: elementwise x < y
+ * Less than operation: elementwise x < y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
@@ -1282,7 +1282,7 @@ public class SDBaseOps { } /** - * Less than operation: elementwise x < y
+ * Less than operation: elementwise x < y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
@@ -1304,7 +1304,7 @@ public class SDBaseOps { } /** - * Less than or equals operation: elementwise x <= y
+ * Less than or equals operation: elementwise x <= y
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1318,7 +1318,7 @@ public class SDBaseOps { } /** - * Less than or equals operation: elementwise x <= y
+ * Less than or equals operation: elementwise x <= y
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1334,7 +1334,7 @@ public class SDBaseOps { } /** - * Less than or equal to operation: elementwise x <= y
+ * Less than or equal to operation: elementwise x <= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
@@ -1354,7 +1354,7 @@ public class SDBaseOps { } /** - * Less than or equal to operation: elementwise x <= y
+ * Less than or equal to operation: elementwise x <= y
* If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
@@ -3590,7 +3590,7 @@ public class SDBaseOps { /** * Generate a sequence mask (with values 0 or 1) based on the specified lengths
- * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * {@code Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)}
* * @param lengths Lengths of the sequences (NUMERIC type) * @param maxLen Maximum sequence length @@ -3604,7 +3604,7 @@ public class SDBaseOps { /** * Generate a sequence mask (with values 0 or 1) based on the specified lengths
- * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * {@code Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)}
* * @param name name May be null. Name for the output variable * @param lengths Lengths of the sequences (NUMERIC type) @@ -3620,7 +3620,7 @@ public class SDBaseOps { /** * Generate a sequence mask (with values 0 or 1) based on the specified lengths
- * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * {@code Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)}
* * @param lengths Lengths of the sequences (NUMERIC type) * @param maxLen Maximum sequence length (INT type) @@ -3635,7 +3635,7 @@ public class SDBaseOps { /** * Generate a sequence mask (with values 0 or 1) based on the specified lengths
- * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
+ * {@code Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)}
* * @param name name May be null. Name for the output variable * @param lengths Lengths of the sequences (NUMERIC type) @@ -3761,7 +3761,7 @@ public class SDBaseOps { * then slice(input, begin=[0,1], size=[2,1] will return:
* [b]
* [e]
- * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
* * @param input input Variable to get subset of (NUMERIC type) * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) @@ -3783,7 +3783,7 @@ public class SDBaseOps { * then slice(input, begin=[0,1], size=[2,1] will return:
* [b]
* [e]
- * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
* * @param name name May be null. Name for the output variable * @param input input Variable to get subset of (NUMERIC type) @@ -3807,7 +3807,7 @@ public class SDBaseOps { * then slice(input, begin=[0,1], size=[2,1] will return:
* [b]
* [e]
- * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
* * @param input input Variable to get subset of (NUMERIC type) * @param begin Beginning index. Must be same length as rank of input array (INT type) @@ -3829,7 +3829,7 @@ public class SDBaseOps { * then slice(input, begin=[0,1], size=[2,1] will return:
* [b]
* [e]
- * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
+ * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
* * @param name name May be null. Name for the output variable * @param input input Variable to get subset of (NUMERIC type) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java index df2151210..9766a5b7c 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java @@ -176,7 +176,7 @@ public class Evaluation extends BaseEvaluation { * Constructor to use for top N accuracy * * @param labels Labels for the classes (may be null) - * @param topN Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N + * @param topN Value to use for top N accuracy calculation (<=1: standard accuracy). Note that with top N * accuracy, an example is considered 'correct' if the probability for the true class is one of the * highest N values */ @@ -1173,7 +1173,7 @@ public class Evaluation extends BaseEvaluation { /** * False Alarm Rate (FAR) reflects rate of misclassified to classified records - * http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
+ * {@link }http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}
* Note: value returned will differ depending on number of classes and settings.
* 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/BlasException.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/BlasException.java index 9e9e807b5..5b103b278 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/BlasException.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/BlasException.java @@ -34,7 +34,7 @@ public class BlasException extends Error { } /** - * Principal constructor - error message & error code + * Principal constructor - error message & error code * @param message the error message to put into the Exception * @param errorCode the library error number */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Lapack.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Lapack.java index b112c73fd..3bf6699ca 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Lapack.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Lapack.java @@ -28,15 +28,15 @@ public interface Lapack { * LU decomposiiton of a matrix * Factorize a matrix A * - * The matrix A is overridden by the L & U combined. + * The matrix A is overridden by the L & U combined. * The permutation results are returned directly as a vector. To * create the permutation matrix use getPFactor method - * To split out the L & U matrix use getLFactor and getUFactor methods + * To split out the L & U matrix use getLFactor and getUFactor methods * * getrf = triangular factorization (TRF) of a general matrix (GE) * * @param A the input matrix, it will be overwritten with the factors - * @returns Permutation array + * @return Permutation array * @throws Error - with a message to indicate failure (usu. bad params) */ INDArray getrf(INDArray A); @@ -53,7 +53,7 @@ public interface Lapack { * matrix Q and an upper triangular R matrix * * @param A the input matrix, it will be overwritten with the factors - * @param The R array if null R is not returned + * @param R The R array if null R is not returned * @throws Error - with a message to indicate failure (usu. bad params) */ void geqrf(INDArray A, INDArray R); @@ -71,8 +71,7 @@ public interface Lapack { * lower L ( or upper U ) triangular matrix * * @param A the input matrix, it will be overwritten with the factors - * @param whether to return the upper (false) or lower factor - * @returns Permutation array + * @param lower whether to return the upper (false) or lower factor * @throws Error - with a message to indicate failure (usu. bad params) */ void potrf(INDArray A, boolean lower); @@ -122,7 +121,7 @@ public interface Lapack { * * @param M - the size of the permutation matrix ( usu. the # rows in factored matrix ) * @param ipiv - the vector returned from a refactoring - * @returned the square permutation matrix - size is the M x M + * @return the square permutation matrix - size is the M x M */ INDArray getPFactor(int M, INDArray ipiv); @@ -131,8 +130,8 @@ public interface Lapack { * extracts the L (lower triangular) matrix from the LU factor result * L will be the same dimensions as A * - * @param A - the combined L & U matrices returned from factorization - * @returned the lower triangular with unit diagonal + * @param A - the combined L & U matrices returned from factorization + * @return the lower triangular with unit diagonal */ INDArray getLFactor(INDArray A); @@ -141,8 +140,8 @@ public interface Lapack { * extracts the U (upper triangular) matrix from the LU factor result * U will be n x n matrix where n = num cols in A * - * @param A - the combined L & U matrices returned from factorization - * @returned the upper triangular matrix + * @param A - the combined L & U matrices returned from factorization + * @return the upper triangular matrix */ INDArray getUFactor(INDArray A); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Level1.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Level1.java index 3b82f6cbd..3658c0e74 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Level1.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/Level1.java @@ -26,7 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; public interface Level1 { /** * computes a vector-vector dot product. - * @param n + * @param N * @param alpha * @param X * @param Y @@ -65,7 +65,7 @@ public interface Level1 { /** * finds the element of a * vector that has the largest absolute value. - * @param n the length to iterate for + * @param N the length to iterate for * @param arr the array to get the max * index for * @param stride the stride for the array @@ -105,7 +105,7 @@ public interface Level1 { /** * computes a vector-scalar product and adds the result to a vector. - * @param n + * @param N * @param alpha * @param x * @param y @@ -115,7 +115,7 @@ public interface Level1 { /** * computes a vector-scalar product and adds the result to a vector. * y = a*x + y - * @param n number of operations + * @param N number of operations * @param alpha * @param x X * @param offsetX offset of first element of X in buffer diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index ecb2110c2..29340385b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -1297,7 +1297,7 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() == 0) { return wrappedBuffer().asIntBuffer(); } else - return wrappedBuffer().asIntBuffer().position((int) offset()); + return (IntBuffer) wrappedBuffer().asIntBuffer().position((int) offset()); } @Override @@ -1308,7 +1308,7 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() == 0) { return wrappedBuffer().asLongBuffer(); } else - return wrappedBuffer().asLongBuffer().position((int) offset()); + return (LongBuffer) wrappedBuffer().asLongBuffer().position((int) offset()); } @Override @@ -1319,7 +1319,7 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() == 0) { return wrappedBuffer().asDoubleBuffer(); } else { - return wrappedBuffer().asDoubleBuffer().position((int) (offset())); + return (DoubleBuffer) wrappedBuffer().asDoubleBuffer().position((int) (offset())); } } @@ -1331,7 +1331,8 @@ public abstract class BaseDataBuffer implements DataBuffer { if (offset() == 0) { return wrappedBuffer().asFloatBuffer(); } else { - return wrappedBuffer().asFloatBuffer().position((int) (offset())); + return (FloatBuffer) wrappedBuffer().asFloatBuffer() + .position((int) (offset())); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java index c683e47ee..1665ca165 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java @@ -34,7 +34,7 @@ import java.util.*; public class NDArrayCreationUtil { private NDArrayCreationUtil() {} - /** Get an array of INDArrays (2d) all with the specified shape. Pair returned to aid + /** Get an array of INDArrays (2d) all with the specified shape. {@code Pair} returned to aid * debugging: String contains information on how to reproduce the matrix (i.e., which function, and arguments) * Each NDArray in the returned array has been obtained by applying an operation such as transpose, tensorAlongDimension, * etc to an original array. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java index 1aadba8c8..9613d9141 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java @@ -302,7 +302,7 @@ public interface DataSet extends Iterable, Seri * Get the example metadata, or null if no metadata has been set * * @return List of metadata instances - * @see {@link #getExampleMetaData(Class)} for convenience method for types + * {@link #getExampleMetaData(Class)} for convenience method for types */ List getExampleMetaData(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/MultiDataSet.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/MultiDataSet.java index 675f66d0b..377beee0d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/MultiDataSet.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/MultiDataSet.java @@ -180,7 +180,7 @@ public interface MultiDataSet extends Serializable { * Get the example metadata, or null if no metadata has been set * * @return List of metadata instances - * @see {@link #getExampleMetaData(Class)} for convenience method for types + * {@link #getExampleMetaData(Class)} for convenience method for types */ List getExampleMetaData(); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/ND4JTestUtils.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/ND4JTestUtils.java index ad386be2f..a1ed3050a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/ND4JTestUtils.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/util/ND4JTestUtils.java @@ -35,140 +35,154 @@ import java.util.*; public class ND4JTestUtils { - private ND4JTestUtils(){ } + private ND4JTestUtils() { + } - @AllArgsConstructor - @Data - public static class ComparisonResult { - List> allResults; - List> passed; - List> failed; - List skippedDir1; - List skippedDir2; + @AllArgsConstructor + @Data + public static class ComparisonResult { + + List> allResults; + List> passed; + List> failed; + List skippedDir1; + List skippedDir2; + } + + /** + * A function for use with {@link #validateSerializedArrays(File, File, boolean, BiFunction)} + * using {@code INDArray#equals(Object)} + */ + public static class EqualsFn implements BiFunction { + + @Override + public Boolean apply(INDArray i1, INDArray i2) { + return i1.equals(i2); + } + } + + /** + * A function for use with {@link #validateSerializedArrays(File, File, boolean, BiFunction)} + * using {@link INDArray#equalsWithEps(Object, double)} + */ + @AllArgsConstructor + public static class EqualsWithEpsFn implements BiFunction { + + private final double eps; + + @Override + public Boolean apply(INDArray i1, INDArray i2) { + return i1.equalsWithEps(i2, eps); + } + } + + /** + * Scan the specified directories for matching files (i.e., same path relative to their respective + * root directories) and compare the contents using INDArray.equals (via {@link EqualsFn} Assumes + * the saved files represent INDArrays saved with {@link Nd4j#saveBinary(INDArray, File)} + * + * @param dir1 First directory + * @param dir2 Second directory + * @param recursive Whether to search recursively (i.e., include files in subdirectories + * @return Comparison results + */ + public static ComparisonResult validateSerializedArrays(File dir1, File dir2, boolean recursive) + throws Exception { + return validateSerializedArrays(dir1, dir2, recursive, new EqualsFn()); + } + + /** + * Scan the specified directories for matching files (i.e., same path relative to their respective + * root directories) and compare the contents using a provided function.
Assumes the saved + * files represent INDArrays saved with {@link Nd4j#saveBinary(INDArray, File)} + * + * @param dir1 First directory + * @param dir2 Second directory + * @param recursive Whether to search recursively (i.e., include files in subdirectories + * @return Comparison results + */ + public static ComparisonResult validateSerializedArrays(File dir1, File dir2, boolean recursive, + BiFunction evalFn) throws Exception { + File[] f1 = FileUtils.listFiles(dir1, null, recursive).toArray(new File[0]); + File[] f2 = FileUtils.listFiles(dir2, null, recursive).toArray(new File[0]); + + Preconditions.checkState(f1.length > 0, "No files found for directory 1: %s", + dir1.getAbsolutePath()); + Preconditions.checkState(f2.length > 0, "No files found for directory 2: %s", + dir2.getAbsolutePath()); + + Map relativized1 = new HashMap<>(); + Map relativized2 = new HashMap<>(); + + URI u = dir1.toURI(); + for (File f : f1) { + if (!f.isFile()) { + continue; + } + String relative = u.relativize(f.toURI()).getPath(); + relativized1.put(relative, f); } - /** - * A function for use with {@link #validateSerializedArrays(File, File, boolean, BiFunction)} using {@code INDArray#equals(Object)} - */ - public static class EqualsFn implements BiFunction { - @Override - public Boolean apply(INDArray i1, INDArray i2) { - return i1.equals(i2); + u = dir2.toURI(); + for (File f : f2) { + if (!f.isFile()) { + continue; } + String relative = u.relativize(f.toURI()).getPath(); + relativized2.put(relative, f); } - /** - * A function for use with {@link #validateSerializedArrays(File, File, boolean, BiFunction)} using {@link INDArray#equalsWithEps(Object, double)} - */ - @AllArgsConstructor - public static class EqualsWithEpsFn implements BiFunction { - private final double eps; - - @Override - public Boolean apply(INDArray i1, INDArray i2) { - return i1.equalsWithEps(i2, eps); - } + List skipped1 = new ArrayList<>(); + for (String s : relativized1.keySet()) { + if (!relativized2.containsKey(s)) { + skipped1.add(relativized1.get(s)); + } } - /** - * Scan the specified directories for matching files (i.e., same path relative to their respective root directories) - * and compare the contents using INDArray.equals (via {@link EqualsFn} - * Assumes the saved files represent INDArrays saved with {@link Nd4j#saveBinary(INDArray, File)} - * @param dir1 First directory - * @param dir2 Second directory - * @param recursive Whether to search recursively (i.e., include files in subdirectories - * @return Comparison results - */ - public static ComparisonResult validateSerializedArrays(File dir1, File dir2, boolean recursive) throws Exception { - return validateSerializedArrays(dir1, dir2, recursive, new EqualsFn()); + List skipped2 = new ArrayList<>(); + for (String s : relativized2.keySet()) { + if (!relativized1.containsKey(s)) { + skipped2.add(relativized1.get(s)); + } } - /** - * Scan the specified directories for matching files (i.e., same path relative to their respective root directories) - * and compare the contents using a provided function.
- * Assumes the saved files represent INDArrays saved with {@link Nd4j#saveBinary(INDArray, File)} - * @param dir1 First directory - * @param dir2 Second directory - * @param recursive Whether to search recursively (i.e., include files in subdirectories - * @return Comparison results - */ - public static ComparisonResult validateSerializedArrays(File dir1, File dir2, boolean recursive, BiFunction evalFn) throws Exception { - File[] f1 = FileUtils.listFiles(dir1, null, recursive).toArray(new File[0]); - File[] f2 = FileUtils.listFiles(dir2, null, recursive).toArray(new File[0]); + List> allResults = new ArrayList<>(); + List> passed = new ArrayList<>(); + List> failed = new ArrayList<>(); + for (Map.Entry e : relativized1.entrySet()) { + File file1 = e.getValue(); + File file2 = relativized2.get(e.getKey()); - Preconditions.checkState(f1.length > 0, "No files found for directory 1: %s", dir1.getAbsolutePath() ); - Preconditions.checkState(f2.length > 0, "No files found for directory 2: %s", dir2.getAbsolutePath() ); - - Map relativized1 = new HashMap<>(); - Map relativized2 = new HashMap<>(); - - URI u = dir1.toURI(); - for(File f : f1){ - if(!f.isFile()) - continue; - String relative = u.relativize(f.toURI()).getPath(); - relativized1.put(relative, f); + if (file2 == null) { + continue; } - u = dir2.toURI(); - for(File f : f2){ - if(!f.isFile()) - continue; - String relative = u.relativize(f.toURI()).getPath(); - relativized2.put(relative, f); - } - - List skipped1 = new ArrayList<>(); - for(String s : relativized1.keySet()){ - if(!relativized2.containsKey(s)){ - skipped1.add(relativized1.get(s)); - } - } - - List skipped2 = new ArrayList<>(); - for(String s : relativized2.keySet()){ - if(!relativized1.containsKey(s)){ - skipped2.add(relativized1.get(s)); - } - } - - List> allResults = new ArrayList<>(); - List> passed = new ArrayList<>(); - List> failed = new ArrayList<>(); - for(Map.Entry e : relativized1.entrySet()){ - File file1 = e.getValue(); - File file2 = relativized2.get(e.getKey()); - - if(file2 == null) - continue; - - INDArray i1 = Nd4j.readBinary(file1); - INDArray i2 = Nd4j.readBinary(file2); - boolean b = evalFn.apply(i1, i2); - Triple t = new Triple<>(file1, file2, b); - allResults.add(t); - if(b){ - passed.add(t); - } else { - failed.add(t); - } - } - - Comparator> c = new Comparator>() { - @Override - public int compare(Triple o1, Triple o2) { - return o1.getFirst().compareTo(o2.getFirst()); - } - }; - - Collections.sort(allResults, c); - Collections.sort(passed, c); - Collections.sort(failed, c); - Collections.sort(skipped1); - Collections.sort(skipped2); - - - return new ComparisonResult(allResults, passed, failed, skipped1, skipped2); + INDArray i1 = Nd4j.readBinary(file1); + INDArray i2 = Nd4j.readBinary(file2); + boolean b = evalFn.apply(i1, i2); + Triple t = new Triple<>(file1, file2, b); + allResults.add(t); + if (b) { + passed.add(t); + } else { + failed.add(t); + } } + + Comparator> c = new Comparator>() { + @Override + public int compare(Triple o1, Triple o2) { + return o1.getFirst().compareTo(o2.getFirst()); + } + }; + + Collections.sort(allResults, c); + Collections.sort(passed, c); + Collections.sort(failed, c); + Collections.sort(skipped1); + Collections.sort(skipped2); + + return new ComparisonResult(allResults, passed, failed, skipped1, skipped2); + } } diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/EmnistDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/EmnistDataSetIterator.java index 6d000582e..27ce9f464 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/EmnistDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/EmnistDataSetIterator.java @@ -211,7 +211,7 @@ public class EmnistDataSetIterator extends BaseDatasetIterator { } /** - * Get the labels as a List + * Get the labels as a {@code List} * * @return Labels */ @@ -244,7 +244,7 @@ public class EmnistDataSetIterator extends BaseDatasetIterator { } /** - * Get the label assignments for the given set as a List + * Get the label assignments for the given set as a {@code List} * * @param dataSet DataSet to get the label assignment for * @return Label assignment and given dataset diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/LFWDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/LFWDataSetIterator.java index 37ce72ea4..295798bdd 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/LFWDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datasets/src/main/java/org/deeplearning4j/datasets/iterator/impl/LFWDataSetIterator.java @@ -31,76 +31,109 @@ import java.util.Random; public class LFWDataSetIterator extends RecordReaderDataSetIterator { - /** Loads subset of images with given imgDim returned by the generator. */ - public LFWDataSetIterator(int[] imgDim) { - this(LFWLoader.SUB_NUM_IMAGES, LFWLoader.SUB_NUM_IMAGES, imgDim, LFWLoader.SUB_NUM_LABELS, false, - new ParentPathLabelGenerator(), true, 1, null, new Random(System.currentTimeMillis())); - } + /** + * Loads subset of images with given imgDim returned by the generator. + */ + public LFWDataSetIterator(int[] imgDim) { + this(LFWLoader.SUB_NUM_IMAGES, LFWLoader.SUB_NUM_IMAGES, imgDim, LFWLoader.SUB_NUM_LABELS, + false, + new ParentPathLabelGenerator(), true, 1, null, new Random(System.currentTimeMillis())); + } - /** Loads images with given batchSize, numExamples returned by the generator. */ - public LFWDataSetIterator(int batchSize, int numExamples) { - this(batchSize, numExamples, new int[] {LFWLoader.HEIGHT, LFWLoader.WIDTH, LFWLoader.CHANNELS}, - LFWLoader.NUM_LABELS, false, LFWLoader.LABEL_PATTERN, true, 1, null, - new Random(System.currentTimeMillis())); - } + /** + * Loads images with given batchSize, numExamples returned by the generator. + */ + public LFWDataSetIterator(int batchSize, int numExamples) { + this(batchSize, numExamples, new int[]{LFWLoader.HEIGHT, LFWLoader.WIDTH, LFWLoader.CHANNELS}, + LFWLoader.NUM_LABELS, false, LFWLoader.LABEL_PATTERN, true, 1, null, + new Random(System.currentTimeMillis())); + } - /** Loads images with given batchSize, numExamples, imgDim returned by the generator. */ - public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim) { - this(batchSize, numExamples, imgDim, LFWLoader.NUM_LABELS, false, LFWLoader.LABEL_PATTERN, true, 1, null, - new Random(System.currentTimeMillis())); - } + /** + * Loads images with given batchSize, numExamples, imgDim returned by the generator. + */ + public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim) { + this(batchSize, numExamples, imgDim, LFWLoader.NUM_LABELS, false, LFWLoader.LABEL_PATTERN, true, + 1, null, + new Random(System.currentTimeMillis())); + } - /** Loads images with given batchSize, imgDim, useSubset, returned by the generator. */ - public LFWDataSetIterator(int batchSize, int[] imgDim, boolean useSubset) { - this(batchSize, useSubset ? LFWLoader.SUB_NUM_IMAGES : LFWLoader.NUM_IMAGES, imgDim, - useSubset ? LFWLoader.SUB_NUM_LABELS : LFWLoader.NUM_LABELS, useSubset, LFWLoader.LABEL_PATTERN, - true, 1, null, new Random(System.currentTimeMillis())); - } + /** + * Loads images with given batchSize, imgDim, useSubset, returned by the generator. + */ + public LFWDataSetIterator(int batchSize, int[] imgDim, boolean useSubset) { + this(batchSize, useSubset ? LFWLoader.SUB_NUM_IMAGES : LFWLoader.NUM_IMAGES, imgDim, + useSubset ? LFWLoader.SUB_NUM_LABELS : LFWLoader.NUM_LABELS, useSubset, + LFWLoader.LABEL_PATTERN, + true, 1, null, new Random(System.currentTimeMillis())); + } - /** Loads images with given batchSize, numExamples, imgDim, train, & splitTrainTest returned by the generator. */ - public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, boolean train, double splitTrainTest) { - this(batchSize, numExamples, imgDim, LFWLoader.NUM_LABELS, false, LFWLoader.LABEL_PATTERN, train, - splitTrainTest, null, new Random(System.currentTimeMillis())); - } + /** + * Loads images with given batchSize, numExamples, imgDim, train, & splitTrainTest returned + * by the generator. + */ + public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, boolean train, + double splitTrainTest) { + this(batchSize, numExamples, imgDim, LFWLoader.NUM_LABELS, false, LFWLoader.LABEL_PATTERN, + train, + splitTrainTest, null, new Random(System.currentTimeMillis())); + } - /** Loads images with given batchSize, numExamples, numLabels, train, & splitTrainTest returned by the generator. */ - public LFWDataSetIterator(int batchSize, int numExamples, int numLabels, boolean train, double splitTrainTest) { - this(batchSize, numExamples, new int[] {LFWLoader.HEIGHT, LFWLoader.WIDTH, LFWLoader.CHANNELS}, numLabels, - false, null, train, splitTrainTest, null, new Random(System.currentTimeMillis())); - } + /** + * Loads images with given batchSize, numExamples, numLabels, train, & splitTrainTest + * returned by the generator. + */ + public LFWDataSetIterator(int batchSize, int numExamples, int numLabels, boolean train, + double splitTrainTest) { + this(batchSize, numExamples, new int[]{LFWLoader.HEIGHT, LFWLoader.WIDTH, LFWLoader.CHANNELS}, + numLabels, + false, null, train, splitTrainTest, null, new Random(System.currentTimeMillis())); + } - /** Loads images with given batchSize, numExamples, imgDim, numLabels, useSubset, train, splitTrainTest & Random returned by the generator. */ - public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numLabels, boolean useSubset, - boolean train, double splitTrainTest, Random rng) { - this(batchSize, numExamples, imgDim, numLabels, useSubset, LFWLoader.LABEL_PATTERN, train, splitTrainTest, null, - rng); - } + /** + * Loads images with given batchSize, numExamples, imgDim, numLabels, useSubset, train, + * splitTrainTest & Random returned by the generator. + */ + public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numLabels, + boolean useSubset, + boolean train, double splitTrainTest, Random rng) { + this(batchSize, numExamples, imgDim, numLabels, useSubset, LFWLoader.LABEL_PATTERN, train, + splitTrainTest, null, + rng); + } - /** Loads images with given batchSize, numExamples, imgDim, numLabels, useSubset, train, splitTrainTest & Random returned by the generator. */ - public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numLabels, boolean useSubset, - PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) { - this(batchSize, numExamples, imgDim, numLabels, useSubset, labelGenerator, train, splitTrainTest, null, rng); - } + /** + * Loads images with given batchSize, numExamples, imgDim, numLabels, useSubset, train, + * splitTrainTest & Random returned by the generator. + */ + public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numLabels, + boolean useSubset, + PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, Random rng) { + this(batchSize, numExamples, imgDim, numLabels, useSubset, labelGenerator, train, + splitTrainTest, null, rng); + } - /** - * Create LFW data specific iterator - * @param batchSize the batch size of the examples - * @param numExamples the overall number of examples - * @param imgDim an array of height, width and channels - * @param numLabels the overall number of examples - * @param useSubset use a subset of the LFWDataSet - * @param labelGenerator path label generator to use - * @param train true if use train value - * @param splitTrainTest the percentage to split data for train and remainder goes to test - * @param imageTransform how to transform the image - - * @param rng random number to lock in batch shuffling - * */ - public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numLabels, boolean useSubset, - PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, - ImageTransform imageTransform, Random rng) { - super(new LFWLoader(imgDim, imageTransform, useSubset).getRecordReader(batchSize, numExamples, imgDim, - numLabels, labelGenerator, train, splitTrainTest, rng), batchSize, 1, numLabels); - } + /** + * Create LFW data specific iterator + * + * @param batchSize the batch size of the examples + * @param numExamples the overall number of examples + * @param imgDim an array of height, width and channels + * @param numLabels the overall number of examples + * @param useSubset use a subset of the LFWDataSet + * @param labelGenerator path label generator to use + * @param train true if use train value + * @param splitTrainTest the percentage to split data for train and remainder goes to test + * @param imageTransform how to transform the image + * @param rng random number to lock in batch shuffling + */ + public LFWDataSetIterator(int batchSize, int numExamples, int[] imgDim, int numLabels, + boolean useSubset, + PathLabelGenerator labelGenerator, boolean train, double splitTrainTest, + ImageTransform imageTransform, Random rng) { + super(new LFWLoader(imgDim, imageTransform, useSubset).getRecordReader(batchSize, numExamples, + imgDim, + numLabels, labelGenerator, train, splitTrainTest, rng), batchSize, 1, numLabels); + } } diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 2e587fa8e..e25c3e7b2 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -19,7 +19,7 @@ dependencies { //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + // api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") diff --git a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java index c73f9c3bb..3658a8006 100644 --- a/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java +++ b/cavis-nd4j/cavis-nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/NDArrayMessage.java @@ -256,7 +256,7 @@ public class NDArrayMessage implements Serializable { String messageId = UUID.randomUUID().toString(); for (int i = 0; i < ret.length; i++) { //data: only grab a chunk of the data - ByteBuffer view = wholeBuffer.byteBuffer().asReadOnlyBuffer().position(i * chunkSize); + ByteBuffer view = (ByteBuffer) wholeBuffer.byteBuffer().asReadOnlyBuffer().position(i * chunkSize); view.limit(Math.min(i * chunkSize + chunkSize, wholeBuffer.capacity())); view.order(ByteOrder.nativeOrder()); view = view.slice(); diff --git a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java index 8f49c7e7a..b17164393 100644 --- a/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java +++ b/cavis-nd4j/cavis-nd4j-parameter-server/cavis-nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/RoutedTransport.java @@ -664,8 +664,8 @@ public class RoutedTransport extends BaseTransport { public static class RemoteConnectionBuilder { - private final Object locker = new Object(); - private final AtomicBoolean activated = new AtomicBoolean(); + private Object locker = new Object(); + private AtomicBoolean activated = new AtomicBoolean(); } } From d767abdebaaedb00f561d0d0d9af98fa613106db Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 11 Oct 2022 07:59:22 +0200 Subject: [PATCH 054/126] More test fixes --- .docker/Dockerfile | 2 ++ build.gradle | 24 ++----------------- build_requirements.md | 18 +------------- .../deeplearning4j/nn/layers/HelperUtils.java | 6 ++--- cavis-full/build.gradle | 9 +++---- 5 files changed, 11 insertions(+), 48 deletions(-) diff --git a/.docker/Dockerfile b/.docker/Dockerfile index 4e2c0ece8..a5508688e 100644 --- a/.docker/Dockerfile +++ b/.docker/Dockerfile @@ -11,3 +11,5 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3. rm cmake-3.24.2-linux-x86_64.sh +RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf + diff --git a/build.gradle b/build.gradle index fc9167f30..ab3337562 100644 --- a/build.gradle +++ b/build.gradle @@ -44,7 +44,6 @@ ext { scalaVersion = "2.12" logger.quiet("Scala main version is set to {}", scalaVersion) - logger.quiet("Running java {}", JavaVersion.current()) } configurations.all { @@ -56,7 +55,6 @@ configurations.all { } - allprojects { Project proj -> apply plugin: 'com.google.osdetector' @@ -65,8 +63,8 @@ allprojects { Project proj -> plugins.withType(JavaPlugin) { - sourceCompatibility = JavaVersion.VERSION_11 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = 11 + targetCompatibility = 1.8 tasks.withType(JavaCompile) { options.release = 8 } @@ -163,21 +161,3 @@ allprojects { Project proj -> } } } - - -task aggregatedJavadocs(type: Javadoc, description: 'Generate javadocs from all child projects as if it was a single project', group: 'Documentation') { - subprojects.each { proj -> - proj.tasks.withType(Javadoc).each { javadocTask -> - logger.quiet("Adding javadoc for project " + proj.name) - source += javadocTask.source - classpath += javadocTask.classpath - excludes += javadocTask.excludes - includes += javadocTask.includes - } - } - destinationDir = file("$buildDir/docs/javadoc") - title = "$project.name $version API" - options.author true - options.links 'http://docs.oracle.com/javase/8/docs/api/' - options.addStringOption('Xdoclint:none', '-quiet') -} \ No newline at end of file diff --git a/build_requirements.md b/build_requirements.md index 77d54050b..db6532203 100644 --- a/build_requirements.md +++ b/build_requirements.md @@ -129,20 +129,4 @@ echo "nameserver 8.8.8.8" | sudo tee -a /etc/resolv.conf # Buildparameter: # -P\\ - CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2 - -# Zeppelin Spark dependencies # -3 - - -To add the dependency to the language models, use the following format in the Dependencies section of the of the Spark Interpreter configuration (Interpreters -> Spark -> Edit -> Dependencies): - -groupId:artifactId:packaging:classifier:version - -In your case it should work with - -edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 - - -Native cpu code under linux needs libc6-dev -/lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found \ No newline at end of file + CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2 \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java index dfff491e4..eb59a2c5f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java @@ -64,7 +64,7 @@ public class HelperUtils { if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) { if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) { log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName); - helperRet = DL4JClassLoading.createNewInstance( + helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( cudnnHelperClassName, (Class) layerHelperSuperClass, new Object[]{arguments}); @@ -76,7 +76,7 @@ public class HelperUtils { ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader(); DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass); try { - helperRet = DL4JClassLoading.createNewInstance( + helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( cudnnHelperClassName, (Class) layerHelperSuperClass, arguments); @@ -99,7 +99,7 @@ public class HelperUtils { } } else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) { - helperRet = DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( oneDnnClassName, arguments); log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName); diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index e25c3e7b2..68e847fdf 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -12,15 +12,12 @@ configurations.archives.artifacts.with { archives -> dependencies { //Todo clean this api platform(project(":cavis-common-platform")) - //api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise + api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' - //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" - //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - // api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" - //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' + //api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") && !sproj.name.equals("Cavis") From 931841d6691e13f7641e22cfdf5aae5d3d6162ec Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 12 Oct 2022 12:03:08 +0200 Subject: [PATCH 055/126] Add jenkinsfile for pipeline build and dockerfile for build --- .docker/Dockerfile | 2 -- build.gradle | 5 +++-- build_requirements.md | 14 +++++++++++++- cavis-full/build.gradle | 3 ++- cavis-native/cavis-native-lib/build.gradle | 7 +++---- 5 files changed, 21 insertions(+), 10 deletions(-) diff --git a/.docker/Dockerfile b/.docker/Dockerfile index a5508688e..4e2c0ece8 100644 --- a/.docker/Dockerfile +++ b/.docker/Dockerfile @@ -11,5 +11,3 @@ RUN wget -nv https://github.com/Kitware/CMake/releases/download/v3.24.2/cmake-3. rm cmake-3.24.2-linux-x86_64.sh -RUN echo "nameserver 8.8.8.8" >> /etc/resolv.conf - diff --git a/build.gradle b/build.gradle index ab3337562..cd5911461 100644 --- a/build.gradle +++ b/build.gradle @@ -44,6 +44,7 @@ ext { scalaVersion = "2.12" logger.quiet("Scala main version is set to {}", scalaVersion) + logger.quiet("Running java {}", JavaVersion.current()) } configurations.all { @@ -63,8 +64,8 @@ allprojects { Project proj -> plugins.withType(JavaPlugin) { - sourceCompatibility = 11 - targetCompatibility = 1.8 + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_1_8 tasks.withType(JavaCompile) { options.release = 8 } diff --git a/build_requirements.md b/build_requirements.md index db6532203..602190b95 100644 --- a/build_requirements.md +++ b/build_requirements.md @@ -129,4 +129,16 @@ echo "nameserver 8.8.8.8" | sudo tee -a /etc/resolv.conf # Buildparameter: # -P\\ - CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2 \ No newline at end of file + CAVIS_AVX_EXTENSION = {avx2 | avx512}, default is avx2 + +# Zeppelin Spark dependencies # +3 + + +To add the dependency to the language models, use the following format in the Dependencies section of the of the Spark Interpreter configuration (Interpreters -> Spark -> Edit -> Dependencies): + +groupId:artifactId:packaging:classifier:version + +In your case it should work with + +edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 \ No newline at end of file diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 68e847fdf..659e119e2 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -16,7 +16,8 @@ dependencies { api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' - //api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 1d083f0ce..10648759d 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -84,7 +84,7 @@ chipList.each {thisChip -> } -//if(osdetector.os.startsWith("windows")) { +if(osdetector.os.startsWith("windows")) { sourceSets { main { java { @@ -93,7 +93,7 @@ chipList.each {thisChip -> } } } -//} +} java { @@ -212,7 +212,7 @@ tasks.withType(org.bytedeco.gradle.javacpp.BuildTask) { // Disable the standard javacpp generated tasks and use own // versions below. This allows to build for each variant [javacppBuildParser, javacppBuildCommand, javacppCompileJava, javacppBuildCompiler].each { - it.enabled false + it.enabled false; } chipList.each { thisChip -> @@ -488,7 +488,6 @@ chipList.each { thisChip -> publishing { publications { mavenJava(MavenPublication) { - artifact jar artifact tasks.getByName("${thisChip}SupportJar") } } From 66ed10a5e38f355b30fa5433ed13d1b902ea9919 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Oct 2022 12:32:41 +0200 Subject: [PATCH 056/126] Add jenkinsfile for pipeline build and dockerfile for build --- cavis-full/build.gradle | 8 +++++--- cavis-native/cavis-native-lib/build.gradle | 5 +++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 659e119e2..2e587fa8e 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -12,13 +12,15 @@ configurations.archives.artifacts.with { archives -> dependencies { //Todo clean this api platform(project(":cavis-common-platform")) - api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise + //api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" api 'org.slf4j:slf4j-simple:2.0.3' api 'org.slf4j:slf4j-api:2.0.3' - api group: "org.bytedeco", name: "javacpp", classifier: "linux-x64_86" + //TODO for the two below.. either platform specific uber jars or a single big one with all platforms + api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" + //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" - + //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") && !sproj.name.equals("Cavis") diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 10648759d..0a638ff15 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -84,7 +84,7 @@ chipList.each {thisChip -> } -if(osdetector.os.startsWith("windows")) { +//if(osdetector.os.startsWith("windows")) { sourceSets { main { java { @@ -93,7 +93,7 @@ if(osdetector.os.startsWith("windows")) { } } } -} +//} java { @@ -488,6 +488,7 @@ chipList.each { thisChip -> publishing { publications { mavenJava(MavenPublication) { + artifact jar artifact tasks.getByName("${thisChip}SupportJar") } } From 1c2ca75308a5780602b025bacfed588d9b927b46 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 21 Oct 2022 15:19:32 +0200 Subject: [PATCH 057/126] Fix javadoc and cleanup --- build.gradle | 19 + build_requirements.md | 6 +- .../autodiff/execution/input/Operands.java | 7 +- .../debugging/ExecDebuggingListener.java | 2 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 6 +- .../internal/AbstractDependencyTracker.java | 22 +- .../samediff/internal/AbstractSession.java | 6 +- .../nd4j/autodiff/samediff/ops/SDBitwise.java | 206 +- .../nd4j/autodiff/samediff/ops/SDImage.java | 6 +- .../nd4j/autodiff/samediff/ops/SDLoss.java | 767 +- .../nd4j/autodiff/samediff/ops/SDMath.java | 2609 ++-- .../org/nd4j/autodiff/samediff/ops/SDNN.java | 831 +- .../nd4j/autodiff/samediff/ops/SDRandom.java | 4 +- .../evaluation/classification/Evaluation.java | 16 +- .../classification/EvaluationBinary.java | 4 +- .../nd4j/linalg/api/blas/impl/BaseLapack.java | 2 +- .../api/buffer/factory/DataBufferFactory.java | 4 +- .../nd4j/linalg/api/memory/MemoryManager.java | 1 - .../api/memory/MemoryWorkspaceManager.java | 8 +- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 10928 ++++++++-------- .../api/ndarray/BaseShapeInfoProvider.java | 2 +- .../org/nd4j/linalg/api/ndarray/INDArray.java | 22 +- .../linalg/api/ndarray/ShapeInfoProvider.java | 2 +- .../org/nd4j/linalg/api/ops/OpContext.java | 3 +- .../java/org/nd4j/linalg/api/rng/Random.java | 4 +- .../rng/distribution/BaseDistribution.java | 405 +- .../api/rng/distribution/Distribution.java | 8 +- .../impl/BinomialDistribution.java | 10 +- .../impl/ConstantDistribution.java | 12 +- .../impl/LogNormalDistribution.java | 8 +- .../distribution/impl/NormalDistribution.java | 6 - .../impl/OrthogonalDistribution.java | 411 +- .../impl/SaddlePointExpansion.java | 2 - .../impl/TruncatedNormalDistribution.java | 12 +- .../impl/UniformDistribution.java | 10 +- .../linalg/checkutil/NDArrayCreationUtil.java | 4 +- .../linalg/dataset/AsyncDataSetIterator.java | 2 +- .../dataset/AsyncMultiDataSetIterator.java | 2 +- .../java/org/nd4j/linalg/dataset/DataSet.java | 1 - .../dataset/api/iterator/KFoldIterator.java | 3 +- .../api/iterator/TestDataSetIterator.java | 1 - .../RandomProjection.java | 4 +- .../nd4j/linalg/env/EnvironmentalAction.java | 1 - .../linalg/factory/BaseNDArrayFactory.java | 1 - .../org/nd4j/linalg/factory/BlasWrapper.java | 24 +- .../nd4j/linalg/factory/NDArrayFactory.java | 10 +- .../java/org/nd4j/linalg/factory/Nd4j.java | 27 +- .../org/nd4j/linalg/factory/Nd4jBackend.java | 6 +- .../org/nd4j/linalg/factory/ops/NDBase.java | 1998 +-- .../org/nd4j/linalg/indexing/Indices.java | 14 +- .../linalg/learning/AdaBeliefUpdater.java | 1 - .../nd4j/linalg/learning/AdaDeltaUpdater.java | 128 +- .../nd4j/linalg/learning/AdaMaxUpdater.java | 1 - .../org/nd4j/linalg/learning/AdamUpdater.java | 1 - .../nd4j/linalg/learning/GradientUpdater.java | 1 - .../nd4j/linalg/learning/NadamUpdater.java | 1 - .../linalg/learning/NesterovsUpdater.java | 1 - .../collection/MultiDimensionalMap.java | 10 +- .../collection/MultiDimensionalSet.java | 16 +- .../java/org/nd4j/common/util/ArrayUtil.java | 4 +- .../deeplearning4j/nn/layers/HelperUtils.java | 6 +- cavis-full/build.gradle | 5 +- cavis-native/cavis-native-lib/build.gradle | 2 +- .../collection/MultiDimensionalMap.java | 10 +- .../collection/MultiDimensionalSet.java | 16 +- .../java/org/nd4j/common/util/ArrayUtil.java | 4 +- settings.gradle | 1 - 67 files changed, 9896 insertions(+), 8781 deletions(-) diff --git a/build.gradle b/build.gradle index cd5911461..fc9167f30 100644 --- a/build.gradle +++ b/build.gradle @@ -56,6 +56,7 @@ configurations.all { } + allprojects { Project proj -> apply plugin: 'com.google.osdetector' @@ -162,3 +163,21 @@ allprojects { Project proj -> } } } + + +task aggregatedJavadocs(type: Javadoc, description: 'Generate javadocs from all child projects as if it was a single project', group: 'Documentation') { + subprojects.each { proj -> + proj.tasks.withType(Javadoc).each { javadocTask -> + logger.quiet("Adding javadoc for project " + proj.name) + source += javadocTask.source + classpath += javadocTask.classpath + excludes += javadocTask.excludes + includes += javadocTask.includes + } + } + destinationDir = file("$buildDir/docs/javadoc") + title = "$project.name $version API" + options.author true + options.links 'http://docs.oracle.com/javase/8/docs/api/' + options.addStringOption('Xdoclint:none', '-quiet') +} \ No newline at end of file diff --git a/build_requirements.md b/build_requirements.md index 602190b95..77d54050b 100644 --- a/build_requirements.md +++ b/build_requirements.md @@ -141,4 +141,8 @@ groupId:artifactId:packaging:classifier:version In your case it should work with -edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 \ No newline at end of file +edu.stanford.nlp:stanford-corenlp:jar:models:3.8.0 + + +Native cpu code under linux needs libc6-dev +/lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java index 2ea351b38..648581291 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/execution/input/Operands.java @@ -90,7 +90,7 @@ public class Operands { /** * This method returns array identified its numeric id - * @param name + * @param id * @return */ public INDArray getById(int id) { @@ -99,7 +99,8 @@ public class Operands { /** * This method returns array identified its numeric id and index - * @param name + * @param id + * @param index * @return */ public INDArray getById(int id, int index) { @@ -121,7 +122,7 @@ public class Operands { } /** - * This method returns contents of this entity as collection of key->value pairs + * This method returns contents of this entity as collection of key->value pairs * @return */ public Collection> asCollection() { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java index ab4423020..748a8a6b8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/listeners/debugging/ExecDebuggingListener.java @@ -50,7 +50,7 @@ public class ExecDebuggingListener extends BaseListener { /** * @param printMode Print mode, see {@link PrintMode} - * @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations" + * @param maxIterations Maximum number of iterations to print. <= 0 for "all iterations" * @param logIter If true: prefix iteration/epoch, such as "(iter=1,epoch=0,op=3)" to the output */ public ExecDebuggingListener(PrintMode printMode, int maxIterations, boolean logIter){ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index dc064e515..1584b5977 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -573,7 +573,7 @@ public class SameDiff extends SDBaseOps { } /** - * Get the function by the {@link DifferentialFunction#getOwnName()} + * Get the function by the {@link org.nd4j.autodiff.functions.DifferentialFunction#getOwnName()} * * @param id the id of the function * @return the function for the given id if it exists @@ -1348,9 +1348,9 @@ public class SameDiff extends SDBaseOps { /** * Get the names of variables (if any) that have been marked as loss variables to be minimized.
* Variables can be marked as loss variables in a few different ways:
- * (a) Losses are automatically added when creating loss functions via {@link #sd()}
+ * (a) Losses are automatically added when creating loss functions via {@link SameDiff#sd}
* (b) Via {@link #setLossVariables(String...)}, @link #addLossVariable(String)} or {@link SDVariable#markAsLoss()}
- * (c) Via {@link TrainingConfig#setLossVariables(List)}
+ * (c) Via {@link org.nd4j.autodiff.samediff.TrainingConfig#setLossVariables(List)}
*/ public List getLossVariables() { return Collections.unmodifiableList(this.lossVariables); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java index 9c3b5d917..7a3d8b01a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractDependencyTracker.java @@ -54,12 +54,12 @@ public abstract class AbstractDependencyTracker { } /** - * @return A new map where the dependents (i.e., Y in "X -> Y") are the key + * @return A new map where the dependents (i.e., Y in "X -> Y") are the key */ protected abstract Map newTMap(); /** - * @return A new set where the dependents (i.e., Y in "X -> Y") are the key + * @return A new set where the dependents (i.e., Y in "X -> Y") are the key */ protected abstract Set newTSet(); @@ -103,7 +103,7 @@ public abstract class AbstractDependencyTracker { /** * Mark the specified value as satisfied. - * For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true) + * For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true) * call, both of these dependencies are considered satisfied. * * @param x Value to mark @@ -191,7 +191,7 @@ public abstract class AbstractDependencyTracker { } /** - * Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)} + * Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)} * or {@link #addOrDependency(Object, Object, Object)} * * @param y Dependent to check @@ -207,7 +207,7 @@ public abstract class AbstractDependencyTracker { } /** - * Get all dependencies x, for x -> y, and (x1 or x2) -> y + * Get all dependencies x, for x -> y, and (x1 or x2) -> y * * @param y Dependent to get dependencies for * @return List of dependencies @@ -223,7 +223,7 @@ public abstract class AbstractDependencyTracker { } /** - * Add a dependency: y depends on x, as in x -> y + * Add a dependency: y depends on x, as in x -> y * * @param y The dependent * @param x The dependee that is required for Y @@ -302,7 +302,7 @@ public abstract class AbstractDependencyTracker { /** - * Remove a dependency (x -> y) + * Remove a dependency (x -> y) * * @param y The dependent that currently requires X * @param x The dependee that is no longer required for Y @@ -357,7 +357,7 @@ public abstract class AbstractDependencyTracker { } /** - * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y
+ * Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the * dependency is considered satisfied * @@ -382,16 +382,16 @@ public abstract class AbstractDependencyTracker { } /** - * @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y) + * @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y) */ public boolean hasNewAllSatisfied() { return !allSatisfiedQueue.isEmpty(); } /** - * Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)} + * Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)} * Throws an exception if {@link #hasNewAllSatisfied()} returns false.
- * Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value; + * Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value; * the value is considered "processed" at this point. * * @return The next new "all satisfied dependent" diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java index d00efcba7..ce29242a8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/internal/AbstractSession.java @@ -487,7 +487,7 @@ public abstract class AbstractSession { } /** - * Add the control dependency from Op -> variable + * Add the control dependency from Op -> variable * * @param es Execution step for the variable * @param v Variable @@ -542,7 +542,7 @@ public abstract class AbstractSession { /** * Update the descendant dependencies - * So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker + * So if the graph structure is X -> A, then add all (X,Y,Z,...) -> A to the dependency tracker * This is for a specific frame and iteration, for both sides of the dependency (in and out) * * @param justExecuted The execution step that has just completed @@ -621,7 +621,7 @@ public abstract class AbstractSession { /** * Suppose operation X has just been executed. - * For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp + * For X -> someOp, add all dependencies for someOp, i.e., all Z -> someOp * (which includes X, but may not only be X) * * @param opName Name of the op diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java index 00102c498..38e99641b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBitwise.java @@ -28,15 +28,15 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; public class SDBitwise extends SDOps { + public SDBitwise(SameDiff sameDiff) { super(sameDiff); } /** * Bitwise AND operation. Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param x First input array (INT type) @@ -47,147 +47,155 @@ public class SDBitwise extends SDOps { SDValidation.validateInteger("and", "x", x); SDValidation.validateInteger("and", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd, x, y).outputVariable(); } /** * Bitwise AND operation. Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param name name May be null. Name for the output variable - * @param x First input array (INT type) - * @param y Second input array (INT type) + * @param x First input array (INT type) + * @param y Second input array (INT type) * @return output Bitwise AND array (INT type) */ public SDVariable and(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("and", "x", x); SDValidation.validateInteger("and", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * Roll integer bits to the left, i.e. {@code var << 4 | var >> (32 - 4)}
* - * @param x Input 1 (INT type) + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitRotl(SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitRotl", "x", x); SDValidation.validateInteger("bitRotl", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + shift).outputVariable(); } /** - * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4)
+ * Roll integer bits to the left, i.e. {@code var << 4 | var >> (32 - 4)}
* - * @param name name May be null. Name for the output variable - * @param x Input 1 (INT type) + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitRotl(String name, SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitRotl", "x", x); SDValidation.validateInteger("bitRotl", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * Roll integer bits to the right, i.e. {@code var >> 4 | var << (32 - 4)}
* - * @param x Input 1 (INT type) + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitRotr(SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitRotr", "x", x); SDValidation.validateInteger("bitRotr", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + shift).outputVariable(); } /** - * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4)
+ * Roll integer bits to the right, i.e. {@code var >> 4 | var << (32 - 4)}
* - * @param name name May be null. Name for the output variable - * @param x Input 1 (INT type) + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitRotr(String name, SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitRotr", "x", x); SDValidation.validateInteger("bitRotr", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Shift integer bits to the left, i.e. var << 4
+ * Shift integer bits to the left, i.e. {@code var << 4}
* - * @param x Input 1 (INT type) + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitShift(SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitShift", "x", x); SDValidation.validateInteger("bitShift", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + shift).outputVariable(); } /** - * Shift integer bits to the left, i.e. var << 4
+ * Shift integer bits to the left, i.e. {@code var << 4}
* - * @param name name May be null. Name for the output variable - * @param x Input 1 (INT type) + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitShift", "x", x); SDValidation.validateInteger("bitShift", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Shift integer bits to the right, i.e. var >> 4
+ * Shift integer bits to the right, i.e. {@code var >> 4}
* - * @param x Input 1 (INT type) + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitShiftRight", "x", x); SDValidation.validateInteger("bitShiftRight", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + shift).outputVariable(); } /** - * Shift integer bits to the right, i.e. var >> 4
+ * Shift integer bits to the right, i.e. {@code var >> 4}
* - * @param name name May be null. Name for the output variable - * @param x Input 1 (INT type) + * @param name name May be null. Name for the output variable + * @param x Input 1 (INT type) * @param shift Number of bits to shift. (INT type) * @return output SDVariable with shifted bits (INT type) */ public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { SDValidation.validateInteger("bitShiftRight", "x", x); SDValidation.validateInteger("bitShiftRight", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Bitwise Hamming distance reduction over all elements of both input arrays.
- * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ * Bitwise Hamming distance reduction over all elements of both input arrays.
For example, if + * x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at + * positions 0 and 1)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* * @param x First input array. (INT type) * @param y Second input array. (INT type) @@ -197,26 +205,28 @@ public class SDBitwise extends SDOps { SDValidation.validateInteger("bitsHammingDistance", "x", x); SDValidation.validateInteger("bitsHammingDistance", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd, x, + y).outputVariable(); } /** - * Bitwise Hamming distance reduction over all elements of both input arrays.
- * For example, if x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at positions 0 and 1)
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ * Bitwise Hamming distance reduction over all elements of both input arrays.
For example, if + * x=01100000 and y=1010000 then the bitwise Hamming distance is 2 (due to differences at + * positions 0 and 1)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* * @param name name May be null. Name for the output variable - * @param x First input array. (INT type) - * @param y Second input array. (INT type) + * @param x First input array. (INT type) + * @param y Second input array. (INT type) * @return output bitwise Hamming distance (INT type) */ public SDVariable bitsHammingDistance(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("bitsHammingDistance", "x", x); SDValidation.validateInteger("bitsHammingDistance", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -230,27 +240,28 @@ public class SDBitwise extends SDOps { public SDVariable leftShift(SDVariable x, SDVariable y) { SDValidation.validateInteger("leftShift", "x", x); SDValidation.validateInteger("leftShift", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, y).outputVariable(); } /** * Bitwise left shift operation. Supports broadcasting.
* * @param name name May be null. Name for the output variable - * @param x Input to be bit shifted (INT type) - * @param y Amount to shift elements of x array (INT type) + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) * @return output Bitwise shifted input x (INT type) */ public SDVariable leftShift(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("leftShift", "x", x); SDValidation.validateInteger("leftShift", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Bitwise left cyclical shift operation. Supports broadcasting.
- * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * Bitwise left cyclical shift operation. Supports broadcasting.
Unlike + * {@link SDBitwise#leftShift(INDArray, INDArray)} the bits will "wrap around":
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
* * @param x Input to be bit shifted (INT type) @@ -260,31 +271,32 @@ public class SDBitwise extends SDOps { public SDVariable leftShiftCyclic(SDVariable x, SDVariable y) { SDValidation.validateInteger("leftShiftCyclic", "x", x); SDValidation.validateInteger("leftShiftCyclic", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + y).outputVariable(); } /** - * Bitwise left cyclical shift operation. Supports broadcasting.
- * Unlike {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
+ * Bitwise left cyclical shift operation. Supports broadcasting.
Unlike + * {@link #leftShift(INDArray, INDArray)} the bits will "wrap around":
* {@code leftShiftCyclic(01110000, 2) -> 11000001}
* * @param name name May be null. Name for the output variable - * @param x Input to be bit shifted (INT type) - * @param y Amount to shift elements of x array (INT type) + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) * @return output Bitwise cyclic shifted input x (INT type) */ public SDVariable leftShiftCyclic(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("leftShiftCyclic", "x", x); SDValidation.validateInteger("leftShiftCyclic", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Bitwise OR operation. Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param x First input array (INT type) @@ -295,26 +307,26 @@ public class SDBitwise extends SDOps { SDValidation.validateInteger("or", "x", x); SDValidation.validateInteger("or", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd, x, y).outputVariable(); } /** * Bitwise OR operation. Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param name name May be null. Name for the output variable - * @param x First input array (INT type) - * @param y First input array (INT type) + * @param x First input array (INT type) + * @param y First input array (INT type) * @return output Bitwise OR array (INT type) */ public SDVariable or(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("or", "x", x); SDValidation.validateInteger("or", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -328,27 +340,28 @@ public class SDBitwise extends SDOps { public SDVariable rightShift(SDVariable x, SDVariable y) { SDValidation.validateInteger("rightShift", "x", x); SDValidation.validateInteger("rightShift", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, y).outputVariable(); } /** * Bitwise right shift operation. Supports broadcasting.
* * @param name name May be null. Name for the output variable - * @param x Input to be bit shifted (INT type) - * @param y Amount to shift elements of x array (INT type) + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) * @return output Bitwise shifted input x (INT type) */ public SDVariable rightShift(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("rightShift", "x", x); SDValidation.validateInteger("rightShift", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Bitwise right cyclical shift operation. Supports broadcasting.
- * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * Bitwise right cyclical shift operation. Supports broadcasting.
Unlike + * {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
* * @param x Input to be bit shifted (INT type) @@ -358,31 +371,32 @@ public class SDBitwise extends SDOps { public SDVariable rightShiftCyclic(SDVariable x, SDVariable y) { SDValidation.validateInteger("rightShiftCyclic", "x", x); SDValidation.validateInteger("rightShiftCyclic", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + y).outputVariable(); } /** - * Bitwise right cyclical shift operation. Supports broadcasting.
- * Unlike {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
+ * Bitwise right cyclical shift operation. Supports broadcasting.
Unlike + * {@link #rightShift(INDArray, INDArray)} the bits will "wrap around":
* {@code rightShiftCyclic(00001110, 2) -> 10000011}
* * @param name name May be null. Name for the output variable - * @param x Input to be bit shifted (INT type) - * @param y Amount to shift elements of x array (INT type) + * @param x Input to be bit shifted (INT type) + * @param y Amount to shift elements of x array (INT type) * @return output Bitwise cyclic shifted input x (INT type) */ public SDVariable rightShiftCyclic(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("rightShiftCyclic", "x", x); SDValidation.validateInteger("rightShiftCyclic", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Bitwise XOR operation (exclusive OR). Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param x First input array (INT type) @@ -393,26 +407,26 @@ public class SDBitwise extends SDOps { SDValidation.validateInteger("xor", "x", x); SDValidation.validateInteger("xor", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd, x, y).outputVariable(); } /** * Bitwise XOR operation (exclusive OR). Supports broadcasting.
- * - * Inputs must satisfy the following constraints:
- * Must be same types: isSameType(x, y)
+ *

+ * Inputs must satisfy the following constraints:
Must be same types: isSameType(x, y)
* Must have broadcastable shapes: isBroadcastableShapes(x, y)
* * @param name name May be null. Name for the output variable - * @param x First input array (INT type) - * @param y First input array (INT type) + * @param x First input array (INT type) + * @param y First input array (INT type) * @return output Bitwise XOR array (INT type) */ public SDVariable xor(String name, SDVariable x, SDVariable y) { SDValidation.validateInteger("xor", "x", x); SDValidation.validateInteger("xor", "y", y); Preconditions.checkArgument(isSameType(x, y), "Must be same types"); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index 6317e0941..558d095db 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -355,7 +355,8 @@ public class SDImage extends SDOps { * @param maxOutSize scalar representing the maximum number of boxes to be selected * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU * @param scoreThreshold threshold for deciding when to remove boxes based on score - * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, + * where M <= max_output_size (NUMERIC type) */ public SDVariable nonMaxSuppression(SDVariable boxes, SDVariable scores, int maxOutSize, double iouThreshold, double scoreThreshold) { @@ -373,7 +374,8 @@ public class SDImage extends SDOps { * @param maxOutSize scalar representing the maximum number of boxes to be selected * @param iouThreshold threshold for deciding whether boxes overlap too much with respect to IOU * @param scoreThreshold threshold for deciding when to remove boxes based on score - * @return output vectort of shape [M] representing the selected indices from the boxes tensor, where M <= max_output_size (NUMERIC type) + * @return output vectort of shape [M] representing the selected indices from the boxes tensor, + * where M <= max_output_size (NUMERIC type) */ public SDVariable nonMaxSuppression(String name, SDVariable boxes, SDVariable scores, int maxOutSize, double iouThreshold, double scoreThreshold) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java index c6fef378e..5f6b76c94 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDLoss.java @@ -26,6 +26,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; public class SDLoss extends SDOps { + public SDLoss(SameDiff sameDiff) { super(sameDiff); } @@ -33,10 +34,11 @@ public class SDLoss extends SDOps { /** * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output loss variable (NUMERIC type) */ public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, @@ -44,7 +46,8 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("absoluteDifference", "label", label); SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); SDValidation.validateNumerical("absoluteDifference", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return out; } @@ -52,11 +55,12 @@ public class SDLoss extends SDOps { /** * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output loss variable (NUMERIC type) */ public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, @@ -64,7 +68,8 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("absoluteDifference", "label", label); SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); SDValidation.validateNumerical("absoluteDifference", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } @@ -72,9 +77,9 @@ public class SDLoss extends SDOps { /** * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output loss variable (NUMERIC type) */ public SDVariable absoluteDifference(SDVariable label, SDVariable predictions, @@ -82,7 +87,9 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("absoluteDifference", "label", label); SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); SDValidation.validateNumerical("absoluteDifference", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return out; } @@ -90,10 +97,10 @@ public class SDLoss extends SDOps { /** * Absolute difference loss: {@code sum_i abs( label[i] - predictions[i] )}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output loss variable (NUMERIC type) */ public SDVariable absoluteDifference(String name, SDVariable label, SDVariable predictions, @@ -101,23 +108,28 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("absoluteDifference", "label", label); SDValidation.validateNumerical("absoluteDifference", "predictions", predictions); SDValidation.validateNumerical("absoluteDifference", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
- * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
- * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
- * along the cosine distance dimension (with keepDims=true).
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or + * {@code 1 - sum_i label[i] * prediction[i]}, which is
equivalent to cosine distance when + * both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to + * have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, + * SDVariable, boolean, int...)
along the cosine distance dimension (with keepDims=true).
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param dimension Dimension to perform the cosine distance over + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over * @return output Cosine distance loss (NUMERIC type) */ public SDVariable cosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, @@ -125,24 +137,28 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("cosineDistance", "label", label); SDValidation.validateNumerical("cosineDistance", "predictions", predictions); SDValidation.validateNumerical("cosineDistance", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, lossReduce, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd, label, + predictions, weights, lossReduce, dimension).outputVariable(); out.markAsLoss(); return out; } /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
- * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
- * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
- * along the cosine distance dimension (with keepDims=true).
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or + * {@code 1 - sum_i label[i] * prediction[i]}, which is
equivalent to cosine distance when + * both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to + * have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, + * SDVariable, boolean, int...)
along the cosine distance dimension (with keepDims=true).
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param dimension Dimension to perform the cosine distance over + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param dimension Dimension to perform the cosine distance over * @return output Cosine distance loss (NUMERIC type) */ public SDVariable cosineDistance(String name, SDVariable label, SDVariable predictions, @@ -150,22 +166,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("cosineDistance", "label", label); SDValidation.validateNumerical("cosineDistance", "predictions", predictions); SDValidation.validateNumerical("cosineDistance", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, lossReduce, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd, label, + predictions, weights, lossReduce, dimension).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
- * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
- * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
- * along the cosine distance dimension (with keepDims=true).
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or + * {@code 1 - sum_i label[i] * prediction[i]}, which is
equivalent to cosine distance when + * both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to + * have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, + * SDVariable, boolean, int...)
along the cosine distance dimension (with keepDims=true).
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) - * @param dimension Dimension to perform the cosine distance over + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over * @return output Cosine distance loss (NUMERIC type) */ public SDVariable cosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, @@ -173,23 +192,27 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("cosineDistance", "label", label); SDValidation.validateNumerical("cosineDistance", "predictions", predictions); SDValidation.validateNumerical("cosineDistance", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd, label, + predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + dimension).outputVariable(); out.markAsLoss(); return out; } /** - * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or {@code 1 - sum_i label[i] * prediction[i]}, which is
- * equivalent to cosine distance when both the predictions and labels are normalized.
- * Note: This loss function assumes that both the predictions and labels are normalized to have unit l2 norm.
- * If this is not the case, you should normalize them first by dividing by norm2(String, SDVariable, boolean, int...)
- * along the cosine distance dimension (with keepDims=true).
+ * Cosine distance loss: {@code 1 - cosineSimilarity(x,y)} or + * {@code 1 - sum_i label[i] * prediction[i]}, which is
equivalent to cosine distance when + * both the predictions and labels are normalized.
+ * Note: This loss function assumes that both the predictions and labels are normalized to + * have unit l2 norm.
+ * If this is not the case, you should normalize them first by dividing by norm2(String, + * SDVariable, boolean, int...)
along the cosine distance dimension (with keepDims=true).
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) - * @param dimension Dimension to perform the cosine distance over + * @param weights Weights array. May be null. If null, a weight of 1.0 is use (NUMERIC type) + * @param dimension Dimension to perform the cosine distance over * @return output Cosine distance loss (NUMERIC type) */ public SDVariable cosineDistance(String name, SDVariable label, SDVariable predictions, @@ -197,20 +220,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("cosineDistance", "label", label); SDValidation.validateNumerical("cosineDistance", "predictions", predictions); SDValidation.validateNumerical("cosineDistance", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss(sd, label, + predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + dimension).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Hinge loss: a loss function used for training classifiers.
- * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
- * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * Hinge loss: a loss function used for training classifiers.
Implements + * {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting + * to {-1,1}
from the user specified {0,1}. Note that Labels should be provided with values + * {0,1}.
* - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) + * (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable (NUMERIC type) */ public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights, @@ -218,21 +246,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("hingeLoss", "label", label); SDValidation.validateNumerical("hingeLoss", "predictions", predictions); SDValidation.validateNumerical("hingeLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd, label, predictions, + weights, lossReduce).outputVariable(); out.markAsLoss(); return out; } /** - * Hinge loss: a loss function used for training classifiers.
- * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
- * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * Hinge loss: a loss function used for training classifiers.
Implements + * {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting + * to {-1,1}
from the user specified {0,1}. Note that Labels should be provided with values + * {0,1}.
* - * @param name name May be null. Name for the output variable - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) + * (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable (NUMERIC type) */ public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, @@ -240,39 +272,45 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("hingeLoss", "label", label); SDValidation.validateNumerical("hingeLoss", "predictions", predictions); SDValidation.validateNumerical("hingeLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd, label, predictions, + weights, lossReduce).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Hinge loss: a loss function used for training classifiers.
- * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
- * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * Hinge loss: a loss function used for training classifiers.
Implements + * {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting + * to {-1,1}
from the user specified {0,1}. Note that Labels should be provided with values + * {0,1}.
* - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) + * (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ public SDVariable hingeLoss(SDVariable label, SDVariable predictions, SDVariable weights) { SDValidation.validateNumerical("hingeLoss", "label", label); SDValidation.validateNumerical("hingeLoss", "predictions", predictions); SDValidation.validateNumerical("hingeLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return out; } /** - * Hinge loss: a loss function used for training classifiers.
- * Implements {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting to {-1,1}
- * from the user specified {0,1}. Note that Labels should be provided with values {0,1}.
+ * Hinge loss: a loss function used for training classifiers.
Implements + * {@code L = max(0, 1 - t * predictions)} where t is the label values after internally converting + * to {-1,1}
from the user specified {0,1}. Note that Labels should be provided with values + * {0,1}.
* - * @param name name May be null. Name for the output variable - * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (internally -1 to 1 is used) + * (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ public SDVariable hingeLoss(String name, SDVariable label, SDVariable predictions, @@ -280,25 +318,27 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("hingeLoss", "label", label); SDValidation.validateNumerical("hingeLoss", "predictions", predictions); SDValidation.validateNumerical("hingeLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HingeLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
- * though is less sensitive to outliers than squared error.
+ * Huber loss function, used for robust regression. It is similar both squared error loss and + * absolute difference loss,
though is less sensitive to outliers than squared error.
* Huber loss implements:
*


* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
*

* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param delta Loss function delta value + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value * @return output Huber loss (NUMERIC type) */ public SDVariable huberLoss(SDVariable label, SDVariable predictions, SDVariable weights, @@ -306,26 +346,28 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("huberLoss", "label", label); SDValidation.validateNumerical("huberLoss", "predictions", predictions); SDValidation.validateNumerical("huberLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, lossReduce, delta).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd, label, predictions, + weights, lossReduce, delta).outputVariable(); out.markAsLoss(); return out; } /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
- * though is less sensitive to outliers than squared error.
+ * Huber loss function, used for robust regression. It is similar both squared error loss and + * absolute difference loss,
though is less sensitive to outliers than squared error.
* Huber loss implements:
*

* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
*

* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param delta Loss function delta value + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param delta Loss function delta value * @return output Huber loss (NUMERIC type) */ public SDVariable huberLoss(String name, SDVariable label, SDVariable predictions, @@ -333,24 +375,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("huberLoss", "label", label); SDValidation.validateNumerical("huberLoss", "predictions", predictions); SDValidation.validateNumerical("huberLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, lossReduce, delta).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd, label, predictions, + weights, lossReduce, delta).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
- * though is less sensitive to outliers than squared error.
+ * Huber loss function, used for robust regression. It is similar both squared error loss and + * absolute difference loss,
though is less sensitive to outliers than squared error.
* Huber loss implements:
*

* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
*

* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param delta Loss function delta value + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value * @return output Huber loss (NUMERIC type) */ public SDVariable huberLoss(SDVariable label, SDVariable predictions, SDVariable weights, @@ -358,25 +401,27 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("huberLoss", "label", label); SDValidation.validateNumerical("huberLoss", "predictions", predictions); SDValidation.validateNumerical("huberLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + delta).outputVariable(); out.markAsLoss(); return out; } /** - * Huber loss function, used for robust regression. It is similar both squared error loss and absolute difference loss,
- * though is less sensitive to outliers than squared error.
+ * Huber loss function, used for robust regression. It is similar both squared error loss and + * absolute difference loss,
though is less sensitive to outliers than squared error.
* Huber loss implements:
*

* {@code L = 0.5 * (label[i] - predictions[i])^2 if abs(label[i] - predictions[i]) < delta}
* {@code L = delta * abs(label[i] - predictions[i]) - 0.5 * delta^2 otherwise}
*

* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param delta Loss function delta value + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param delta Loss function delta value * @return output Huber loss (NUMERIC type) */ public SDVariable huberLoss(String name, SDVariable label, SDVariable predictions, @@ -384,7 +429,9 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("huberLoss", "label", label); SDValidation.validateNumerical("huberLoss", "predictions", predictions); SDValidation.validateNumerical("huberLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, delta).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.HuberLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + delta).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } @@ -397,7 +444,7 @@ public class SDLoss extends SDOps { */ public SDVariable l2Loss(SDVariable var) { SDValidation.validateNumerical("l2Loss", "var", var); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd,var).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd, var).outputVariable(); out.markAsLoss(); return out; } @@ -406,25 +453,28 @@ public class SDLoss extends SDOps { * L2 loss: 1/2 * sum(x^2)
* * @param name name May be null. Name for the output variable - * @param var Variable to calculate L2 loss of (NUMERIC type) + * @param var Variable to calculate L2 loss of (NUMERIC type) * @return output L2 loss (NUMERIC type) */ public SDVariable l2Loss(String name, SDVariable var) { SDValidation.validateNumerical("l2Loss", "var", var); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd,var).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.L2Loss(sd, var).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. + * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * + * log(1-predictions[i] + epsilon))}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param epsilon epsilon + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon * @return output Log loss (NUMERIC type) */ public SDVariable logLoss(SDVariable label, SDVariable predictions, SDVariable weights, @@ -432,21 +482,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logLoss", "label", label); SDValidation.validateNumerical("logLoss", "predictions", predictions); SDValidation.validateNumerical("logLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, weights, lossReduce, epsilon).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd, label, predictions, weights, + lossReduce, epsilon).outputVariable(); out.markAsLoss(); return out; } /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. + * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * + * log(1-predictions[i] + epsilon))}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param epsilon epsilon + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param epsilon epsilon * @return output Log loss (NUMERIC type) */ public SDVariable logLoss(String name, SDVariable label, SDVariable predictions, @@ -454,53 +508,61 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logLoss", "label", label); SDValidation.validateNumerical("logLoss", "predictions", predictions); SDValidation.validateNumerical("logLoss", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, weights, lossReduce, epsilon).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd, label, predictions, weights, + lossReduce, epsilon).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. + * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * + * log(1-predictions[i] + epsilon))}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) * @return output Log loss (NUMERIC type) */ public SDVariable logLoss(SDVariable label, SDVariable predictions) { SDValidation.validateNumerical("logLoss", "label", label); SDValidation.validateNumerical("logLoss", "predictions", predictions); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd, label, predictions, null, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); out.markAsLoss(); return out; } /** - * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * log(1-predictions[i] + epsilon))}
+ * Log loss, i.e., binary cross entropy loss, usually used for binary multi-label classification. + * Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(predictions[i] + epsilon) + (1-labels[i]) * + * log(1-predictions[i] + epsilon))}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) * @return output Log loss (NUMERIC type) */ public SDVariable logLoss(String name, SDVariable label, SDVariable predictions) { SDValidation.validateNumerical("logLoss", "label", label); SDValidation.validateNumerical("logLoss", "predictions", predictions); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd,label, predictions, null, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogLoss(sd, label, predictions, null, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Log poisson loss: a loss function used for training classifiers.
- * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * Log poisson loss: a loss function used for training classifiers.
Implements + * {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
* - * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson * @return output Loss variable (NUMERIC type) */ public SDVariable logPoisson(SDVariable label, SDVariable predictions, SDVariable weights, @@ -508,21 +570,23 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logPoisson", "label", label); SDValidation.validateNumerical("logPoisson", "predictions", predictions); SDValidation.validateNumerical("logPoisson", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, lossReduce, full).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd, label, predictions, + weights, lossReduce, full).outputVariable(); out.markAsLoss(); return out; } /** - * Log poisson loss: a loss function used for training classifiers.
- * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * Log poisson loss: a loss function used for training classifiers.
Implements + * {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
* - * @param name name May be null. Name for the output variable - * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param full Boolean flag. true for logPoissonFull, false for logPoisson * @return output Loss variable (NUMERIC type) */ public SDVariable logPoisson(String name, SDVariable label, SDVariable predictions, @@ -530,19 +594,20 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logPoisson", "label", label); SDValidation.validateNumerical("logPoisson", "predictions", predictions); SDValidation.validateNumerical("logPoisson", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, lossReduce, full).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd, label, predictions, + weights, lossReduce, full).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Log poisson loss: a loss function used for training classifiers.
- * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * Log poisson loss: a loss function used for training classifiers.
Implements + * {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
* - * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson * @return output Loss variable (NUMERIC type) */ public SDVariable logPoisson(SDVariable label, SDVariable predictions, SDVariable weights, @@ -550,20 +615,22 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logPoisson", "label", label); SDValidation.validateNumerical("logPoisson", "predictions", predictions); SDValidation.validateNumerical("logPoisson", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + full).outputVariable(); out.markAsLoss(); return out; } /** - * Log poisson loss: a loss function used for training classifiers.
- * Implements {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
+ * Log poisson loss: a loss function used for training classifiers.
Implements + * {@code L = exp(c) - z * c} where c is log(predictions) and z is labels.
* - * @param name name May be null. Name for the output variable - * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array. Each value should be 0.0 or 1.0 (NUMERIC type) * @param predictions Predictions array (has to be log(x) of actual predictions) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param full Boolean flag. true for logPoissonFull, false for logPoisson + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param full Boolean flag. true for logPoissonFull, false for logPoisson * @return output Loss variable (NUMERIC type) */ public SDVariable logPoisson(String name, SDVariable label, SDVariable predictions, @@ -571,21 +638,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("logPoisson", "label", label); SDValidation.validateNumerical("logPoisson", "predictions", predictions); SDValidation.validateNumerical("logPoisson", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, full).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss(sd, label, predictions, + weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + full).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
- * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * Mean pairwise squared error.
MPWSE loss calculates the difference between pairs of + * consecutive elements in the predictions and labels arrays.
For example, if predictions = + * [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either + * null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable, scalar output (NUMERIC type) */ public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, @@ -593,22 +664,25 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return out; } /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
- * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * Mean pairwise squared error.
MPWSE loss calculates the difference between pairs of + * consecutive elements in the predictions and labels arrays.
For example, if predictions = + * [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either + * null, scalar, or have shape [batchSize] (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable, scalar output (NUMERIC type) */ public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, @@ -616,20 +690,22 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
- * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * Mean pairwise squared error.
MPWSE loss calculates the difference between pairs of + * consecutive elements in the predictions and labels arrays.
For example, if predictions = + * [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either + * null, scalar, or have shape [batchSize] (NUMERIC type) * @return output Loss variable, scalar output (NUMERIC type) */ public SDVariable meanPairwiseSquaredError(SDVariable label, SDVariable predictions, @@ -637,21 +713,24 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return out; } /** - * Mean pairwise squared error.
- * MPWSE loss calculates the difference between pairs of consecutive elements in the predictions and labels arrays.
- * For example, if predictions = [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
+ * Mean pairwise squared error.
MPWSE loss calculates the difference between pairs of + * consecutive elements in the predictions and labels arrays.
For example, if predictions = + * [p0, p1, p2] and labels are [l0, l1, l2] then MPWSE is:
* {@code [((p0-p1) - (l0-l1))^2 + ((p0-p2) - (l0-l2))^2 + ((p1-p2) - (l1-l2))^2] / 3}
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either null, scalar, or have shape [batchSize] (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used. Must be either + * null, scalar, or have shape [batchSize] (NUMERIC type) * @return output Loss variable, scalar output (NUMERIC type) */ public SDVariable meanPairwiseSquaredError(String name, SDVariable label, SDVariable predictions, @@ -659,20 +738,24 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanPairwiseSquaredError", "label", label); SDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanPairwiseSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
- * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
- * this is the mean squared error loss function.
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., + * squared error on a per-element basis.
When averaged (using + * {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the + * default))
this is the mean squared error loss function.
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable (NUMERIC type) */ public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, @@ -680,21 +763,24 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanSquaredError", "label", label); SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return out; } /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
- * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
- * this is the mean squared error loss function.
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., + * squared error on a per-element basis.
When averaged (using + * {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the + * default))
this is the mean squared error loss function.
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} * @return output Loss variable (NUMERIC type) */ public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, @@ -702,39 +788,44 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanSquaredError", "label", label); SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, lossReduce).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd, label, + predictions, weights, lossReduce).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
- * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
- * this is the mean squared error loss function.
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., + * squared error on a per-element basis.
When averaged (using + * {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the + * default))
this is the mean squared error loss function.
* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ public SDVariable meanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights) { SDValidation.validateNumerical("meanSquaredError", "label", label); SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return out; } /** - * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., squared error on a per-element basis.
- * When averaged (using {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the default))
- * this is the mean squared error loss function.
+ * Mean squared error loss function. Implements {@code (label[i] - prediction[i])^2} - i.e., + * squared error on a per-element basis.
When averaged (using + * {@link LossReduce#MEAN_BY_WEIGHT} or {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} (the + * default))
this is the mean squared error loss function.
* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictions Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ public SDVariable meanSquaredError(String name, SDVariable label, SDVariable predictions, @@ -742,30 +833,35 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("meanSquaredError", "label", label); SDValidation.validateNumerical("meanSquaredError", "predictions", predictions); SDValidation.validateNumerical("meanSquaredError", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd,label, predictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss(sd, label, + predictions, weights, + org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
- * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
- * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
- * though this is done in a mathematically equivalent but more numerical stable form.
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input + * "pre-sigmoid preductions")
and implements the binary cross entropy loss function. This + * implementation is numerically more stable than using
standard (but separate) sigmoid + * activation function and log loss (binary cross entropy) loss function.
Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * + * log(1-sigmoid(logits[i])))}
though this is done in a mathematically equivalent but more + * numerical stable form.
*
- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*

* {@code numClasses = labels.size(1);
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
*

* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictionLogits Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 * @return output Loss variable (NUMERIC type) */ public SDVariable sigmoidCrossEntropy(SDVariable label, SDVariable predictionLogits, @@ -773,31 +869,35 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd, label, + predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); out.markAsLoss(); return out; } /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
- * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
- * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
- * though this is done in a mathematically equivalent but more numerical stable form.
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input + * "pre-sigmoid preductions")
and implements the binary cross entropy loss function. This + * implementation is numerically more stable than using
standard (but separate) sigmoid + * activation function and log loss (binary cross entropy) loss function.
Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * + * log(1-sigmoid(logits[i])))}
though this is done in a mathematically equivalent but more + * numerical stable form.
*
- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*

* {@code numClasses = labels.size(1);
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
*

* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictionLogits Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 * @return output Loss variable (NUMERIC type) */ public SDVariable sigmoidCrossEntropy(String name, SDVariable label, SDVariable predictionLogits, @@ -805,28 +905,31 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd, label, + predictionLogits, weights, lossReduce, labelSmoothing).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
- * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
- * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
- * though this is done in a mathematically equivalent but more numerical stable form.
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input + * "pre-sigmoid preductions")
and implements the binary cross entropy loss function. This + * implementation is numerically more stable than using
standard (but separate) sigmoid + * activation function and log loss (binary cross entropy) loss function.
Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * + * log(1-sigmoid(logits[i])))}
though this is done in a mathematically equivalent but more + * numerical stable form.
*
- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*

* {@code numClasses = labels.size(1);
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
*

* - * @param label Label array (NUMERIC type) + * @param label Label array (NUMERIC type) * @param predictionLogits Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) * @return output Loss variable (NUMERIC type) */ public SDVariable sigmoidCrossEntropy(SDVariable label, SDVariable predictionLogits, @@ -834,29 +937,33 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd, label, + predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + 0.0).outputVariable(); out.markAsLoss(); return out; } /** - * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input "pre-sigmoid preductions")
- * and implements the binary cross entropy loss function. This implementation is numerically more stable than using
- * standard (but separate) sigmoid activation function and log loss (binary cross entropy) loss function.
- * Implements:
- * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * log(1-sigmoid(logits[i])))}
- * though this is done in a mathematically equivalent but more numerical stable form.
+ * Sigmoid cross entropy: applies the sigmoid activation function on the input logits (input + * "pre-sigmoid preductions")
and implements the binary cross entropy loss function. This + * implementation is numerically more stable than using
standard (but separate) sigmoid + * activation function and log loss (binary cross entropy) loss function.
Implements:
+ * {@code -1/numExamples * sum_i (labels[i] * log(sigmoid(logits[i])) + (1-labels[i]) * + * log(1-sigmoid(logits[i])))}
though this is done in a mathematically equivalent but more + * numerical stable form.
*
- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*

* {@code numClasses = labels.size(1);
* label = (1.0 - labelSmoothing) * label + 0.5 * labelSmoothing}
*

* - * @param name name May be null. Name for the output variable - * @param label Label array (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param label Label array (NUMERIC type) * @param predictionLogits Predictions array (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) * @return output Loss variable (NUMERIC type) */ public SDVariable sigmoidCrossEntropy(String name, SDVariable label, SDVariable predictionLogits, @@ -864,28 +971,33 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("sigmoidCrossEntropy", "label", label); SDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", predictionLogits); SDValidation.validateNumerical("sigmoidCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd,label, predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss(sd, label, + predictionLogits, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + 0.0).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
- * otherwise, the output is a scalar.
+ * Applies the softmax activation function to the input, then implement multi-class cross + * entropy:
{@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
If + * {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, + * numClasses] predicitons/labels;
otherwise, the output is a scalar.
*


- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*


* {@code numClasses = labels.size(1);
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
*

* - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param oneHotLabels Label array. Should be one-hot per example and same shape as + * predictions (for example, [mb, nOut]) (NUMERIC type) * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 * @return output Loss variable (NUMERIC type) */ public SDVariable softmaxCrossEntropy(SDVariable oneHotLabels, SDVariable logitPredictions, @@ -893,29 +1005,33 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd, oneHotLabels, + logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); out.markAsLoss(); return out; } /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
- * otherwise, the output is a scalar.
+ * Applies the softmax activation function to the input, then implement multi-class cross + * entropy:
{@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
If + * {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, + * numClasses] predicitons/labels;
otherwise, the output is a scalar.
*


- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*


* {@code numClasses = labels.size(1);
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
*

* - * @param name name May be null. Name for the output variable - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param oneHotLabels Label array. Should be one-hot per example and same shape as + * predictions (for example, [mb, nOut]) (NUMERIC type) * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) - * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} - * @param labelSmoothing Label smoothing value. Default value: 0 + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) + * @param lossReduce Reduction type for the loss. See {@link LossReduce} for more details. + * Default: {@link LossReduce#MEAN_BY_NONZERO_WEIGHT_COUNT} + * @param labelSmoothing Label smoothing value. Default value: 0 * @return output Loss variable (NUMERIC type) */ public SDVariable softmaxCrossEntropy(String name, SDVariable oneHotLabels, @@ -924,26 +1040,29 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd, oneHotLabels, + logitPredictions, weights, lossReduce, labelSmoothing).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
- * otherwise, the output is a scalar.
+ * Applies the softmax activation function to the input, then implement multi-class cross + * entropy:
{@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
If + * {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, + * numClasses] predicitons/labels;
otherwise, the output is a scalar.
*


- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*


* {@code numClasses = labels.size(1);
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
*

* - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param oneHotLabels Label array. Should be one-hot per example and same shape as + * predictions (for example, [mb, nOut]) (NUMERIC type) * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) * @return output Loss variable (NUMERIC type) */ public SDVariable softmaxCrossEntropy(SDVariable oneHotLabels, SDVariable logitPredictions, @@ -951,27 +1070,31 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd, oneHotLabels, + logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + 0.0).outputVariable(); out.markAsLoss(); return out; } /** - * Applies the softmax activation function to the input, then implement multi-class cross entropy:
- * {@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
- * If {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, numClasses] predicitons/labels;
- * otherwise, the output is a scalar.
+ * Applies the softmax activation function to the input, then implement multi-class cross + * entropy:
{@code -sum_classes label[i] * log(p[c])} where {@code p = softmax(logits)}
If + * {@link LossReduce#NONE} is used, returned shape is [numExamples] out for [numExamples, + * numClasses] predicitons/labels;
otherwise, the output is a scalar.
*


- * When label smoothing is > 0, the following label smoothing is used:
+ * When label smoothing is > 0, the following label smoothing is used:
*


* {@code numClasses = labels.size(1);
* oneHotLabel = (1.0 - labelSmoothing) * oneHotLabels + labelSmoothing/numClasses}
*

* - * @param name name May be null. Name for the output variable - * @param oneHotLabels Label array. Should be one-hot per example and same shape as predictions (for example, [mb, nOut]) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param oneHotLabels Label array. Should be one-hot per example and same shape as + * predictions (for example, [mb, nOut]) (NUMERIC type) * @param logitPredictions Predictions array (pre-softmax) (NUMERIC type) - * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) + * @param weights Weights array. May be null. If null, a weight of 1.0 is used (NUMERIC + * type) * @return output Loss variable (NUMERIC type) */ public SDVariable softmaxCrossEntropy(String name, SDVariable oneHotLabels, @@ -979,14 +1102,16 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", oneHotLabels); SDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", logitPredictions); SDValidation.validateNumerical("softmaxCrossEntropy", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd,oneHotLabels, logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss(sd, oneHotLabels, + logitPredictions, weights, org.nd4j.autodiff.loss.LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, + 0.0).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } /** - * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
- * is represented as an integer array instead of the equivalent one-hot array.
+ * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels + * variable
is represented as an integer array instead of the equivalent one-hot array.
* i.e., if logits are rank N, then labels have rank N-1
* * @param logits Logits array ("pre-softmax activations") (NUMERIC type) @@ -996,17 +1121,18 @@ public class SDLoss extends SDOps { public SDVariable sparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels) { SDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); SDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(sd,logits, labels).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits( + sd, logits, labels).outputVariable(); out.markAsLoss(); return out; } /** - * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels variable
- * is represented as an integer array instead of the equivalent one-hot array.
+ * As per softmaxCrossEntropy(String, SDVariable, SDVariable, LossReduce) but the labels + * variable
is represented as an integer array instead of the equivalent one-hot array.
* i.e., if logits are rank N, then labels have rank N-1
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param logits Logits array ("pre-softmax activations") (NUMERIC type) * @param labels Labels array. Must be an integer type. (INT type) * @return output Softmax cross entropy (NUMERIC type) @@ -1014,7 +1140,8 @@ public class SDLoss extends SDOps { public SDVariable sparseSoftmaxCrossEntropy(String name, SDVariable logits, SDVariable labels) { SDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", logits); SDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", labels); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits(sd,logits, labels).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits( + sd, logits, labels).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } @@ -1023,7 +1150,7 @@ public class SDLoss extends SDOps { * Weighted cross entropy loss with logits
* * @param targets targets array (NUMERIC type) - * @param inputs input array (NUMERIC type) + * @param inputs input array (NUMERIC type) * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ @@ -1032,7 +1159,8 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd,targets, inputs, weights).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd, targets, + inputs, weights).outputVariable(); out.markAsLoss(); return out; } @@ -1040,9 +1168,9 @@ public class SDLoss extends SDOps { /** * Weighted cross entropy loss with logits
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param targets targets array (NUMERIC type) - * @param inputs input array (NUMERIC type) + * @param inputs input array (NUMERIC type) * @param weights eights array. May be null. If null, a weight of 1.0 is used (NUMERIC type) * @return output Loss variable (NUMERIC type) */ @@ -1051,7 +1179,8 @@ public class SDLoss extends SDOps { SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "targets", targets); SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "inputs", inputs); SDValidation.validateNumerical("weightedCrossEntropyWithLogits", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd,targets, inputs, weights).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss(sd, targets, + inputs, weights).outputVariable(); out.markAsLoss(); return sd.updateVariableNameAndReference(out, name); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 5c3579396..bbef06cfb 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -29,6 +29,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.indexing.conditions.Condition; public class SDMath extends SDOps { + public SDMath(SameDiff sameDiff) { super(sameDiff); } @@ -36,53 +37,60 @@ public class SDMath extends SDOps { /** * Clips tensor values to a maximum average L2-norm.
* - * @param x Input variable (NUMERIC type) - * @param clipValue Value for clipping + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable clipByAvgNorm(SDVariable x, double clipValue, int... dimensions) { SDValidation.validateNumerical("ClipByAvgNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd, x, clipValue, + dimensions).outputVariable(); } /** * Clips tensor values to a maximum average L2-norm.
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param clipValue Value for clipping + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable clipByAvgNorm(String name, SDVariable x, double clipValue, int... dimensions) { SDValidation.validateNumerical("ClipByAvgNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd, x, + clipValue, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Looks up ids in a list of embedding tensors.
* - * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); SDValidation.validateInteger("EmbeddingLookup", "indices", indices); - return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd, x, indices, + PartitionMode).outputVariable(); } /** * Looks up ids in a list of embedding tensors.
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ @@ -90,35 +98,39 @@ public class SDMath extends SDOps { PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); SDValidation.validateInteger("EmbeddingLookup", "indices", indices); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd, x, + indices, PartitionMode).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Return array of max elements indices with along tensor dimensions
* - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @param dataType Data type * @return output Array max elements indices with along dimensions. (INT type) */ public SDVariable mergeMaxIndex(SDVariable[] x, DataType dataType) { SDValidation.validateNumerical("MergeMaxIndex", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, dataType).outputVariable(); + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd, x, dataType).outputVariable(); } /** * Return array of max elements indices with along tensor dimensions
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) * @param dataType Data type * @return output Array max elements indices with along dimensions. (INT type) */ public SDVariable mergeMaxIndex(String name, SDVariable[] x, DataType dataType) { SDValidation.validateNumerical("MergeMaxIndex", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, dataType).outputVariable(); + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd, x, + dataType).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -130,21 +142,25 @@ public class SDMath extends SDOps { */ public SDVariable mergeMaxIndex(SDVariable... x) { SDValidation.validateNumerical("MergeMaxIndex", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, DataType.INT).outputVariable(); + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd, x, + DataType.INT).outputVariable(); } /** * Return array of max elements indices with along tensor dimensions
* * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @return output Array max elements indices with along dimensions. (INT type) */ public SDVariable mergeMaxIndex(String name, SDVariable... x) { SDValidation.validateNumerical("MergeMaxIndex", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd,x, DataType.INT).outputVariable(); + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMaxIndex(sd, x, + DataType.INT).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -156,19 +172,19 @@ public class SDMath extends SDOps { */ public SDVariable abs(SDVariable x) { SDValidation.validateNumerical("abs", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd, x).outputVariable(); } /** * Elementwise absolute value operation: out = abs(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable abs(String name, SDVariable x) { SDValidation.validateNumerical("abs", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Abs(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -180,19 +196,20 @@ public class SDMath extends SDOps { */ public SDVariable acos(SDVariable x) { SDValidation.validateNumerical("acos", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd, x).outputVariable(); } /** * Elementwise acos (arccosine, inverse cosine) operation: out = arccos(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable acos(String name, SDVariable x) { SDValidation.validateNumerical("acos", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACos(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -204,28 +221,30 @@ public class SDMath extends SDOps { */ public SDVariable acosh(SDVariable x) { SDValidation.validateNumerical("acosh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd, x).outputVariable(); } /** * Elementwise acosh (inverse hyperbolic cosine) function: out = acosh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable acosh(String name, SDVariable x) { SDValidation.validateNumerical("acosh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise addition operation, out = x + y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -234,169 +253,204 @@ public class SDMath extends SDOps { public SDVariable add(SDVariable x, SDVariable y) { SDValidation.validateNumerical("add", "x", x); SDValidation.validateNumerical("add", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd, x, + y).outputVariable(); } /** * Pairwise addition operation, out = x + y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable add(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("add", "x", x); SDValidation.validateNumerical("add", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar add operation, out = in + scalar
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable add(SDVariable x, double value) { SDValidation.validateNumerical("add", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd, x, value).outputVariable(); } /** * Scalar add operation, out = in + scalar
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable add(String name, SDVariable x, double value) { SDValidation.validateNumerical("add", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * Absolute max array reduction operation, optionally along specified dimensions: out = + * max(abs(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amax(SDVariable in, int... dimensions) { SDValidation.validateNumerical("amax", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd, in, dimensions).outputVariable(); } /** - * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
+ * Absolute max array reduction operation, optionally along specified dimensions: out = + * max(abs(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amax(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("amax", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMax(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * Absolute mean array reduction operation, optionally along specified dimensions: out = + * mean(abs(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amean(SDVariable in, int... dimensions) { SDValidation.validateNumerical("amean", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd, in, + dimensions).outputVariable(); } /** - * Absolute mean array reduction operation, optionally along specified dimensions: out = mean(abs(x))
+ * Absolute mean array reduction operation, optionally along specified dimensions: out = + * mean(abs(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amean(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("amean", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.AMean(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * Absolute min array reduction operation, optionally along specified dimensions: out = + * min(abs(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amin(SDVariable in, int... dimensions) { SDValidation.validateNumerical("amin", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd, in, dimensions).outputVariable(); } /** - * Absolute min array reduction operation, optionally along specified dimensions: out = min(abs(x))
+ * Absolute min array reduction operation, optionally along specified dimensions: out = + * min(abs(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable amin(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("amin", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.AMin(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Boolean AND operation: elementwise (x != 0) && (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * Boolean AND operation: {@code elementwise (x != 0) && (y != 0)}
If x and y arrays have + * equal shape, the output shape is the same as these inputs.
Note: supports broadcasting if x + * and y have different shapes and are broadcastable.
Returns an array with values 1 where + * condition is satisfied, or value 0 otherwise.
* * @param x Input 1 (BOOL type) * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable and(SDVariable x, SDVariable y) { SDValidation.validateBool("and", "x", x); SDValidation.validateBool("and", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd, x, y).outputVariable(); } /** - * Boolean AND operation: elementwise (x != 0) && (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * Boolean AND operation: {@code elementwise (x != 0) && (y != 0)}
If x and y arrays have + * equal shape, the output shape is the same as these inputs.
Note: supports broadcasting if x + * and y have different shapes and are broadcastable.
Returns an array with values 1 where + * condition is satisfied, or value 0 otherwise.
* * @param name name May be null. Name for the output variable - * @param x Input 1 (BOOL type) - * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable and(String name, SDVariable x, SDVariable y) { SDValidation.validateBool("and", "x", x); SDValidation.validateBool("and", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -408,19 +462,20 @@ public class SDMath extends SDOps { */ public SDVariable asin(SDVariable x) { SDValidation.validateNumerical("asin", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd, x).outputVariable(); } /** * Elementwise asin (arcsin, inverse sine) operation: out = arcsin(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable asin(String name, SDVariable x) { SDValidation.validateNumerical("asin", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASin(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -432,47 +487,57 @@ public class SDMath extends SDOps { */ public SDVariable asinh(SDVariable x) { SDValidation.validateNumerical("asinh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd, x).outputVariable(); } /** * Elementwise asinh (inverse hyperbolic sine) function: out = asinh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable asinh(String name, SDVariable x) { SDValidation.validateNumerical("asinh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ASinh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * Absolute sum array reduction operation, optionally along specified dimensions: out = + * sum(abs(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable asum(SDVariable in, int... dimensions) { SDValidation.validateNumerical("asum", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd, in, dimensions).outputVariable(); } /** - * Absolute sum array reduction operation, optionally along specified dimensions: out = sum(abs(x))
+ * Absolute sum array reduction operation, optionally along specified dimensions: out = + * sum(abs(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable asum(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("asum", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.same.ASum(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -484,25 +549,26 @@ public class SDMath extends SDOps { */ public SDVariable atan(SDVariable x) { SDValidation.validateNumerical("atan", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd, x).outputVariable(); } /** * Elementwise atan (arctangent, inverse tangent) operation: out = arctangent(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable atan(String name, SDVariable x) { SDValidation.validateNumerical("atan", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATan(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
- * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
Similar to + * atan(y/x) but sigts of x and y are used to determine the location of the result
* * @param y Input Y variable (NUMERIC type) * @param x Input X variable (NUMERIC type) @@ -511,22 +577,23 @@ public class SDMath extends SDOps { public SDVariable atan2(SDVariable y, SDVariable x) { SDValidation.validateNumerical("atan2", "y", y); SDValidation.validateNumerical("atan2", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd,y, x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd, y, x).outputVariable(); } /** - * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
- * Similar to atan(y/x) but sigts of x and y are used to determine the location of the result
+ * Elementwise atan (arctangent, inverse tangent) operation: out = atan2(x,y).
Similar to + * atan(y/x) but sigts of x and y are used to determine the location of the result
* * @param name name May be null. Name for the output variable - * @param y Input Y variable (NUMERIC type) - * @param x Input X variable (NUMERIC type) + * @param y Input Y variable (NUMERIC type) + * @param x Input X variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable atan2(String name, SDVariable y, SDVariable x) { SDValidation.validateNumerical("atan2", "y", y); SDValidation.validateNumerical("atan2", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd,y, x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2(sd, y, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -538,220 +605,236 @@ public class SDMath extends SDOps { */ public SDVariable atanh(SDVariable x) { SDValidation.validateNumerical("atanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd, x).outputVariable(); } /** * Elementwise atanh (inverse hyperbolic tangent) function: out = atanh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable atanh(String name, SDVariable x) { SDValidation.validateNumerical("atanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ATanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Bit shift operation
* - * @param x input (NUMERIC type) + * @param x input (NUMERIC type) * @param shift shift value (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShift(SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShift", "x", x); SDValidation.validateNumerical("bitShift", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + shift).outputVariable(); } /** * Bit shift operation
* - * @param name name May be null. Name for the output variable - * @param x input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x input (NUMERIC type) * @param shift shift value (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShift(String name, SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShift", "x", x); SDValidation.validateNumerical("bitShift", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Right bit shift operation
* - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @param shift shift argument (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShiftRight(SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRight", "x", x); SDValidation.validateNumerical("bitShiftRight", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + shift).outputVariable(); } /** * Right bit shift operation
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) * @param shift shift argument (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShiftRight(String name, SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRight", "x", x); SDValidation.validateNumerical("bitShiftRight", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Cyclic bit shift operation
* - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @param shift shift argy=ument (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShiftRotl(SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRotl", "x", x); SDValidation.validateNumerical("bitShiftRotl", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + shift).outputVariable(); } /** * Cyclic bit shift operation
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) * @param shift shift argy=ument (NUMERIC type) * @return output shifted output (NUMERIC type) */ public SDVariable bitShiftRotl(String name, SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRotl", "x", x); SDValidation.validateNumerical("bitShiftRotl", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Cyclic right shift operation
* - * @param x Input tensor (NUMERIC type) + * @param x Input tensor (NUMERIC type) * @param shift Shift argument (NUMERIC type) * @return output Shifted output (NUMERIC type) */ public SDVariable bitShiftRotr(SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRotr", "x", x); SDValidation.validateNumerical("bitShiftRotr", "shift", shift); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + shift).outputVariable(); } /** * Cyclic right shift operation
* - * @param name name May be null. Name for the output variable - * @param x Input tensor (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) * @param shift Shift argument (NUMERIC type) * @return output Shifted output (NUMERIC type) */ public SDVariable bitShiftRotr(String name, SDVariable x, SDVariable shift) { SDValidation.validateNumerical("bitShiftRotr", "x", x); SDValidation.validateNumerical("bitShiftRotr", "shift", shift); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd,x, shift).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits(sd, x, + shift).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise ceiling function: out = ceil(x).
- * Rounds each value up to the nearest integer value (if not already an integer)
+ * Element-wise ceiling function: out = ceil(x).
Rounds each value up to the nearest integer + * value (if not already an integer)
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable ceil(SDVariable x) { SDValidation.validateNumerical("ceil", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd, x).outputVariable(); } /** - * Element-wise ceiling function: out = ceil(x).
- * Rounds each value up to the nearest integer value (if not already an integer)
+ * Element-wise ceiling function: out = ceil(x).
Rounds each value up to the nearest integer + * value (if not already an integer)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable ceil(String name, SDVariable x) { SDValidation.validateNumerical("ceil", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Ceil(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Clipping by L2 norm, optionally along dimension(s)
- * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
- * to the corresponding l2Norm along the specified dimensions
+ * Clipping by L2 norm, optionally along dimension(s)
if l2Norm(x,dimension) < clipValue, + * then input is returned unmodifed
Otherwise, out[i] = in[i] * clipValue / l2Norm(in, + * dimensions) where each value is clipped according
to the corresponding l2Norm along the + * specified dimensions
* - * @param x Input variable (NUMERIC type) - * @param clipValue Clipping value (maximum l2 norm) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { SDValidation.validateNumerical("clipByNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd,x, clipValue, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd, x, clipValue, + dimensions).outputVariable(); } /** - * Clipping by L2 norm, optionally along dimension(s)
- * if l2Norm(x,dimension) < clipValue, then input is returned unmodifed
- * Otherwise, out[i] = in[i] * clipValue / l2Norm(in, dimensions) where each value is clipped according
- * to the corresponding l2Norm along the specified dimensions
+ * Clipping by L2 norm, optionally along dimension(s)
if l2Norm(x,dimension) < clipValue, + * then input is returned unmodifed
Otherwise, out[i] = in[i] * clipValue / l2Norm(in, + * dimensions) where each value is clipped according
to the corresponding l2Norm along the + * specified dimensions
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param clipValue Clipping value (maximum l2 norm) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Clipping value (maximum l2 norm) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable clipByNorm(String name, SDVariable x, double clipValue, int... dimensions) { SDValidation.validateNumerical("clipByNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd,x, clipValue, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm(sd, x, clipValue, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise clipping function:
- * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
- * out[i] = clipValueMin if in[i] < clipValueMin
- * out[i] = clipValueMax if in[i] > clipValueMax
+ * Element-wise clipping function:
out[i] = in[i] if in[i] >= clipValueMin and in[i] <= + * clipValueMax
out[i] = clipValueMin if in[i] < clipValueMin
out[i] = clipValueMax if + * in[i] > clipValueMax
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param clipValueMin Minimum value for clipping * @param clipValueMax Maximum value for clipping * @return output Output variable (NUMERIC type) */ public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { SDValidation.validateNumerical("clipByValue", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd,x, clipValueMin, clipValueMax).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd, x, clipValueMin, + clipValueMax).outputVariable(); } /** * Element-wise clipping function:
- * out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax
- * out[i] = clipValueMin if in[i] < clipValueMin
- * out[i] = clipValueMax if in[i] > clipValueMax
+ * {@code out[i] = in[i] if in[i] >= clipValueMin and in[i] <= clipValueMax out[i] = clipValueMin + * if in[i] < clipValueMin out[i] = clipValueMax if in[i] > clipValueMax} * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param clipValueMin Minimum value for clipping * @param clipValueMax Maximum value for clipping * @return output Output variable (NUMERIC type) @@ -759,40 +842,40 @@ public class SDMath extends SDOps { public SDVariable clipByValue(String name, SDVariable x, double clipValueMin, double clipValueMax) { SDValidation.validateNumerical("clipByValue", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd,x, clipValueMin, clipValueMax).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue(sd, x, + clipValueMin, clipValueMax).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
- * [1, 0, 0]
- * [0, 1, 1]
- * [0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values. This version assumes the + * number of classes is 1 + max(max(labels), max(pred))
For example, if labels = [0, 1, 1] and + * predicted = [0, 2, 1] then output is:
[1, 0, 0]
[0, 1, 1]
[0, 0, 0]
* - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length + * as labels (NUMERIC type) * @param dataType Data type * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, DataType dataType) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); - return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, dataType).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + dataType).outputVariable(); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1] and predicted = [0, 2, 1] then output is:
- * [1, 0, 0]
- * [0, 1, 1]
- * [0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values. This version assumes the + * number of classes is 1 + max(max(labels), max(pred))
For example, if labels = [0, 1, 1] and + * predicted = [0, 2, 1] then output is:
[1, 0, 0]
[0, 1, 1]
[0, 0, 0]
* - * @param name name May be null. Name for the output variable - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length + * as labels (NUMERIC type) * @param dataType Data type * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ @@ -800,42 +883,40 @@ public class SDMath extends SDOps { DataType dataType) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, dataType).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + dataType).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
- * [1, 0, 0, 0]
- * [0, 1, 1, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values.
For example, if labels = + * [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
[1, 0, 0, 0]
[0, 1, + * 1, 0]
[0, 0, 0, 0]
[0, 0, 0, 0]
* - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same + * length as labels (NUMERIC type) * @param numClasses Number of classes * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, int numClasses) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); - return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, numClasses).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + numClasses).outputVariable(); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
- * [1, 0, 0, 0]
- * [0, 1, 1, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values.
For example, if labels = + * [0, 1, 1], predicted = [0, 2, 1], and numClasses=4 then output is:
[1, 0, 0, 0]
[0, 1, + * 1, 0]
[0, 0, 0, 0]
[0, 0, 0, 0]
* - * @param name name May be null. Name for the output variable - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same + * length as labels (NUMERIC type) * @param numClasses Number of classes * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ @@ -843,42 +924,46 @@ public class SDMath extends SDOps { int numClasses) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, numClasses).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + numClasses).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
- * [1, 0, 0]
- * [0, 3, 2]
- * [0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values. This version assumes the + * number of classes is 1 + max(max(labels), max(pred))
For example, if labels = [0, 1, 1], + * predicted = [0, 2, 1] and weights = [1, 2, 3]
[1, 0, 0]
[0, 3, 2]
[0, 0, 0]
* - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length + * as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the + * weight/contribution of each prediction. Must be same length as both labels and + * predictions arrays (NUMERIC type) * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); SDValidation.validateNumerical("confusionMatrix", "weights", weights); - return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + weights).outputVariable(); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values. This version assumes the number of classes is 1 + max(max(labels), max(pred))
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1] and weights = [1, 2, 3]
- * [1, 0, 0]
- * [0, 3, 2]
- * [0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values. This version assumes the + * number of classes is 1 + max(max(labels), max(pred))
For example, if labels = [0, 1, 1], + * predicted = [0, 2, 1] and weights = [1, 2, 3]
[1, 0, 0]
[0, 3, 2]
[0, 0, 0]
* - * @param name name May be null. Name for the output variable - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same length + * as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the + * weight/contribution of each prediction. Must be same length as both labels and + * predictions arrays (NUMERIC type) * @return output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, @@ -886,23 +971,24 @@ public class SDMath extends SDOps { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); SDValidation.validateNumerical("confusionMatrix", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + weights).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
- * [1, 0, 0, 0]
- * [0, 3, 2, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values.
For example, if labels = + * [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
[1, 0, 0, 0]
+ * [0, 3, 2, 0]
[0, 0, 0, 0]
[0, 0, 0, 0]
* - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) - * @param numClasses + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same + * length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the + * weight/contribution of each prediction. Must be same length as both labels + * and predictions arrays (NUMERIC type) + * @param numClasses * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights, @@ -910,23 +996,24 @@ public class SDMath extends SDOps { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); SDValidation.validateNumerical("confusionMatrix", "weights", weights); - return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights, numClasses).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, weights, + numClasses).outputVariable(); } /** - * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and predictions, both of
- * which are represented as integer values.
- * For example, if labels = [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
- * [1, 0, 0, 0]
- * [0, 3, 2, 0]
- * [0, 0, 0, 0]
- * [0, 0, 0, 0]
+ * Compute the 2d confusion matrix of size [numClasses, numClasses] from a pair of labels and + * predictions, both of
which are represented as integer values.
For example, if labels = + * [0, 1, 1], predicted = [0, 2, 1], numClasses = 4, and weights = [1, 2, 3]
[1, 0, 0, 0]
+ * [0, 3, 2, 0]
[0, 0, 0, 0]
[0, 0, 0, 0]
* - * @param name name May be null. Name for the output variable - * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) - * @param pred Predictions - 1D array of integer values representing predictions. Same length as labels (NUMERIC type) - * @param weights Weights - 1D array of values (may be real/decimal) representing the weight/contribution of each prediction. Must be same length as both labels and predictions arrays (NUMERIC type) - * @param numClasses + * @param name name May be null. Name for the output variable + * @param labels Labels - 1D array of integer values representing label values (NUMERIC type) + * @param pred Predictions - 1D array of integer values representing predictions. Same + * length as labels (NUMERIC type) + * @param weights Weights - 1D array of values (may be real/decimal) representing the + * weight/contribution of each prediction. Must be same length as both labels + * and predictions arrays (NUMERIC type) + * @param numClasses * @return output Output variable (2D, shape [numClasses, numClasses}) (NUMERIC type) */ public SDVariable confusionMatrix(String name, SDVariable labels, SDVariable pred, @@ -934,7 +1021,8 @@ public class SDMath extends SDOps { SDValidation.validateNumerical("confusionMatrix", "labels", labels); SDValidation.validateNumerical("confusionMatrix", "pred", pred); SDValidation.validateNumerical("confusionMatrix", "weights", weights); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd,labels, pred, weights, numClasses).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.ConfusionMatrix(sd, labels, pred, + weights, numClasses).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -946,19 +1034,19 @@ public class SDMath extends SDOps { */ public SDVariable cos(SDVariable x) { SDValidation.validateNumerical("cos", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd, x).outputVariable(); } /** * Elementwise cosine operation: out = cos(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cos(String name, SDVariable x) { SDValidation.validateNumerical("cos", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cos(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -970,153 +1058,185 @@ public class SDMath extends SDOps { */ public SDVariable cosh(SDVariable x) { SDValidation.validateNumerical("cosh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd, x).outputVariable(); } /** * Elementwise cosh (hyperbolic cosine) operation: out = cosh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cosh(String name, SDVariable x) { SDValidation.validateNumerical("cosh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Cosh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Cosine distance reduction operation. The output contains the cosine distance for each
- * tensor/subset along the specified dimensions:
- * out = 1.0 - cosineSimilarity(x,y)
+ * tensor/subset along the specified dimensions:
out = 1.0 - cosineSimilarity(x,y)
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable cosineDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("cosineDistance", "x", x); SDValidation.validateNumerical("cosineDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd, x, y, + dimensions).outputVariable(); } /** * Cosine distance reduction operation. The output contains the cosine distance for each
- * tensor/subset along the specified dimensions:
- * out = 1.0 - cosineSimilarity(x,y)
+ * tensor/subset along the specified dimensions:
out = 1.0 - cosineSimilarity(x,y)
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate cosineDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable cosineDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("cosineDistance", "x", x); SDValidation.validateNumerical("cosineDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
- * along the specified dimensions:
- * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for + * each tensor/subset
along the specified dimensions:
out = (sum_i x[i] * y[i]) / ( + * sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable cosineSimilarity(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("cosineSimilarity", "x", x); SDValidation.validateNumerical("cosineSimilarity", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd, x, y, + dimensions).outputVariable(); } /** - * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for each tensor/subset
- * along the specified dimensions:
- * out = (sum_i x[i] * y[i]) / ( sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
+ * Cosine similarity pairwise reduction operation. The output contains the cosine similarity for + * each tensor/subset
along the specified dimensions:
out = (sum_i x[i] * y[i]) / ( + * sqrt(sum_i x[i]^2) * sqrt(sum_i y[i]^2)
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate cosineSimilarity over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable cosineSimilarity(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("cosineSimilarity", "x", x); SDValidation.validateNumerical("cosineSimilarity", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * Count non zero array reduction operation, optionally along specified dimensions: out = count(x + * != 0)
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable countNonZero(SDVariable in, int... dimensions) { SDValidation.validateNumerical("countNonZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd, in, + dimensions).outputVariable(); } /** - * Count non zero array reduction operation, optionally along specified dimensions: out = count(x != 0)
+ * Count non zero array reduction operation, optionally along specified dimensions: out = count(x + * != 0)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable countNonZero(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("countNonZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * Count zero array reduction operation, optionally along specified dimensions: out = count(x == + * 0)
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable countZero(SDVariable in, int... dimensions) { SDValidation.validateNumerical("countZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd, in, + dimensions).outputVariable(); } /** - * Count zero array reduction operation, optionally along specified dimensions: out = count(x == 0)
+ * Count zero array reduction operation, optionally along specified dimensions: out = count(x == + * 0)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable countZero(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("countZero", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
- * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| + * sin(theta).
Can take rank 1 or above inputs (of equal shapes), but note that the last + * dimension must have dimension 3
* * @param a First input (NUMERIC type) * @param b Second input (NUMERIC type) @@ -1125,22 +1245,23 @@ public class SDMath extends SDOps { public SDVariable cross(SDVariable a, SDVariable b) { SDValidation.validateNumerical("cross", "a", a); SDValidation.validateNumerical("cross", "b", b); - return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Cross(sd, a, b).outputVariable(); } /** - * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| sin(theta).
- * Can take rank 1 or above inputs (of equal shapes), but note that the last dimension must have dimension 3
+ * Returns the pair-wise cross product of equal size arrays a and b: a x b = ||a||x||b|| + * sin(theta).
Can take rank 1 or above inputs (of equal shapes), but note that the last + * dimension must have dimension 3
* * @param name name May be null. Name for the output variable - * @param a First input (NUMERIC type) - * @param b Second input (NUMERIC type) + * @param a First input (NUMERIC type) + * @param b Second input (NUMERIC type) * @return output Element-wise cross product (NUMERIC type) */ public SDVariable cross(String name, SDVariable a, SDVariable b) { SDValidation.validateNumerical("cross", "a", a); SDValidation.validateNumerical("cross", "b", b); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd,a, b).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Cross(sd, a, b).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1152,102 +1273,91 @@ public class SDMath extends SDOps { */ public SDVariable cube(SDVariable x) { SDValidation.validateNumerical("cube", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd, x).outputVariable(); } /** * Element-wise cube function: out = x^3
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cube(String name, SDVariable x) { SDValidation.validateNumerical("cube", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Cube(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
- * For example, if input = [1,2,3], then output is given by:
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
+ * Returns an output variable with diagonal values equal to the specified values; off-diagonal + * values will be set to 0
For example, if input = [1,2,3], then output is given by:
[ 1, + * 0, 0]
[ 0, 2, 0]
[ 0, 0, 3]
*
- * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
- * i.e., for input rank R, output has rank 2R
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then + * output[i,...,k,i,...,k] = input[i,...,k].
i.e., for input rank R, output has rank 2R
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable diag(SDVariable x) { SDValidation.validateNumerical("diag", "x", x); - return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Diag(sd, x).outputVariable(); } /** - * Returns an output variable with diagonal values equal to the specified values; off-diagonal values will be set to 0
- * For example, if input = [1,2,3], then output is given by:
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
+ * Returns an output variable with diagonal values equal to the specified values; off-diagonal + * values will be set to 0
For example, if input = [1,2,3], then output is given by:
[ 1, + * 0, 0]
[ 0, 2, 0]
[ 0, 0, 3]
*
- * Higher input ranks are also supported: if input has shape [a,...,R-1] then output[i,...,k,i,...,k] = input[i,...,k].
- * i.e., for input rank R, output has rank 2R
+ * Higher input ranks are also supported: if input has shape [a,...,R-1] then + * output[i,...,k,i,...,k] = input[i,...,k].
i.e., for input rank R, output has rank 2R
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable diag(String name, SDVariable x) { SDValidation.validateNumerical("diag", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Diag(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Extract the diagonal part from the input array.
- * If input is
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
- * then output is [1, 2, 3].
- * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * Extract the diagonal part from the input array.
If input is
[ 1, 0, 0]
[ 0, 2, + * 0]
[ 0, 0, 3]
then output is [1, 2, 3].
Supports higher dimensions: in general, + * out[i,...,k] = in[i,...,k,i,...,k]
* * @param x Input variable (NUMERIC type) * @return output Diagonal part of the input (NUMERIC type) */ public SDVariable diagPart(SDVariable x) { SDValidation.validateNumerical("diagPart", "x", x); - return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd, x).outputVariable(); } /** - * Extract the diagonal part from the input array.
- * If input is
- * [ 1, 0, 0]
- * [ 0, 2, 0]
- * [ 0, 0, 3]
- * then output is [1, 2, 3].
- * Supports higher dimensions: in general, out[i,...,k] = in[i,...,k,i,...,k]
+ * Extract the diagonal part from the input array.
If input is
[ 1, 0, 0]
[ 0, 2, + * 0]
[ 0, 0, 3]
then output is [1, 2, 3].
Supports higher dimensions: in general, + * out[i,...,k] = in[i,...,k,i,...,k]
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Diagonal part of the input (NUMERIC type) */ public SDVariable diagPart(String name, SDVariable x) { SDValidation.validateNumerical("diagPart", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.DiagPart(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise division operation, out = x / y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -1256,79 +1366,91 @@ public class SDMath extends SDOps { public SDVariable div(SDVariable x, SDVariable y) { SDValidation.validateNumerical("div", "x", x); SDValidation.validateNumerical("div", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd, x, + y).outputVariable(); } /** * Pairwise division operation, out = x / y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable div(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("div", "x", x); SDValidation.validateNumerical("div", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar division operation, out = in / scalar
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable div(SDVariable x, double value) { SDValidation.validateNumerical("div", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd, x, value).outputVariable(); } /** * Scalar division operation, out = in / scalar
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable div(String name, SDVariable x, double value) { SDValidation.validateNumerical("div", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Entropy reduction: -sum(x * log(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable entropy(SDVariable in, int... dimensions) { SDValidation.validateNumerical("entropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd, in, + dimensions).outputVariable(); } /** * Entropy reduction: -sum(x * log(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable entropy(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("entropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1340,19 +1462,19 @@ public class SDMath extends SDOps { */ public SDVariable erf(SDVariable x) { SDValidation.validateNumerical("erf", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd, x).outputVariable(); } /** * Element-wise Gaussian error function - out = erf(in)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable erf(String name, SDVariable x) { SDValidation.validateNumerical("erf", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erf(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1364,55 +1486,62 @@ public class SDMath extends SDOps { */ public SDVariable erfc(SDVariable x) { SDValidation.validateNumerical("erfc", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd, x).outputVariable(); } /** * Element-wise complementary Gaussian error function - out = erfc(in) = 1 - erf(in)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable erfc(String name, SDVariable x) { SDValidation.validateNumerical("erfc", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Erfc(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
- * tensor/subset along the specified dimensions:
- * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the + * Euclidean distance for each
tensor/subset along the specified dimensions:
out = sqrt( + * sum_i (x[i] - y[i])^2 )
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable euclideanDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("euclideanDistance", "x", x); SDValidation.validateNumerical("euclideanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd, x, y, + dimensions).outputVariable(); } /** - * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the Euclidean distance for each
- * tensor/subset along the specified dimensions:
- * out = sqrt( sum_i (x[i] - y[i])^2 )
+ * Euclidean distance (l2 norm, l2 distance) reduction operation. The output contains the + * Euclidean distance for each
tensor/subset along the specified dimensions:
out = sqrt( + * sum_i (x[i] - y[i])^2 )
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate euclideanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable euclideanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("euclideanDistance", "x", x); SDValidation.validateNumerical("euclideanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1424,19 +1553,19 @@ public class SDMath extends SDOps { */ public SDVariable exp(SDVariable x) { SDValidation.validateNumerical("exp", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd, x).outputVariable(); } /** * Elementwise exponent function: out = exp(x) = 2.71828...^x
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable exp(String name, SDVariable x) { SDValidation.validateNumerical("exp", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Exp(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1448,19 +1577,20 @@ public class SDMath extends SDOps { */ public SDVariable expm1(SDVariable x) { SDValidation.validateNumerical("expm1", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd, x).outputVariable(); } /** * Elementwise 1.0 - exponent function: out = 1.0 - exp(x) = 1.0 - 2.71828...^x
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable expm1(String name, SDVariable x) { SDValidation.validateNumerical("expm1", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Expm1(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1471,7 +1601,7 @@ public class SDMath extends SDOps { * @return output Identity matrix (NUMERIC type) */ public SDVariable eye(int rows) { - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows).outputVariable(); } /** @@ -1482,7 +1612,7 @@ public class SDMath extends SDOps { * @return output Identity matrix (NUMERIC type) */ public SDVariable eye(String name, int rows) { - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1494,7 +1624,7 @@ public class SDMath extends SDOps { * @return output (NUMERIC type) */ public SDVariable eye(int rows, int cols) { - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols).outputVariable(); } /** @@ -1506,13 +1636,12 @@ public class SDMath extends SDOps { * @return output (NUMERIC type) */ public SDVariable eye(String name, int rows, int cols) { - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Generate an identity matrix with the specified number of rows and columns
- * Example:
+ * Generate an identity matrix with the specified number of rows and columns
Example:
*


* {@code INDArray eye = eye(3,2)
* eye:
@@ -1521,20 +1650,22 @@ public class SDMath extends SDOps { * [ 0, 0]}
*

* - * @param rows Number of rows - * @param cols Number of columns - * @param dataType Data type - * @param dimensions (Size: AtLeast(min=0)) + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) * @return output Identity matrix (NUMERIC type) */ public SDVariable eye(int rows, int cols, DataType dataType, int... dimensions) { - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols, dataType, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols, dataType, + dimensions).outputVariable(); } /** - * Generate an identity matrix with the specified number of rows and columns
- * Example:
+ * Generate an identity matrix with the specified number of rows and columns
Example:
*

* {@code INDArray eye = eye(3,2)
* eye:
@@ -1543,16 +1674,19 @@ public class SDMath extends SDOps { * [ 0, 0]}
*

* - * @param name name May be null. Name for the output variable - * @param rows Number of rows - * @param cols Number of columns - * @param dataType Data type - * @param dimensions (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param rows Number of rows + * @param cols Number of columns + * @param dataType Data type + * @param dimensions (Size: AtLeast(min=0)) * @return output Identity matrix (NUMERIC type) */ public SDVariable eye(String name, int rows, int cols, DataType dataType, int... dimensions) { - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols, dataType, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols, dataType, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1566,7 +1700,7 @@ public class SDMath extends SDOps { public SDVariable eye(SDVariable rows, SDVariable cols) { SDValidation.validateInteger("eye", "rows", rows); SDValidation.validateInteger("eye", "cols", cols); - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols).outputVariable(); } /** @@ -1580,7 +1714,7 @@ public class SDMath extends SDOps { public SDVariable eye(String name, SDVariable rows, SDVariable cols) { SDValidation.validateInteger("eye", "rows", rows); SDValidation.validateInteger("eye", "cols", cols); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows, cols).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows, cols).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1592,7 +1726,7 @@ public class SDMath extends SDOps { */ public SDVariable eye(SDVariable rows) { SDValidation.validateInteger("eye", "rows", rows); - return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows).outputVariable(); } /** @@ -1604,138 +1738,149 @@ public class SDMath extends SDOps { */ public SDVariable eye(String name, SDVariable rows) { SDValidation.validateInteger("eye", "rows", rows); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd,rows).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.Eye(sd, rows).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * First index reduction operation.
Returns a variable that contains the index of the first + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable firstIndex(SDVariable in, Condition condition, int... dimensions) { SDValidation.validateNumerical("firstIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, false, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd, in, false, condition, + dimensions).outputVariable(); } /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * First index reduction operation.
Returns a variable that contains the index of the first + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable firstIndex(String name, SDVariable in, Condition condition, int... dimensions) { SDValidation.validateNumerical("firstIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, false, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd, in, false, + condition, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * First index reduction operation.
Returns a variable that contains the index of the first + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("firstIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd, in, keepDims, condition, + dimensions).outputVariable(); } /** - * First index reduction operation.
- * Returns a variable that contains the index of the first element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * First index reduction operation.
Returns a variable that contains the index of the first + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable firstIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("firstIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex(sd, in, keepDims, + condition, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise floor function: out = floor(x).
- * Rounds each value down to the nearest integer value (if not already an integer)
+ * Element-wise floor function: out = floor(x).
Rounds each value down to the nearest integer + * value (if not already an integer)
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable floor(SDVariable x) { SDValidation.validateNumerical("floor", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd, x).outputVariable(); } /** - * Element-wise floor function: out = floor(x).
- * Rounds each value down to the nearest integer value (if not already an integer)
+ * Element-wise floor function: out = floor(x).
Rounds each value down to the nearest integer + * value (if not already an integer)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable floor(String name, SDVariable x) { SDValidation.validateNumerical("floor", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise floor division operation, out = floor(x / y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -1744,34 +1889,38 @@ public class SDMath extends SDOps { public SDVariable floorDiv(SDVariable x, SDVariable y) { SDValidation.validateNumerical("floorDiv", "x", x); SDValidation.validateNumerical("floorDiv", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd, x, + y).outputVariable(); } /** * Pairwise floor division operation, out = floor(x / y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable floorDiv(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("floorDiv", "x", x); SDValidation.validateNumerical("floorDiv", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd, + x, y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise Modulus division operation
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -1780,509 +1929,564 @@ public class SDMath extends SDOps { public SDVariable floorMod(SDVariable x, SDVariable y) { SDValidation.validateNumerical("floorMod", "x", x); SDValidation.validateNumerical("floorMod", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd, x, + y).outputVariable(); } /** * Pairwise Modulus division operation
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable floorMod(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("floorMod", "x", x); SDValidation.validateNumerical("floorMod", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd, + x, y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar floor modulus operation
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable floorMod(SDVariable x, double value) { SDValidation.validateNumerical("floorMod", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd, x, value).outputVariable(); } /** * Scalar floor modulus operation
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable floorMod(String name, SDVariable x, double value) { SDValidation.validateNumerical("floorMod", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Hamming distance reduction operation. The output contains the cosine distance for each
- * tensor/subset along the specified dimensions:
- * out = count( x[i] != y[i] )
+ * tensor/subset along the specified dimensions:
out = count( x[i] != y[i] )
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable hammingDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("hammingDistance", "x", x); SDValidation.validateNumerical("hammingDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd, x, y, + dimensions).outputVariable(); } /** * Hamming distance reduction operation. The output contains the cosine distance for each
- * tensor/subset along the specified dimensions:
- * out = count( x[i] != y[i] )
+ * tensor/subset along the specified dimensions:
out = count( x[i] != y[i] )
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate hammingDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable hammingDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("hammingDistance", "x", x); SDValidation.validateNumerical("hammingDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Index of the max absolute value: argmax(abs(in))
- * see argmax(String, INDArray, boolean, int...)
+ * Index of the max absolute value: argmax(abs(in))
see argmax(String, INDArray, boolean, + * int...)
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamax(SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd, in, false, + dimensions).outputVariable(); } /** - * Index of the max absolute value: argmax(abs(in))
- * see argmax(String, INDArray, boolean, int...)
+ * Index of the max absolute value: argmax(abs(in))
see argmax(String, INDArray, boolean, + * int...)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamax(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd, in, false, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Index of the max absolute value: argmax(abs(in))
- * see argmax(String, INDArray, boolean, int...)
+ * Index of the max absolute value: argmax(abs(in))
see argmax(String, INDArray, boolean, + * int...)
* - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd, in, keepDims, + dimensions).outputVariable(); } /** - * Index of the max absolute value: argmax(abs(in))
- * see argmax(String, INDArray, boolean, int...)
+ * Index of the max absolute value: argmax(abs(in))
see argmax(String, INDArray, boolean, + * int...)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd, in, keepDims, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Index of the min absolute value: argmin(abs(in))
- * see argmin(String, INDArray, boolean, int...)
+ * Index of the min absolute value: argmin(abs(in))
see argmin(String, INDArray, boolean, + * int...)
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamin(SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd, in, false, + dimensions).outputVariable(); } /** - * Index of the min absolute value: argmin(abs(in))
- * see argmin(String, INDArray, boolean, int...)
+ * Index of the min absolute value: argmin(abs(in))
see argmin(String, INDArray, boolean, + * int...)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamin(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd, in, false, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Index of the min absolute value: argmin(abs(in))
- * see argmin(String, INDArray, boolean, int...)
+ * Index of the min absolute value: argmin(abs(in))
see argmin(String, INDArray, boolean, + * int...)
* - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd, in, keepDims, + dimensions).outputVariable(); } /** - * Index of the min absolute value: argmin(abs(in))
- * see argmin(String, INDArray, boolean, int...)
+ * Index of the min absolute value: argmin(abs(in))
see argmin(String, INDArray, boolean, + * int...)
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd, in, keepDims, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is finite operation: elementwise isFinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is finite operation: elementwise isFinite(x)
Returns an array with the same shape/size as + * the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isFinite(SDVariable x) { SDValidation.validateNumerical("isFinite", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd, x).outputVariable(); } /** - * Is finite operation: elementwise isFinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is finite operation: elementwise isFinite(x)
Returns an array with the same shape/size as + * the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isFinite(String name, SDVariable x) { SDValidation.validateNumerical("isFinite", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is infinite operation: elementwise isInfinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is infinite operation: elementwise isInfinite(x)
Returns an array with the same shape/size + * as the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isInfinite(SDVariable x) { SDValidation.validateNumerical("isInfinite", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd, x).outputVariable(); } /** - * Is infinite operation: elementwise isInfinite(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is infinite operation: elementwise isInfinite(x)
Returns an array with the same shape/size + * as the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isInfinite(String name, SDVariable x) { SDValidation.validateNumerical("isInfinite", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is maximum operation: elementwise x == max(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is maximum operation: elementwise x == max(x)
Returns an array with the same shape/size as + * the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isMax(SDVariable x) { SDValidation.validateNumerical("isMax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd, x).outputVariable(); } /** - * Is maximum operation: elementwise x == max(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is maximum operation: elementwise x == max(x)
Returns an array with the same shape/size as + * the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isMax(String name, SDVariable x) { SDValidation.validateNumerical("isMax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.any.IsMax(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is Not a Number operation: elementwise isNaN(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is Not a Number operation: elementwise isNaN(x)
Returns an array with the same shape/size + * as the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isNaN(SDVariable x) { SDValidation.validateNumerical("isNaN", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd, x).outputVariable(); } /** - * Is Not a Number operation: elementwise isNaN(x)
- * Returns an array with the same shape/size as the input, with values 1 where condition is satisfied, or
- * value 0 otherwise
+ * Is Not a Number operation: elementwise isNaN(x)
Returns an array with the same shape/size + * as the input, with values 1 where condition is satisfied, or
value 0 otherwise
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable isNaN(String name, SDVariable x) { SDValidation.validateNumerical("isNaN", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is the array non decreasing?
- * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
- * in 'c' (row major) order
+ * Is the array non decreasing?
An array is non-decreasing if for every valid i, x[i] <= + * x[i+1]. For Rank 2+ arrays, values are compared
in 'c' (row major) order
* * @param x Input variable (NUMERIC type) * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) */ public SDVariable isNonDecreasing(SDVariable x) { SDValidation.validateNumerical("isNonDecreasing", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd, + x).outputVariable(); } /** - * Is the array non decreasing?
- * An array is non-decreasing if for every valid i, x[i] <= x[i+1]. For Rank 2+ arrays, values are compared
- * in 'c' (row major) order
+ * Is the array non decreasing?
An array is non-decreasing if for every valid i, x[i] <= + * x[i+1]. For Rank 2+ arrays, values are compared
in 'c' (row major) order
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Scalar variable with value 1 if non-decreasing, or 0 otherwise (NUMERIC type) */ public SDVariable isNonDecreasing(String name, SDVariable x) { SDValidation.validateNumerical("isNonDecreasing", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Is the array strictly increasing?
- * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
- * in 'c' (row major) order
+ * Is the array strictly increasing?
An array is strictly increasing if for every valid i, + * x[i] < x[i+1]. For Rank 2+ arrays, values are compared
in 'c' (row major) order
* * @param x Input variable (NUMERIC type) - * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC + * type) */ public SDVariable isStrictlyIncreasing(SDVariable x) { SDValidation.validateNumerical("isStrictlyIncreasing", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd, + x).outputVariable(); } /** - * Is the array strictly increasing?
- * An array is strictly increasing if for every valid i, x[i] < x[i+1]. For Rank 2+ arrays, values are compared
- * in 'c' (row major) order
+ * Is the array strictly increasing?
An array is strictly increasing if for every valid i, + * x[i] < x[i+1]. For Rank 2+ arrays, values are compared
in 'c' (row major) order
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @return output Scalar variable with value 1 if strictly increasing, or 0 otherwise (NUMERIC + * type) */ public SDVariable isStrictlyIncreasing(String name, SDVariable x) { SDValidation.validateNumerical("isStrictlyIncreasing", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
- * tensor along the specified dimensions.
+ * tensor along the specified dimensions.
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable jaccardDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("jaccardDistance", "x", x); SDValidation.validateNumerical("jaccardDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd, x, y, + dimensions).outputVariable(); } /** * Jaccard similarity reduction operation. The output contains the Jaccard distance for each
- * tensor along the specified dimensions.
+ * tensor along the specified dimensions.
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate jaccardDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable jaccardDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("jaccardDistance", "x", x); SDValidation.validateNumerical("jaccardDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * Last index reduction operation.
Returns a variable that contains the index of the last + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable lastIndex(SDVariable in, Condition condition, int... dimensions) { SDValidation.validateNumerical("lastIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, false, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd, in, false, condition, + dimensions).outputVariable(); } /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * Last index reduction operation.
Returns a variable that contains the index of the last + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable lastIndex(String name, SDVariable in, Condition condition, int... dimensions) { SDValidation.validateNumerical("lastIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, false, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd, in, false, condition, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * Last index reduction operation.
Returns a variable that contains the index of the last + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("lastIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd, in, keepDims, condition, + dimensions).outputVariable(); } /** - * Last index reduction operation.
- * Returns a variable that contains the index of the last element that matches the specified condition (for each
- * slice along the specified dimensions)
- * Note that if keepDims = true, the output variable has the same rank as the input variable,
- * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
- * the mean along a dimension).
- * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
- * keepDims = true: [a,1,c]
- * keepDims = false: [a,c]
+ * Last index reduction operation.
Returns a variable that contains the index of the last + * element that matches the specified condition (for each
slice along the specified + * dimensions)
Note that if keepDims = true, the output variable has the same rank as the + * input variable,
with the reduced dimensions having size 1. This can be useful for later + * broadcast operations (such as subtracting
the mean along a dimension).
Example: if + * input has shape [a,b,c] and dimensions=[1] then output has shape:
keepDims = true: + * [a,1,c]
keepDims = false: [a,c]
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param condition Condition to check on input variable - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param condition Condition to check on input variable + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=1)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable lastIndex(String name, SDVariable in, Condition condition, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("lastIndex", "in", in); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd,in, keepDims, condition, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex(sd, in, keepDims, + condition, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -2295,20 +2499,21 @@ public class SDMath extends SDOps { public SDVariable[] listDiff(SDVariable x, SDVariable y) { SDValidation.validateNumerical("listDiff", "x", x); SDValidation.validateNumerical("listDiff", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd,x, y).outputVariables(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd, x, y).outputVariables(); } /** * Calculates difference between inputs X and Y.
* * @param names names May be null. Arrays of names for the output variables. - * @param x Input variable X (NUMERIC type) - * @param y Input variable Y (NUMERIC type) + * @param x Input variable X (NUMERIC type) + * @param y Input variable Y (NUMERIC type) */ public SDVariable[] listDiff(String[] names, SDVariable x, SDVariable y) { SDValidation.validateNumerical("listDiff", "x", x); SDValidation.validateNumerical("listDiff", "y", y); - SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd,x, y).outputVariables(); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff(sd, x, + y).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } @@ -2320,45 +2525,45 @@ public class SDMath extends SDOps { */ public SDVariable log(SDVariable x) { SDValidation.validateNumerical("log", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd, x).outputVariable(); } /** * Element-wise logarithm function (base e - natural logarithm): out = log(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable log(String name, SDVariable x) { SDValidation.validateNumerical("log", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise logarithm function (with specified base): out = log_{base}(x)
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param base Logarithm base * @return output Output variable (NUMERIC type) */ public SDVariable log(SDVariable x, double base) { SDValidation.validateNumerical("log", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd,x, base).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd, x, base).outputVariable(); } /** * Element-wise logarithm function (with specified base): out = log_{base}(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param base Logarithm base * @return output Output variable (NUMERIC type) */ public SDVariable log(String name, SDVariable x, double base) { SDValidation.validateNumerical("log", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd,x, base).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LogX(sd, x, base).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -2370,178 +2575,202 @@ public class SDMath extends SDOps { */ public SDVariable log1p(SDVariable x) { SDValidation.validateNumerical("log1p", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd, x).outputVariable(); } /** * Elementwise natural logarithm function: out = log_e (1 + x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable log1p(String name, SDVariable x) { SDValidation.validateNumerical("log1p", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Log entropy reduction: log(-sum(x * log(x)))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable logEntropy(SDVariable in, int... dimensions) { SDValidation.validateNumerical("logEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd, in, + dimensions).outputVariable(); } /** * Log entropy reduction: log(-sum(x * log(x)))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable logEntropy(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("logEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Log-sum-exp reduction (optionally along dimension).
- * Computes log(sum(exp(x))
+ * Log-sum-exp reduction (optionally along dimension).
Computes log(sum(exp(x))
* - * @param input Input variable (NUMERIC type) + * @param input Input variable (NUMERIC type) * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable logSumExp(SDVariable input, int... dimensions) { SDValidation.validateNumerical("logSumExp", "input", input); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd,input, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd, input, + dimensions).outputVariable(); } /** - * Log-sum-exp reduction (optionally along dimension).
- * Computes log(sum(exp(x))
+ * Log-sum-exp reduction (optionally along dimension).
Computes log(sum(exp(x))
* - * @param name name May be null. Name for the output variable - * @param input Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) * @param dimensions Optional dimensions to reduce along (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable logSumExp(String name, SDVariable input, int... dimensions) { SDValidation.validateNumerical("logSumExp", "input", input); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd,input, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp(sd, input, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
- * tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i]-y[i])
+ * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the + * Manhattan distance for each
tensor/subset along the specified dimensions:
out = sum_i + * abs(x[i]-y[i])
* - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable manhattanDistance(SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("manhattanDistance", "x", x); SDValidation.validateNumerical("manhattanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd, x, y, + dimensions).outputVariable(); } /** - * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the Manhattan distance for each
- * tensor/subset along the specified dimensions:
- * out = sum_i abs(x[i]-y[i])
+ * Manhattan distance (l1 norm, l1 distance) reduction operation. The output contains the + * Manhattan distance for each
tensor/subset along the specified dimensions:
out = sum_i + * abs(x[i]-y[i])
* - * @param name name May be null. Name for the output variable - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensions Dimensions to calculate manhattanDistance over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public SDVariable manhattanDistance(String name, SDVariable x, SDVariable y, int... dimensions) { SDValidation.validateNumerical("manhattanDistance", "x", x); SDValidation.validateNumerical("manhattanDistance", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd,x, y, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance(sd, x, y, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
- * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
- * shape [m,m] sub-matrix.
+ * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
For + * higher dimensional input with shape [..., m, m] the matrix determinant is returned for each + *
shape [m,m] sub-matrix.
* * @param in Input (NUMERIC type) * @return output Matrix determinant variable (NUMERIC type) */ public SDVariable matrixDeterminant(SDVariable in) { SDValidation.validateNumerical("matrixDeterminant", "in", in); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd,in).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd, + in).outputVariable(); } /** - * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
- * For higher dimensional input with shape [..., m, m] the matrix determinant is returned for each
- * shape [m,m] sub-matrix.
+ * Matrix determinant op. For 2D input, this returns the standard matrix determinant.
For + * higher dimensional input with shape [..., m, m] the matrix determinant is returned for each + *
shape [m,m] sub-matrix.
* * @param name name May be null. Name for the output variable - * @param in Input (NUMERIC type) + * @param in Input (NUMERIC type) * @return output Matrix determinant variable (NUMERIC type) */ public SDVariable matrixDeterminant(String name, SDVariable in) { SDValidation.validateNumerical("matrixDeterminant", "in", in); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd,in).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant(sd, + in).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
- * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
- * shape [m,m] sub-matrix.
+ * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
For higher + * dimensional input with shape [..., m, m] the matrix inverse is returned for each
shape + * [m,m] sub-matrix.
* * @param in Input (NUMERIC type) * @return output Matrix inverse variable (NUMERIC type) */ public SDVariable matrixInverse(SDVariable in) { SDValidation.validateNumerical("matrixInverse", "in", in); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd,in).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd, + in).outputVariable(); } /** - * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
- * For higher dimensional input with shape [..., m, m] the matrix inverse is returned for each
- * shape [m,m] sub-matrix.
+ * Matrix inverse op. For 2D input, this returns the standard matrix inverse.
For higher + * dimensional input with shape [..., m, m] the matrix inverse is returned for each
shape + * [m,m] sub-matrix.
* * @param name name May be null. Name for the output variable - * @param in Input (NUMERIC type) + * @param in Input (NUMERIC type) * @return output Matrix inverse variable (NUMERIC type) */ public SDVariable matrixInverse(String name, SDVariable in) { SDValidation.validateNumerical("matrixInverse", "in", in); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd,in).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(sd, + in).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise max operation, out = max(x, y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x First input variable, x (NUMERIC type) * @param y Second input variable, y (NUMERIC type) @@ -2550,144 +2779,158 @@ public class SDMath extends SDOps { public SDVariable max(SDVariable x, SDVariable y) { SDValidation.validateNumerical("max", "x", x); SDValidation.validateNumerical("max", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, x, y).outputVariable(); } /** * Pairwise max operation, out = max(x, y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x First input variable, x (NUMERIC type) - * @param y Second input variable, y (NUMERIC type) + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) * @return out Output (NUMERIC type) */ public SDVariable max(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("max", "x", x); SDValidation.validateNumerical("max", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
- * out = sum_i in[i]
+ * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise + * addition:
out = sum_i in[i]
* * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeAdd(SDVariable... inputs) { SDValidation.validateNumerical("mergeAdd", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd, + inputs).outputVariable(); } /** - * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
- * out = sum_i in[i]
+ * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise + * addition:
out = sum_i in[i]
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeAdd(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeAdd", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp(sd, + inputs).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
- * out = mean_i in[i]
+ * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise + * mean operation:
out = mean_i in[i]
* * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeAvg(SDVariable... inputs) { SDValidation.validateNumerical("mergeAvg", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd, inputs).outputVariable(); } /** - * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise mean operation:
- * out = mean_i in[i]
+ * Merge average function: merges an arbitrary number of equal shaped arrays using element-wise + * mean operation:
out = mean_i in[i]
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeAvg(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeAvg", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeAvg(sd, inputs).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
- * out = max_i in[i]
+ * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise + * maximum operation:
out = max_i in[i]
* * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeMax(SDVariable... inputs) { SDValidation.validateNumerical("mergeMax", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd, inputs).outputVariable(); } /** - * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise maximum operation:
- * out = max_i in[i]
+ * Merge max function: merges an arbitrary number of equal shaped arrays using element-wise + * maximum operation:
out = max_i in[i]
* - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param inputs Input variables (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mergeMax(String name, SDVariable... inputs) { SDValidation.validateNumerical("mergeMax", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd,inputs).outputVariable(); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.MergeMax(sd, inputs).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Broadcasts parameters for evaluation on an N-D grid.
* - * @param inputs (NUMERIC type) - * @param cartesian + * @param inputs (NUMERIC type) + * @param cartesian */ public SDVariable[] meshgrid(SDVariable[] inputs, boolean cartesian) { SDValidation.validateNumerical("meshgrid", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); - return new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd,inputs, cartesian).outputVariables(); + Preconditions.checkArgument(inputs.length >= 0, + "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + return new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd, inputs, cartesian).outputVariables(); } /** * Broadcasts parameters for evaluation on an N-D grid.
* - * @param names names May be null. Arrays of names for the output variables. - * @param inputs (NUMERIC type) - * @param cartesian + * @param names names May be null. Arrays of names for the output variables. + * @param inputs (NUMERIC type) + * @param cartesian */ public SDVariable[] meshgrid(String[] names, SDVariable[] inputs, boolean cartesian) { SDValidation.validateNumerical("meshgrid", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 0, "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); - SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd,inputs, cartesian).outputVariables(); + Preconditions.checkArgument(inputs.length >= 0, + "inputs has incorrect size/length. Expected: inputs.length >= 0, got %s", inputs.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(sd, inputs, + cartesian).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } /** * Pairwise max operation, out = min(x, y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x First input variable, x (NUMERIC type) * @param y Second input variable, y (NUMERIC type) @@ -2696,34 +2939,37 @@ public class SDMath extends SDOps { public SDVariable min(SDVariable x, SDVariable y) { SDValidation.validateNumerical("min", "x", x); SDValidation.validateNumerical("min", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd, x, y).outputVariable(); } /** * Pairwise max operation, out = min(x, y)
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x First input variable, x (NUMERIC type) - * @param y Second input variable, y (NUMERIC type) + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) * @return out Output (NUMERIC type) */ public SDVariable min(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("min", "x", x); SDValidation.validateNumerical("min", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise modulus (remainder) operation, out = x % y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -2732,60 +2978,69 @@ public class SDMath extends SDOps { public SDVariable mod(SDVariable x, SDVariable y) { SDValidation.validateNumerical("mod", "x", x); SDValidation.validateNumerical("mod", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd, x, + y).outputVariable(); } /** * Pairwise modulus (remainder) operation, out = x % y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mod(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("mod", "x", x); SDValidation.validateNumerical("mod", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * Calculate the mean and (population) variance for the input variable, for the specified + * axis
* * @param input Input to calculate moments for (NUMERIC type) - * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) */ public SDVariable[] moments(SDVariable input, int... axes) { SDValidation.validateNumerical("moments", "input", input); - Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); - return new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd,input, axes).outputVariables(); + Preconditions.checkArgument(axes.length >= 0, + "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + return new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd, input, axes).outputVariables(); } /** - * Calculate the mean and (population) variance for the input variable, for the specified axis
+ * Calculate the mean and (population) variance for the input variable, for the specified + * axis
* * @param names names May be null. Arrays of names for the output variables. * @param input Input to calculate moments for (NUMERIC type) - * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) + * @param axes Dimensions to perform calculation over (Size: AtLeast(min=0)) */ public SDVariable[] moments(String[] names, SDVariable input, int... axes) { SDValidation.validateNumerical("moments", "input", input); - Preconditions.checkArgument(axes.length >= 0, "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); - SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd,input, axes).outputVariables(); + Preconditions.checkArgument(axes.length >= 0, + "axes has incorrect size/length. Expected: axes.length >= 0, got %s", axes.length); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.Moments(sd, input, + axes).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } /** * Pairwise multiplication operation, out = x * y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -2794,51 +3049,56 @@ public class SDMath extends SDOps { public SDVariable mul(SDVariable x, SDVariable y) { SDValidation.validateNumerical("mul", "x", x); SDValidation.validateNumerical("mul", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd, x, + y).outputVariable(); } /** * Pairwise multiplication operation, out = x * y
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable mul(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("mul", "x", x); SDValidation.validateNumerical("mul", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar multiplication operation, out = in * scalar
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable mul(SDVariable x, double value) { SDValidation.validateNumerical("mul", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd, x, + value).outputVariable(); } /** * Scalar multiplication operation, out = in * scalar
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable mul(String name, SDVariable x, double value) { SDValidation.validateNumerical("mul", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -2850,113 +3110,127 @@ public class SDMath extends SDOps { */ public SDVariable neg(SDVariable x) { SDValidation.validateNumerical("neg", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd, x).outputVariable(); } /** * Elementwise negative operation: out = -x
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable neg(String name, SDVariable x) { SDValidation.validateNumerical("neg", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Negative(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Calculate the mean and variance from the sufficient statistics
* - * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) - * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) - * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) - * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the + * sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC + * type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values + * (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for + * numerical stability) */ public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) { SDValidation.validateNumerical("normalizeMoments", "counts", counts); SDValidation.validateNumerical("normalizeMoments", "means", means); SDValidation.validateNumerical("normalizeMoments", "variances", variances); - return new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd,counts, means, variances, shift).outputVariables(); + return new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd, counts, means, variances, + shift).outputVariables(); } /** * Calculate the mean and variance from the sufficient statistics
* - * @param names names May be null. Arrays of names for the output variables. - * @param counts Rank 0 (scalar) value with the total number of values used to calculate the sufficient statistics (NUMERIC type) - * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC type) - * @param variances Variaance sufficient statistics: this is the squared sum of all data values (NUMERIC type) - * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for numerical stability) + * @param names names May be null. Arrays of names for the output variables. + * @param counts Rank 0 (scalar) value with the total number of values used to calculate the + * sufficient statistics (NUMERIC type) + * @param means Mean-value sufficient statistics: this is the SUM of all data values (NUMERIC + * type) + * @param variances Variaance sufficient statistics: this is the squared sum of all data values + * (NUMERIC type) + * @param shift Shift value, possibly 0, used when calculating the sufficient statistics (for + * numerical stability) */ public SDVariable[] normalizeMoments(String[] names, SDVariable counts, SDVariable means, SDVariable variances, double shift) { SDValidation.validateNumerical("normalizeMoments", "counts", counts); SDValidation.validateNumerical("normalizeMoments", "means", means); SDValidation.validateNumerical("normalizeMoments", "variances", variances); - SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd,counts, means, variances, shift).outputVariables(); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments(sd, counts, means, + variances, shift).outputVariables(); return sd.updateVariableNamesAndReferences(out, names); } /** - * Boolean OR operation: elementwise (x != 0) || (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * Boolean OR operation: elementwise (x != 0) || (y != 0)
If x and y arrays have equal shape, + * the output shape is the same as these inputs.
Note: supports broadcasting if x and y have + * different shapes and are broadcastable.
Returns an array with values 1 where condition is + * satisfied, or value 0 otherwise.
* * @param x Input 1 (BOOL type) * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable or(SDVariable x, SDVariable y) { SDValidation.validateBool("or", "x", x); SDValidation.validateBool("or", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd, x, y).outputVariable(); } /** - * Boolean OR operation: elementwise (x != 0) || (y != 0)
- * If x and y arrays have equal shape, the output shape is the same as these inputs.
- * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
+ * Boolean OR operation: elementwise (x != 0) || (y != 0)
If x and y arrays have equal shape, + * the output shape is the same as these inputs.
Note: supports broadcasting if x and y have + * different shapes and are broadcastable.
Returns an array with values 1 where condition is + * satisfied, or value 0 otherwise.
* * @param name name May be null. Name for the output variable - * @param x Input 1 (BOOL type) - * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable or(String name, SDVariable x, SDVariable y) { SDValidation.validateBool("or", "x", x); SDValidation.validateBool("or", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise power function: out = x^value
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable pow(SDVariable x, double value) { SDValidation.validateNumerical("pow", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd, x, value).outputVariable(); } /** * Element-wise power function: out = x^value
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable pow(String name, SDVariable x, double value) { SDValidation.validateNumerical("pow", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Pow(sd, x, value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -2970,58 +3244,61 @@ public class SDMath extends SDOps { public SDVariable pow(SDVariable x, SDVariable y) { SDValidation.validateNumerical("pow", "x", x); SDValidation.validateNumerical("pow", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd, x, y).outputVariable(); } /** * Element-wise (broadcastable) power function: out = x[i]^y[i]
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Power (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Power (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable pow(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("pow", "x", x); SDValidation.validateNumerical("pow", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Rational Tanh Approximation elementwise function, as described in the paper:
- * Compact Convolutional Neural Network Cascade for Face Detection
- * This is a faster Tanh approximation
+ * Rational Tanh Approximation elementwise function, as described in the paper:
Compact + * Convolutional Neural Network Cascade for Face Detection
This is a faster Tanh + * approximation
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rationalTanh(SDVariable x) { SDValidation.validateNumerical("rationalTanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd, x).outputVariable(); } /** - * Rational Tanh Approximation elementwise function, as described in the paper:
- * Compact Convolutional Neural Network Cascade for Face Detection
- * This is a faster Tanh approximation
+ * Rational Tanh Approximation elementwise function, as described in the paper:
Compact + * Convolutional Neural Network Cascade for Face Detection
This is a faster Tanh + * approximation
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rationalTanh(String name, SDVariable x) { SDValidation.validateNumerical("rationalTanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise reverse division operation, out = y / x
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -3030,51 +3307,56 @@ public class SDMath extends SDOps { public SDVariable rdiv(SDVariable x, SDVariable y) { SDValidation.validateNumerical("rdiv", "x", x); SDValidation.validateNumerical("rdiv", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd, x, + y).outputVariable(); } /** * Pairwise reverse division operation, out = y / x
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rdiv(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("rdiv", "x", x); SDValidation.validateNumerical("rdiv", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar reverse division operation, out = scalar / in
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable rdiv(SDVariable x, double value) { SDValidation.validateNumerical("rdiv", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd, x, + value).outputVariable(); } /** * Scalar reverse division operation, out = scalar / in
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable rdiv(String name, SDVariable x, double value) { SDValidation.validateNumerical("rdiv", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3086,19 +3368,20 @@ public class SDMath extends SDOps { */ public SDVariable reciprocal(SDVariable x) { SDValidation.validateNumerical("reciprocal", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd, x).outputVariable(); } /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable reciprocal(String name, SDVariable x) { SDValidation.validateNumerical("reciprocal", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3110,45 +3393,46 @@ public class SDMath extends SDOps { */ public SDVariable rectifiedTanh(SDVariable x) { SDValidation.validateNumerical("rectifiedTanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd, x).outputVariable(); } /** * Rectified tanh operation: max(0, tanh(in))
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rectifiedTanh(String name, SDVariable x) { SDValidation.validateNumerical("rectifiedTanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise round function: out = round(x).
- * Rounds (up or down depending on value) to the nearest integer value.
+ * Element-wise round function: out = round(x).
Rounds (up or down depending on value) to the + * nearest integer value.
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable round(SDVariable x) { SDValidation.validateNumerical("round", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd, x).outputVariable(); } /** - * Element-wise round function: out = round(x).
- * Rounds (up or down depending on value) to the nearest integer value.
+ * Element-wise round function: out = round(x).
Rounds (up or down depending on value) to the + * nearest integer value.
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable round(String name, SDVariable x) { SDValidation.validateNumerical("round", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Round(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3160,28 +3444,30 @@ public class SDMath extends SDOps { */ public SDVariable rsqrt(SDVariable x) { SDValidation.validateNumerical("rsqrt", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd, x).outputVariable(); } /** * Element-wise reciprocal (inverse) of square root: out = 1.0 / sqrt(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rsqrt(String name, SDVariable x) { SDValidation.validateNumerical("rsqrt", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise reverse subtraction operation, out = y - x
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -3190,153 +3476,152 @@ public class SDMath extends SDOps { public SDVariable rsub(SDVariable x, SDVariable y) { SDValidation.validateNumerical("rsub", "x", x); SDValidation.validateNumerical("rsub", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd, x, + y).outputVariable(); } /** * Pairwise reverse subtraction operation, out = y - x
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable rsub(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("rsub", "x", x); SDValidation.validateNumerical("rsub", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar reverse subtraction operation, out = scalar - in
* - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable rsub(SDVariable x, double value) { SDValidation.validateNumerical("rsub", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd, x, + value).outputVariable(); } /** * Scalar reverse subtraction operation, out = scalar - in
* - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable rsub(String name, SDVariable x, double value) { SDValidation.validateNumerical("rsub", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Set the diagonal value to the specified values
- * If input is
- * [ a, b, c]
- * [ d, e, f]
- * [ g, h, i]
- * and diag = [ 1, 2, 3] then output is
- * [ 1, b, c]
- * [ d, 2, f]
- * [ g, h, 3]
+ * Set the diagonal value to the specified values
If input is
[ a, b, c]
[ d, e, + * f]
[ g, h, i]
and diag = [ 1, 2, 3] then output is
[ 1, b, c]
[ d, 2, f]
[ + * g, h, 3]
* - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param diag Diagonal (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable setDiag(SDVariable in, SDVariable diag) { SDValidation.validateNumerical("setDiag", "in", in); SDValidation.validateNumerical("setDiag", "diag", diag); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd,in, diag).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd, in, + diag).outputVariable(); } /** - * Set the diagonal value to the specified values
- * If input is
- * [ a, b, c]
- * [ d, e, f]
- * [ g, h, i]
- * and diag = [ 1, 2, 3] then output is
- * [ 1, b, c]
- * [ d, 2, f]
- * [ g, h, 3]
+ * Set the diagonal value to the specified values
If input is
[ a, b, c]
[ d, e, + * f]
[ g, h, i]
and diag = [ 1, 2, 3] then output is
[ 1, b, c]
[ d, 2, f]
[ + * g, h, 3]
* * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param diag Diagonal (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable setDiag(String name, SDVariable in, SDVariable diag) { SDValidation.validateNumerical("setDiag", "in", in); SDValidation.validateNumerical("setDiag", "diag", diag); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd,in, diag).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag(sd, in, + diag).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Shannon Entropy reduction: -sum(x * log2(x))
* - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable shannonEntropy(SDVariable in, int... dimensions) { SDValidation.validateNumerical("shannonEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd, in, + dimensions).outputVariable(); } /** * Shannon Entropy reduction: -sum(x * log2(x))
* - * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param name name May be null. Name for the output variable + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public SDVariable shannonEntropy(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("shannonEntropy", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd,in, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy(sd, in, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise sign (signum) function:
- * out = -1 if in < 0
- * out = 0 if in = 0
- * out = 1 if in > 0
+ * Element-wise sign (signum) function:
out = -1 if in < 0
out = 0 if in = 0
out = 1 + * if in > 0
* * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sign(SDVariable x) { SDValidation.validateNumerical("sign", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd, x).outputVariable(); } /** - * Element-wise sign (signum) function:
- * out = -1 if in < 0
- * out = 0 if in = 0
- * out = 1 if in > 0
+ * Element-wise sign (signum) function:
out = -1 if in < 0
out = 0 if in = 0
out = 1 + * if in > 0
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sign(String name, SDVariable x) { SDValidation.validateNumerical("sign", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Sign(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3348,19 +3633,19 @@ public class SDMath extends SDOps { */ public SDVariable sin(SDVariable x) { SDValidation.validateNumerical("sin", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd, x).outputVariable(); } /** * Elementwise sine operation: out = sin(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sin(String name, SDVariable x) { SDValidation.validateNumerical("sin", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sin(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3372,19 +3657,20 @@ public class SDMath extends SDOps { */ public SDVariable sinh(SDVariable x) { SDValidation.validateNumerical("sinh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd, x).outputVariable(); } /** * Elementwise sinh (hyperbolic sine) operation: out = sinh(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sinh(String name, SDVariable x) { SDValidation.validateNumerical("sinh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3396,19 +3682,20 @@ public class SDMath extends SDOps { */ public SDVariable sqrt(SDVariable x) { SDValidation.validateNumerical("sqrt", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd, x).outputVariable(); } /** * Element-wise square root function: out = sqrt(x)
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sqrt(String name, SDVariable x) { SDValidation.validateNumerical("sqrt", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3420,28 +3707,30 @@ public class SDMath extends SDOps { */ public SDVariable square(SDVariable x) { SDValidation.validateNumerical("square", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd, x).outputVariable(); } /** * Element-wise square function: out = x^2
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable square(String name, SDVariable x) { SDValidation.validateNumerical("square", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.same.Square(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise squared difference operation.
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -3450,25 +3739,28 @@ public class SDMath extends SDOps { public SDVariable squaredDifference(SDVariable x, SDVariable y) { SDValidation.validateNumerical("squaredDifference", "x", x); SDValidation.validateNumerical("squaredDifference", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd, + x, y).outputVariable(); } /** * Pairwise squared difference operation.
- * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
- * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
- * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ *

+ * Note: supports broadcasting if x and y have different shapes and are broadcastable.
For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable squaredDifference(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("squaredDifference", "x", x); SDValidation.validateNumerical("squaredDifference", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp( + sd, x, y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3480,20 +3772,23 @@ public class SDMath extends SDOps { * with mean and stdev being calculated along the given dimension.
*


* For example: given x as a mini batch of the shape [numExamples, exampleLength]:
- *


    - *
  • use dimension 1 too use the statistics (mean, stdev) for each example

  • - *
  • use dimension 0 if you want to use the statistics for each column across all examples

  • - *
  • use dimensions 0,1 if you want to use the statistics across all columns and examples

  • + *
      + *
    • use dimension 1 too use the statistics (mean, stdev) for each example
    • + *
    • use dimension 0 if you want to use the statistics for each column across all examples
    • + *
    • use dimensions 0,1 if you want to use the statistics across all columns and examples
    • *

    * - * @param x Input variable (NUMERIC type) - * @param dimensions (Size: AtLeast(min=1)) + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable standardize(SDVariable x, int... dimensions) { SDValidation.validateNumerical("standardize", "x", x); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd,x, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd, x, + dimensions).outputVariable(); } /** @@ -3504,60 +3799,60 @@ public class SDMath extends SDOps { * with mean and stdev being calculated along the given dimension.
    *


    * For example: given x as a mini batch of the shape [numExamples, exampleLength]:
    - *


      - *
    • use dimension 1 too use the statistics (mean, stdev) for each example

    • - *
    • use dimension 0 if you want to use the statistics for each column across all examples

    • - *
    • use dimensions 0,1 if you want to use the statistics across all columns and examples

    • + *
        + *
      • use dimension 1 too use the statistics (mean, stdev) for each example
      • + *
      • use dimension 0 if you want to use the statistics for each column across all examples
      • + *
      • use dimensions 0,1 if you want to use the statistics across all columns and examples
      • *

      * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param dimensions (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable standardize(String name, SDVariable x, int... dimensions) { SDValidation.validateNumerical("standardize", "x", x); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd,x, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize(sd, x, + dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Elementwise step function:
      - * out(x) = 1 if x >= cutoff
      - * out(x) = 0 otherwise
      + * Elementwise step function:
      {@code out(x) = 1 if x >= cutoff
      out(x) = 0 otherwise}
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable step(SDVariable x, double value) { SDValidation.validateNumerical("step", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.Step(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.Step(sd, x, value).outputVariable(); } /** - * Elementwise step function:
      - * out(x) = 1 if x >= cutoff
      - * out(x) = 0 otherwise
      + * Elementwise step function:
      {@code out(x) = 1 if x >= cutoff
      out(x) = 0 otherwise} * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable step(String name, SDVariable x, double value) { SDValidation.validateNumerical("step", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Step(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Step(sd, x, value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Pairwise subtraction operation, out = x - y
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -3566,51 +3861,55 @@ public class SDMath extends SDOps { public SDVariable sub(SDVariable x, SDVariable y) { SDValidation.validateNumerical("sub", "x", x); SDValidation.validateNumerical("sub", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd, x, + y).outputVariable(); } /** * Pairwise subtraction operation, out = x - y
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) - * @param y Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sub(String name, SDVariable x, SDVariable y) { SDValidation.validateNumerical("sub", "x", x); SDValidation.validateNumerical("sub", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Scalar subtraction operation, out = in - scalar
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable sub(SDVariable x, double value) { SDValidation.validateNumerical("sub", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd, x, value).outputVariable(); } /** * Scalar subtraction operation, out = in - scalar
      * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param value Scalar value for op * @return output Output variable (NUMERIC type) */ public SDVariable sub(String name, SDVariable x, double value) { SDValidation.validateNumerical("sub", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd, x, + value).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3622,19 +3921,19 @@ public class SDMath extends SDOps { */ public SDVariable tan(SDVariable x) { SDValidation.validateNumerical("tan", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd, x).outputVariable(); } /** * Elementwise tangent operation: out = tan(x)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable tan(String name, SDVariable x) { SDValidation.validateNumerical("tan", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tan(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -3646,105 +3945,111 @@ public class SDMath extends SDOps { */ public SDVariable tanh(SDVariable x) { SDValidation.validateNumerical("tanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd, x).outputVariable(); } /** * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable tanh(String name, SDVariable x) { SDValidation.validateNumerical("tanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Matrix trace operation
      - * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
      - * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
      + * Matrix trace operation
      For rank 2 matrices, the output is a scalar vith the trace - i.e., + * sum of the main diagonal.
      For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
      * * @param in Input variable (NUMERIC type) * @return output Trace (NUMERIC type) */ public SDVariable trace(SDVariable in) { SDValidation.validateNumerical("trace", "in", in); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd,in).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd, in).outputVariable(); } /** - * Matrix trace operation
      - * For rank 2 matrices, the output is a scalar vith the trace - i.e., sum of the main diagonal.
      - * For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
      + * Matrix trace operation
      For rank 2 matrices, the output is a scalar vith the trace - i.e., + * sum of the main diagonal.
      For higher rank inputs, output[a,b,c] = trace(in[a,b,c,:,:])
      * * @param name name May be null. Name for the output variable - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @return output Trace (NUMERIC type) */ public SDVariable trace(String name, SDVariable in) { SDValidation.validateNumerical("trace", "in", in); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd,in).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Trace(sd, + in).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
      + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
      If x and y arrays + * have equal shape, the output shape is the same as these inputs.
      Note: supports broadcasting + * if x and y have different shapes and are broadcastable.
      Returns an array with values 1 + * where condition is satisfied, or value 0 otherwise.
      * * @param x Input 1 (BOOL type) * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable xor(SDVariable x, SDVariable y) { SDValidation.validateBool("xor", "x", x); SDValidation.validateBool("xor", "y", y); - return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd,x, y).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd, x, y).outputVariable(); } /** - * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * Returns an array with values 1 where condition is satisfied, or value 0 otherwise.
      + * Boolean XOR (exclusive OR) operation: elementwise (x != 0) XOR (y != 0)
      If x and y arrays + * have equal shape, the output shape is the same as these inputs.
      Note: supports broadcasting + * if x and y have different shapes and are broadcastable.
      Returns an array with values 1 + * where condition is satisfied, or value 0 otherwise.
      * * @param name name May be null. Name for the output variable - * @param x Input 1 (BOOL type) - * @param y Input 2 (BOOL type) - * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL type) + * @param x Input 1 (BOOL type) + * @param y Input 2 (BOOL type) + * @return output INDArray with values 0 and 1 based on where the condition is satisfied (BOOL + * type) */ public SDVariable xor(String name, SDVariable x, SDVariable y) { SDValidation.validateBool("xor", "x", x); SDValidation.validateBool("xor", "y", y); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd,x, y).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor(sd, x, + y).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
      + * Full array zero fraction array reduction operation, optionally along specified dimensions: out + * = (count(x == 0) / length(x))
      * * @param input Input variable (NUMERIC type) * @return output Reduced array of rank 0 (scalar) (NUMERIC type) */ public SDVariable zeroFraction(SDVariable input) { SDValidation.validateNumerical("zeroFraction", "input", input); - return new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd,input).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd, input).outputVariable(); } /** - * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x))
      + * Full array zero fraction array reduction operation, optionally along specified dimensions: out + * = (count(x == 0) / length(x))
      * - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param input Input variable (NUMERIC type) * @return output Reduced array of rank 0 (scalar) (NUMERIC type) */ public SDVariable zeroFraction(String name, SDVariable input) { SDValidation.validateNumerical("zeroFraction", "input", input); - SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd,input).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction(sd, + input).outputVariable(); return sd.updateVariableNameAndReference(out, name); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 846291e47..b617d2865 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -27,47 +27,53 @@ import org.nd4j.common.base.Preconditions; import org.nd4j.enums.PadMode; public class SDNN extends SDOps { + public SDNN(SameDiff sameDiff) { super(sameDiff); } /** - * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
      + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which + * selects only the negative part of the activation. Note that as a result this non-linearity + * doubles the depth of the activations.
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cReLU(SDVariable x) { SDValidation.validateNumerical("CReLU", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd, x).outputVariable(); } /** - * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
      + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which + * selects only the negative part of the activation. Note that as a result this non-linearity + * doubles the depth of the activations.
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable cReLU(String name, SDVariable x) { SDValidation.validateNumerical("CReLU", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Neural network batch normalization operation.
      - * For details, see https://arxiv.org/abs/1502.03167
      + * Neural network batch normalization operation.
      For details, see https://arxiv.org/abs/1502.03167
      * - * @param input Input variable. (NUMERIC type) - * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) - * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. - * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC - * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format + * activations. For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC For 1d/RNN + * activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) * @return output variable for batch normalization (NUMERIC type) */ public SDVariable batchNorm(SDVariable input, SDVariable mean, SDVariable variance, @@ -77,24 +83,26 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("batchNorm", "variance", variance); SDValidation.validateNumerical("batchNorm", "gamma", gamma); SDValidation.validateNumerical("batchNorm", "beta", beta); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd,input, mean, variance, gamma, beta, epsilon, axis).outputVariable(); + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd, input, mean, variance, + gamma, beta, epsilon, axis).outputVariable(); } /** - * Neural network batch normalization operation.
      - * For details, see https://arxiv.org/abs/1502.03167
      + * Neural network batch normalization operation.
      For details, see https://arxiv.org/abs/1502.03167
      * - * @param name name May be null. Name for the output variable - * @param input Input variable. (NUMERIC type) - * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input variable. (NUMERIC type) + * @param mean Mean value. For 1d axis, this should match input.size(axis) (NUMERIC type) * @param variance Variance value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) - * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) - * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format activations. - * For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC - * For 1d/RNN activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) + * @param gamma Gamma value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param beta Beta value. For 1d axis, this should match input.size(axis) (NUMERIC type) + * @param epsilon Epsilon constant for numerical stability (to avoid division by 0) + * @param axis For 2d CNN activations: 1 for NCHW format activations, or 3 for NHWC format + * activations. For 3d CNN activations: 1 for NCDHW format, 4 for NDHWC For 1d/RNN + * activations: 1 for NCW format, 2 for NWC (Size: AtLeast(min=1)) * @return output variable for batch normalization (NUMERIC type) */ public SDVariable batchNorm(String name, SDVariable input, SDVariable mean, SDVariable variance, @@ -104,73 +112,82 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("batchNorm", "variance", variance); SDValidation.validateNumerical("batchNorm", "gamma", gamma); SDValidation.validateNumerical("batchNorm", "beta", beta); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd,input, mean, variance, gamma, beta, epsilon, axis).outputVariable(); + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm(sd, input, mean, + variance, gamma, beta, epsilon, axis).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
      + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and + * a 1D bias vector
      * * @param input 4d input variable (NUMERIC type) - * @param bias 1d bias (NUMERIC type) - * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. - * Unused for 2d inputs + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; + * nchw=false - [minibatch, height, width, channels]. Unused for 2d inputs * @return output Output variable, after applying bias add operation (NUMERIC type) */ public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { SDValidation.validateNumerical("biasAdd", "input", input); SDValidation.validateNumerical("biasAdd", "bias", bias); - return new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd,input, bias, nchw).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd, input, bias, + nchw).outputVariable(); } /** - * Bias addition operation: a special case of addition, typically used with CNN 4D activations and a 1D bias vector
      + * Bias addition operation: a special case of addition, typically used with CNN 4D activations and + * a 1D bias vector
      * - * @param name name May be null. Name for the output variable + * @param name name May be null. Name for the output variable * @param input 4d input variable (NUMERIC type) - * @param bias 1d bias (NUMERIC type) - * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels]. - * Unused for 2d inputs + * @param bias 1d bias (NUMERIC type) + * @param nchw The format - nchw=true means [minibatch, channels, height, width] format; + * nchw=false - [minibatch, height, width, channels]. Unused for 2d inputs * @return output Output variable, after applying bias add operation (NUMERIC type) */ public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) { SDValidation.validateNumerical("biasAdd", "input", input); SDValidation.validateNumerical("biasAdd", "bias", bias); - SDVariable out = new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd,input, bias, nchw).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd(sd, input, bias, + nchw).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * This operation performs dot product attention on the given timeseries input with the given queries
      - * out = sum(similarity(k_i, q) * v_i)
      + * This operation performs dot product attention on the given timeseries input with the given + * queries
      out = sum(similarity(k_i, q) * v_i)
      *
      * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
      *
      - * Optionally with normalization step:
      - * similarity(k, q) = softmax(k * q / sqrt(size(q))
      + * Optionally with normalization step:
      similarity(k, q) = softmax(k * q / sqrt(size(q))
      *
      * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
      *
      - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
      - * be 3D but can have queryCount = 1
      + * Note: This supports multiple queries at once, if only one query is available the queries vector + * still has to
      be 3D but can have queryCount = 1
      *
      - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
      - * both.
      + * Note: keys and values usually is the same array. If you want to use it as the same array, + * simply pass it for
      both.
      *
      - * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
      - * output rank will depend on the input rank.
      + * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them + * doesn't work. The
      output rank will depend on the input rank.
      * - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] - * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], - * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or 4D + * array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array + * of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] or 4D + * array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, + * timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply + * normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or + * [batchSize, numHeads, featureValues, queryCount], (optionally) Attention Weights of shape + * [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC + * type) */ public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) { @@ -178,40 +195,44 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("dotProductAttention", "keys", keys); SDValidation.validateNumerical("dotProductAttention", "values", values); SDValidation.validateNumerical("dotProductAttention", "mask", mask); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd,queries, keys, values, mask, scaled, false).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd, queries, keys, + values, mask, scaled, false).outputVariable(); } /** - * This operation performs dot product attention on the given timeseries input with the given queries
      - * out = sum(similarity(k_i, q) * v_i)
      + * This operation performs dot product attention on the given timeseries input with the given + * queries
      out = sum(similarity(k_i, q) * v_i)
      *
      * similarity(k, q) = softmax(k * q) where x * q is the dot product of x and q
      *
      - * Optionally with normalization step:
      - * similarity(k, q) = softmax(k * q / sqrt(size(q))
      + * Optionally with normalization step:
      similarity(k, q) = softmax(k * q / sqrt(size(q))
      *
      * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, p. 4, eq. 1)
      *
      - * Note: This supports multiple queries at once, if only one query is available the queries vector still has to
      - * be 3D but can have queryCount = 1
      + * Note: This supports multiple queries at once, if only one query is available the queries vector + * still has to
      be 3D but can have queryCount = 1
      *
      - * Note: keys and values usually is the same array. If you want to use it as the same array, simply pass it for
      - * both.
      + * Note: keys and values usually is the same array. If you want to use it as the same array, + * simply pass it for
      both.
      *
      - * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them doesn't work. The
      - * output rank will depend on the input rank.
      + * Note: Queries, keys and values must either be all rank 3 or all rank 4 arrays. Mixing them + * doesn't work. The
      output rank will depend on the input rank.
      * - * @param name name May be null. Name for the output variable - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] - * or 4D array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] - * or 4D array of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] - * or 4D array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or [batchSize, numHeads, featureValues, queryCount], - * (optionally) Attention Weights of shape [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] or 4D + * array of shape [batchSize, numHeads, featureKeys, queryCount] (NUMERIC type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] or 4D array + * of shape [batchSize, numHeads, featureKeys, timesteps] (NUMERIC type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] or 4D + * array of shape [batchSize, numHeads, featureValues, timesteps] (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, + * timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply + * normalization + * @return output Attention result arrays of shape [batchSize, featureValues, queryCount] or + * [batchSize, numHeads, featureValues, queryCount], (optionally) Attention Weights of shape + * [batchSize, timesteps, queryCount] or [batchSize, numHeads, timesteps, queryCount] (NUMERIC + * type) */ public SDVariable dotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) { @@ -219,41 +240,44 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("dotProductAttention", "keys", keys); SDValidation.validateNumerical("dotProductAttention", "values", values); SDValidation.validateNumerical("dotProductAttention", "mask", mask); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd,queries, keys, values, mask, scaled, false).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention(sd, + queries, keys, values, mask, scaled, false).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Dropout operation
      * - * @param input Input array (NUMERIC type) - * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability + * 1-p) * @return output Output (NUMERIC type) */ public SDVariable dropout(SDVariable input, double inputRetainProbability) { SDValidation.validateNumerical("dropout", "input", input); - return new org.nd4j.linalg.api.ops.random.impl.DropOut(sd,input, inputRetainProbability).outputVariable(); + return new org.nd4j.linalg.api.ops.random.impl.DropOut(sd, input, + inputRetainProbability).outputVariable(); } /** * Dropout operation
      * - * @param name name May be null. Name for the output variable - * @param input Input array (NUMERIC type) - * @param inputRetainProbability Probability of retaining an input (set to 0 with probability 1-p) + * @param name name May be null. Name for the output variable + * @param input Input array (NUMERIC type) + * @param inputRetainProbability Probability of retaining an input (set to 0 with probability + * 1-p) * @return output Output (NUMERIC type) */ public SDVariable dropout(String name, SDVariable input, double inputRetainProbability) { SDValidation.validateNumerical("dropout", "input", input); - SDVariable out = new org.nd4j.linalg.api.ops.random.impl.DropOut(sd,input, inputRetainProbability).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.random.impl.DropOut(sd, input, + inputRetainProbability).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise exponential linear unit (ELU) function:
      - * out = x if x > 0
      - * out = a * (exp(x) - 1) if x <= 0
      - * with constant a = 1.0
      + * {@code out = x if x > 0 out = a * (exp(x) - 1) if x <= 0 with constant a = 1.0} *


      * See: https://arxiv.org/abs/1511.07289
      * @@ -262,112 +286,107 @@ public class SDNN extends SDOps { */ public SDVariable elu(SDVariable x) { SDValidation.validateNumerical("elu", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd, x).outputVariable(); } /** - * Element-wise exponential linear unit (ELU) function:
      - * out = x if x > 0
      - * out = a * (exp(x) - 1) if x <= 0
      - * with constant a = 1.0
      + * Element-wise exponential linear unit (ELU) function:
      out = x if x > 0
      out = a * (exp(x) + * - 1) if x <= 0
      with constant a = 1.0
      *


      * See: https://arxiv.org/abs/1511.07289
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable elu(String name, SDVariable x) { SDValidation.validateNumerical("elu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.ELU(sd, x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * GELU activation function - Gaussian Error Linear Units
      - * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      - * This method uses the sigmoid approximation
      + * GELU activation function - Gaussian Error Linear Units
      For more details, see Gaussian + * Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      This method + * uses the sigmoid approximation
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable gelu(SDVariable x) { SDValidation.validateNumerical("gelu", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd, x).outputVariable(); } /** - * GELU activation function - Gaussian Error Linear Units
      - * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      - * This method uses the sigmoid approximation
      + * GELU activation function - Gaussian Error Linear Units
      For more details, see Gaussian + * Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      This method + * uses the sigmoid approximation
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable gelu(String name, SDVariable x) { SDValidation.validateNumerical("gelu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.GELU(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise hard sigmoid function:
      - * out[i] = 0 if in[i] <= -2.5
      - * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
      - * out[i] = 1 if in[i] >= 2.5
      + * Element-wise hard sigmoid function:
      out[i] = 0 if in[i] <= -2.5
      out[1] = 0.2*in[i]+0.5 + * if -2.5 < in[i] < 2.5
      out[i] = 1 if in[i] >= 2.5
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardSigmoid(SDVariable x) { SDValidation.validateNumerical("hardSigmoid", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd, x).outputVariable(); } /** - * Element-wise hard sigmoid function:
      - * out[i] = 0 if in[i] <= -2.5
      - * out[1] = 0.2*in[i]+0.5 if -2.5 < in[i] < 2.5
      - * out[i] = 1 if in[i] >= 2.5
      + * Element-wise hard sigmoid function:
      out[i] = 0 if in[i] <= -2.5
      out[1] = 0.2*in[i]+0.5 + * if -2.5 < in[i] < 2.5
      out[i] = 1 if in[i] >= 2.5
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardSigmoid(String name, SDVariable x) { SDValidation.validateNumerical("hardSigmoid", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise hard tanh function:
      - * out[i] = -1 if in[i] <= -1
      - * out[1] = in[i] if -1 < in[i] < 1
      - * out[i] = 1 if in[i] >= 1
      + * Element-wise hard tanh function:
      out[i] = -1 if in[i] <= -1
      out[1] = in[i] if -1 < + * in[i] < 1
      out[i] = 1 if in[i] >= 1
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardTanh(SDVariable x) { SDValidation.validateNumerical("hardTanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd, x).outputVariable(); } /** - * Element-wise hard tanh function:
      - * out[i] = -1 if in[i] <= -1
      - * out[1] = in[i] if -1 < in[i] < 1
      - * out[i] = 1 if in[i] >= 1
      + * Element-wise hard tanh function:
      out[i] = -1 if in[i] <= -1
      out[1] = in[i] if -1 < + * in[i] < 1
      out[i] = 1 if in[i] >= 1
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardTanh(String name, SDVariable x) { SDValidation.validateNumerical("hardTanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -379,19 +398,21 @@ public class SDNN extends SDOps { */ public SDVariable hardTanhDerivative(SDVariable x) { SDValidation.validateNumerical("hardTanhDerivative", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd, + x).outputVariable(); } /** * Derivative (dOut/dIn) of the element-wise hard Tanh function - hardTanh(INDArray)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable hardTanhDerivative(String name, SDVariable x) { SDValidation.validateNumerical("hardTanhDerivative", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -400,11 +421,13 @@ public class SDNN extends SDOps { *
      * y = gain * standardize(x) + bias
      * - * @param input Input variable (NUMERIC type) - * @param gain Gain (NUMERIC type) - * @param bias Bias (NUMERIC type) - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), + * false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, + * dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, @@ -412,8 +435,11 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("layerNorm", "input", input); SDValidation.validateNumerical("layerNorm", "gain", gain); SDValidation.validateNumerical("layerNorm", "bias", bias); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, bias, channelsFirst, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd, input, gain, bias, + channelsFirst, dimensions).outputVariable(); } /** @@ -421,12 +447,14 @@ public class SDNN extends SDOps { *
      * y = gain * standardize(x) + bias
      * - * @param name name May be null. Name for the output variable - * @param input Input variable (NUMERIC type) - * @param gain Gain (NUMERIC type) - * @param bias Bias (NUMERIC type) - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param bias Bias (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), + * false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, + * dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, SDVariable bias, @@ -434,8 +462,11 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("layerNorm", "input", input); SDValidation.validateNumerical("layerNorm", "gain", gain); SDValidation.validateNumerical("layerNorm", "bias", bias); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, bias, channelsFirst, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd, input, gain, + bias, channelsFirst, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -444,18 +475,23 @@ public class SDNN extends SDOps { *
      * y = gain * standardize(x) + bias
      * - * @param input Input variable (NUMERIC type) - * @param gain Gain (NUMERIC type) - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), + * false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, + * dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { SDValidation.validateNumerical("layerNorm", "input", input); SDValidation.validateNumerical("layerNorm", "gain", gain); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, null, channelsFirst, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd, input, gain, null, + channelsFirst, dimensions).outputVariable(); } /** @@ -463,111 +499,115 @@ public class SDNN extends SDOps { *
      * y = gain * standardize(x) + bias
      * - * @param name name May be null. Name for the output variable - * @param input Input variable (NUMERIC type) - * @param gain Gain (NUMERIC type) - * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), false for NHWC data - * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) + * @param name name May be null. Name for the output variable + * @param input Input variable (NUMERIC type) + * @param gain Gain (NUMERIC type) + * @param channelsFirst For 2D input - unused. True for NCHW (minibatch, channels, height, width), + * false for NHWC data + * @param dimensions Dimensions to perform layer norm over - dimension=1 for 2d/MLP data, + * dimension=1,2,3 for CNNs (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public SDVariable layerNorm(String name, SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { SDValidation.validateNumerical("layerNorm", "input", input); SDValidation.validateNumerical("layerNorm", "gain", gain); - Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd,input, gain, null, channelsFirst, dimensions).outputVariable(); + Preconditions.checkArgument(dimensions.length >= 1, + "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", + dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm(sd, input, gain, + null, channelsFirst, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise leaky ReLU function:
      - * out = x if x >= 0.0
      - * out = alpha * x if x < cutoff
      + * Element-wise leaky ReLU function:
      out = x if x >= 0.0
      out = alpha * x if x < cutoff
      * Alpha value is most commonly set to 0.01
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ public SDVariable leakyRelu(SDVariable x, double alpha) { SDValidation.validateNumerical("leakyRelu", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd,x, alpha).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd, x, alpha).outputVariable(); } /** - * Element-wise leaky ReLU function:
      - * out = x if x >= 0.0
      - * out = alpha * x if x < cutoff
      + * Element-wise leaky ReLU function:
      out = x if x >= 0.0
      out = alpha * x if x < cutoff
      * Alpha value is most commonly set to 0.01
      * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ public SDVariable leakyRelu(String name, SDVariable x, double alpha) { SDValidation.validateNumerical("leakyRelu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd,x, alpha).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU(sd, x, + alpha).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Leaky ReLU derivative: dOut/dIn given input.
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ public SDVariable leakyReluDerivative(SDVariable x, double alpha) { SDValidation.validateNumerical("leakyReluDerivative", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd,x, alpha).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd, x, + alpha).outputVariable(); } /** * Leaky ReLU derivative: dOut/dIn given input.
      * - * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) * @param alpha Cutoff - commonly 0.01 * @return output Output variable (NUMERIC type) */ public SDVariable leakyReluDerivative(String name, SDVariable x, double alpha) { SDValidation.validateNumerical("leakyReluDerivative", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd,x, alpha).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative(sd, x, + alpha).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Linear layer operation: out = mmul(in,w) + bias
      - * Note that bias array is optional
      + * Linear layer operation: out = mmul(in,w) + bias
      Note that bias array is optional
      * - * @param input Input data (NUMERIC type) + * @param input Input data (NUMERIC type) * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) - * @param bias Optional bias variable (may be null) (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable linear(SDVariable input, SDVariable weights, SDVariable bias) { SDValidation.validateNumerical("linear", "input", input); SDValidation.validateNumerical("linear", "weights", weights); SDValidation.validateNumerical("linear", "bias", bias); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd,input, weights, bias).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd, input, weights, + bias).outputVariable(); } /** - * Linear layer operation: out = mmul(in,w) + bias
      - * Note that bias array is optional
      + * Linear layer operation: out = mmul(in,w) + bias
      Note that bias array is optional
      * - * @param name name May be null. Name for the output variable - * @param input Input data (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) * @param weights Weights variable, shape [nIn, nOut] (NUMERIC type) - * @param bias Optional bias variable (may be null) (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable linear(String name, SDVariable input, SDVariable weights, SDVariable bias) { SDValidation.validateNumerical("linear", "input", input); SDValidation.validateNumerical("linear", "weights", weights); SDValidation.validateNumerical("linear", "bias", bias); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd,input, weights, bias).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB(sd, input, weights, + bias).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -579,95 +619,108 @@ public class SDNN extends SDOps { */ public SDVariable logSigmoid(SDVariable x) { SDValidation.validateNumerical("logSigmoid", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd, x).outputVariable(); } /** * Element-wise sigmoid function: out[i] = log(sigmoid(in[i]))
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable logSigmoid(String name, SDVariable x) { SDValidation.validateNumerical("logSigmoid", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Log softmax activation
      * - * @param x (NUMERIC type) + * @param x (NUMERIC type) * @return output (NUMERIC type) */ public SDVariable logSoftmax(SDVariable x) { SDValidation.validateNumerical("logSoftmax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd, x).outputVariable(); } /** * Log softmax activation
      * * @param name name May be null. Name for the output variable - * @param x (NUMERIC type) + * @param x (NUMERIC type) * @return output (NUMERIC type) */ public SDVariable logSoftmax(String name, SDVariable x) { SDValidation.validateNumerical("logSoftmax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Log softmax activation
      * - * @param x Input (NUMERIC type) + * @param x Input (NUMERIC type) * @param dimension Dimension along which to apply log softmax * @return output Output - log(softmax(input)) (NUMERIC type) */ public SDVariable logSoftmax(SDVariable x, int dimension) { SDValidation.validateNumerical("logSoftmax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x, dimension).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd, x, + dimension).outputVariable(); } /** * Log softmax activation
      * - * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) * @param dimension Dimension along which to apply log softmax * @return output Output - log(softmax(input)) (NUMERIC type) */ public SDVariable logSoftmax(String name, SDVariable x, int dimension) { SDValidation.validateNumerical("logSoftmax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd,x, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax(sd, x, + dimension).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * This performs multi-headed dot product attention on the given timeseries input
      - * out = concat(head_1, head_2, ..., head_n) * Wo
      - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
      + * This performs multi-headed dot product attention on the given timeseries input
      out = + * concat(head_1, head_2, ..., head_n) * Wo
      head_i = dot_product_attention(Wq_i*q, Wk_i*k, + * Wv_i*v)
      *
      * Optionally with normalization when calculating the attention for each head.
      *
      - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
      + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 + * Multi-Head Attention")
      *
      - * This makes use of dot_product_attention OP support for rank 4 inputs.
      - * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
      + * This makes use of dot_product_attention OP support for rank 4 inputs.
      see + * dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
      * - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) - * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) - * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) - * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) - * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @return output Attention result arrays of shape [batchSize, outSize, queryCount] - * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC + * type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC + * type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC + * type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] + * (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] + * (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, + * featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] + * (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, + * timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] (optionally) + * Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) */ public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, @@ -680,33 +733,43 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd,queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd, + queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); } /** - * This performs multi-headed dot product attention on the given timeseries input
      - * out = concat(head_1, head_2, ..., head_n) * Wo
      - * head_i = dot_product_attention(Wq_i*q, Wk_i*k, Wv_i*v)
      + * This performs multi-headed dot product attention on the given timeseries input
      out = + * concat(head_1, head_2, ..., head_n) * Wo
      head_i = dot_product_attention(Wq_i*q, Wk_i*k, + * Wv_i*v)
      *
      * Optionally with normalization when calculating the attention for each head.
      *
      - * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 Multi-Head Attention")
      + * See also "Attention is all you need" (https://arxiv.org/abs/1706.03762, pp. 4,5, "3.2.2 + * Multi-Head Attention")
      *
      - * This makes use of dot_product_attention OP support for rank 4 inputs.
      - * see dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
      + * This makes use of dot_product_attention OP support for rank 4 inputs.
      see + * dotProductAttention(INDArray, INDArray, INDArray, INDArray, boolean, boolean)
      * - * @param name name May be null. Name for the output variable - * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC type) - * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC type) - * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC type) - * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) - * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] (NUMERIC type) - * @param Wv input value projection weights of shape [numHeads, projectedValues, featureValues] (NUMERIC type) - * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] (NUMERIC type) - * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, timesteps] (NUMERIC type) - * @param scaled normalization, false -> do not apply normalization, true -> apply normalization - * @return output Attention result arrays of shape [batchSize, outSize, queryCount] - * (optionally) Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param queries input 3D array "queries" of shape [batchSize, featureKeys, queryCount] (NUMERIC + * type) + * @param keys input 3D array "keys" of shape [batchSize, featureKeys, timesteps] (NUMERIC + * type) + * @param values input 3D array "values" of shape [batchSize, featureValues, timesteps] (NUMERIC + * type) + * @param Wq input query projection weights of shape [numHeads, projectedKeys, featureKeys] + * (NUMERIC type) + * @param Wk input key projection weights of shape [numHeads, projectedKeys, featureKeys] + * (NUMERIC type) + * @param Wv input value projection weights of shape [numHeads, projectedValues, + * featureValues] (NUMERIC type) + * @param Wo output projection weights of shape [numHeads * projectedValues, outSize] + * (NUMERIC type) + * @param mask OPTIONAL; array that defines which values should be skipped of shape [batchSize, + * timesteps] (NUMERIC type) + * @param scaled normalization, false -> do not apply normalization, true -> apply normalization + * @return output Attention result arrays of shape [batchSize, outSize, queryCount] (optionally) + * Attention Weights of shape [batchSize, numHeads, timesteps, queryCount] (NUMERIC type) */ public SDVariable multiHeadDotProductAttention(String name, SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, @@ -719,32 +782,34 @@ public class SDNN extends SDOps { SDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", Wv); SDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", Wo); SDValidation.validateNumerical("multiHeadDotProductAttention", "mask", mask); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(sd,queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention( + sd, queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Padding operation
      * - * @param input Input tensor (NUMERIC type) - * @param padding Padding value (NUMERIC type) - * @param PadMode Padding format + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format * @param constant Padding constant * @return output Padded input (NUMERIC type) */ public SDVariable pad(SDVariable input, SDVariable padding, PadMode PadMode, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd, input, padding, PadMode, + constant).outputVariable(); } /** * Padding operation
      * - * @param name name May be null. Name for the output variable - * @param input Input tensor (NUMERIC type) - * @param padding Padding value (NUMERIC type) - * @param PadMode Padding format + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format * @param constant Padding constant * @return output Padded input (NUMERIC type) */ @@ -752,233 +817,247 @@ public class SDNN extends SDOps { double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd, input, padding, PadMode, + constant).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Padding operation
      * - * @param input Input tensor (NUMERIC type) - * @param padding Padding value (NUMERIC type) + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) * @param constant Padding constant * @return output Padded input (NUMERIC type) */ public SDVariable pad(SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd, input, padding, PadMode.CONSTANT, + constant).outputVariable(); } /** * Padding operation
      * - * @param name name May be null. Name for the output variable - * @param input Input tensor (NUMERIC type) - * @param padding Padding value (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) * @param constant Padding constant * @return output Padded input (NUMERIC type) */ public SDVariable pad(String name, SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd, input, padding, + PadMode.CONSTANT, constant).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * GELU activation function - Gaussian Error Linear Units
      - * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      - * This method uses the precise method
      + * GELU activation function - Gaussian Error Linear Units
      For more details, see Gaussian + * Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      This method + * uses the precise method
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable preciseGelu(SDVariable x) { SDValidation.validateNumerical("preciseGelu", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd, x).outputVariable(); } /** - * GELU activation function - Gaussian Error Linear Units
      - * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      - * This method uses the precise method
      + * GELU activation function - Gaussian Error Linear Units
      For more details, see Gaussian + * Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
      This method + * uses the precise method
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable preciseGelu(String name, SDVariable x) { SDValidation.validateNumerical("preciseGelu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
      - * out[i] = in[i] if in[i] >= 0
      - * out[i] = in[i] * alpha[i] otherwise
      + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable + * alpha:
      out[i] = in[i] if in[i] >= 0
      out[i] = in[i] * alpha[i] otherwise
      *
      - * sharedAxes allows you to share learnable parameters along axes.
      - * For example, if the input has shape [batchSize, channels, height, width]
      - * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
      - * alpha with shape [channels].
      + * sharedAxes allows you to share learnable parameters along axes.
      For example, if the input + * has shape [batchSize, channels, height, width]
      and you want each channel to have its own + * cutoff, use sharedAxes = [2, 3] and an
      alpha with shape [channels].
      * - * @param input Input data (NUMERIC type) - * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is + * batch or not) should not be part of alpha. (NUMERIC type) * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) * @return output Output (NUMERIC type) */ public SDVariable prelu(SDVariable input, SDVariable alpha, int... sharedAxes) { SDValidation.validateNumerical("prelu", "input", input); SDValidation.validateNumerical("prelu", "alpha", alpha); - Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); - return new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd,input, alpha, sharedAxes).outputVariable(); + Preconditions.checkArgument(sharedAxes.length >= 1, + "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", + sharedAxes.length); + return new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd, input, alpha, + sharedAxes).outputVariable(); } /** - * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable alpha:
      - * out[i] = in[i] if in[i] >= 0
      - * out[i] = in[i] * alpha[i] otherwise
      + * PReLU (Parameterized Rectified Linear Unit) operation. Like LeakyReLU with a learnable + * alpha:
      out[i] = in[i] if in[i] >= 0
      out[i] = in[i] * alpha[i] otherwise
      *
      - * sharedAxes allows you to share learnable parameters along axes.
      - * For example, if the input has shape [batchSize, channels, height, width]
      - * and you want each channel to have its own cutoff, use sharedAxes = [2, 3] and an
      - * alpha with shape [channels].
      + * sharedAxes allows you to share learnable parameters along axes.
      For example, if the input + * has shape [batchSize, channels, height, width]
      and you want each channel to have its own + * cutoff, use sharedAxes = [2, 3] and an
      alpha with shape [channels].
      * - * @param name name May be null. Name for the output variable - * @param input Input data (NUMERIC type) - * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is batch or not) should not be part of alpha. (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) + * @param alpha The cutoff variable. Note that the batch dimension (the 0th, whether it is + * batch or not) should not be part of alpha. (NUMERIC type) * @param sharedAxes Which axes to share cutoff parameters along. (Size: AtLeast(min=1)) * @return output Output (NUMERIC type) */ public SDVariable prelu(String name, SDVariable input, SDVariable alpha, int... sharedAxes) { SDValidation.validateNumerical("prelu", "input", input); SDValidation.validateNumerical("prelu", "alpha", alpha); - Preconditions.checkArgument(sharedAxes.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", sharedAxes.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd,input, alpha, sharedAxes).outputVariable(); + Preconditions.checkArgument(sharedAxes.length >= 1, + "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", + sharedAxes.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.PRelu(sd, input, alpha, + sharedAxes).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise rectified linear function with specified cutoff:
      - * out[i] = in[i] if in[i] >= cutoff
      - * out[i] = 0 otherwise
      + * {@code out[i] = in[i] if in[i] >= cutoff out[i] = 0 otherwise} * - * @param x Input (NUMERIC type) - * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 + * @param x Input (NUMERIC type) + * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 * @return output Output (NUMERIC type) */ public SDVariable relu(SDVariable x, double cutoff) { SDValidation.validateNumerical("relu", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd,x, cutoff).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd, x, cutoff).outputVariable(); } /** - * Element-wise rectified linear function with specified cutoff:
      - * out[i] = in[i] if in[i] >= cutoff
      - * out[i] = 0 otherwise
      + * Element-wise rectified linear function with specified cutoff:
      out[i] = in[i] if in[i] >= + * cutoff
      out[i] = 0 otherwise
      * - * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) * @param cutoff Cutoff value for ReLU operation - x > cutoff ? x : 0. Usually 0 * @return output Output (NUMERIC type) */ public SDVariable relu(String name, SDVariable x, double cutoff) { SDValidation.validateNumerical("relu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd,x, cutoff).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear(sd, x, + cutoff).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise "rectified linear 6" function with specified cutoff:
      - * out[i] = min(max(in, cutoff), 6)
      + * Element-wise "rectified linear 6" function with specified cutoff:
      out[i] = min(max(in, + * cutoff), 6)
      * - * @param x Input (NUMERIC type) + * @param x Input (NUMERIC type) * @param cutoff Cutoff value for ReLU operation. Usually 0 * @return output Output (NUMERIC type) */ public SDVariable relu6(SDVariable x, double cutoff) { SDValidation.validateNumerical("relu6", "x", x); - return new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd,x, cutoff).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd, x, cutoff).outputVariable(); } /** - * Element-wise "rectified linear 6" function with specified cutoff:
      - * out[i] = min(max(in, cutoff), 6)
      + * Element-wise "rectified linear 6" function with specified cutoff:
      out[i] = min(max(in, + * cutoff), 6)
      * - * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) * @param cutoff Cutoff value for ReLU operation. Usually 0 * @return output Output (NUMERIC type) */ public SDVariable relu6(String name, SDVariable x, double cutoff) { SDValidation.validateNumerical("relu6", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd,x, cutoff).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.Relu6(sd, x, cutoff).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
      - * Note that bias array is optional
      + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
      Note that bias + * array is optional
      * - * @param input Input data (NUMERIC type) + * @param input Input data (NUMERIC type) * @param weights Weights variable (NUMERIC type) - * @param bias Optional bias variable (may be null) (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { SDValidation.validateNumerical("reluLayer", "input", input); SDValidation.validateNumerical("reluLayer", "weights", weights); SDValidation.validateNumerical("reluLayer", "bias", bias); - return new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd,input, weights, bias).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd, input, weights, + bias).outputVariable(); } /** - * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
      - * Note that bias array is optional
      + * ReLU (Rectified Linear Unit) layer operation: out = relu(mmul(in,w) + bias)
      Note that bias + * array is optional
      * - * @param name name May be null. Name for the output variable - * @param input Input data (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param input Input data (NUMERIC type) * @param weights Weights variable (NUMERIC type) - * @param bias Optional bias variable (may be null) (NUMERIC type) + * @param bias Optional bias variable (may be null) (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable reluLayer(String name, SDVariable input, SDVariable weights, SDVariable bias) { SDValidation.validateNumerical("reluLayer", "input", input); SDValidation.validateNumerical("reluLayer", "weights", weights); SDValidation.validateNumerical("reluLayer", "bias", bias); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd,input, weights, bias).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.ReluLayer(sd, input, weights, + bias).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
      + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
      *
      - * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
      - * Uses default scale and alpha values.
      + * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
      Uses default scale + * and alpha values.
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable selu(SDVariable x) { SDValidation.validateNumerical("selu", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd, x).outputVariable(); } /** - * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
      + * Element-wise SeLU function - Scaled exponential Lineal Unit: see Self-Normalizing Neural Networks
      *
      - * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
      - * Uses default scale and alpha values.
      + * out[i] = scale * alpha * (exp(in[i])-1) if in[i]>0, or 0 if in[i] <= 0
      Uses default scale + * and alpha values.
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable selu(String name, SDVariable x) { SDValidation.validateNumerical("selu", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SELU(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -990,73 +1069,78 @@ public class SDNN extends SDOps { */ public SDVariable sigmoid(SDVariable x) { SDValidation.validateNumerical("sigmoid", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd, x).outputVariable(); } /** * Element-wise sigmoid function: out[i] = 1.0/(1+exp(-in[i]))
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable sigmoid(String name, SDVariable x) { SDValidation.validateNumerical("sigmoid", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
      * - * @param x Input Variable (NUMERIC type) + * @param x Input Variable (NUMERIC type) * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) * @return output Output (gradient at input of sigmoid) (NUMERIC type) */ public SDVariable sigmoidDerivative(SDVariable x, SDVariable wrt) { SDValidation.validateNumerical("sigmoidDerivative", "x", x); SDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd,x, wrt).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd, x, + wrt).outputVariable(); } /** * Element-wise sigmoid function derivative: dL/dIn given input and dL/dOut
      * * @param name name May be null. Name for the output variable - * @param x Input Variable (NUMERIC type) - * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) + * @param x Input Variable (NUMERIC type) + * @param wrt Gradient at the output - dL/dOut. Must have same shape as the input (NUMERIC type) * @return output Output (gradient at input of sigmoid) (NUMERIC type) */ public SDVariable sigmoidDerivative(String name, SDVariable x, SDVariable wrt) { SDValidation.validateNumerical("sigmoidDerivative", "x", x); SDValidation.validateNumerical("sigmoidDerivative", "wrt", wrt); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd,x, wrt).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative(sd, x, + wrt).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Softmax activation, along the specified dimension
      * - * @param x Input (NUMERIC type) + * @param x Input (NUMERIC type) * @param dimension Dimension along which to apply softmax * @return output Output variable (NUMERIC type) */ public SDVariable softmax(SDVariable x, int dimension) { SDValidation.validateNumerical("softmax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, dimension).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd, x, + dimension).outputVariable(); } /** * Softmax activation, along the specified dimension
      * - * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Input (NUMERIC type) * @param dimension Dimension along which to apply softmax * @return output Output variable (NUMERIC type) */ public SDVariable softmax(String name, SDVariable x, int dimension) { SDValidation.validateNumerical("softmax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd, x, + dimension).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1068,49 +1152,52 @@ public class SDNN extends SDOps { */ public SDVariable softmax(SDVariable x) { SDValidation.validateNumerical("softmax", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, -1).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd, x, -1).outputVariable(); } /** * Softmax activation, along the specified dimension
      * * @param name name May be null. Name for the output variable - * @param x Input (NUMERIC type) + * @param x Input (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable softmax(String name, SDVariable x) { SDValidation.validateNumerical("softmax", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd,x, -1).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax(sd, x, + -1).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** * Softmax derivative function
      * - * @param x Softmax input (NUMERIC type) - * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) * @param dimension Softmax dimension * @return output (NUMERIC type) */ public SDVariable softmaxDerivative(SDVariable x, SDVariable wrt, int dimension) { SDValidation.validateNumerical("softmaxDerivative", "x", x); SDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd,x, wrt, dimension).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd, x, wrt, + dimension).outputVariable(); } /** * Softmax derivative function
      * - * @param name name May be null. Name for the output variable - * @param x Softmax input (NUMERIC type) - * @param wrt Gradient at output, dL/dx (NUMERIC type) + * @param name name May be null. Name for the output variable + * @param x Softmax input (NUMERIC type) + * @param wrt Gradient at output, dL/dx (NUMERIC type) * @param dimension Softmax dimension * @return output (NUMERIC type) */ public SDVariable softmaxDerivative(String name, SDVariable x, SDVariable wrt, int dimension) { SDValidation.validateNumerical("softmaxDerivative", "x", x); SDValidation.validateNumerical("softmaxDerivative", "wrt", wrt); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd,x, wrt, dimension).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp(sd, x, wrt, + dimension).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1122,19 +1209,20 @@ public class SDNN extends SDOps { */ public SDVariable softplus(SDVariable x) { SDValidation.validateNumerical("softplus", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd, x).outputVariable(); } /** * Element-wise softplus function: out = log(exp(x) + 1)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable softplus(String name, SDVariable x) { SDValidation.validateNumerical("softplus", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1146,19 +1234,20 @@ public class SDNN extends SDOps { */ public SDVariable softsign(SDVariable x) { SDValidation.validateNumerical("softsign", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd, x).outputVariable(); } /** * Element-wise softsign function: out = x / (abs(x) + 1)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable softsign(String name, SDVariable x) { SDValidation.validateNumerical("softsign", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1170,45 +1259,48 @@ public class SDNN extends SDOps { */ public SDVariable softsignDerivative(SDVariable x) { SDValidation.validateNumerical("softsignDerivative", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd, + x).outputVariable(); } /** * Element-wise derivative (dOut/dIn) of the softsign function softsign(INDArray)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output (NUMERIC type) */ public SDVariable softsignDerivative(String name, SDVariable x) { SDValidation.validateNumerical("softsignDerivative", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } /** - * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
      - * See: https://arxiv.org/abs/1710.05941
      + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
      See: https://arxiv.org/abs/1710.05941
      * * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable swish(SDVariable x) { SDValidation.validateNumerical("swish", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd, x).outputVariable(); } /** - * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
      - * See: https://arxiv.org/abs/1710.05941
      + * Element-wise "swish" function: out = x * sigmoid(b*x) with b=1.0
      See: https://arxiv.org/abs/1710.05941
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable swish(String name, SDVariable x) { SDValidation.validateNumerical("swish", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Swish(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1220,19 +1312,20 @@ public class SDNN extends SDOps { */ public SDVariable tanh(SDVariable x) { SDValidation.validateNumerical("tanh", "x", x); - return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd, x).outputVariable(); } /** * Elementwise tanh (hyperbolic tangent) operation: out = tanh(x)
      * * @param name name May be null. Name for the output variable - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ public SDVariable tanh(String name, SDVariable x) { SDValidation.validateNumerical("tanh", "x", x); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd,x).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh(sd, + x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java index d57afe876..9c53bda9f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDRandom.java @@ -100,7 +100,7 @@ public class SDRandom extends SDOps { * P(x) = lambda * exp(-lambda * x)
      * * Inputs must satisfy the following constraints:
      - * Must be positive: lambda > 0
      + * Must be positive: lambda > 0
      * * @param lambda lambda parameter * @param datatype Data type of the output variable @@ -118,7 +118,7 @@ public class SDRandom extends SDOps { * P(x) = lambda * exp(-lambda * x)
      * * Inputs must satisfy the following constraints:
      - * Must be positive: lambda > 0
      + * Must be positive: lambda > 0
      * * @param name name May be null. Name for the output variable * @param lambda lambda parameter diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java index 9766a5b7c..a5fe6b751 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/Evaluation.java @@ -829,9 +829,9 @@ public class Evaluation extends BaseEvaluation { * Precision based on guesses so far.
      * Note: value returned will differ depending on number of classes and settings.
      * 1. For binary classification, if the positive class is set (via default value of 1, via constructor, - * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class + * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
      - * 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged + * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged precision, equivalent to {@code precision(EvaluationAveraging.Macro)}
      * * @return the total precision based on guesses so far @@ -977,7 +977,7 @@ public class Evaluation extends BaseEvaluation { * 1. For binary classification, if the positive class is set (via default value of 1, via constructor, * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
      - * 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged + * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged recall, equivalent to {@code recall(EvaluationAveraging.Macro)}
      * * @return the recall for the outcomes @@ -1173,12 +1173,12 @@ public class Evaluation extends BaseEvaluation { /** * False Alarm Rate (FAR) reflects rate of misclassified to classified records - * {@link }http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}
      + * {@see http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw}
      * Note: value returned will differ depending on number of classes and settings.
      * 1. For binary classification, if the positive class is set (via default value of 1, via constructor, - * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class + * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
      - * 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged + * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged false alarm rate) * * @return the fpr for the outcomes @@ -1243,9 +1243,9 @@ public class Evaluation extends BaseEvaluation { *
      * Note: value returned will differ depending on number of classes and settings.
      * 1. For binary classification, if the positive class is set (via default value of 1, via constructor, - * or via {@link #setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class + * or via {@link Evaluation#setBinaryPositiveClass(Integer)}), the returned value will be for the specified positive class * only.
      - * 2. For the multi-class case, or when {@link #getBinaryPositiveClass()} is null, the returned value is macro-averaged + * 2. For the multi-class case, or when {@link Evaluation#getBinaryPositiveClass()} is null, the returned value is macro-averaged * across all classes. i.e., is macro-averaged f1, equivalent to {@code f1(EvaluationAveraging.Macro)}
      * * @return the f1 score or harmonic mean of precision and recall based on current guesses diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java index c3a66aa93..49a2f8e3d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/evaluation/classification/EvaluationBinary.java @@ -584,7 +584,7 @@ public class EvaluationBinary extends BaseEvaluation { /** * False Alarm Rate (FAR) reflects rate of misclassified to classified records - * http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
      + * http://ro.ecu.edu.au/cgi/viewcontent.cgi?article=1058&context=isw
      * * @param outputNum Class index to calculate False Alarm Rate (FAR) * @return The FAR for the outcomes @@ -611,7 +611,7 @@ public class EvaluationBinary extends BaseEvaluation { StringBuilder sb = new StringBuilder(); - //Report: Accuracy, precision, recall, F1. Then: confusion matrix + //Report: Accuracy, precision, recall, F1. Then: confusion matrix] int maxLabelsLength = 15; if (labels != null) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java index 3fc641421..5503fdcf0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/blas/impl/BaseLapack.java @@ -202,7 +202,7 @@ public abstract class BaseLapack implements Lapack { * * @param jobz 'N' - no eigen vectors, 'V' - return eigenvectors * @param uplo upper or lower part of symmetric matrix to use - * @param N the number of rows & cols in the matrix A + * @param N the number of rows & cols in the matrix A * @param A the matrix to calculate eigenvectors * @param R an output array for eigenvalues ( may be null ) */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java index e21b2cd51..8789de3ff 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java @@ -74,14 +74,14 @@ public interface DataBufferFactory { DataBuffer createDouble(long offset, int length); /** - * This method will create new DataBuffer of the same dataType & same length + * This method will create new DataBuffer of the same dataType & same length * @param buffer * @return */ DataBuffer createSame(DataBuffer buffer, boolean init); /** - * This method will create new DataBuffer of the same dataType & same length + * This method will create new DataBuffer of the same dataType & same length * @param buffer * @return */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryManager.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryManager.java index bdf18ed35..32f5503ea 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryManager.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryManager.java @@ -132,7 +132,6 @@ public interface MemoryManager { * * @param pointer * @param kind - * @return */ void release(Pointer pointer, MemoryKind kind); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java index bade5db59..b9a2a0622 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java @@ -137,7 +137,7 @@ public interface MemoryWorkspaceManager { void destroyWorkspace(MemoryWorkspace workspace); /** - * This method destroys & deallocates all Workspaces for a calling Thread + * This method destroys & deallocates all Workspaces for a calling Thread * * PLEASE NOTE: This method is NOT safe */ @@ -149,21 +149,21 @@ public interface MemoryWorkspaceManager { void destroyWorkspace(); /** - * This method gets & activates default workspace + * This method gets and activates default workspace * * @return */ MemoryWorkspace getAndActivateWorkspace(); /** - * This method gets & activates workspace with a given Id + * This method gets and activates workspace with a given Id * * @return */ MemoryWorkspace getAndActivateWorkspace(String id); /** - * This method gets & activates default with a given configuration and Id + * This method gets and activates default with a given configuration and Id * * @return */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 9887ddadb..53ce1beba 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -94,5537 +94,5899 @@ import static org.nd4j.linalg.factory.Nd4j.*; @Slf4j public abstract class BaseNDArray implements INDArray, Iterable { - private static final long serialVersionUID = 3285982317165542614L; + private static final long serialVersionUID = 3285982317165542614L; - protected transient volatile DataBuffer shapeInformation; - protected transient volatile DataBuffer data; - //protected transient DataBuffer shape; - //protected transient DataBuffer stride; - protected transient boolean compressed = false; + protected transient volatile DataBuffer shapeInformation; + protected transient volatile DataBuffer data; + //protected transient DataBuffer shape; + //protected transient DataBuffer stride; + protected transient boolean compressed = false; - protected transient boolean released = false; + protected transient boolean released = false; - // this field holds jvm copy of shapeInfo - protected transient JvmShapeInfo jvmShapeInfo; + // this field holds jvm copy of shapeInfo + protected transient JvmShapeInfo jvmShapeInfo; - private static final AtomicLong arrayCounter = new AtomicLong(0); - protected transient final long arrayId = arrayCounter.getAndIncrement(); + private static final AtomicLong arrayCounter = new AtomicLong(0); + protected transient final long arrayId = arrayCounter.getAndIncrement(); - //Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over - private static final int[][] tadFinalPermuteDimensions; - static { - tadFinalPermuteDimensions = new int[32][0]; - tadFinalPermuteDimensions[1] = new int[] {1, 0}; //Edge case for 1d tensors: selectively apply to column vectors - for (int i = 2; i < 32; i++) { - tadFinalPermuteDimensions[i] = new int[i]; - for (int k = i - 1, j = 0; k >= 0; k--, j++) - tadFinalPermuteDimensions[i][j] = k; - } - val t =1; - } + //Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over + private static final int[][] tadFinalPermuteDimensions; - public BaseNDArray() { - } - - @Override - public boolean isCompressed() { - return compressed; - } - - @Override - public void markAsCompressed(boolean reallyCompressed) { - this.compressed = reallyCompressed; - } - - /** - * - * @param buffer - */ - public BaseNDArray(DataBuffer buffer) { - this.data = buffer; - if (buffer.length() >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); - long[] shape = {1, (int) buffer.length()}; - long[] stride = Nd4j.getStrides(shape); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, 1, Nd4j.order(), buffer.dataType(), false)); - init(shape, stride); - } - - /** - * - * @param buffer - * @param shape - * @param stride - * @param offset - * @param ordering - */ - public BaseNDArray(DataBuffer buffer, int[] shape, int[] stride, long offset, char ordering) { - Shape.assertValidOrder(ordering); - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, buffer.dataType(), false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering) { - this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering) { - Shape.assertValidOrder(ordering); - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, ews, ordering, buffer.dataType(), false )); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering, DataType dataType) { - this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, char ordering, DataType dataType) { - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, ews, ordering, dataType, false)); - init(shape, stride); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type) { - this.data = buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type, MemoryWorkspace workspace) { - this.data = buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - } - - public BaseNDArray(DataBuffer buffer, DataType dataType, long[] shape, long[] stride, long offset, char ordering) { - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) : buffer; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType, false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - } - - /** - * Initialize the ndarray as a matrix - * with the given data (indices preserved) - * @param data - */ - public BaseNDArray(double[][] data) { - this(data, Nd4j.order()); - } - - /** - * - * @param data - * @param ordering - */ - public BaseNDArray(double[][] data, char ordering) { - this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), - new int[] {data.length, data[0].length}, - Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering); - - int c = columns(); - for (int r = 0; r < rows(); r++) { - Preconditions.checkState(data[r].length == c, "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c ); + static { + tadFinalPermuteDimensions = new int[32][0]; + tadFinalPermuteDimensions[1] = new int[]{1, + 0}; //Edge case for 1d tensors: selectively apply to column vectors + for (int i = 2; i < 32; i++) { + tadFinalPermuteDimensions[i] = new int[i]; + for (int k = i - 1, j = 0; k >= 0; k--, j++) { + tadFinalPermuteDimensions[i][j] = k; } } + val t = 1; + } + public BaseNDArray() { + } - /** - * Create with the specified shape and buffer - * - * @param shape the shape - * @param buffer the buffer - */ - public BaseNDArray(int[] shape, DataBuffer buffer) { - this.data = buffer; - init(shape, Nd4j.getStrides(shape)); - } - - /** - * Create this ndarray with the given data and shape and 0 offset - * - * @param data the data to use - * @param shape the shape of the ndarray - */ - public BaseNDArray(float[] data, int[] shape, char ordering) { - this(data, shape, 0, ordering); - } - - /** - * @param data the data to use - * @param shape the shape of the ndarray - * @param offset the desired offset - * @param ordering the ordering of the ndarray - */ - public BaseNDArray(float[] data, int[] shape, long offset, char ordering) { - this(data, shape, Nd4j.getStrides(shape, ordering), offset); - } - - public BaseNDArray(double[] data, long[] shape, long offset, char ordering) { - this(data, shape, Nd4j.getStrides(shape, ordering), offset); - } - - public BaseNDArray(float[] data, long[] shape, long offset, char ordering) { - this(data, shape, Nd4j.getStrides(shape, ordering), offset); + @Override + public boolean isCompressed() { + return compressed; + } + + @Override + public void markAsCompressed(boolean reallyCompressed) { + this.compressed = reallyCompressed; + } + + /** + * @param buffer + */ + public BaseNDArray(DataBuffer buffer) { + this.data = buffer; + if (buffer.length() >= Integer.MAX_VALUE) { + throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); + } + long[] shape = {1, (int) buffer.length()}; + long[] stride = Nd4j.getStrides(shape); + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, stride, 1, Nd4j.order(), buffer.dataType(), false)); + init(shape, stride); + } + + /** + * @param buffer + * @param shape + * @param stride + * @param offset + * @param ordering + */ + public BaseNDArray(DataBuffer buffer, int[] shape, int[] stride, long offset, char ordering) { + Shape.assertValidOrder(ordering); + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) + : buffer; + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, buffer.dataType(), + false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering) { + this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), + ordering); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, + char ordering) { + Shape.assertValidOrder(ordering); + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) + : buffer; + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, stride, ews, ordering, buffer.dataType(), false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, char ordering, + DataType dataType) { + this(buffer, shape, stride, offset, Shape.elementWiseStride(shape, stride, ordering == 'f'), + ordering, dataType); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, long offset, long ews, + char ordering, DataType dataType) { + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) + : buffer; + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, stride, ews, ordering, dataType, false)); + init(shape, stride); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type) { + this.data = buffer; + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, long[] stride, char ordering, DataType type, + MemoryWorkspace workspace) { + this.data = buffer; + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, type, false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + } + + public BaseNDArray(DataBuffer buffer, DataType dataType, long[] shape, long[] stride, long offset, + char ordering) { + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, Shape.lengthOfBuffer(shape, stride)) + : buffer; + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, dataType, false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + } + + /** + * Initialize the ndarray as a matrix with the given data (indices preserved) + * + * @param data + */ + public BaseNDArray(double[][] data) { + this(data, Nd4j.order()); + } + + /** + * @param data + * @param ordering + */ + public BaseNDArray(double[][] data, char ordering) { + this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), + new int[]{data.length, data[0].length}, + Nd4j.getStrides(new int[]{data.length, data[0].length}, ordering), 0, ordering); + + int c = columns(); + for (int r = 0; r < rows(); r++) { + Preconditions.checkState(data[r].length == c, + "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c); } + } - /** - * Construct an ndarray of the specified shape - * with an empty data array - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - * @param offset the desired offset - * @param ordering the ordering of the ndarray - */ - public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) { - this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); - } + /** + * Create with the specified shape and buffer + * + * @param shape the shape + * @param buffer the buffer + */ + public BaseNDArray(int[] shape, DataBuffer buffer) { + this.data = buffer; + init(shape, Nd4j.getStrides(shape)); + } - public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) { - this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); - } + /** + * Create this ndarray with the given data and shape and 0 offset + * + * @param data the data to use + * @param shape the shape of the ndarray + */ + public BaseNDArray(float[] data, int[] shape, char ordering) { + this(data, shape, 0, ordering); + } - /** - * Construct an ndarray of the specified shape. - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - * @param offset the desired offset - * @param ordering the ordering of the ndarray - * @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't. - */ - public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); - } + /** + * @param data the data to use + * @param shape the shape of the ndarray + * @param offset the desired offset + * @param ordering the ordering of the ndarray + */ + public BaseNDArray(float[] data, int[] shape, long offset, char ordering) { + this(data, shape, Nd4j.getStrides(shape, ordering), offset); + } - public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); - } + public BaseNDArray(double[] data, long[] shape, long offset, char ordering) { + this(data, shape, Nd4j.getStrides(shape, ordering), offset); + } - public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize) { - this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), type, shape, stride, offset, ordering); - } + public BaseNDArray(float[] data, long[] shape, long offset, char ordering) { + this(data, shape, Nd4j.getStrides(shape, ordering), offset); + } - public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize, MemoryWorkspace workspace) { - this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize, workspace), type, shape, stride, offset, ordering); - } - public BaseNDArray(DataType type, long[] shape, long[] paddings, long[] paddingOffsets, char ordering, MemoryWorkspace workspace) { + /** + * Construct an ndarray of the specified shape with an empty data array + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + * @param offset the desired offset + * @param ordering the ordering of the ndarray + */ + public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) { + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, + offset, ordering); + } - //calculate strides with paddings - int rank = shape.length; - if(paddings == null || paddings.length != rank ) throw new IllegalArgumentException("The length of Padding should be equal to the length of Shape"); - long [] paddedShape = new long[rank]; - boolean empty = false; - boolean zeroOffset = paddingOffsets == null || paddingOffsets.length == 0; - boolean paddingOffsetsInvalid = paddingOffsets != null && paddingOffsets.length != rank ; - long ews = 1; - if(!paddingOffsetsInvalid){ - for(int i=0; ipaddings[i]){ - paddingOffsetsInvalid = true; - break; - } - } + public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) { + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, + offset, ordering); + } + + /** + * Construct an ndarray of the specified shape. + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + * @param offset the desired offset + * @param ordering the ordering of the ndarray + * @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't. + */ + public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) { + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, + stride, offset, ordering); + } + + public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) { + this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, + stride, offset, ordering); + } + + public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, + boolean initialize) { + this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), + type, shape, stride, offset, ordering); + } + + public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, + boolean initialize, MemoryWorkspace workspace) { + this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize, + workspace), type, shape, stride, offset, ordering); + } + + public BaseNDArray(DataType type, long[] shape, long[] paddings, long[] paddingOffsets, + char ordering, MemoryWorkspace workspace) { + + //calculate strides with paddings + int rank = shape.length; + if (paddings == null || paddings.length != rank) { + throw new IllegalArgumentException( + "The length of Padding should be equal to the length of Shape"); + } + long[] paddedShape = new long[rank]; + boolean empty = false; + boolean zeroOffset = paddingOffsets == null || paddingOffsets.length == 0; + boolean paddingOffsetsInvalid = paddingOffsets != null && paddingOffsets.length != rank; + long ews = 1; + if (!paddingOffsetsInvalid) { + for (int i = 0; i < rank; i++) { + paddedShape[i] = shape[i] + paddings[i]; + if (paddings[i] != 0) { + ews = 0; + } + if (shape[i] == 0) { + empty = true; + } + if (paddingOffsets[i] > paddings[i]) { + paddingOffsetsInvalid = true; + break; } - if(!zeroOffset && paddingOffsetsInvalid) throw new IllegalArgumentException("If PaddingOffsets is not empty or zero length then its length should match the length of Paddings and also its elements should not be greater"); - - long[] paddedStride = ordering == 'c' ? ArrayUtil.calcStrides(paddedShape,1): ArrayUtil.calcStridesFortran(paddedShape,1); - long paddedAllocSize = ordering == 'c' ? paddedShape[0] * paddedStride[0] : paddedShape[rank-1] * paddedStride[rank-1]; - - long offset = (empty || ews == 1 || zeroOffset) ? 0 : ArrayUtil.calcOffset(paddedShape, paddingOffsets, paddedStride); - DataBuffer buffer = Nd4j.createBuffer(type, paddedAllocSize, false, workspace); - this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, paddedAllocSize - offset) : buffer ; - long extras = ArrayOptionsHelper.setOptionBit(0, type); - if(empty) extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.ATYPE_EMPTY_BIT); - else if(ews!=1) extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.HAS_PADDED_BUFFER); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, paddedStride, ews, ordering, extras)); + } } + if (!zeroOffset && paddingOffsetsInvalid) { + throw new IllegalArgumentException( + "If PaddingOffsets is not empty or zero length then its length should match the length of Paddings and also its elements should not be greater"); + } - /** - * Create the ndarray with - * the specified shape and stride and an offset of 0 - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - * @param ordering the ordering of the ndarray - */ - public BaseNDArray(int[] shape, int[] stride, char ordering) { - this(shape, stride, 0, ordering); + long[] paddedStride = ordering == 'c' ? ArrayUtil.calcStrides(paddedShape, 1) + : ArrayUtil.calcStridesFortran(paddedShape, 1); + long paddedAllocSize = ordering == 'c' ? paddedShape[0] * paddedStride[0] + : paddedShape[rank - 1] * paddedStride[rank - 1]; + + long offset = (empty || ews == 1 || zeroOffset) ? 0 + : ArrayUtil.calcOffset(paddedShape, paddingOffsets, paddedStride); + DataBuffer buffer = Nd4j.createBuffer(type, paddedAllocSize, false, workspace); + this.data = offset > 0 ? Nd4j.createBuffer(buffer, offset, paddedAllocSize - offset) : buffer; + long extras = ArrayOptionsHelper.setOptionBit(0, type); + if (empty) { + extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.ATYPE_EMPTY_BIT); + } else if (ews != 1) { + extras = ArrayOptionsHelper.setOptionBit(extras, ArrayOptionsHelper.HAS_PADDED_BUFFER); + } + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, paddedStride, ews, ordering, extras)); + } + + /** + * Create the ndarray with the specified shape and stride and an offset of 0 + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + * @param ordering the ordering of the ndarray + */ + public BaseNDArray(int[] shape, int[] stride, char ordering) { + this(shape, stride, 0, ordering); + } + + + /** + * @param shape + * @param offset + * @param ordering + */ + public BaseNDArray(int[] shape, long offset, char ordering) { + this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); + } + + public BaseNDArray(long[] shape, long offset, char ordering) { + this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); + } + + + /** + * Create an ndarray with the given shape + * + * @param shape + */ + public BaseNDArray(int[] shape) { + this(shape, 0, Nd4j.order()); + } + + public BaseNDArray(long[] shape) { + this(shape, 0, Nd4j.order()); + } + + + /** + * Creates a new n times m DoubleMatrix. + * + * @param newRows the number of rows (n) of the new matrix. + * @param newColumns the number of columns (m) of the new matrix. + */ + public BaseNDArray(int newRows, int newColumns, char ordering) { + Shape.assertValidOrder(ordering); + this.data = Nd4j.createBuffer((long) newRows * newColumns); + val shape = new long[]{newRows, newColumns}; + val stride = Nd4j.getStrides(shape, ordering); + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); + init(shape, stride); + } + + public BaseNDArray(long newRows, long newColumns, char ordering) { + Shape.assertValidOrder(ordering); + this.data = Nd4j.createBuffer(newRows * newColumns); + long[] shape = new long[]{newRows, newColumns}; + long[] stride = Nd4j.getStrides(shape, ordering); + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); + init(shape, stride); + } + + + /** + * Create an ndarray from the specified slices. This will go through and merge all of the data + * from each slice in to one ndarray which will then take the specified shape + * + * @param slices the slices to merge + * @param shape the shape of the ndarray + */ + public BaseNDArray(List slices, int[] shape, char ordering) { + this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); + } + + public BaseNDArray(List slices, long[] shape, char ordering) { + this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); + } + + + /** + * Create an ndarray from the specified slices. This will go through and merge all of the data + * from each slice in to one ndarray which will then take the specified shape + * + * @param slices the slices to merge + * @param shape the shape of the ndarray + */ + public BaseNDArray(List slices, int[] shape, int[] stride, char ordering) { + Shape.assertValidOrder(ordering); + DataBuffer ret = slices.get(0).data().dataType() == (DataType.FLOAT) + ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) + : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); + this.data = ret; + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, + slices.get(0).dataType(), false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + + if (slices.get(0).isScalar()) { + for (int i = 0; i < length(); i++) { + putScalar(i, slices.get(i).getDouble(0)); + } + } else { + for (int i = 0; i < slices(); i++) { + putSlice(i, slices.get(i)); + } } + } - /** - * - * @param shape - * @param offset - * @param ordering - */ - public BaseNDArray(int[] shape, long offset, char ordering) { - this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); - } - - public BaseNDArray(long[] shape, long offset, char ordering) { - this(shape, Nd4j.getStrides(shape, ordering), offset, ordering); - } - - - /** - * Create an ndarray - * with the given shape - * @param shape - */ - public BaseNDArray(int[] shape) { - this(shape, 0, Nd4j.order()); - } - - public BaseNDArray(long[] shape) { - this(shape, 0, Nd4j.order()); - } - - - /** - * Creates a new n times m DoubleMatrix. - * - * @param newRows the number of rows (n) of the new matrix. - * @param newColumns the number of columns (m) of the new matrix. - */ - public BaseNDArray(int newRows, int newColumns, char ordering) { - Shape.assertValidOrder(ordering); - this.data = Nd4j.createBuffer((long) newRows * newColumns); - val shape = new long[] {newRows, newColumns}; - val stride = Nd4j.getStrides(shape, ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); - init(shape, stride); - } - - public BaseNDArray(long newRows, long newColumns, char ordering) { - Shape.assertValidOrder(ordering); - this.data = Nd4j.createBuffer(newRows * newColumns); - long[] shape = new long[] {newRows, newColumns}; - long[] stride = Nd4j.getStrides(shape, ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, Nd4j.dataType(), false)); - init(shape, stride); - } - - - /** - * Create an ndarray from the specified slices. - * This will go through and merge all of the - * data from each slice in to one ndarray - * which will then take the specified shape - * - * @param slices the slices to merge - * @param shape the shape of the ndarray - */ - public BaseNDArray(List slices, int[] shape, char ordering) { - this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); - } - - public BaseNDArray(List slices, long[] shape, char ordering) { - this(slices, shape, Nd4j.getStrides(shape, ordering), ordering); - } - - - /** - * Create an ndarray from the specified slices. - * This will go through and merge all of the - * data from each slice in to one ndarray - * which will then take the specified shape - * - * @param slices the slices to merge - * @param shape the shape of the ndarray - */ - public BaseNDArray(List slices, int[] shape, int[] stride, char ordering) { - Shape.assertValidOrder(ordering); - DataBuffer ret = slices.get(0).data().dataType() == (DataType.FLOAT) - ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) - : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); - this.data = ret; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - - if (slices.get(0).isScalar()) { - for (int i = 0; i < length(); i++) { - putScalar(i, slices.get(i).getDouble(0)); - } - } else { - for (int i = 0; i < slices(); i++) { - putSlice(i, slices.get(i)); - } - } - } - - - public BaseNDArray(List slices, long[] shape, long[] stride, char ordering) { - DataBuffer ret = Nd4j.createBuffer(slices.get(0).dataType(), Shape.lengthOf(shape), false); /*slices.get(0).data().dataType() == (DataType.FLOAT) + public BaseNDArray(List slices, long[] shape, long[] stride, char ordering) { + DataBuffer ret = Nd4j.createBuffer(slices.get(0).dataType(), Shape.lengthOf(shape), false); /*slices.get(0).data().dataType() == (DataType.FLOAT) ? Nd4j.createBuffer(new float[ArrayUtil.prod(shape)]) : Nd4j.createBuffer(new double[ArrayUtil.prod(shape)]); */ - this.data = ret; - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); + this.data = ret; + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, slices.get(0).dataType(), + false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, ordering == 'f')); - if (slices.get(0).isScalar()) { - for (int i = 0; i < length(); i++) { - putScalar(i, slices.get(i).getDouble(0)); - } - } else { - for (int i = 0; i < slices(); i++) { - putSlice(i, slices.get(i)); - } + if (slices.get(0).isScalar()) { + for (int i = 0; i < length(); i++) { + putScalar(i, slices.get(i).getDouble(0)); + } + } else { + for (int i = 0; i < slices(); i++) { + putSlice(i, slices.get(i)); + } + } + } + + /** + * @param data + * @param shape + * @param stride + * @param ordering + */ + public BaseNDArray(float[] data, int[] shape, int[] stride, char ordering) { + this(data, shape, stride, 0, ordering); + } + + /** + * @param data + * @param shape + * @param stride + * @param offset + * @param ordering + */ + public BaseNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) { + Shape.assertValidOrder(ordering); + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, + data == null || data.length <= 0)); + if (data != null && data.length > 0) { + + val perfD = PerformanceTracker.getInstance().helperStartTransaction(); + + this.data = internalCreateBuffer(data, offset); + + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, + (long) data.length * Nd4j.sizeOfDataType(DataType.FLOAT), MemcpyDirection.HOST_TO_HOST); + + if (offset >= data.length) { + throw new IllegalArgumentException("invalid offset: must be < data.length"); } } - /** - * - * @param data - * @param shape - * @param stride - * @param ordering - */ - public BaseNDArray(float[] data, int[] shape, int[] stride, char ordering) { - this(data, shape, stride, 0, ordering); - } - - /** - * - * @param data - * @param shape - * @param stride - * @param offset - * @param ordering - */ - public BaseNDArray(float[] data, int[] shape, int[] stride, long offset, char ordering) { - Shape.assertValidOrder(ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data == null || data.length <= 0)); - if (data != null && data.length > 0) { - - val perfD = PerformanceTracker.getInstance().helperStartTransaction(); - - this.data = internalCreateBuffer(data, offset); - - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfD, (long) data.length * Nd4j.sizeOfDataType(DataType.FLOAT), MemcpyDirection.HOST_TO_HOST); - - if (offset >= data.length) - throw new IllegalArgumentException("invalid offset: must be < data.length"); - } - - init(shape, stride); - } - - public BaseNDArray(float[] data, long[] shape, long[] stride, long offset, char ordering) { - Shape.assertValidOrder(ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, data == null || data.length <= 0)); - if (data != null && data.length > 0) { - this.data = Nd4j.createTypedBuffer(data, DataType.FLOAT); - if (offset >= data.length) - throw new IllegalArgumentException("invalid offset: must be < data.length"); - } - - init(shape, stride); - } - - public BaseNDArray(double[] data, long[] shape, long[] stride, long offset, char ordering) { - Shape.assertValidOrder(ordering); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, - Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.DOUBLE, data == null || data.length <= 0)); - if (data != null && data.length > 0) { - this.data = Nd4j.createBuffer(data, offset); - if (offset >= data.length) - throw new IllegalArgumentException("invalid offset: must be < data.length"); - } - - init(shape, stride); - } - - /** - * - * @param data - * @param shape - * @param stride - * @param offset - */ - public BaseNDArray(DataBuffer data, int[] shape, int[] stride, long offset) { - this.data = Nd4j.createBuffer(data, offset, ArrayUtil.prodLong(shape)); - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), - Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'), Nd4j.order(), data.dataType(), false)); - init(shape, stride); - // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f')); - - - } - - /** - * - * @param data - * @param shape - * @param strides - */ - public BaseNDArray(int[] data, int[] shape, int[] strides) { - this(internalCreateBuffer(data), shape, strides); - } - - /** - * - * @param data - * @param shape - */ - public BaseNDArray(DataBuffer data, int[] shape) { - this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); - } - - public BaseNDArray(DataBuffer data, long[] shape) { - this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); - } - - - /** - * - * @param buffer - * @param shape - * @param offset - */ - public BaseNDArray(DataBuffer buffer, int[] shape, long offset) { - this(Nd4j.createBuffer(buffer, offset, ArrayUtil.prodLong(shape)), shape, Nd4j.getStrides(shape), offset, - Nd4j.order()); - } - - /** - * - * @param buffer - * @param shape - * @param ordering - */ - public BaseNDArray(DataBuffer buffer, int[] shape, char ordering) { - this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); - } - - public BaseNDArray(DataBuffer buffer, long[] shape, char ordering) { - this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); - } - - /** - * - * @param data - * @param shape - * @param ordering - */ - public BaseNDArray(double[] data, int[] shape, char ordering) { - this(Nd4j.createBuffer(data), shape, ordering); - } - - public BaseNDArray(double[] data, long[] shape, char ordering) { - this(Nd4j.createBuffer(data), shape, ordering); - } - - public BaseNDArray(float[] data, long[] shape, char ordering) { - this(Nd4j.createBuffer(data), shape, ordering); - } - - /** - * - * @param data - * @param shape - * @param stride - * @param offset - * @param ordering - */ - public BaseNDArray(double[] data, int[] shape, int[] stride, long offset, char ordering) { - this(internalCreateBuffer(data, offset), shape, stride, offset, ordering); - } - - /** - * - * @param data - * @param order - */ - public BaseNDArray(float[] data, char order) { - this(internalCreateBuffer(data), order); - } - - protected static DataBuffer internalCreateBuffer(float[] data) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - protected static DataBuffer internalCreateBuffer(double[] data) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - protected static DataBuffer internalCreateBuffer(int[] data) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - protected static DataBuffer internalCreateBuffer(float[] data, long offset) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data, offset); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - protected static DataBuffer internalCreateBuffer(double[] data, long offset) { - val perfX = PerformanceTracker.getInstance().helperStartTransaction(); - - val buffer = Nd4j.createBuffer(data, offset); - PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); - - return buffer; - } - - /** - * - * @param floatBuffer - * @param order - */ - public BaseNDArray(DataBuffer floatBuffer, char order) { - this(floatBuffer, new int[] {(int) floatBuffer.length()}, - Nd4j.getStrides(new int[] {(int) floatBuffer.length()}, order), 0, order); - Shape.assertValidOrder(order); - if (floatBuffer.length() >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); - } - - /** - * - * @param buffer - * @param shape - * @param strides - */ - public BaseNDArray(DataBuffer buffer, int[] shape, int[] strides) { - this(buffer, shape, strides, 0, Nd4j.order()); - } - - - /** - * Create this ndarray with the given data and shape and 0 offset - * - * @param data the data to use - * @param shape the shape of the ndarray - */ - public BaseNDArray(float[] data, int[] shape) { - this(data, shape, 0); - } - - - /** - * - * @param data - * @param shape - * @param offset - */ - public BaseNDArray(float[] data, int[] shape, long offset) { - this(data, shape, offset, Nd4j.order()); - - } - - /** - * Construct an ndarray of the specified shape - * with an empty data array - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - * @param offset the desired offset - */ - public BaseNDArray(int[] shape, int[] stride, long offset) { - this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); - } - - public BaseNDArray(long[] shape, long[] stride, long offset) { - this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); - } - - /** - * Create the ndarray with - * the specified shape and stride and an offset of 0 - * - * @param shape the shape of the ndarray - * @param stride the stride of the ndarray - */ - public BaseNDArray(int[] shape, int[] stride) { - this(shape, stride, 0); - } - - /** - * - * @param shape - * @param offset - */ - public BaseNDArray(int[] shape, long offset) { - this(shape, Nd4j.getStrides(shape), offset); - } - - /** - * - * @param shape - * @param ordering - */ - public BaseNDArray(int[] shape, char ordering) { - this(shape, 0, ordering); - } - - - /** - * Creates a new n times m DoubleMatrix. - * - * @param newRows the number of rows (n) of the new matrix. - * @param newColumns the number of columns (m) of the new matrix. - */ - public BaseNDArray(int newRows, int newColumns) { - this(newRows, newColumns, Nd4j.order()); - } - - public BaseNDArray(long newRows, long newColumns) { - this(newRows, newColumns, Nd4j.order()); - } - - - /** - * Create an ndarray from the specified slices. - * This will go through and merge all of the - * data from each slice in to one ndarray - * which will then take the specified shape - * - * @param slices the slices to merge - * @param shape the shape of the ndarray - */ - public BaseNDArray(List slices, int[] shape) { - this(slices, shape, Nd4j.order()); - } - - public BaseNDArray(List slices, long[] shape) { - this(slices, shape, Nd4j.order()); - } - - /** - * Create an ndarray from the specified slices. - * This will go through and merge all of the - * data from each slice in to one ndarray - * which will then take the specified shape - * - * @param slices the slices to merge - * @param shape the shape of the ndarray - */ - public BaseNDArray(List slices, int[] shape, int[] stride) { - this(slices, shape, stride, Nd4j.order()); - } - - public BaseNDArray(List slices, long[] shape, long[] stride) { - this(slices, shape, stride, Nd4j.order()); - } - - /** - * - * @param data - * @param shape - * @param stride - */ - public BaseNDArray(float[] data, int[] shape, int[] stride) { - this(data, shape, stride, Nd4j.order()); - } - - - /** - * - * @param data - * @param shape - * @param stride - * @param offset - */ - public BaseNDArray(float[] data, int[] shape, int[] stride, long offset) { - this(data, shape, stride, offset, Nd4j.order()); - } - - public BaseNDArray(double[] data, long[] shape, long[] stride, long offset) { - this(data, shape, stride, offset, Nd4j.order()); - } - - public BaseNDArray(float[] data, long[] shape, long[] stride, long offset) { - this(data, shape, stride, offset, Nd4j.order()); - } - - /** - * - * @param data - */ - public BaseNDArray(float[] data) { - this(Nd4j.createBuffer(data)); - } - - - /** - * Initialize the ndarray - * with the given data - * @param data - */ - public BaseNDArray(float[][] data) { - this(data, Nd4j.order()); - } - - /** - * - * @param data - * @param ordering - */ - public BaseNDArray(float[][] data, char ordering) { - this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), - new int[] {data.length, data[0].length}, - Nd4j.getStrides(new int[] {data.length, data[0].length}, ordering), 0, ordering); - - int c = columns(); - for (int r = 0; r < rows(); r++) { - Preconditions.checkState(data[r].length == c, "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c ); + init(shape, stride); + } + + public BaseNDArray(float[] data, long[] shape, long[] stride, long offset, char ordering) { + Shape.assertValidOrder(ordering); + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.FLOAT, + data == null || data.length <= 0)); + if (data != null && data.length > 0) { + this.data = Nd4j.createTypedBuffer(data, DataType.FLOAT); + if (offset >= data.length) { + throw new IllegalArgumentException("invalid offset: must be < data.length"); } } + init(shape, stride); + } - - /** - * Constructor for stride and offset - * - * @param buffer - * @param shape - * @param offset - * @param ordering - */ - public BaseNDArray(DataBuffer buffer, int[] shape, long offset, char ordering) { - this(buffer, shape, Nd4j.getStrides(shape, ordering), offset, ordering); + public BaseNDArray(double[] data, long[] shape, long[] stride, long offset, char ordering) { + Shape.assertValidOrder(ordering); + setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape, stride, + Shape.elementWiseStride(shape, stride, ordering == 'f'), ordering, DataType.DOUBLE, + data == null || data.length <= 0)); + if (data != null && data.length > 0) { + this.data = Nd4j.createBuffer(data, offset); + if (offset >= data.length) { + throw new IllegalArgumentException("invalid offset: must be < data.length"); + } } - public BaseNDArray(double[] data, int[] shape, int[] stride, long offset) { - this(data, shape, stride, offset, Nd4j.order()); + init(shape, stride); + } + + /** + * @param data + * @param shape + * @param stride + * @param offset + */ + public BaseNDArray(DataBuffer data, int[] shape, int[] stride, long offset) { + this.data = Nd4j.createBuffer(data, offset, ArrayUtil.prodLong(shape)); + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), + Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f'), Nd4j.order(), + data.dataType(), false)); + init(shape, stride); + // Shape.setElementWiseStride(this.shapeInfo(),Shape.elementWiseStride(shape, stride, Nd4j.order() == 'f')); + + } + + /** + * @param data + * @param shape + * @param strides + */ + public BaseNDArray(int[] data, int[] shape, int[] strides) { + this(internalCreateBuffer(data), shape, strides); + } + + /** + * @param data + * @param shape + */ + public BaseNDArray(DataBuffer data, int[] shape) { + this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); + } + + public BaseNDArray(DataBuffer data, long[] shape) { + this(data, shape, Nd4j.getStrides(shape, Nd4j.order()), 0, Nd4j.order()); + } + + + /** + * @param buffer + * @param shape + * @param offset + */ + public BaseNDArray(DataBuffer buffer, int[] shape, long offset) { + this(Nd4j.createBuffer(buffer, offset, ArrayUtil.prodLong(shape)), shape, + Nd4j.getStrides(shape), offset, + Nd4j.order()); + } + + /** + * @param buffer + * @param shape + * @param ordering + */ + public BaseNDArray(DataBuffer buffer, int[] shape, char ordering) { + this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); + } + + public BaseNDArray(DataBuffer buffer, long[] shape, char ordering) { + this(buffer, shape, Nd4j.getStrides(shape, ordering), 0, ordering); + } + + /** + * @param data + * @param shape + * @param ordering + */ + public BaseNDArray(double[] data, int[] shape, char ordering) { + this(Nd4j.createBuffer(data), shape, ordering); + } + + public BaseNDArray(double[] data, long[] shape, char ordering) { + this(Nd4j.createBuffer(data), shape, ordering); + } + + public BaseNDArray(float[] data, long[] shape, char ordering) { + this(Nd4j.createBuffer(data), shape, ordering); + } + + /** + * @param data + * @param shape + * @param stride + * @param offset + * @param ordering + */ + public BaseNDArray(double[] data, int[] shape, int[] stride, long offset, char ordering) { + this(internalCreateBuffer(data, offset), shape, stride, offset, ordering); + } + + /** + * @param data + * @param order + */ + public BaseNDArray(float[] data, char order) { + this(internalCreateBuffer(data), order); + } + + protected static DataBuffer internalCreateBuffer(float[] data) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + protected static DataBuffer internalCreateBuffer(double[] data) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + protected static DataBuffer internalCreateBuffer(int[] data) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + protected static DataBuffer internalCreateBuffer(float[] data, long offset) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data, offset); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + protected static DataBuffer internalCreateBuffer(double[] data, long offset) { + val perfX = PerformanceTracker.getInstance().helperStartTransaction(); + + val buffer = Nd4j.createBuffer(data, offset); + PerformanceTracker.getInstance().helperRegisterTransaction(0, perfX, + (long) data.length * Nd4j.sizeOfDataType(buffer.dataType()), MemcpyDirection.HOST_TO_HOST); + + return buffer; + } + + /** + * @param floatBuffer + * @param order + */ + public BaseNDArray(DataBuffer floatBuffer, char order) { + this(floatBuffer, new int[]{(int) floatBuffer.length()}, + Nd4j.getStrides(new int[]{(int) floatBuffer.length()}, order), 0, order); + Shape.assertValidOrder(order); + if (floatBuffer.length() >= Integer.MAX_VALUE) { + throw new IllegalArgumentException("Length of buffer can not be >= Integer.MAX_VALUE"); + } + } + + /** + * @param buffer + * @param shape + * @param strides + */ + public BaseNDArray(DataBuffer buffer, int[] shape, int[] strides) { + this(buffer, shape, strides, 0, Nd4j.order()); + } + + + /** + * Create this ndarray with the given data and shape and 0 offset + * + * @param data the data to use + * @param shape the shape of the ndarray + */ + public BaseNDArray(float[] data, int[] shape) { + this(data, shape, 0); + } + + + /** + * @param data + * @param shape + * @param offset + */ + public BaseNDArray(float[] data, int[] shape, long offset) { + this(data, shape, offset, Nd4j.order()); + + } + + /** + * Construct an ndarray of the specified shape with an empty data array + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + * @param offset the desired offset + */ + public BaseNDArray(int[] shape, int[] stride, long offset) { + this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); + } + + public BaseNDArray(long[] shape, long[] stride, long offset) { + this(new float[ArrayUtil.prod(shape)], shape, stride, offset, Nd4j.order()); + } + + /** + * Create the ndarray with the specified shape and stride and an offset of 0 + * + * @param shape the shape of the ndarray + * @param stride the stride of the ndarray + */ + public BaseNDArray(int[] shape, int[] stride) { + this(shape, stride, 0); + } + + /** + * @param shape + * @param offset + */ + public BaseNDArray(int[] shape, long offset) { + this(shape, Nd4j.getStrides(shape), offset); + } + + /** + * @param shape + * @param ordering + */ + public BaseNDArray(int[] shape, char ordering) { + this(shape, 0, ordering); + } + + + /** + * Creates a new n times m DoubleMatrix. + * + * @param newRows the number of rows (n) of the new matrix. + * @param newColumns the number of columns (m) of the new matrix. + */ + public BaseNDArray(int newRows, int newColumns) { + this(newRows, newColumns, Nd4j.order()); + } + + public BaseNDArray(long newRows, long newColumns) { + this(newRows, newColumns, Nd4j.order()); + } + + + /** + * Create an ndarray from the specified slices. This will go through and merge all of the data + * from each slice in to one ndarray which will then take the specified shape + * + * @param slices the slices to merge + * @param shape the shape of the ndarray + */ + public BaseNDArray(List slices, int[] shape) { + this(slices, shape, Nd4j.order()); + } + + public BaseNDArray(List slices, long[] shape) { + this(slices, shape, Nd4j.order()); + } + + /** + * Create an ndarray from the specified slices. This will go through and merge all of the data + * from each slice in to one ndarray which will then take the specified shape + * + * @param slices the slices to merge + * @param shape the shape of the ndarray + */ + public BaseNDArray(List slices, int[] shape, int[] stride) { + this(slices, shape, stride, Nd4j.order()); + } + + public BaseNDArray(List slices, long[] shape, long[] stride) { + this(slices, shape, stride, Nd4j.order()); + } + + /** + * @param data + * @param shape + * @param stride + */ + public BaseNDArray(float[] data, int[] shape, int[] stride) { + this(data, shape, stride, Nd4j.order()); + } + + + /** + * @param data + * @param shape + * @param stride + * @param offset + */ + public BaseNDArray(float[] data, int[] shape, int[] stride, long offset) { + this(data, shape, stride, offset, Nd4j.order()); + } + + public BaseNDArray(double[] data, long[] shape, long[] stride, long offset) { + this(data, shape, stride, offset, Nd4j.order()); + } + + public BaseNDArray(float[] data, long[] shape, long[] stride, long offset) { + this(data, shape, stride, offset, Nd4j.order()); + } + + /** + * @param data + */ + public BaseNDArray(float[] data) { + this(Nd4j.createBuffer(data)); + } + + + /** + * Initialize the ndarray with the given data + * + * @param data + */ + public BaseNDArray(float[][] data) { + this(data, Nd4j.order()); + } + + /** + * @param data + * @param ordering + */ + public BaseNDArray(float[][] data, char ordering) { + this(internalCreateBuffer(ordering == 'c' ? ArrayUtil.flatten(data) : ArrayUtil.flattenF(data)), + new int[]{data.length, data[0].length}, + Nd4j.getStrides(new int[]{data.length, data[0].length}, ordering), 0, ordering); + + int c = columns(); + for (int r = 0; r < rows(); r++) { + Preconditions.checkState(data[r].length == c, + "data[%s].length=%s must be equal to number of columns %s", r, data[r].length, c); + } + } + + + /** + * Constructor for stride and offset + * + * @param buffer + * @param shape + * @param offset + * @param ordering + */ + public BaseNDArray(DataBuffer buffer, int[] shape, long offset, char ordering) { + this(buffer, shape, Nd4j.getStrides(shape, ordering), offset, ordering); + } + + public BaseNDArray(double[] data, int[] shape, int[] stride, long offset) { + this(data, shape, stride, offset, Nd4j.order()); + } + + + /** + * Returns whether the ndarray is valid or not + * + * @return true if the ndarray is valid false otherwise + */ + @Deprecated + public boolean isValid() { + try { + linearIndex(length() - 1); + } catch (Exception e) { + return false; + } + return true; + } + + protected INDArray create(DataBuffer data, int[] shape, long offset) { + return Nd4j.create(data, shape, offset); + } + + @Override + public int elementWiseStride() { + return Shape.elementWiseStride(shapeInfoDataBuffer()); + } + + @Override + public long tensorsAlongDimension(int... dimension) { + if (dimension == null || dimension.length == 0) { + throw new IllegalArgumentException( + "Invalid input: dimensions not specified (null or length 0)"); + } + if (dimension.length >= rank() + || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) { + return 1; + } + for (int i = 0; i < dimension.length; i++) { + if (dimension[i] < 0) { + dimension[i] += rank(); + } + } + long[] tensorShape = ArrayUtil.keep(shape(), dimension); + long len = ArrayUtil.prodLong(tensorShape); + if (len == 0) { + throw new IllegalStateException("Illegal length found after removing index"); + } + long length = length(); + if (length / len >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Tensors along dimension can not be >= Integer.MAX_VALUE"); + } + return length / len; + } + + @Override + public INDArray tensorAlongDimension(long index, int... dimension) { + if (dimension == null || dimension.length == 0) { + throw new IllegalArgumentException( + "Invalid input: dimensions not specified (null or length 0)"); + } + + Preconditions.checkArgument(!this.isEmpty(), + "tensorAlongDimension(...) can't be used on empty tensors"); + + if (dimension.length >= rank() + || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) { + return this; + } + for (int i = 0; i < dimension.length; i++) { + if (dimension[i] < 0) { + dimension[i] += rank(); + } + } + + //dedup + if (dimension.length > 1) { + dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension)))); + } + + if (dimension.length > 1) { + Arrays.sort(dimension); } + long tads = tensorsAlongDimension(dimension); + if (index >= tads) { + throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); + } - /** - * Returns whether the ndarray is valid or not - * @return true if the ndarray is valid - * false otherwise - */ - @Deprecated - public boolean isValid() { - try { - linearIndex(length() - 1); - } catch (Exception e) { - return false; - } - return true; + if (dimension.length == 1) { + if (dimension[0] == 0 && isColumnVector()) { + return this.transpose(); + } else if (dimension[0] == 1 && isRowVector()) { + return this; + } } - protected INDArray create(DataBuffer data, int[] shape, long offset) { - return Nd4j.create(data, shape, offset); + Pair tadInfo = Nd4j.getExecutioner().getTADManager() + .getTADOnlyShapeInfo(this, dimension); + DataBuffer shapeInfo = tadInfo.getFirst(); + val jShapeInfo = shapeInfo.asLong(); + val shape = Shape.shape(jShapeInfo); + val stride = Shape.stride(jShapeInfo); + long offset = offset() + tadInfo.getSecond().getLong(index); + val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2); + char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3); + val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder); + return toTad; + } + + private void setShapeInformation(Pair shapeInfo) { + this.shapeInformation = shapeInfo.getFirst(); + this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond()); + } + + + private INDArray doTad(int index, int... dimension) { + if (dimension == null || dimension.length == 0) { + throw new IllegalArgumentException( + "Invalid input: dimensions not specified (null or length 0)"); + } + + if (dimension.length >= rank()) { + return this; + } + for (int i = 0; i < dimension.length; i++) { + if (dimension[i] < 0) { + dimension[i] += rank(); + } + } + + if (dimension.length > 1) { + Arrays.sort(dimension); + } + + long tads = tensorsAlongDimension(dimension); + if (index >= tads) { + throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); + } + + if (dimension.length == 1) { + if (dimension[0] == 0 && isColumnVector()) { + return this.transpose(); + } else if (dimension[0] == 1 && isRowVector()) { + return this; + } } - @Override - public int elementWiseStride() { - return Shape.elementWiseStride(shapeInfoDataBuffer()); - } + long[] tensorShape = ArrayUtil.keep(shape(), dimension); + int[] reverseDimensions = ArrayUtil.reverseCopy(dimension); + int[] remove = ArrayUtil.removeIndex(ArrayUtil.range(0, rank()), dimension); + int[] newPermuteDims = Ints.concat(remove, reverseDimensions); + int[] finalPermuteDims = tadFinalPermuteDimensions[dimension.length]; - @Override - public long tensorsAlongDimension(int... dimension) { - if (dimension == null || dimension.length == 0) - throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); - if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) - return 1; - for (int i = 0; i < dimension.length; i++) - if (dimension[i] < 0) - dimension[i] += rank(); - long[] tensorShape = ArrayUtil.keep(shape(), dimension); - long len = ArrayUtil.prodLong(tensorShape); - if (len == 0) - throw new IllegalStateException("Illegal length found after removing index"); - long length = length(); - if (length / len >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Tensors along dimension can not be >= Integer.MAX_VALUE"); - return length / len; - } + INDArray permuted = permute(newPermuteDims); + long sliceIdx = NDArrayMath.sliceOffsetForTensor(index, permuted, tensorShape); - @Override - public INDArray tensorAlongDimension(long index, int... dimension) { - if (dimension == null || dimension.length == 0) - throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); - - Preconditions.checkArgument(!this.isEmpty(), "tensorAlongDimension(...) can't be used on empty tensors"); - - if (dimension.length >= rank() || dimension.length == 1 && dimension[0] == Integer.MAX_VALUE) - return this; - for (int i = 0; i < dimension.length; i++) - if (dimension[i] < 0) - dimension[i] += rank(); - - //dedup - if (dimension.length > 1) - dimension = Ints.toArray(new ArrayList<>(new TreeSet<>(Ints.asList(dimension)))); - - if (dimension.length > 1) { - Arrays.sort(dimension); - } - - long tads = tensorsAlongDimension(dimension); - if (index >= tads) - throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); - - - if (dimension.length == 1) { - if (dimension[0] == 0 && isColumnVector()) { - return this.transpose(); - } else if (dimension[0] == 1 && isRowVector()) { - return this; - } - } - - - Pair tadInfo = Nd4j.getExecutioner().getTADManager().getTADOnlyShapeInfo(this, dimension); - DataBuffer shapeInfo = tadInfo.getFirst(); - val jShapeInfo = shapeInfo.asLong(); - val shape = Shape.shape(jShapeInfo); - val stride = Shape.stride(jShapeInfo); - long offset = offset() + tadInfo.getSecond().getLong(index); - val ews = shapeInfo.getLong(jShapeInfo[0] * 2 + 2); - char tadOrder = (char) shapeInfo.getInt(jShapeInfo[0] * 2 + 3); - val toTad = Nd4j.create(data(), shape, stride, offset, ews, tadOrder); - return toTad; - } - - private void setShapeInformation(Pair shapeInfo) { - this.shapeInformation = shapeInfo.getFirst(); - this.jvmShapeInfo = new JvmShapeInfo(shapeInfo.getSecond()); - } - - - private INDArray doTad(int index, int... dimension) { - if (dimension == null || dimension.length == 0) - throw new IllegalArgumentException("Invalid input: dimensions not specified (null or length 0)"); - - if (dimension.length >= rank()) - return this; - for (int i = 0; i < dimension.length; i++) - if (dimension[i] < 0) - dimension[i] += rank(); - - if (dimension.length > 1) - Arrays.sort(dimension); - - long tads = tensorsAlongDimension(dimension); - if (index >= tads) - throw new IllegalArgumentException("Illegal index " + index + " out of tads " + tads); - - - if (dimension.length == 1) { - if (dimension[0] == 0 && isColumnVector()) { - return this.transpose(); - } else if (dimension[0] == 1 && isRowVector()) { - return this; - } - } - - - long[] tensorShape = ArrayUtil.keep(shape(), dimension); - int[] reverseDimensions = ArrayUtil.reverseCopy(dimension); - int[] remove = ArrayUtil.removeIndex(ArrayUtil.range(0, rank()), dimension); - int[] newPermuteDims = Ints.concat(remove, reverseDimensions); - int[] finalPermuteDims = tadFinalPermuteDimensions[dimension.length]; - - INDArray permuted = permute(newPermuteDims); - long sliceIdx = NDArrayMath.sliceOffsetForTensor(index, permuted, tensorShape); - - INDArray ret2 = permuted.slice(sliceIdx); - if (dimension.length == tensorShape.length && ArrayUtil.prodLong(tensorShape) == ret2.length()) { - if (dimension.length == 1 && ret2.isRowVector()) - return ret2; - if (finalPermuteDims.length != ret2.rank()) { - finalPermuteDims = new int[ret2.rank()]; - int count = 0; - for (int i = finalPermuteDims.length - 1; i >= 0; i--) - finalPermuteDims[count++] = i; - } - return ret2.permutei(finalPermuteDims); - } - - - int length = ArrayUtil.prod(tensorShape); - int tensorLength = ArrayUtil.prod(tensorShape); - long offset = (long) index * tensorLength / NDArrayMath.lengthPerSlice(ret2); - - if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) { - if (offset > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - ret2 = ret2.slice((int) offset); - if (dimension.length == 1 && ret2.isRowVectorOrScalar()) - return ret2; - return ret2.permutei(finalPermuteDims); - } - - else if (length == NDArrayMath.lengthPerSlice(ret2)) { - offset -= ret2.slices() * (offset / ret2.slices()); - - if (offset > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - ret2 = ret2.slice((int) offset); - if (dimension.length == 1 && ret2.isRowVectorOrScalar()) - return ret2; - return ret2.permutei(finalPermuteDims); - } - - while (ret2.length() > length) { - sliceIdx = NDArrayMath.sliceOffsetForTensor(index, ret2, tensorShape); - sliceIdx -= ret2.slices() * (sliceIdx / ret2.slices()); - ret2 = ret2.slice(sliceIdx); - } - - if (dimension.length == 1 && ret2.isRowVectorOrScalar()) + INDArray ret2 = permuted.slice(sliceIdx); + if (dimension.length == tensorShape.length + && ArrayUtil.prodLong(tensorShape) == ret2.length()) { + if (dimension.length == 1 && ret2.isRowVector()) { return ret2; - - return ret2.permutei(finalPermuteDims); + } + if (finalPermuteDims.length != ret2.rank()) { + finalPermuteDims = new int[ret2.rank()]; + int count = 0; + for (int i = finalPermuteDims.length - 1; i >= 0; i--) { + finalPermuteDims[count++] = i; + } + } + return ret2.permutei(finalPermuteDims); } - @Override - public long vectorsAlongDimension(int dimension) { - if (dimension == 0 && isVector() || isRowVectorOrScalar()) - return 1; - if (size(dimension) == 1 && !isVector()) { - for (int i = dimension; i < rank(); i++) { - if (size(i) != 1) - return vectorsAlongDimension(i); - } + int length = ArrayUtil.prod(tensorShape); + int tensorLength = ArrayUtil.prod(tensorShape); + long offset = (long) index * tensorLength / NDArrayMath.lengthPerSlice(ret2); - return length(); + if (sliceIdx == 0 && length == NDArrayMath.lengthPerSlice(ret2)) { + if (offset > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + ret2 = ret2.slice((int) offset); + if (dimension.length == 1 && ret2.isRowVectorOrScalar()) { + return ret2; + } + return ret2.permutei(finalPermuteDims); + } else if (length == NDArrayMath.lengthPerSlice(ret2)) { + offset -= ret2.slices() * (offset / ret2.slices()); - } else if (size(0) == 1 && !isVectorOrScalar()) { - int realDimension = rank() - getLeadingOnes(); - long length = length(); - if (length / size(realDimension) >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); - return length / size(realDimension); + if (offset > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + ret2 = ret2.slice((int) offset); + if (dimension.length == 1 && ret2.isRowVectorOrScalar()) { + return ret2; + } + return ret2.permutei(finalPermuteDims); + } + + while (ret2.length() > length) { + sliceIdx = NDArrayMath.sliceOffsetForTensor(index, ret2, tensorShape); + sliceIdx -= ret2.slices() * (sliceIdx / ret2.slices()); + ret2 = ret2.slice(sliceIdx); + } + + if (dimension.length == 1 && ret2.isRowVectorOrScalar()) { + return ret2; + } + + return ret2.permutei(finalPermuteDims); + } + + @Override + public long vectorsAlongDimension(int dimension) { + if (dimension == 0 && isVector() || isRowVectorOrScalar()) { + return 1; + } + if (size(dimension) == 1 && !isVector()) { + for (int i = dimension; i < rank(); i++) { + if (size(i) != 1) { + return vectorsAlongDimension(i); + } + } + + return length(); + + } else if (size(0) == 1 && !isVectorOrScalar()) { + int realDimension = rank() - getLeadingOnes(); + long length = length(); + if (length / size(realDimension) >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Vectors along dimension can not be >= Integer.MAX_VALUE"); + } + return length / size(realDimension); + } + + long length = length(); + + if (dimension >= jvmShapeInfo.rank) { + if (length / size(jvmShapeInfo.rank - 1) >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Vectors along dimension can not be >= Integer.MAX_VALUE"); + } + return (int) (length / size(jvmShapeInfo.rank - 1)); + } + if (length / size(dimension) >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Vectors along dimension can not be >= Integer.MAX_VALUE"); + } + return length / size(dimension); + } + + @Override + public INDArray vectorAlongDimension(int index, int dimension) { + if (dimension < 0) { + dimension = jvmShapeInfo.getRank() + dimension; + } + + //return the whole thing + if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2 + || rank() > 2 && dimension == 0 && size(dimension) == 1) { + return this; + } + + return tensorAlongDimension(index, dimension); + } + + @Override + public void setOrder(char order) { + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), + isEmpty())); + } + + @Override + public void setShapeAndStride(int[] shape, int[] stride) { + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, + ordering(), this.dataType(), false)); + } + + @Override + public INDArray cumsumi(int dimension) { + validateNumericalArray("cumsumi", true); + + if (isScalar() || isEmpty()) { + return this; + } + + if (isVector()) { + double s = 0.0; + for (int i = 0; i < length(); i++) { + s += getDouble(i); + putScalar(i, s); + } + } else if (dimension == Integer.MAX_VALUE) { + INDArray flattened = ravel(); + double prevVal = flattened.getDouble(0); + for (int i = 1; i < flattened.length(); i++) { + double d = prevVal + flattened.getDouble(i); + flattened.putScalar(i, d); + prevVal = d; + } + + return flattened; + } else { + for (int i = 0; i < vectorsAlongDimension(dimension); i++) { + INDArray vec = vectorAlongDimension(i, dimension); + vec.cumsumi(0); + + } + } + + return this; + } + + @Override + public Number normmaxNumber() { + return normmax(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number norm2Number() { + return norm2(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number norm1Number() { + return norm1(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number stdNumber() { + return std(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number prodNumber() { + if (isScalar()) { + return getNumber(0); + } + return prod(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number meanNumber() { + validateNumericalArray("meanNumber", false); + if (isScalar()) { + return getNumber(0); + } + return mean(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number ameanNumber() { + return amean(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number varNumber() { + return var(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number maxNumber() { + if (isScalar()) { + return getNumber(0); + } + return max(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number amaxNumber() { + return amax(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number minNumber() { + if (isScalar()) { + return getNumber(0); + } + return min(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number aminNumber() { + return amin(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number scan(Condition condition) { + MatchCondition op = new MatchCondition(this, condition); + return Nd4j.getExecutioner().exec(op).getDouble(0); + } + + @Override + public Number sumNumber() { + validateNumericalArray("sum", false); + if (isScalar()) { + return getNumber(0); + } + val scalar = sum(Integer.MAX_VALUE); + Nd4j.getExecutioner().commit(); + return scalar.getDouble(0); + } + + @Override + public Number entropyNumber() { + return entropy(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number shannonEntropyNumber() { + return shannonEntropy(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public Number logEntropyNumber() { + return logEntropy(Integer.MAX_VALUE).getDouble(0); + } + + @Override + public INDArray cumsum(int dimension) { + validateNumericalArray("cumsum", true); + return dup().cumsumi(dimension); + } + + @Override + public INDArray assign(final INDArray arr) { + Preconditions.checkState( + (this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) + || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()), + "Cannot assign arrays: arrays must both be scalars, both vectors, or shapes must be equal other than size 1 dimensions. Attempting to do x.assign(y)" + + + " with x.shape=%ndShape and y.shape=%ndShape", this, arr); + + Preconditions.checkArgument(this.length() == arr.length(), + "Length of both arrays must be equal"); + + Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this)); + return this; + } + + @Override + public INDArray putScalar(long i, double value) { + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + if (i < 0) { + i += rank(); + } + + // TODO: i'm not sure that rank == 1 has fair shortcut here + if (isScalar()) { + autoProcessScalarCall(); + data.put(i, value); + return this; + } else if (rank() == 1) { + data.put(i * stride(0), value); + return this; + } + + // we cant raise rank here, if original rank is 1 + if (isRowVector() && rank() == 2) { + return putScalar(0, i, value); + } else if (isColumnVector() && rank() == 2) { + return putScalar(i, 0, value); + } + long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); + return putScalar(indexes, value); + } + + @Override + public INDArray putScalar(long i, float value) { + return putScalar(i, (double) value); + } + + @Override + public INDArray putScalar(long i, int value) { + return putScalar(i, (double) value); + } + + @Override + public INDArray putScalar(int[] indexes, double value) { + Nd4j.getCompressor().autoDecompress(this); + + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] < 0) { + indexes[i] += this.size(i); + } + } + + if (indexes.length == 1) { + return putScalar(indexes[0], value); + } else if (indexes.length == 2) { + return putScalar(indexes[0], indexes[1], value); + } else if (indexes.length == 3) { + return putScalar(indexes[0], indexes[1], indexes[2], value); + } else if (indexes.length == 4) { + return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); + } else { + autoProcessScalarCall(); + long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); + data.put(offset, value); + } + return this; + } + + @Override + public INDArray putScalar(long[] indexes, double value) { + Nd4j.getCompressor().autoDecompress(this); + + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] < 0) { + indexes[i] += size(i); + } + } + + if (indexes.length == 1) { + return putScalar(indexes[0], value); + } else if (indexes.length == 2) { + return putScalar(indexes[0], indexes[1], value); + } else if (indexes.length == 3) { + return putScalar(indexes[0], indexes[1], indexes[2], value); + } else if (indexes.length == 4) { + return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); + } else { + autoProcessScalarCall(); + long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); + data.put(offset, value); + } + return this; + } + + @Override + public INDArray putScalar(long[] indexes, float value) { + return putScalar(indexes, (double) value); + } + + @Override + public INDArray putScalar(long row, long col, double value) { + Nd4j.getCompressor().autoDecompress(this); + autoProcessScalarCall(); + + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + if (rank() > 2) { + throw new IllegalStateException( + "Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray"); + } + long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, row, col); + data.put(offset, value); + return this; + } + + @Override + public INDArray putScalar(long dim0, long dim1, long dim2, double value) { + Nd4j.getCompressor().autoDecompress(this); + autoProcessScalarCall(); + + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + if (rank() != 3) { + throw new IllegalStateException( + "Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray"); + } + long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2); + long size_0 = jvmShapeInfo.javaShapeInformation[1]; + long size_1 = jvmShapeInfo.javaShapeInformation[1 + 1]; + long size_2 = jvmShapeInfo.javaShapeInformation[1 + 2]; + + if (size_0 != 1) { + offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 3]; + } + if (size_1 != 1) { + offset += dim1 * jvmShapeInfo.javaShapeInformation[1 + 1 + 3]; + } + if (size_2 != 1) { + offset += dim2 * jvmShapeInfo.javaShapeInformation[1 + 2 + 3]; + } + + data.put(offset, value); + return this; + } + + @Override + public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double value) { + Nd4j.getCompressor().autoDecompress(this); + autoProcessScalarCall(); + Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, + "Cannot put value %s into boolean array" + + " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); + + if (rank() != 4) { + throw new IllegalStateException( + "Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray"); + } + long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, dim0, dim1, dim2, dim3); + data.put(offset, value); + return this; + } + + @Override + public INDArray putScalar(int[] indexes, float value) { + return putScalar(indexes, (double) value); + } + + @Override + public INDArray putScalar(int[] indexes, int value) { + return putScalar(indexes, (double) value); + } + + @Override + public INDArray putScalar(long[] indexes, int value) { + return putScalar(indexes, (double) value); + } + + @Override + public INDArray eps(Number other) { + validateNumericalArray("eps", true); + return Nd4j.getExecutioner().exec( + new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), + other)); + } + + @Override + public INDArray eps(INDArray other) { + validateNumericalArray("eps", true); + return Nd4j.getExecutioner().exec(new Eps(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); + } + + @Override + public INDArray lt(Number other) { + validateNumericalArray("less than (lt)", false); + return Nd4j.getExecutioner().exec(new ScalarLessThan(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray lte(Number other) { + validateNumericalArray("less than or equals (lte)", false); + return Nd4j.getExecutioner().exec(new ScalarLessThanOrEqual(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray eq(Number other) { + Preconditions.checkArgument( + dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, + "Scalar equality on boolean arrays can only be applied with values 0 or 1: got value %s", + other); + return Nd4j.getExecutioner().exec(new ScalarEquals(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray gt(Number other) { + validateNumericalArray("greater than (gt)", false); + return Nd4j.getExecutioner().exec(new ScalarGreaterThan(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray gte(Number other) { + validateNumericalArray("greater than or equals (gte)", false); + return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray lt(INDArray other) { + validateNumericalArray("less than (lt)", false); + if (Shape.shapeEquals(this.shape(), other.shape())) { + return Nd4j.getExecutioner().exec(new LessThan(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; + } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{ + Nd4j.createUninitialized(DataType.BOOL, + Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; + } else { + throw new IllegalArgumentException("Shapes must be broadcastable"); + } + } + + @Override + public INDArray neq(Number other) { + Preconditions.checkArgument( + dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, + "Scalar non-equality on boolean arrays can only be applied with values 0 or 1: got value %s", + other); + Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); + return Nd4j.getExecutioner().exec(new ScalarNotEquals(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); + } + + @Override + public INDArray neq(INDArray other) { + Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); + return Nd4j.getExecutioner().exec(new NotEqualTo(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; + } + + @Override + public INDArray eq(INDArray other) { + if (Shape.shapeEquals(this.shape(), other.shape())) { + return Nd4j.getExecutioner().exec(new EqualTo(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; + } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{ + Nd4j.createUninitialized(DataType.BOOL, + Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; + } else { + throw new IllegalArgumentException("Shapes must be broadcastable"); + } + } + + @Override + public INDArray gt(INDArray other) { + validateNumericalArray("greater than (gt)", false); + if (Shape.shapeEquals(this.shape(), other.shape())) { + return Nd4j.getExecutioner().exec(new GreaterThan(this, other, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; + } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{ + Nd4j.createUninitialized(DataType.BOOL, + Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; + } else { + throw new IllegalArgumentException("Shapes must be broadcastable"); + } + } + + @Override + public INDArray isInfinite() { + validateNumericalArray("isInfinite", true); + if (isEmpty()) { + return Nd4j.empty(DataType.BOOL); + } + return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), + Conditions.isInfinite())); + } + + @Override + public INDArray isNaN() { + validateNumericalArray("isNaN", true); + if (isEmpty()) { + return Nd4j.empty(DataType.BOOL); + } + return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, + Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), + Conditions.isNan())); + } + + @Override + public INDArray neg() { + validateNumericalArray("negative (neg)", true); + if (isEmpty()) { + return this; + } + return Nd4j.getExecutioner().exec(new Negative(this, + Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()))); + } + + @Override + public INDArray negi() { + validateNumericalArray("negative (negi)", true); + if (isEmpty()) { + return this; + } + Nd4j.getExecutioner().exec(new Negative(this)); + return this; + } + + @Override + public INDArray rdiv(Number n, INDArray result) { + return rdivi(n, result); + } + + @Override + public INDArray rdivi(Number n, INDArray result) { + validateNumericalArray("rdivi", false); + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + Nd4j.getExecutioner().exec(new ScalarReverseDivision(this, null, result, n)); + return result; + } + + @Override + public INDArray rsub(Number n, INDArray result) { + return rsubi(n, result); + } + + @Override + public INDArray rsubi(Number n, INDArray result) { + validateNumericalArray("rsubi", false); + + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + + Nd4j.getExecutioner().exec(new ScalarReverseSubtraction(this, result, n)); + return result; + } + + @Override + public INDArray div(Number n, INDArray result) { + return divi(n, result); + } + + @Override + public INDArray divi(Number n, INDArray result) { + validateNumericalArray("divi", false); + + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + Nd4j.getExecutioner().exec(new ScalarDivision(this, null, result, n)); + return result; + } + + @Override + public INDArray mul(Number n, INDArray result) { + return muli(n, result); + } + + @Override + public INDArray muli(Number n, INDArray result) { + validateNumericalArray("muli", false); + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + Nd4j.getExecutioner().exec(new ScalarMultiplication(this, null, result, n)); + return result; + } + + @Override + public INDArray sub(Number n, INDArray result) { + return subi(n, result); + } + + @Override + public INDArray subi(Number n, INDArray result) { + validateNumericalArray("subi", false); + + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + + Nd4j.getExecutioner().exec(new ScalarSubtraction(this, null, result, n)); + return result; + } + + @Override + public INDArray add(Number n, INDArray result) { + return addi(n, result); + } + + @Override + public INDArray addi(Number n, INDArray result) { + validateNumericalArray("addi", false); + if (Double.isNaN(n.doubleValue())) { + n = Nd4j.EPS_THRESHOLD; + } + + Nd4j.getExecutioner().exec(new ScalarAdd(this, null, result, n)); + return result; + } + + @Override + public INDArray getScalar(long row, long column) { + return getScalar(new long[]{row, column}); + } + + @Override + public INDArray dup() { + return dup(Nd4j.order()); + } + + @Override + public INDArray dup(char order) { + WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray"); + if (this.isCompressed() && this.ordering() == order) { + INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer()); + ret.markAsCompressed(true); + return ret; + } + if (isEmpty()) { + return this; + } + + Nd4j.getCompressor().autoDecompress(this); + + // fixme: eventually it would be nice to have this in native code + if (isS()) { + val list = new ArrayList(); + for (int e = 0; e < this.length(); e++) { + list.add(this.getString(e)); } - long length = length(); + return Nd4j.create(list, this.shape(), this.ordering()); + } - if (dimension >= jvmShapeInfo.rank) { - if (length / size(jvmShapeInfo.rank - 1) >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); - return (int) (length / size(jvmShapeInfo.rank - 1)); + val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order); + z.assign(this); + return z; + } + + @Override + public int getInt(int... indices) { + return (int) getDouble(indices); + } + + @Override + public long getLong(long index) { + Nd4j.getCompressor().autoDecompress(this); + Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + + if (index >= length()) { + throw new IllegalArgumentException( + "Unable to get linear index " + index + ": values is greater than length (" + length() + + ")"); + } + + autoProcessScalarCall(); + + if (index == 0) { + return data().getLong(index); + } + + long[] dimensions = + ordering() == 'c' ? Shape.ind2subC(this, index) : Shape.ind2sub(this, index); + Shape.assertShapeLessThan(dimensions, shape()); + return getLong(dimensions); + } + + @Override + public long getLong(long... indices) { + if (isScalar()) { + return data().getLong(0); + } + return Shape.getLong(this, indices); + } + + @Override + public double getDouble(int... indices) { + autoProcessScalarCall(); + Nd4j.getCompressor().autoDecompress(this); + Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + + for (int i = 0; i < indices.length; i++) { + if (indices[i] < 0) { + indices[i] += rank(); } - if (length / size(dimension) >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Vectors along dimension can not be >= Integer.MAX_VALUE"); - return length / size(dimension); } - - @Override - public INDArray vectorAlongDimension(int index, int dimension) { - if (dimension < 0) { - dimension = jvmShapeInfo.getRank() + dimension; + if (indices.length == 1) { + if (rank() == 1) { + return Shape.getDouble(this, indices[0]); + } else if (isRowVector()) { + return Shape.getDouble(this, 0, indices[0]); + } else if (isColumnVector()) { + return Shape.getDouble(this, indices[0], 0); + } else if ((isScalar() || length() == 1) && indices[0] == 0) { + return data().getDouble(0); } + } + return Shape.getDouble(this, indices); + } - //return the whole thing - if (dimension == jvmShapeInfo.getRank() - 1 && size(dimension) == 1 && rank() > 2 - || rank() > 2 && dimension == 0 && size(dimension) == 1) { - return this; + @Override + public double getDouble(long... indices) { + autoProcessScalarCall(); + Nd4j.getCompressor().autoDecompress(this); + Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + + for (int i = 0; i < indices.length; i++) { + if (indices[i] < 0) { + indices[i] += rank(); } - - return tensorAlongDimension(index, dimension); } - - @Override - public void setOrder(char order) { - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(shape(), stride(), elementWiseStride(), order, this.dataType(), isEmpty())); - } - - @Override - public void setShapeAndStride(int[] shape, int[] stride) { - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 0, ordering(), this.dataType(), false)); - } - - @Override - public INDArray cumsumi(int dimension) { - validateNumericalArray("cumsumi", true); - - if(isScalar() || isEmpty()) - return this; - - if (isVector()) { - double s = 0.0; - for (int i = 0; i < length(); i++) { - s += getDouble(i); - putScalar(i, s); - } - } else if (dimension == Integer.MAX_VALUE) { - INDArray flattened = ravel(); - double prevVal = flattened.getDouble(0); - for (int i = 1; i < flattened.length(); i++) { - double d = prevVal + flattened.getDouble(i); - flattened.putScalar(i, d); - prevVal = d; - } - - return flattened; + if (indices.length == 1) { + if (rank() == 1) { + return Shape.getDouble(this, indices[0]); + } else if (isRowVector()) { + return Shape.getDouble(this, 0, indices[0]); + } else if (isColumnVector()) { + return Shape.getDouble(this, indices[0], 0); + } else if (isScalar() && indices[0] == 0) { + return data().getDouble(0); } else { - for (int i = 0; i < vectorsAlongDimension(dimension); i++) { - INDArray vec = vectorAlongDimension(i, dimension); - vec.cumsumi(0); - - } - } - - return this; - } - - @Override - public Number normmaxNumber() { - return normmax(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number norm2Number() { - return norm2(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number norm1Number() { - return norm1(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number stdNumber() { - return std(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number prodNumber() { - if(isScalar()) - return getNumber(0); - return prod(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number meanNumber() { - validateNumericalArray("meanNumber", false); - if(isScalar()) - return getNumber(0); - return mean(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number ameanNumber() { - return amean(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number varNumber() { - return var(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number maxNumber() { - if(isScalar()) - return getNumber(0); - return max(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number amaxNumber() { - return amax(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number minNumber() { - if(isScalar()) - return getNumber(0); - return min(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number aminNumber() { - return amin(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number scan(Condition condition) { - MatchCondition op = new MatchCondition(this, condition); - return Nd4j.getExecutioner().exec(op).getDouble(0); - } - - @Override - public Number sumNumber() { - validateNumericalArray("sum", false); - if(isScalar()) - return getNumber(0); - val scalar = sum(Integer.MAX_VALUE); - Nd4j.getExecutioner().commit(); - return scalar.getDouble(0); - } - - @Override - public Number entropyNumber() { - return entropy(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number shannonEntropyNumber() { - return shannonEntropy(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public Number logEntropyNumber() { - return logEntropy(Integer.MAX_VALUE).getDouble(0); - } - - @Override - public INDArray cumsum(int dimension) { - validateNumericalArray("cumsum", true); - return dup().cumsumi(dimension); - } - - @Override - public INDArray assign(final INDArray arr) { - Preconditions.checkState((this.isScalar() && arr.isScalar()) || (this.isVector() && arr.isVector()) || Shape.shapeEqualWithSqueeze(this.shape(), arr.shape()), - "Cannot assign arrays: arrays must both be scalars, both vectors, or shapes must be equal other than size 1 dimensions. Attempting to do x.assign(y)" + - " with x.shape=%ndShape and y.shape=%ndShape", this, arr ); - - Preconditions.checkArgument(this.length() == arr.length(), "Length of both arrays must be equal"); - - Nd4j.getExecutioner().exec(new org.nd4j.linalg.api.ops.impl.transforms.any.Assign(arr, this)); - return this; - } - - @Override - public INDArray putScalar(long i, double value) { - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - if (i < 0) - i += rank(); - - // TODO: i'm not sure that rank == 1 has fair shortcut here - if (isScalar()) { - autoProcessScalarCall(); - data.put(i, value); - return this; - } else if (rank() == 1) { - data.put(i * stride(0), value); - return this; - } - - // we cant raise rank here, if original rank is 1 - if (isRowVector() && rank() == 2) { - return putScalar(0, i, value); - } else if (isColumnVector() && rank() == 2) { - return putScalar(i, 0, value); - } - long[] indexes = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); - return putScalar(indexes, value); - } - - @Override - public INDArray putScalar(long i, float value) { - return putScalar(i, (double) value); - } - - @Override - public INDArray putScalar(long i, int value) { - return putScalar(i, (double) value); - } - - @Override - public INDArray putScalar(int[] indexes, double value) { - Nd4j.getCompressor().autoDecompress(this); - - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] < 0) - indexes[i] += this.size(i); - } - - if (indexes.length == 1) { - return putScalar(indexes[0], value); - } else if (indexes.length == 2) { - return putScalar(indexes[0], indexes[1], value); - } else if (indexes.length == 3) { - return putScalar(indexes[0], indexes[1], indexes[2], value); - } else if (indexes.length == 4) { - return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); - } else { - autoProcessScalarCall(); - long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); - data.put(offset, value); - } - return this; - } - - @Override - public INDArray putScalar(long[] indexes, double value) { - Nd4j.getCompressor().autoDecompress(this); - - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] < 0) - indexes[i] += size(i); - } - - if (indexes.length == 1) { - return putScalar(indexes[0], value); - } else if (indexes.length == 2) { - return putScalar(indexes[0], indexes[1], value); - } else if (indexes.length == 3) { - return putScalar(indexes[0], indexes[1], indexes[2], value); - } else if (indexes.length == 4) { - return putScalar(indexes[0], indexes[1], indexes[2], indexes[3], value); - } else { - autoProcessScalarCall(); - long offset = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); - data.put(offset, value); - } - return this; - } - - @Override - public INDArray putScalar(long[] indexes, float value) { - return putScalar(indexes, (double) value); - } - - @Override - public INDArray putScalar(long row, long col, double value) { - Nd4j.getCompressor().autoDecompress(this); - autoProcessScalarCall(); - - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - if (rank() > 2) - throw new IllegalStateException("Cannot use putScalar(int,int,double) on a rank " + rank() + " INDArray"); - long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, row, col); - data.put(offset, value); - return this; - } - - @Override - public INDArray putScalar(long dim0, long dim1, long dim2, double value) { - Nd4j.getCompressor().autoDecompress(this); - autoProcessScalarCall(); - - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - if (rank() != 3) throw new IllegalStateException( - "Cannot use putScalar(int,int,int,double) on a rank " + rank() + " INDArray"); - long offset = 0; // Shape.getOffsetUnsafe(javaShapeInformation, dim0, dim1, dim2); - long size_0 = jvmShapeInfo.javaShapeInformation[1]; - long size_1 = jvmShapeInfo.javaShapeInformation[1 + 1]; - long size_2 = jvmShapeInfo.javaShapeInformation[1 + 2]; - - if (size_0 != 1) - offset += dim0 * jvmShapeInfo.javaShapeInformation[1 + 3]; - if (size_1 != 1) - offset += dim1 * jvmShapeInfo.javaShapeInformation[1 + 1 + 3]; - if (size_2 != 1) - offset += dim2 * jvmShapeInfo.javaShapeInformation[1 + 2 + 3]; - - data.put(offset, value); - return this; - } - - @Override - public INDArray putScalar(long dim0, long dim1, long dim2, long dim3, double value) { - Nd4j.getCompressor().autoDecompress(this); - autoProcessScalarCall(); - Preconditions.checkArgument(dataType() != DataType.BOOL || value == 0.0 || value == 1.0, "Cannot put value %s into boolean array" + - " - only putScalar with values 0 or 1 is allowed on boolean arrays", value); - - if (rank() != 4) - throw new IllegalStateException( - "Cannot use putScalar(int,int,int,int,double) on a rank " + rank() + " INDArray"); - long offset = Shape.getOffsetUnsafe(jvmShapeInfo.javaShapeInformation, dim0, dim1, dim2, dim3); - data.put(offset, value); - return this; - } - - @Override - public INDArray putScalar(int[] indexes, float value) { - return putScalar(indexes, (double) value); - } - - @Override - public INDArray putScalar(int[] indexes, int value) { - return putScalar(indexes, (double) value); - } - - @Override - public INDArray putScalar(long[] indexes, int value) { - return putScalar(indexes, (double) value); - } - - @Override - public INDArray eps(Number other) { - validateNumericalArray("eps", true); - return Nd4j.getExecutioner().exec(new ScalarEps(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray eps(INDArray other) { - validateNumericalArray("eps", true); - return Nd4j.getExecutioner().exec(new Eps(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()))); - } - - @Override - public INDArray lt(Number other) { - validateNumericalArray("less than (lt)", false); - return Nd4j.getExecutioner().exec(new ScalarLessThan(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray lte(Number other) { - validateNumericalArray("less than or equals (lte)", false); - return Nd4j.getExecutioner().exec(new ScalarLessThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray eq(Number other) { - Preconditions.checkArgument(dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, "Scalar equality on boolean arrays can only be applied with values 0 or 1: got value %s",other); - return Nd4j.getExecutioner().exec(new ScalarEquals(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray gt(Number other) { - validateNumericalArray("greater than (gt)", false); - return Nd4j.getExecutioner().exec(new ScalarGreaterThan(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray gte(Number other) { - validateNumericalArray("greater than or equals (gte)", false); - return Nd4j.getExecutioner().exec(new ScalarGreaterThanOrEqual(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray lt(INDArray other) { - validateNumericalArray("less than (lt)", false); - if (Shape.shapeEquals(this.shape(), other.shape())) { - return Nd4j.getExecutioner().exec(new LessThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; - } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return Nd4j.exec(new LessThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; - } else - throw new IllegalArgumentException("Shapes must be broadcastable"); - } - - @Override - public INDArray neq(Number other) { - Preconditions.checkArgument(dataType() != DataType.BOOL || other.doubleValue() == 0.0 || other.doubleValue() == 1.0, "Scalar non-equality on boolean arrays can only be applied with values 0 or 1: got value %s",other); - Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); - return Nd4j.getExecutioner().exec(new ScalarNotEquals(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), other)); - } - - @Override - public INDArray neq(INDArray other) { - Preconditions.checkState(!isEmpty(), "Cannot perform operation neq (not equal) on empty array"); - return Nd4j.getExecutioner().exec(new NotEqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; - } - - @Override - public INDArray eq(INDArray other) { - if (Shape.shapeEquals(this.shape(), other.shape())) { - return Nd4j.getExecutioner().exec(new EqualTo(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; - } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return Nd4j.exec(new EqualTo(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; - } else - throw new IllegalArgumentException("Shapes must be broadcastable"); - } - - @Override - public INDArray gt(INDArray other) { - validateNumericalArray("greater than (gt)", false); - if (Shape.shapeEquals(this.shape(), other.shape())) { - return Nd4j.getExecutioner().exec(new GreaterThan(this, other, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering())))[0]; - } else if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return Nd4j.exec(new GreaterThan(new INDArray[]{this, other}, new INDArray[]{Nd4j.createUninitialized(DataType.BOOL, Shape.broadcastOutputShape(this.shape(), other.shape()))}))[0]; - } else - throw new IllegalArgumentException("Shapes must be broadcastable"); - } - - @Override - public INDArray isInfinite(){ - validateNumericalArray("isInfinite", true); - if(isEmpty()) - return Nd4j.empty(DataType.BOOL); - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isInfinite())); - } - - @Override - public INDArray isNaN(){ - validateNumericalArray("isNaN", true); - if(isEmpty()) - return Nd4j.empty(DataType.BOOL); - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, Nd4j.createUninitialized(DataType.BOOL, this.shape(), this.ordering()), Conditions.isNan())); - } - - @Override - public INDArray neg() { - validateNumericalArray("negative (neg)", true); - if(isEmpty()) - return this; - return Nd4j.getExecutioner().exec(new Negative(this, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()))); - } - - @Override - public INDArray negi() { - validateNumericalArray("negative (negi)", true); - if(isEmpty()) - return this; - Nd4j.getExecutioner().exec(new Negative(this)); - return this; - } - - @Override - public INDArray rdiv(Number n, INDArray result) { - return rdivi(n, result); - } - - @Override - public INDArray rdivi(Number n, INDArray result) { - validateNumericalArray("rdivi", false); - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - Nd4j.getExecutioner().exec(new ScalarReverseDivision(this, null, result, n)); - return result; - } - - @Override - public INDArray rsub(Number n, INDArray result) { - return rsubi(n, result); - } - - @Override - public INDArray rsubi(Number n, INDArray result) { - validateNumericalArray("rsubi", false); - - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - - Nd4j.getExecutioner().exec(new ScalarReverseSubtraction(this, result, n)); - return result; - } - - @Override - public INDArray div(Number n, INDArray result) { - return divi(n, result); - } - - @Override - public INDArray divi(Number n, INDArray result) { - validateNumericalArray("divi", false); - - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - Nd4j.getExecutioner().exec(new ScalarDivision(this, null, result, n)); - return result; - } - - @Override - public INDArray mul(Number n, INDArray result) { - return muli(n, result); - } - - @Override - public INDArray muli(Number n, INDArray result) { - validateNumericalArray("muli", false); - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - Nd4j.getExecutioner().exec(new ScalarMultiplication(this, null, result, n)); - return result; - } - - @Override - public INDArray sub(Number n, INDArray result) { - return subi(n, result); - } - - @Override - public INDArray subi(Number n, INDArray result) { - validateNumericalArray("subi", false); - - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - - Nd4j.getExecutioner().exec(new ScalarSubtraction(this, null, result, n)); - return result; - } - - @Override - public INDArray add(Number n, INDArray result) { - return addi(n, result); - } - - @Override - public INDArray addi(Number n, INDArray result) { - validateNumericalArray("addi", false); - if (Double.isNaN(n.doubleValue())) - n = Nd4j.EPS_THRESHOLD; - - Nd4j.getExecutioner().exec(new ScalarAdd(this, null, result, n)); - return result; - } - - @Override - public INDArray getScalar(long row, long column) { - return getScalar(new long[] {row, column}); - } - - @Override - public INDArray dup() { - return dup(Nd4j.order()); - } - - @Override - public INDArray dup(char order) { - WorkspaceUtils.assertValidArray(this, "Cannot duplicate INDArray"); - if (this.isCompressed() && this.ordering() == order) { - INDArray ret = Nd4j.createArrayFromShapeBuffer(data().dup(), this.shapeInfoDataBuffer()); - ret.markAsCompressed(true); - return ret; + "Indexes length must be > 1 for non vectors and scalars"); } - if(isEmpty()) - return this; + } + return Shape.getDouble(this, indices); + } - Nd4j.getCompressor().autoDecompress(this); + @Override + public float getFloat(int... indices) { + return (float) getDouble(indices); + } - // fixme: eventually it would be nice to have this in native code - if (isS()) { - val list = new ArrayList(); - for (int e = 0; e < this.length(); e++) - list.add(this.getString(e)); + @Override + public float getFloat(long... indices) { + return (float) getDouble(indices); + } - return Nd4j.create(list, this.shape(), this.ordering()); + @Override + public boolean isScalar() { + if (isEmpty()) { + return false; + } + + if (jvmShapeInfo.rank == 0) { + return true; + } else if (jvmShapeInfo.rank > 2) { + return false; + } else if (jvmShapeInfo.rank == 1) { + return shape()[0] == 1; + } else if (jvmShapeInfo.rank == 2) { + return shape()[0] == 1 && shape()[1] == 1 || length() == 1; + } else { + return false; + } + + } + + @Override + public INDArray put(int[] indices, INDArray element) { + Nd4j.getCompressor().autoDecompress(this); + if (!element.isScalar()) { + throw new IllegalArgumentException("Unable to insert anything but a scalar"); + } + if (isRowVector() && indices[0] == 0 && indices.length == 2) { + int ix = 0; + for (int i = 1; i < indices.length; i++) { + ix += indices[i] * stride(i); } - - val z = Nd4j.createUninitialized(this.dataType(), this.shape(), order); - z.assign(this); - return z; - } - - @Override - public int getInt(int... indices) { - return (int) getDouble(indices); - } - - @Override - public long getLong(long index) { - Nd4j.getCompressor().autoDecompress(this); - Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); - - if (index >= length()) { - throw new IllegalArgumentException("Unable to get linear index " + index + ": values is greater than length (" + length() + ")"); + if (ix >= data.length()) { + throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); } - - autoProcessScalarCall(); - - if (index == 0) - return data().getLong(index); - - long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, index) : Shape.ind2sub(this, index); - Shape.assertShapeLessThan(dimensions, shape()); - return getLong(dimensions); - } - - @Override - public long getLong(long... indices) { - if(isScalar()) - return data().getLong(0); - return Shape.getLong(this, indices); - } - - @Override - public double getDouble(int... indices) { - autoProcessScalarCall(); - Nd4j.getCompressor().autoDecompress(this); - Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); - + data.put(ix, element.getDouble(0)); + } else { + int ix = 0; for (int i = 0; i < indices.length; i++) { - if (indices[i] < 0) - indices[i] += rank(); - } - if (indices.length == 1) { - if (rank() == 1) - return Shape.getDouble(this, indices[0]); - else if (isRowVector()) - return Shape.getDouble(this, 0, indices[0]); - else if (isColumnVector()) - return Shape.getDouble(this, indices[0], 0); - else if ((isScalar() || length() == 1) && indices[0] == 0) - return data().getDouble(0); - } - return Shape.getDouble(this, indices); - } - - @Override - public double getDouble(long... indices) { - autoProcessScalarCall(); - Nd4j.getCompressor().autoDecompress(this); - Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); - - for (int i = 0; i < indices.length; i++) { - if (indices[i] < 0) - indices[i] += rank(); - } - if (indices.length == 1) { - if (rank() == 1) - return Shape.getDouble(this, indices[0]); - else if (isRowVector()) - return Shape.getDouble(this, 0, indices[0]); - else if (isColumnVector()) - return Shape.getDouble(this, indices[0], 0); - else if (isScalar() && indices[0] == 0) - return data().getDouble(0); - else - throw new IllegalStateException("Indexes length must be > 1 for non vectors and scalars"); - } - return Shape.getDouble(this, indices); - } - - @Override - public float getFloat(int... indices) { - return (float) getDouble(indices); - } - - @Override - public float getFloat(long... indices) { - return (float) getDouble(indices); - } - - @Override - public boolean isScalar() { - if (isEmpty()) - return false; - - if (jvmShapeInfo.rank == 0) { - return true; - } else if (jvmShapeInfo.rank > 2) { - return false; - } else if (jvmShapeInfo.rank == 1) { - return shape()[0] == 1; - } else if (jvmShapeInfo.rank == 2) { - return shape()[0] == 1 && shape()[1] == 1 || length() == 1; - } - - else - return false; - - } - - @Override - public INDArray put(int[] indices, INDArray element) { - Nd4j.getCompressor().autoDecompress(this); - if (!element.isScalar()) - throw new IllegalArgumentException("Unable to insert anything but a scalar"); - if (isRowVector() && indices[0] == 0 && indices.length == 2) { - int ix = 0; - for (int i = 1; i < indices.length; i++) + if (size(i) != 1) { ix += indices[i] * stride(i); - if (ix >= data.length()) - throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); - data.put(ix, element.getDouble(0)); - } else { - int ix = 0; - for (int i = 0; i < indices.length; i++) - if (size(i) != 1) - ix += indices[i] * stride(i); - if (ix >= data.length()) - throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); - data.put(ix, element.getDouble(0)); + } } - return this; - } - - @Override - public INDArray match(INDArray comp, Condition condition) { - // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition - Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal"); - Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal"); - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition)); - } - - @Override - public INDArray match(Number comp, Condition condition) { - return Nd4j.getExecutioner().exec(new MatchConditionTransform(this,comp.doubleValue(), condition)); - } - - @Override - public INDArray getWhere(INDArray comp, Condition condition) { - return BooleanIndexing.chooseFrom(new INDArray[]{this,comp},condition); - } - - @Override - public INDArray getWhere(Number comp, Condition condition) { - return BooleanIndexing.chooseFrom(new INDArray[]{this}, Collections.singletonList(comp.doubleValue()),Collections.emptyList(),condition); - } - - @Override - public INDArray putWhere(INDArray comp, INDArray put, Condition condition) { - Nd4j.getCompressor().autoDecompress(this); - MatchConditionTransform matchCondition = new MatchConditionTransform(this,comp,condition); - Nd4j.getExecutioner().exec(matchCondition); - return putWhereWithMask(matchCondition.z(),put); - } - - @Override - public INDArray putWhere(Number comp, INDArray put, Condition condition) { - return putWhere(Nd4j.scalar(comp),put,condition); - } - - @Override - public INDArray putWhere(Number comp, Number put, Condition condition) { - return putWhere(Nd4j.scalar(comp),Nd4j.scalar(put),condition); - } - - - @Override - public INDArray putWhereWithMask(INDArray mask, INDArray put) { - INDArray output = dup(); - Nd4j.getExecutioner().execAndReturn(new Where(new INDArray[]{mask,this,put},new INDArray[]{output})); - return output; - } - - @Override - public INDArray putWhereWithMask(INDArray mask, Number put) { - return putWhereWithMask(mask,Nd4j.scalar(put)); - } - - @Override - public INDArray put(int i, int j, INDArray element) { - return put(new int[] {i, j}, element); - } - - @Override - public INDArray put(int i, int j, Number element) { - return putScalar(new int[] {i, j}, element.doubleValue()); - } - - @Override - public INDArray putSlice(int slice, INDArray put) { - Nd4j.getCompressor().autoDecompress(this); - - - if (isScalar()) { - Preconditions.checkState(put.isScalar(), "Invalid dimension. Can only insert a scalar in to another scalar"); - put(0, put.getScalar(0)); - return this; - } else if (isVector()) { - Preconditions.checkState(put.isVectorOrScalar() && put.length() == length(), - "Invalid dimension on insertion. Can only insert scalars/vectors into other scalar/vectors"); - if (put.isScalar()) - putScalar(slice, put.getDouble(0)); - else - for (int i = 0; i < length(); i++) - putScalar(i, put.getDouble(i)); - return this; + if (ix >= data.length()) { + throw new IllegalArgumentException("Illegal indices " + Arrays.toString(indices)); } + data.put(ix, element.getDouble(0)); + } + return this; + } - assertSlice(put, slice); + @Override + public INDArray match(INDArray comp, Condition condition) { + // TODO: obviously, we can make this broadcastable, eventually. But this will require new CustomOp based on MatchCondition + Preconditions.checkArgument(Arrays.equals(this.shape(), comp.shape()), "Shapes must be equal"); + Preconditions.checkArgument(this.dataType() == comp.dataType(), "Data types bmust be equal"); + return Nd4j.getExecutioner().exec(new MatchConditionTransform(this, comp, + Nd4j.createUninitialized(DataType.BOOL, this.shape()), condition)); + } + + @Override + public INDArray match(Number comp, Condition condition) { + return Nd4j.getExecutioner() + .exec(new MatchConditionTransform(this, comp.doubleValue(), condition)); + } + + @Override + public INDArray getWhere(INDArray comp, Condition condition) { + return BooleanIndexing.chooseFrom(new INDArray[]{this, comp}, condition); + } + + @Override + public INDArray getWhere(Number comp, Condition condition) { + return BooleanIndexing.chooseFrom(new INDArray[]{this}, + Collections.singletonList(comp.doubleValue()), Collections.emptyList(), condition); + } + + @Override + public INDArray putWhere(INDArray comp, INDArray put, Condition condition) { + Nd4j.getCompressor().autoDecompress(this); + MatchConditionTransform matchCondition = new MatchConditionTransform(this, comp, condition); + Nd4j.getExecutioner().exec(matchCondition); + return putWhereWithMask(matchCondition.z(), put); + } + + @Override + public INDArray putWhere(Number comp, INDArray put, Condition condition) { + return putWhere(Nd4j.scalar(comp), put, condition); + } + + @Override + public INDArray putWhere(Number comp, Number put, Condition condition) { + return putWhere(Nd4j.scalar(comp), Nd4j.scalar(put), condition); + } - INDArray view = slice(slice); + @Override + public INDArray putWhereWithMask(INDArray mask, INDArray put) { + INDArray output = dup(); + Nd4j.getExecutioner() + .execAndReturn(new Where(new INDArray[]{mask, this, put}, new INDArray[]{output})); + return output; + } - if (put.length() == 1) { + @Override + public INDArray putWhereWithMask(INDArray mask, Number put) { + return putWhereWithMask(mask, Nd4j.scalar(put)); + } + + @Override + public INDArray put(int i, int j, INDArray element) { + return put(new int[]{i, j}, element); + } + + @Override + public INDArray put(int i, int j, Number element) { + return putScalar(new int[]{i, j}, element.doubleValue()); + } + + @Override + public INDArray putSlice(int slice, INDArray put) { + Nd4j.getCompressor().autoDecompress(this); + + if (isScalar()) { + Preconditions.checkState(put.isScalar(), + "Invalid dimension. Can only insert a scalar in to another scalar"); + put(0, put.getScalar(0)); + return this; + } else if (isVector()) { + Preconditions.checkState(put.isVectorOrScalar() && put.length() == length(), + "Invalid dimension on insertion. Can only insert scalars/vectors into other scalar/vectors"); + if (put.isScalar()) { putScalar(slice, put.getDouble(0)); } else { - if(!(view.isVector() && put.isVector() && view.length() == put.length()) && !view.equalShapes(put)){ - throw new IllegalStateException("Cannot put slice: array to be put (" + Arrays.toString(put.shape()) + - ") and slice array (" + Arrays.toString(view.shape()) + ") have different shapes"); - } - view.assign(put); - } - return this; - } - - protected void assertSlice(INDArray put, long slice) { - Preconditions.checkArgument(slice < slices(), "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", slice, slices()); - long[] sliceShape = put.shape(); - if (Shape.isRowVectorShape(sliceShape)) { - } else { - long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); - - //no need to compare for scalar; primarily due to shapes either being [1] or length 0 - if (put.isScalar()) - return; - - if (isVector() && put.isVector() && put.length() < length()) - return; - //edge case for column vectors - if (Shape.isColumnVectorShape(sliceShape)) - return; - if (!Shape.shapeEquals(sliceShape, requiredShape) && !Shape.isRowVectorShape(requiredShape) - && !Shape.isRowVectorShape(sliceShape)) - throw new IllegalStateException(String.format("Invalid shape size of %s . Should have been %s ", - Arrays.toString(sliceShape), Arrays.toString(requiredShape))); - } - } - - public boolean isMatrix() { - return rank() == 2; - } - - protected INDArray newShape(long[] newShape, char ordering) { - - return Nd4j.create(data(), newShape, stride(), 0, ordering); - } - - protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset, char ordering) { - return Nd4j.create(data, newShape, newStrides, offset, ordering); - } - - protected INDArray create(DataBuffer data, long[] newShape, long[] newStrides, long offset, char ordering) { - return Nd4j.create(data, newShape, newStrides, offset, ordering); - } - - protected INDArray create(DataBuffer data, int[] newShape, int[] newStrides, long offset) { - return Nd4j.create(data, newShape, newStrides, offset); - } - - protected INDArray create(int[] shape) { - return Nd4j.create(shape, getStrides(shape, Nd4j.order()), 0); - } - - protected INDArray create(int[] shape, int[] strides, long offset) { - return Nd4j.create(shape, strides, offset); - } - - protected int[] getStrides(int[] shape, char ordering) { - return Nd4j.getStrides(shape, ordering); - } - - @Override - public double squaredDistance(INDArray other) { - validateNumericalArray("squaredDistance", false); - double d2 = distance2(other); - return d2 * d2; - } - - @Override - public double distance2(INDArray other) { - validateNumericalArray("distance2", false); - Nd4j.getCompressor().autoDecompress(this); - return Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this, other)).getFinalResult().doubleValue(); - } - - @Override - public double distance1(INDArray other) { - validateNumericalArray("distance1", false); - Nd4j.getCompressor().autoDecompress(this); - return Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this, other)).getFinalResult().doubleValue(); - } - - @Override - public INDArray get(INDArray indices) { - if(indices.rank() > 2) { - throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); - } - - if (rank() == 1) { - Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well"); - val ret = Nd4j.createUninitialized(this.dataType(), indices.length()); - for (int e = 0; e < indices.length(); e++) { - val idx = indices.getLong(e); - val value = getDouble(idx); - ret.putScalar(e, value); - } - - return ret; - } else if(indices.rows() == rank()) { - INDArray ret = Nd4j.create(this.dataType(), indices.columns()); - - for(int i = 0; i < indices.columns(); i++) { - int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); - val v = getDouble(specifiedIndex); - ret.putScalar(i, v); - } - - return ret; - } - else { - List arrList = new ArrayList<>(); - - if(indices.isMatrix() || indices.isColumnVector() - || (indices.isScalar() && indices.rank() == 2)) { // we need this for compatibility with legacy code - for(int i = 0; i < indices.rows(); i++) { - if(i == 0) { - INDArray row = indices.getRow(i); - for(int j = 0; j < row.length(); j++) { - arrList.add(slice(row.getInt(j))); - } - } - else { - INDArray row = indices.slice(i); - for(int j = 0; j < row.length(); j++) { - INDArray put = arrList.get(j).slice(row.getInt(j)); - put = put.reshape(Longs.concat(new long[]{1},put.shape())); - arrList.set(j,put); - } - } - - } - } - else if(indices.isRowVector()) { - for(int i = 0; i < indices.length(); i++) { - INDArray add = slice(indices.getInt(i)); - add = add.reshape(Longs.concat(new long[] {1,},add.shape())); - arrList.add(add); - } - } - - return Nd4j.concat(0,arrList.toArray(new INDArray[arrList.size()])); - - } - - - } - - @Override - public INDArray put(INDArray indices, INDArray element) { - if(indices.rank() > 2) { - throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); - } - - if(indices.rows() == rank()) { - NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape()); - for(int i = 0; i < indices.columns(); i++) { - int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); - putScalar(specifiedIndex,element.getDouble(ndIndexIterator.next())); + for (int i = 0; i < length(); i++) { + putScalar(i, put.getDouble(i)); } } - else { - List arrList = new ArrayList<>(); - - if(indices.isMatrix() || indices.isColumnVector()) { - for(int i = 0; i < indices.rows(); i++) { - INDArray row = indices.getRow(i); - for(int j = 0; j < row.length(); j++) { - INDArray slice = slice(row.getInt(j)); - Nd4j.getExecutioner().execAndReturn(new Assign(new INDArray[]{slice,element},new INDArray[]{slice})); - arrList.add(slice(row.getInt(j))); - } - } - } - else if(indices.isRowVector()) { - for(int i = 0; i < indices.length(); i++) { - arrList.add(slice(indices.getInt(i))); - } - } - } - return this; + return this; } - @Override - public INDArray put(INDArrayIndex[] indices, INDArray element) { - Nd4j.getCompressor().autoDecompress(this); - boolean isSpecifiedIndex = false; - for(INDArrayIndex idx : indices){ - if(idx instanceof SpecifiedIndex){ - isSpecifiedIndex = true; - break; - } + assertSlice(put, slice); + + INDArray view = slice(slice); + + if (put.length() == 1) { + putScalar(slice, put.getDouble(0)); + } else { + if (!(view.isVector() && put.isVector() && view.length() == put.length()) + && !view.equalShapes(put)) { + throw new IllegalStateException( + "Cannot put slice: array to be put (" + Arrays.toString(put.shape()) + + ") and slice array (" + Arrays.toString(view.shape()) + ") have different shapes"); + } + view.assign(put); + } + return this; + } + + protected void assertSlice(INDArray put, long slice) { + Preconditions.checkArgument(slice < slices(), + "Invalid slice specified: slice %s must be in range 0 (inclusive) to numSlices=%s (exclusive)", + slice, slices()); + long[] sliceShape = put.shape(); + if (Shape.isRowVectorShape(sliceShape)) { + } else { + long[] requiredShape = ArrayUtil.removeIndex(shape(), 0); + + //no need to compare for scalar; primarily due to shapes either being [1] or length 0 + if (put.isScalar()) { + return; } - if(!isSpecifiedIndex){ - return get(indices).assign(element); - } else { - //Can't get a view, so we'll do it in subsets instead - // This is inefficient, but it is correct... - int numSpecified = 0; - List specifiedIdxs = new ArrayList<>(); - List specifiedIdxDims = new ArrayList<>(); - - INDArrayIndex[] destinationIndices = indices.clone(); //Shallow clone - INDArrayIndex[] sourceIndices = indices.clone(); - for( int i=0; i can't use point(1) on [1,x,y] - sourceIndices[i] = NDArrayIndex.point(0); - } - } - int[] counts = new int[specifiedIdxs.size()]; - int[] dims = new int[specifiedIdxDims.size()]; - for( int i=0; i 2) { + throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); } - @Override - public INDArray put(INDArrayIndex[] indices, Number element) { - Nd4j.getCompressor().autoDecompress(this); - INDArray get = get(indices); - for (int i = 0; i < get.length(); i++) - get.putScalar(i, element.doubleValue()); - return this; - } + if (rank() == 1) { + Preconditions.checkArgument(indices.rank() <= 1, + "For 1D vector indices must be either scalar or vector as well"); + val ret = Nd4j.createUninitialized(this.dataType(), indices.length()); + for (int e = 0; e < indices.length(); e++) { + val idx = indices.getLong(e); + val value = getDouble(idx); + ret.putScalar(e, value); + } + + return ret; + } else if (indices.rows() == rank()) { + INDArray ret = Nd4j.create(this.dataType(), indices.columns()); + + for (int i = 0; i < indices.columns(); i++) { + int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); + val v = getDouble(specifiedIndex); + ret.putScalar(i, v); + } + + return ret; + } else { + List arrList = new ArrayList<>(); + + if (indices.isMatrix() || indices.isColumnVector() + || (indices.isScalar() + && indices.rank() == 2)) { // we need this for compatibility with legacy code + for (int i = 0; i < indices.rows(); i++) { + if (i == 0) { + INDArray row = indices.getRow(i); + for (int j = 0; j < row.length(); j++) { + arrList.add(slice(row.getInt(j))); + } + } else { + INDArray row = indices.slice(i); + for (int j = 0; j < row.length(); j++) { + INDArray put = arrList.get(j).slice(row.getInt(j)); + put = put.reshape(Longs.concat(new long[]{1}, put.shape())); + arrList.set(j, put); + } + } + + } + } else if (indices.isRowVector()) { + for (int i = 0; i < indices.length(); i++) { + INDArray add = slice(indices.getInt(i)); + add = add.reshape(Longs.concat(new long[]{1,}, add.shape())); + arrList.add(add); + } + } + + return Nd4j.concat(0, arrList.toArray(new INDArray[arrList.size()])); - @Override - public INDArray swapAxes(int dimension, int with) { - int[] shape = ArrayUtil.range(0, shape().length); - shape[dimension] = with; - shape[with] = dimension; - return permute(shape); } - @Override - public boolean isView() { + } + + @Override + public INDArray put(INDArray indices, INDArray element) { + if (indices.rank() > 2) { + throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); + } + + if (indices.rows() == rank()) { + NdIndexIterator ndIndexIterator = new NdIndexIterator(element.shape()); + for (int i = 0; i < indices.columns(); i++) { + int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); + putScalar(specifiedIndex, element.getDouble(ndIndexIterator.next())); + } + } else { + List arrList = new ArrayList<>(); + + if (indices.isMatrix() || indices.isColumnVector()) { + for (int i = 0; i < indices.rows(); i++) { + INDArray row = indices.getRow(i); + for (int j = 0; j < row.length(); j++) { + INDArray slice = slice(row.getInt(j)); + Nd4j.getExecutioner() + .execAndReturn(new Assign(new INDArray[]{slice, element}, new INDArray[]{slice})); + arrList.add(slice(row.getInt(j))); + } + } + } else if (indices.isRowVector()) { + for (int i = 0; i < indices.length(); i++) { + arrList.add(slice(indices.getInt(i))); + } + } + } + return this; + } + + @Override + public INDArray put(INDArrayIndex[] indices, INDArray element) { + Nd4j.getCompressor().autoDecompress(this); + boolean isSpecifiedIndex = false; + for (INDArrayIndex idx : indices) { + if (idx instanceof SpecifiedIndex) { + isSpecifiedIndex = true; + break; + } + } + + if (!isSpecifiedIndex) { + return get(indices).assign(element); + } else { + //Can't get a view, so we'll do it in subsets instead + // This is inefficient, but it is correct... + int numSpecified = 0; + List specifiedIdxs = new ArrayList<>(); + List specifiedIdxDims = new ArrayList<>(); + + INDArrayIndex[] destinationIndices = indices.clone(); //Shallow clone + INDArrayIndex[] sourceIndices = indices.clone(); + for (int i = 0; i < indices.length; i++) { + INDArrayIndex idx = indices[i]; + if (idx instanceof SpecifiedIndex) { + numSpecified++; + long[] idxs = ((SpecifiedIndex) idx).getIndexes(); + specifiedIdxs.add(idxs); + specifiedIdxDims.add(i); + } else if (idx instanceof PointIndex) { + //Example: [2,3,3].put(point(1), ..., [1,x,y]) -> can't use point(1) on [1,x,y] + sourceIndices[i] = NDArrayIndex.point(0); + } + } + int[] counts = new int[specifiedIdxs.size()]; + int[] dims = new int[specifiedIdxDims.size()]; + for (int i = 0; i < specifiedIdxs.size(); i++) { + counts[i] = specifiedIdxs.get(i).length; + dims[i] = specifiedIdxDims.get(i); + } + + NdIndexIterator iter = new NdIndexIterator(counts); + while (iter.hasNext()) { + long[] iterationIdxs = iter.next(); + for (int i = 0; i < iterationIdxs.length; i++) { + long[] indicesForDim = specifiedIdxs.get(i); + destinationIndices[dims[i]] = NDArrayIndex.point(indicesForDim[(int) iterationIdxs[i]]); + sourceIndices[dims[i]] = NDArrayIndex.point(iterationIdxs[i]); + } + + INDArray sourceView = element.get(sourceIndices); + INDArray destinationView = this.get(destinationIndices); + destinationView.assign(sourceView); + } + } + return this; + } + + @Override + public INDArray put(INDArrayIndex[] indices, Number element) { + Nd4j.getCompressor().autoDecompress(this); + INDArray get = get(indices); + for (int i = 0; i < get.length(); i++) { + get.putScalar(i, element.doubleValue()); + } + return this; + } + + @Override + public INDArray swapAxes(int dimension, int with) { + int[] shape = ArrayUtil.range(0, shape().length); + shape[dimension] = with; + shape[with] = dimension; + return permute(shape); + } + + + @Override + public boolean isView() { /* We don't really use Shape offset value anywhere And it's possible to be not a view, and have non-empty originalBuffer */ - // length/data.length can be different in case of Threshold conversion - if(isEmpty() || isS()) - return false; + // length/data.length can be different in case of Threshold conversion + if (isEmpty() || isS()) { + return false; + } - val c2 = (length() < data().length() && data.dataType() != DataType.INT); - val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer()); + val c2 = (length() < data().length() && data.dataType() != DataType.INT); + val c3 = (data().originalDataBuffer() != null && data != data.originalDataBuffer()); - return c2 || c3; + return c2 || c3; + } + + @Override + public boolean isSparse() { + return false; + } + + @Override + public DataBuffer data() { + return data; + } + + @Override + public void setData(DataBuffer data) { + this.data = data; + } + + @Override + public long slices() { + return size(0); + } + + protected INDArray create(DataBuffer buffer) { + return Nd4j.create(buffer); + } + + @Override + public INDArray cond(Condition condition) { + if (isEmpty()) { + return Nd4j.empty(DataType.BOOL); + } + INDArray ret = Nd4j.createUninitialized(DataType.BOOL, this.shape()); + Nd4j.getExecutioner().exec(new MatchConditionTransform(this, ret, condition)); + return ret; + } + + protected void init(int[] shape, int[] stride) { + //null character + if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { + //Shape.setOrder(shapeInfo(), Nd4j.order()); + val si = Nd4j.getShapeInfoProvider() + .createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 1, + Nd4j.order(), this.dataType(), false); + setShapeInformation(si); } - @Override - public boolean isSparse() { - return false; + } + + protected void init(long[] shape, long[] stride) { + //null character + if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { + val si = Nd4j.getShapeInfoProvider() + .createShapeInformation(shape, stride, 1, Nd4j.order(), this.dataType(), false); + setShapeInformation(si); } - @Override - public DataBuffer data() { - return data; + } + + @Override + public INDArray getScalar(long i) { + if (i >= this.length()) { + throw new ND4JIllegalStateException("Index can't be greater then array length"); + } + + if (i < 0) { + i += this.length(); + } + + long idx = this.isScalar() ? 0 + : Shape.getOffset(jvmShapeInfo.javaShapeInformation, Shape.ind2subC(this.shape(), i)); + val buffer = Nd4j.createBuffer(this.data(), this.data().originalOffset() + idx, 1); + val shape = Nd4j.getShapeInfoProvider() + .createShapeInformation(new long[0], new long[0], 1, 'c', dataType(), false); + return Nd4j.createArrayFromShapeBuffer(buffer, shape); + } + + /** + * Do a row wise op (a,s,m,d) a : add s : subtract m : multiply d : divide h : reverse subtraction + * t : reverse division + * + * @param columnVector the column vector + * @param operation the operation + * @return + */ + protected INDArray doColumnWise(INDArray columnVector, char operation) { + Nd4j.getCompressor().autoDecompress(this); + if (columnVector.isScalar()) { + switch (operation) { + case 'a': + addi(columnVector.getDouble(0)); + break; + case 'p': + assign(columnVector.getDouble(0)); + break; + case 's': + subi(columnVector.getDouble(0)); + break; + case 'm': + muli(columnVector.getDouble(0)); + break; + case 'd': + divi(columnVector.getDouble(0)); + break; + case 'h': + rsubi(columnVector.getDouble(0)); + break; + case 't': + rdivi(columnVector.getDouble(0)); + break; + + } + + return this; + } else if (isScalar()) { + switch (operation) { + case 'a': + return columnVector.addi(getDouble(0)); + case 'p': + return columnVector.assign(getDouble(0)); + case 's': + return columnVector.subi(getDouble(0)); + case 'm': + return columnVector.muli(getDouble(0)); + case 'd': + return columnVector.divi(getDouble(0)); + case 'h': + return columnVector.rsubi(getDouble(0)); + case 't': + return columnVector.rdivi(getDouble(0)); + + } } - @Override - public void setData(DataBuffer data) { - this.data = data; + //Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0) + //Or, simply require it to be a rank 1 vector + if ((!columnVector.isColumnVector() && columnVector.rank() > 1) + || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) { + throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + + ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")"); } - @Override - public long slices() { - return size(0); + if (columnVector.data().sameUnderlyingData(data())) { + return doColumnWise(columnVector.dup(), operation); + } + if (equalShapes(columnVector)) { + switch (operation) { + case 'a': + addi(columnVector); + break; + case 'p': + assign(columnVector); + break; + case 's': + subi(columnVector); + break; + case 'm': + muli(columnVector); + break; + case 'd': + divi(columnVector); + break; + case 'h': + rsubi(columnVector); + break; + case 't': + rdivi(columnVector); + break; + } + + return this; } - - protected INDArray create(DataBuffer buffer) { - return Nd4j.create(buffer); - } - - @Override - public INDArray cond(Condition condition) { - if(isEmpty()) - return Nd4j.empty(DataType.BOOL); - INDArray ret = Nd4j.createUninitialized(DataType.BOOL, this.shape()); - Nd4j.getExecutioner().exec(new MatchConditionTransform(this,ret, condition)); - return ret; - } - - protected void init(int[] shape, int[] stride) { - //null character - if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { - //Shape.setOrder(shapeInfo(), Nd4j.order()); - val si = Nd4j.getShapeInfoProvider().createShapeInformation(ArrayUtil.toLongArray(shape), ArrayUtil.toLongArray(stride), 1, Nd4j.order(), this.dataType(), false); - setShapeInformation(si); - } - - } - - protected void init(long[] shape, long[] stride) { - //null character - if (shapeInformation == null || jvmShapeInfo == null || ordering() == '\u0000') { - val si = Nd4j.getShapeInfoProvider().createShapeInformation(shape,stride, 1, Nd4j.order(), this.dataType(), false); - setShapeInformation(si); - } - - } - - @Override - public INDArray getScalar(long i) { - if (i >= this.length()) - throw new ND4JIllegalStateException("Index can't be greater then array length"); - - if (i < 0) - i += this.length(); - - long idx = this.isScalar() ? 0 : Shape.getOffset(jvmShapeInfo.javaShapeInformation, Shape.ind2subC(this.shape(), i)); - val buffer = Nd4j.createBuffer( this.data(), this.data().originalOffset() + idx, 1); - val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1,'c', dataType(), false); - return Nd4j.createArrayFromShapeBuffer(buffer, shape); - } - - /** - * Do a row wise op (a,s,m,d) - * a : add - * s : subtract - * m : multiply - * d : divide - * h : reverse subtraction - * t : reverse division - * - * @param columnVector the column vector - * @param operation the operation - * @return - */ - protected INDArray doColumnWise(INDArray columnVector, char operation) { - Nd4j.getCompressor().autoDecompress(this); - if(columnVector.isScalar()) { - switch (operation) { - case 'a': - addi(columnVector.getDouble(0)); - break; - case 'p': - assign(columnVector.getDouble(0)); - break; - case 's': - subi(columnVector.getDouble(0)); - break; - case 'm': - muli(columnVector.getDouble(0)); - break; - case 'd': - divi(columnVector.getDouble(0)); - break; - case 'h': - rsubi(columnVector.getDouble(0)); - break; - case 't': - rdivi(columnVector.getDouble(0)); - break; - - } - - return this; - } - - else if(isScalar()) { - switch (operation) { - case 'a': - return columnVector.addi(getDouble(0)); - case 'p': - return columnVector.assign(getDouble(0)); - case 's': - return columnVector.subi(getDouble(0)); - case 'm': - return columnVector.muli(getDouble(0)); - case 'd': - return columnVector.divi(getDouble(0)); - case 'h': - return columnVector.rsubi(getDouble(0)); - case 't': - return columnVector.rdivi(getDouble(0)); - - } - } - - //Input validation: require (a) columnVector to actually be a column vector, and (b) this.size(0) to match columnVector.size(0) - //Or, simply require it to be a rank 1 vector - if ((!columnVector.isColumnVector() && columnVector.rank() > 1) || this.size(0) != columnVector.size(0) || columnVector.length() <= 1) { - throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) - + ", column vector shape =" + Arrays.toString(columnVector.shape()) + ")"); - } - - if (columnVector.data().sameUnderlyingData(data())) - return doColumnWise(columnVector.dup(), operation); - if (equalShapes(columnVector)) { - switch (operation) { - case 'a': - addi(columnVector); - break; - case 'p': - assign(columnVector); - break; - case 's': - subi(columnVector); - break; - case 'm': - muli(columnVector); - break; - case 'd': - divi(columnVector); - break; - case 'h': - rsubi(columnVector); - break; - case 't': - rdivi(columnVector); - break; - } - - return this; - } - if (rows() == 1 && columnVector.isScalar()) { - applyScalarOp(columnVector, operation); - } else { - // special optimization case, broadcast turns into ScalarOp Along Dimension - if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'c' && columnVector.elementWiseStride() == 1) { - switch (operation) { - case 'a': { - ScalarAdd op = new ScalarAdd(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 'p': { - ScalarSet op = new ScalarSet(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 's': { - ScalarSubtraction op = new ScalarSubtraction(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 'm': { - ScalarMultiplication op = - new ScalarMultiplication(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 'd': { - ScalarDivision op = new ScalarDivision(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 'h': { - ScalarReverseSubtraction op = - new ScalarReverseSubtraction(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - case 't': { - ScalarReverseDivision op = - new ScalarReverseDivision(this, columnVector, this, 0.0); - op.setDimension(1); - Nd4j.getExecutioner().exec(op); - break; - } - } - } else { - applyBroadcastOp(columnVector, operation); - } - - } - - return this; - - } - - /** - * Do a row wise op (a,s,m,d) - * a : add - * s : subtract - * m : multiply - * d : divide - * h : reverse subtraction - * t : reverse division - * - * @param rowVector the row vector - * @param operation the operation - * @return - */ - protected INDArray doRowWise(INDArray rowVector, final char operation) { - Nd4j.getCompressor().autoDecompress(this); - - - if(rowVector.isScalar()) { - switch (operation) { - case 'a': - addi(rowVector.getDouble(0)); - break; - case 'p': - assign(rowVector.getDouble(0)); - break; - case 's': - subi(rowVector.getDouble(0)); - break; - case 'm': - muli(rowVector.getDouble(0)); - break; - case 'd': - divi(rowVector.getDouble(0)); - break; - case 'h': - rsubi(rowVector.getDouble(0)); - break; - case 't': - rdivi(rowVector.getDouble(0)); - break; - - } - - return this; - } - else if(isScalar()) { - switch (operation) { - case 'a': - return rowVector.addi(getDouble(0)); - case 'p': - return rowVector.assign(getDouble(0)); - case 's': - return rowVector.subi(getDouble(0)); - case 'm': - return rowVector.muli(getDouble(0)); - case 'd': - return rowVector.divi(getDouble(0)); - case 'h': - return rowVector.rsubi(getDouble(0)); - case 't': - return rowVector.rdivi(getDouble(0)); - - } - } - - //Input validation: require (a) rowVector to actually be a row vector, and (b) this.size(1) to match rowVector.size(1) - if (!rowVector.isRowVector() || this.rank() > 1 && rowVector.rank() > 1 && this.size(1) != rowVector.size(1) || rowVector.length() <= 1) { - throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) - + ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")"); - } - - if (rowVector.data().sameUnderlyingData(data())) - return doRowWise(rowVector.dup(), operation); - - if (isVector()) { - switch (operation) { - case 'a': - addi(rowVector); - break; - case 'p': - assign(rowVector); - break; - case 's': - subi(rowVector); - break; - case 'm': - muli(rowVector); - break; - case 'd': - divi(rowVector); - break; - case 'h': - rsubi(rowVector); - break; - case 't': - rdivi(rowVector); - break; - } - - return this; - } - - if (rank() == 2 && columns() == 1 && rowVector.isScalar()) { - applyScalarOp(rowVector, operation); - } else { - // special optimization case, broadcast turns into ScalarOp Along Dimension - if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'f' && rowVector.elementWiseStride() == 1) { - switch (operation) { - case 'a': { - ScalarAdd op = new ScalarAdd(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 'p': { - ScalarSet op = new ScalarSet(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 's': { - ScalarSubtraction op = new ScalarSubtraction(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 'm': { - ScalarMultiplication op = new ScalarMultiplication(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 'd': { - ScalarDivision op = new ScalarDivision(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 'h': { - ScalarReverseSubtraction op = - new ScalarReverseSubtraction(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - case 't': { - ScalarReverseDivision op = new ScalarReverseDivision(this, rowVector, this, 0.0); - op.setDimension(0); - Nd4j.getExecutioner().exec(op); - break; - } - - } - } else { - applyBroadcastOp(rowVector, operation); - } - } - - return this; - } - - - private void applyBroadcastOp(INDArray vector, final char operation) { - Nd4j.getCompressor().autoDecompress(this); - int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0; - - // FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics - if (this.data() == vector.data()) - vector = vector.dup(); + if (rows() == 1 && columnVector.isScalar()) { + applyScalarOp(columnVector, operation); + } else { + // special optimization case, broadcast turns into ScalarOp Along Dimension + if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'c' + && columnVector.elementWiseStride() == 1) { switch (operation) { - case 'a': - Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension)); - return; - case 's': - Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension)); - return; - case 'm': - Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension)); - return; - case 'd': - Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension)); - return; - case 'h': - Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension)); - return; - case 't': - Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension)); - return; - case 'p': - Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension)); - return; - default: - throw new UnsupportedOperationException("Unknown operation: " + operation); + case 'a': { + ScalarAdd op = new ScalarAdd(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 'p': { + ScalarSet op = new ScalarSet(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 's': { + ScalarSubtraction op = new ScalarSubtraction(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 'm': { + ScalarMultiplication op = + new ScalarMultiplication(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 'd': { + ScalarDivision op = new ScalarDivision(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 'h': { + ScalarReverseSubtraction op = + new ScalarReverseSubtraction(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } + case 't': { + ScalarReverseDivision op = + new ScalarReverseDivision(this, columnVector, this, 0.0); + op.setDimension(1); + Nd4j.getExecutioner().exec(op); + break; + } } + } else { + applyBroadcastOp(columnVector, operation); + } + } - private void applyScalarOp(INDArray vector, char operation) { - Nd4j.getCompressor().autoDecompress(this); + return this; + + } + + /** + * Do a row wise op (a,s,m,d) a : add s : subtract m : multiply d : divide h : reverse subtraction + * t : reverse division + * + * @param rowVector the row vector + * @param operation the operation + * @return + */ + protected INDArray doRowWise(INDArray rowVector, final char operation) { + Nd4j.getCompressor().autoDecompress(this); + + if (rowVector.isScalar()) { + switch (operation) { + case 'a': + addi(rowVector.getDouble(0)); + break; + case 'p': + assign(rowVector.getDouble(0)); + break; + case 's': + subi(rowVector.getDouble(0)); + break; + case 'm': + muli(rowVector.getDouble(0)); + break; + case 'd': + divi(rowVector.getDouble(0)); + break; + case 'h': + rsubi(rowVector.getDouble(0)); + break; + case 't': + rdivi(rowVector.getDouble(0)); + break; + + } + + return this; + } else if (isScalar()) { + switch (operation) { + case 'a': + return rowVector.addi(getDouble(0)); + case 'p': + return rowVector.assign(getDouble(0)); + case 's': + return rowVector.subi(getDouble(0)); + case 'm': + return rowVector.muli(getDouble(0)); + case 'd': + return rowVector.divi(getDouble(0)); + case 'h': + return rowVector.rsubi(getDouble(0)); + case 't': + return rowVector.rdivi(getDouble(0)); + + } + } + + //Input validation: require (a) rowVector to actually be a row vector, and (b) this.size(1) to match rowVector.size(1) + if (!rowVector.isRowVector() + || this.rank() > 1 && rowVector.rank() > 1 && this.size(1) != rowVector.size(1) + || rowVector.length() <= 1) { + throw new IllegalStateException("Mismatched shapes (shape = " + Arrays.toString(shape()) + + ", row vector shape =" + Arrays.toString(rowVector.shape()) + ")"); + } + + if (rowVector.data().sameUnderlyingData(data())) { + return doRowWise(rowVector.dup(), operation); + } + + if (isVector()) { + switch (operation) { + case 'a': + addi(rowVector); + break; + case 'p': + assign(rowVector); + break; + case 's': + subi(rowVector); + break; + case 'm': + muli(rowVector); + break; + case 'd': + divi(rowVector); + break; + case 'h': + rsubi(rowVector); + break; + case 't': + rdivi(rowVector); + break; + } + + return this; + } + + if (rank() == 2 && columns() == 1 && rowVector.isScalar()) { + applyScalarOp(rowVector, operation); + } else { + // special optimization case, broadcast turns into ScalarOp Along Dimension + if (rank() == 2 && elementWiseStride() == 1 && ordering() == 'f' + && rowVector.elementWiseStride() == 1) { switch (operation) { - case 'a': - addi(vector.getDouble(0)); - break; - case 's': - subi(vector.getDouble(0)); - break; - case 'm': - muli(vector.getDouble(0)); - break; - case 'd': - divi(vector.getDouble(0)); - break; - case 'h': - rsubi(vector.getDouble(0)); - break; - case 't': - rdivi(vector.getDouble(0)); - break; + case 'a': { + ScalarAdd op = new ScalarAdd(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 'p': { + ScalarSet op = new ScalarSet(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 's': { + ScalarSubtraction op = new ScalarSubtraction(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 'm': { + ScalarMultiplication op = new ScalarMultiplication(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 'd': { + ScalarDivision op = new ScalarDivision(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 'h': { + ScalarReverseSubtraction op = + new ScalarReverseSubtraction(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + case 't': { + ScalarReverseDivision op = new ScalarReverseDivision(this, rowVector, this, 0.0); + op.setDimension(0); + Nd4j.getExecutioner().exec(op); + break; + } + + } + } else { + applyBroadcastOp(rowVector, operation); + } + } + + return this; + } + + + private void applyBroadcastOp(INDArray vector, final char operation) { + Nd4j.getCompressor().autoDecompress(this); + int alongDimension = Shape.isRowVectorShape(vector.shape()) ? 1 : 0; + + // FIXME: probably this is wrong, because strict equality is always false in current DataBuffer mechanics + if (this.data() == vector.data()) { + vector = vector.dup(); + } + switch (operation) { + case 'a': + Nd4j.getExecutioner().exec(new BroadcastAddOp(this, vector, this, alongDimension)); + return; + case 's': + Nd4j.getExecutioner().exec(new BroadcastSubOp(this, vector, this, alongDimension)); + return; + case 'm': + Nd4j.getExecutioner().exec(new BroadcastMulOp(this, vector, this, alongDimension)); + return; + case 'd': + Nd4j.getExecutioner().exec(new BroadcastDivOp(this, vector, this, alongDimension)); + return; + case 'h': + Nd4j.getExecutioner().exec(new BroadcastRSubOp(this, vector, this, alongDimension)); + return; + case 't': + Nd4j.getExecutioner().exec(new BroadcastRDivOp(this, vector, this, alongDimension)); + return; + case 'p': + Nd4j.getExecutioner().exec(new BroadcastCopyOp(this, vector, this, alongDimension)); + return; + default: + throw new UnsupportedOperationException("Unknown operation: " + operation); + } + } + + private void applyScalarOp(INDArray vector, char operation) { + Nd4j.getCompressor().autoDecompress(this); + switch (operation) { + case 'a': + addi(vector.getDouble(0)); + break; + case 's': + subi(vector.getDouble(0)); + break; + case 'm': + muli(vector.getDouble(0)); + break; + case 'd': + divi(vector.getDouble(0)); + break; + case 'h': + rsubi(vector.getDouble(0)); + break; + case 't': + rdivi(vector.getDouble(0)); + break; + } + } + + protected DataBuffer shapeOf() { + // if (shape == null) + // shape = Shape.shapeOf(shapeInfoDataBuffer()); + // return shape; + + return Shape.shapeOf(shapeInfoDataBuffer()); + } + + protected DataBuffer strideOf() { + // if (stride == null) + // stride = Shape.stride(shapeInfoDataBuffer()); + // return stride; + return Shape.stride(shapeInfoDataBuffer()); + } + + @Override + public int stride(int dimension) { + int rank = jvmShapeInfo.rank; + Preconditions.checkArgument(dimension < rank, + "Cannot get stride for dimension %s from rank %s array: " + + "dimension indices must be in range -rank <= dimension < rank", dimension, rank); + if (dimension < 0) { + return (int) stride()[dimension + rank]; + } + return (int) stride()[dimension]; + } + + @Override + public INDArray rdiviColumnVector(INDArray columnVector) { + validateNumericalArray("rdiviColumnVector", false); + return doColumnWise(columnVector, 't'); + } + + @Override + public INDArray rdivColumnVector(INDArray columnVector) { + validateNumericalArray("rdivColumnVector", false); + return dup().rdiviColumnVector(columnVector); + } + + @Override + public INDArray rdiviRowVector(INDArray rowVector) { + validateNumericalArray("rdiviRowVector", false); + return doRowWise(rowVector, 't'); + } + + @Override + public INDArray rdivRowVector(INDArray rowVector) { + validateNumericalArray("rdivRowVector", false); + return dup().rdiviRowVector(rowVector); + } + + @Override + public INDArray rsubiColumnVector(INDArray columnVector) { + validateNumericalArray("rsubiColumnVector", false); + return doColumnWise(columnVector, 'h'); + } + + @Override + public INDArray rsubColumnVector(INDArray columnVector) { + validateNumericalArray("rsubColumnVector", false); + return dup().rsubiColumnVector(columnVector); + } + + @Override + public INDArray rsubiRowVector(INDArray rowVector) { + validateNumericalArray("rsubiRowVector", false); + return doRowWise(rowVector, 'h'); + } + + @Override + public INDArray rsubRowVector(INDArray rowVector) { + validateNumericalArray("rsubRowVector", false); + return dup().rsubiRowVector(rowVector); + } + + @Override + public INDArray put(int i, INDArray element) { + Preconditions.checkArgument(element.isScalar(), + "Element must be a scalar: element has shape %ndShape", element); + return putScalar(i, element.getDouble(0)); + } + + @Override + public INDArray diviColumnVector(INDArray columnVector) { + validateNumericalArray("diviColumnVector", false); + return doColumnWise(columnVector, 'd'); + } + + @Override + public INDArray divColumnVector(INDArray columnVector) { + validateNumericalArray("divColumnVector", false); + return dup().diviColumnVector(columnVector); + } + + @Override + public INDArray diviRowVector(INDArray rowVector) { + validateNumericalArray("diviRowVector", false); + return doRowWise(rowVector, 'd'); + } + + @Override + public INDArray divRowVector(INDArray rowVector) { + validateNumericalArray("divRowVector", false); + return dup().diviRowVector(rowVector); + } + + @Override + public INDArray muliColumnVector(INDArray columnVector) { + validateNumericalArray("muliColumnVector", false); + return doColumnWise(columnVector, 'm'); + } + + @Override + public INDArray mulColumnVector(INDArray columnVector) { + validateNumericalArray("mulColumnVector", false); + return dup().muliColumnVector(columnVector); + } + + @Override + public INDArray muliRowVector(INDArray rowVector) { + validateNumericalArray("muliRowVector", false); + return doRowWise(rowVector, 'm'); + } + + @Override + public INDArray mulRowVector(INDArray rowVector) { + validateNumericalArray("mulRowVector", false); + return dup().muliRowVector(rowVector); + } + + @Override + public INDArray subiColumnVector(INDArray columnVector) { + validateNumericalArray("subiColumnVector", false); + return doColumnWise(columnVector, 's'); + } + + @Override + public INDArray subColumnVector(INDArray columnVector) { + validateNumericalArray("subColumnVector", false); + return dup().subiColumnVector(columnVector); + } + + @Override + public INDArray subiRowVector(INDArray rowVector) { + validateNumericalArray("subiRowVector", false); + return doRowWise(rowVector, 's'); + } + + @Override + public INDArray subRowVector(INDArray rowVector) { + validateNumericalArray("subRowVector", false); + return dup().subiRowVector(rowVector); + } + + @Override + public INDArray addiColumnVector(INDArray columnVector) { + validateNumericalArray("addiColumnVector", false); + return doColumnWise(columnVector, 'a'); + } + + @Override + public INDArray putiColumnVector(INDArray columnVector) { + return doColumnWise(columnVector, 'p'); + } + + @Override + public INDArray addColumnVector(INDArray columnVector) { + validateNumericalArray("addColumnVector", false); + return dup().addiColumnVector(columnVector); + } + + @Override + public INDArray addiRowVector(INDArray rowVector) { + validateNumericalArray("addiRowVector", false); + return doRowWise(rowVector, 'a'); + } + + @Override + public INDArray putiRowVector(INDArray rowVector) { + validateNumericalArray("putiRowVector", false); + return doRowWise(rowVector, 'p'); + } + + @Override + public INDArray addRowVector(INDArray rowVector) { + validateNumericalArray("addRowVector", false); + return dup().addiRowVector(rowVector); + } + + @Override + public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { + return mMulTranspose.exec(this, other, result); + } + + @Override + public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) { + return mMulTranspose.exec(this, other, null); + } + + @Override + public INDArray mmul(INDArray other, char resultOrder) { + Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', + "Order must be either 'c' or 'f', but [" + resultOrder + "] was given"); + Preconditions.checkState(this.dataType() == other.dataType(), + "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), + other.dataType()); + // FIXME: add support for 3D+ here? + long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()}; + INDArray result = createUninitialized(this.dataType(), shape, resultOrder); + if (result.isScalar()) { + return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1); + } + return mmuli(other, result); + } + + @Override + public INDArray mmul(INDArray other) { + return mmul(other, + (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c'); + } + + protected INDArray create(int[] shape, char ordering) { + return Nd4j.create(shape, ordering); + } + + @Override + public double[][] toDoubleMatrix() { + if (!isMatrix()) { + throw new ND4JIllegalStateException( + "Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort( + this)); + } + + if (this.size(0) > Integer.MAX_VALUE || this.size(1) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + + double[][] ret = new double[rows()][columns()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = getRow(i).dup().data().asDouble(); + } + + return ret; + } + + @Override + public double[] toDoubleVector() { + if (!isVectorOrScalar()) { + throw new ND4JIllegalStateException( + "Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort( + this)); + } + return dup().data().asDouble(); + } + + @Override + public float[] toFloatVector() { + if (!isVectorOrScalar()) { + throw new ND4JIllegalStateException( + "Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort( + this)); + } + return dup().data().asFloat(); + } + + @Override + public float[][] toFloatMatrix() { + if (!isMatrix()) { + throw new ND4JIllegalStateException( + "Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort( + this)); + } + + if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + + float[][] ret = new float[(int) rows()][(int) columns()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = getRow(i).dup().data().asFloat(); + } + + return ret; + } + + @Override + public int[] toIntVector() { + if (isEmpty()) { + return new int[0]; + } + + if (!isVectorOrScalar()) { + throw new ND4JIllegalStateException( + "Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort( + this)); + } + if (isView() || elementWiseStride() != 1) { + return dup().data().asInt(); + } + return data().asInt(); + } + + @Override + public long[] toLongVector() { + if (!isVectorOrScalar()) { + throw new ND4JIllegalStateException( + "Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort( + this)); + } + if (isView() || elementWiseStride() != 1) { + return dup().data().asLong(); + } + return data().asLong(); + } + + @Override + public long[][] toLongMatrix() { + if (!isMatrix()) { + throw new ND4JIllegalStateException( + "Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort( + this)); + } + + if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + + long[][] ret = new long[(int) rows()][(int) columns()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = getRow(i).dup().data().asLong(); + } + + return ret; + } + + @Override + public int[][] toIntMatrix() { + if (!isMatrix()) { + throw new ND4JIllegalStateException( + "Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort( + this)); + } + + if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + + int[][] ret = new int[(int) rows()][(int) columns()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = getRow(i).dup().data().asInt(); + } + + return ret; + } + + /** + * Perform an copy matrix multiplication + * + * @param other the other matrix to perform matrix multiply with + * @param result the result ndarray + * @return the result of the matrix multiplication + */ + @Override + public INDArray mmul(INDArray other, INDArray result) { + return mmuli(other, result); + } + + @Override + public INDArray div(INDArray other) { + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return divi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return divi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + this.shape(), this.ordering())); + } + } + + @Override + public INDArray div(INDArray other, INDArray result) { + validateNumericalArray("div", true); + return divi(other, result); + } + + @Override + public INDArray mul(INDArray other) { + validateNumericalArray("mul", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return muli(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + val z = Nd4j.createUninitialized( + Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), + this.ordering()); + return muli(other, z); + } + } + + @Override + public INDArray mul(INDArray other, INDArray result) { + return muli(other, result); + } + + @Override + public INDArray sub(INDArray other) { + validateNumericalArray("sub", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return subi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return subi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + this.shape(), this.ordering())); + } + } + + @Override + public INDArray sub(INDArray other, INDArray result) { + return subi(other, result); + } + + @Override + public INDArray add(INDArray other) { + validateNumericalArray("add", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return addi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return addi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + this.shape(), this.ordering())); + } + } + + @Override + public INDArray add(INDArray other, INDArray result) { + validateNumericalArray("add", false); + return addi(other, result); + } + + @Override + public INDArray mmuli(INDArray other, MMulTranspose transpose) { + validateNumericalArray("mmuli", false); + return dup().mmuli(other, this, transpose); + } + + @Override + public INDArray mmuli(INDArray other) { + validateNumericalArray("mmuli", false); + return dup().mmuli(other, this); + } + + @Override + public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { + return transpose.exec(this, other, result); + } + + @Override + public INDArray mmuli(INDArray other, INDArray result) { + validateNumericalArray("mmuli", false); + LinAlgExceptions.assertMultiplies(this, other); + if (other.rank() == 1) { + //GEMV edge case + Preconditions.checkState(result.length() == this.size(0) && this.size(1) == other.size(0), + "Invalid matrix multiplication: %ndShape x %ndShape with result shape %ndShape", this, + other, result); + } else { + //Standard case + Preconditions.checkState( + result.rank() == 2 && result.size(0) == this.size(0) && result.size(1) == other.size(1), + "Invalid result array shape: expected shape [%s,%s], got shape %ndShape result array for %ndShape x %ndShape", + this.size(0), other.size(1), result, + this, other); + } + + if (other.isScalar()) { + return muli(other.getDouble(0), result); + } + if (isScalar()) { + return other.muli(getDouble(0), result); + } + + /* check sizes and resize if necessary */ + + if (result == this || result == other) { + /* actually, blas cannot do multiplications in-place. Therefore, we will fake by + * allocating a temporary object on the side and copy the result later. + */ + INDArray temp = Nd4j.create(result.dataType(), result.shape(), + Nd4j.getStrides(result.shape(), 'f'), 'f'); + + if (other.columns() == 1 || other.rank() == 1) { + Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(result), + BlasBufferUtil.getCharForTranspose(this), 1.0, this, other, 0.0, temp); + } else { + Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result), + BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0, + this, other, 0.0, temp); + } + + result.assign(temp); + + + } else { + + //We require that the result array is 'f' (fortran) order + // However, user might have called mmuli with a c order array for the result + // In which case, we need to allocate a temporary f order array, and later do an assign to the real result array + + boolean requiresTemp = + result.ordering() != 'f' || result.isView() || !Shape.hasDefaultStridesForShape(result); + INDArray gemmResultArr; + if (requiresTemp) { + //Can use createUninitialized due to beta==0.0 parameter in gemm + gemmResultArr = Nd4j.createUninitialized(result.dataType(), result.shape(), 'f'); + } else { + gemmResultArr = result; + } + + if (other.columns() == 1 || other.rank() == 1) { + Nd4j.getBlasWrapper().level2().gemv( + ordering(), + BlasBufferUtil.getCharForTranspose(other), + 1.0, + this, + other, + 0.0, + gemmResultArr); + } else { + //gemm doesn't support strides so vectors and views + //don't work + Nd4j.getBlasWrapper().level3().gemm(ordering(), + BlasBufferUtil.getCharForTranspose(other), + BlasBufferUtil.getCharForTranspose(gemmResultArr), + 1.0, + this, + other, + 0.0, + gemmResultArr); + } + + if (requiresTemp) { + result.assign(gemmResultArr); + } + } + + // 1D edge case: reshape back to vector + if (other.rank() == 1) { + result = result.reshape(result.length()); + } + return result; + } + + private INDArray create(int[] shape, int[] stride) { + return Nd4j.create(shape, stride); + } + + @Override + public INDArray divi(INDArray other) { + return divi(other, this); + } + + @Override + public INDArray divi(INDArray other, INDArray result) { + validateNumericalArray("divi", false); + Shape.assertBroadcastable("divi", this, other, result); + Nd4j.exec(new DivOp(this, other, result)); + return result; + } + + @Override + public INDArray muli(INDArray other) { + return muli(other, this); + } + + @Override + public INDArray muli(INDArray other, INDArray result) { + validateNumericalArray("muli", false); + Shape.assertBroadcastable("muli", this, other, result); + Nd4j.exec(new MulOp(this, other, result)); + return result; + } + + @Override + public INDArray subi(INDArray other) { + return subi(other, this); + } + + /** + * in place subtraction of two matrices + * + * @param other the second ndarray to subtract + * @param result the result ndarray + * @return the result of the subtraction + */ + @Override + public INDArray subi(INDArray other, INDArray result) { + validateNumericalArray("subi", false); + Shape.assertBroadcastable("subi", this, other, result); + Nd4j.exec(new SubOp(this, other, result)); + return result; + } + + @Override + public INDArray addi(INDArray other) { + return addi(other, this); + } + + @Override + public INDArray addi(INDArray other, INDArray result) { + validateNumericalArray("addi", false); + Shape.assertBroadcastable("addi", this, other, result); + Nd4j.exec(new AddOp(this, other, result)); + return result; + } + + @Override + public INDArray normmax(boolean keepDims, int... dimension) { + validateNumericalArray("normmax", false); + return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); + } + + @Override + public INDArray normmax(int... dimension) { + return normmax(false, dimension); + } + + @Override + public INDArray rdiv(INDArray other) { + validateNumericalArray("rdiv", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return rdivi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return rdivi(other, this.ulike()); + } + } + + @Override + public INDArray rdivi(INDArray other) { + return rdivi(other, this); + } + + @Override + public INDArray rdiv(INDArray other, INDArray result) { + validateNumericalArray("rdiv", false); + return dup().rdivi(other, result); + } + + @Override + public INDArray rdivi(INDArray other, INDArray result) { + validateNumericalArray("rdivi", false); + Shape.assertBroadcastable("rdivi", this, other, result); + Nd4j.exec(new RDivOp(this, other, result)); + return result; + } + + @Override + public INDArray rsub(INDArray other, INDArray result) { + validateNumericalArray("rsub", false); + return rsubi(other, result); + } + + @Override + public INDArray rsub(INDArray other) { + validateNumericalArray("rsub", false); + if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { + return rsubi(other, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), + Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + } else { + return rsubi(other, this.ulike()); + } + } + + @Override + public INDArray rsubi(INDArray other) { + return rsubi(other, this); + } + + @Override + public INDArray rsubi(INDArray other, INDArray result) { + validateNumericalArray("rsubi", false); + Shape.assertBroadcastable("rsubi", this, other, result); + Nd4j.exec(new RSubOp(this, other, result)); + return result; + } + + @Override + public INDArray assign(Number value) { + Preconditions.checkState( + dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, + "Only values 0 or 1 are allowed for scalar " + + "assign on boolean arrays: got value %s on to assign to boolean array with shape %ndShape", + value, this); + Nd4j.getExecutioner().exec(new ScalarSet(this, value)); + return this; + } + + @Override + public INDArray assign(boolean value) { + return assign(value ? 1 : 0); + } + + @Override + public INDArray assignIf(INDArray arr, Condition condition) { + BooleanIndexing.assignIf(this, arr, condition); + return this; + } + + @Override + public INDArray replaceWhere(INDArray arr, Condition condition) { + Nd4j.getCompressor().autoDecompress(this); + BooleanIndexing.replaceWhere(this, arr, condition); + return this; + } + + @Override + @Deprecated //TODO: Investigate. Not deprecated in the base interface. + public long linearIndex(long i) { + long idx = i; + for (int j = 0; j < jvmShapeInfo.rank - 1; j++) { + if (size((int) i) == 1) { + continue; + } + idx += i * stride(j); + } + return Shape.offset(jvmShapeInfo.javaShapeInformation) + (idx); + } + + @Override + public INDArray slice(long slice) { + Nd4j.getCompressor().autoDecompress(this); + + long slices = slices(); + if (slice >= slices) { + throw new IllegalArgumentException("Illegal slice " + slice); + } + + if (jvmShapeInfo.rank == 0) { + throw new IllegalArgumentException("Can't slice a 0-d NDArray"); + } + + if (slice < 0) { + slice += rank(); + } + INDArrayIndex[] indexes = new INDArrayIndex[rank()]; + indexes[0] = NDArrayIndex.point(slice); + for (int i = 1; i < rank(); i++) { + indexes[i] = NDArrayIndex.all(); + } + return get(indexes); + } + + + protected INDArray createScalarForIndex(long i, boolean applyOffset) { + if (isVector()) { + return getScalar(i); + } + return Nd4j.create(data(), new long[]{1, 1}, new long[]{1, 1}, i); + } + + protected INDArray createScalar(double d) { + return Nd4j.scalar(d); + } + + @Override + public int getTrailingOnes() { + int numLeadingOnes = 0; + for (int i = rank() - 1; i > 0; i--) { + if (size(i) == 1) { + numLeadingOnes++; } } - protected DataBuffer shapeOf() { - // if (shape == null) - // shape = Shape.shapeOf(shapeInfoDataBuffer()); - // return shape; + return numLeadingOnes; + } - return Shape.shapeOf(shapeInfoDataBuffer()); - } - - protected DataBuffer strideOf() { - // if (stride == null) - // stride = Shape.stride(shapeInfoDataBuffer()); - // return stride; - return Shape.stride(shapeInfoDataBuffer()); - } - - @Override - public int stride(int dimension) { - int rank = jvmShapeInfo.rank; - Preconditions.checkArgument(dimension < rank, "Cannot get stride for dimension %s from rank %s array: " + - "dimension indices must be in range -rank <= dimension < rank", dimension, rank); - if (dimension < 0) - return (int) stride()[dimension + rank]; - return (int) stride()[dimension]; - } - - @Override - public INDArray rdiviColumnVector(INDArray columnVector) { - validateNumericalArray("rdiviColumnVector", false); - return doColumnWise(columnVector, 't'); - } - - @Override - public INDArray rdivColumnVector(INDArray columnVector) { - validateNumericalArray("rdivColumnVector", false); - return dup().rdiviColumnVector(columnVector); - } - - @Override - public INDArray rdiviRowVector(INDArray rowVector) { - validateNumericalArray("rdiviRowVector", false); - return doRowWise(rowVector, 't'); - } - - @Override - public INDArray rdivRowVector(INDArray rowVector) { - validateNumericalArray("rdivRowVector", false); - return dup().rdiviRowVector(rowVector); - } - - @Override - public INDArray rsubiColumnVector(INDArray columnVector) { - validateNumericalArray("rsubiColumnVector", false); - return doColumnWise(columnVector, 'h'); - } - - @Override - public INDArray rsubColumnVector(INDArray columnVector) { - validateNumericalArray("rsubColumnVector", false); - return dup().rsubiColumnVector(columnVector); - } - - @Override - public INDArray rsubiRowVector(INDArray rowVector) { - validateNumericalArray("rsubiRowVector", false); - return doRowWise(rowVector, 'h'); - } - - @Override - public INDArray rsubRowVector(INDArray rowVector) { - validateNumericalArray("rsubRowVector", false); - return dup().rsubiRowVector(rowVector); - } - - @Override - public INDArray put(int i, INDArray element) { - Preconditions.checkArgument(element.isScalar(), "Element must be a scalar: element has shape %ndShape", element); - return putScalar(i, element.getDouble(0)); - } - - @Override - public INDArray diviColumnVector(INDArray columnVector) { - validateNumericalArray("diviColumnVector", false); - return doColumnWise(columnVector, 'd'); - } - - @Override - public INDArray divColumnVector(INDArray columnVector) { - validateNumericalArray("divColumnVector", false); - return dup().diviColumnVector(columnVector); - } - - @Override - public INDArray diviRowVector(INDArray rowVector) { - validateNumericalArray("diviRowVector", false); - return doRowWise(rowVector, 'd'); - } - - @Override - public INDArray divRowVector(INDArray rowVector) { - validateNumericalArray("divRowVector", false); - return dup().diviRowVector(rowVector); - } - - @Override - public INDArray muliColumnVector(INDArray columnVector) { - validateNumericalArray("muliColumnVector", false); - return doColumnWise(columnVector, 'm'); - } - - @Override - public INDArray mulColumnVector(INDArray columnVector) { - validateNumericalArray("mulColumnVector", false); - return dup().muliColumnVector(columnVector); - } - - @Override - public INDArray muliRowVector(INDArray rowVector) { - validateNumericalArray("muliRowVector", false); - return doRowWise(rowVector, 'm'); - } - - @Override - public INDArray mulRowVector(INDArray rowVector) { - validateNumericalArray("mulRowVector", false); - return dup().muliRowVector(rowVector); - } - - @Override - public INDArray subiColumnVector(INDArray columnVector) { - validateNumericalArray("subiColumnVector", false); - return doColumnWise(columnVector, 's'); - } - - @Override - public INDArray subColumnVector(INDArray columnVector) { - validateNumericalArray("subColumnVector", false); - return dup().subiColumnVector(columnVector); - } - - @Override - public INDArray subiRowVector(INDArray rowVector) { - validateNumericalArray("subiRowVector", false); - return doRowWise(rowVector, 's'); - } - - @Override - public INDArray subRowVector(INDArray rowVector) { - validateNumericalArray("subRowVector", false); - return dup().subiRowVector(rowVector); - } - - @Override - public INDArray addiColumnVector(INDArray columnVector) { - validateNumericalArray("addiColumnVector", false); - return doColumnWise(columnVector, 'a'); - } - - @Override - public INDArray putiColumnVector(INDArray columnVector) { - return doColumnWise(columnVector, 'p'); - } - - @Override - public INDArray addColumnVector(INDArray columnVector) { - validateNumericalArray("addColumnVector", false); - return dup().addiColumnVector(columnVector); - } - - @Override - public INDArray addiRowVector(INDArray rowVector) { - validateNumericalArray("addiRowVector", false); - return doRowWise(rowVector, 'a'); - } - - @Override - public INDArray putiRowVector(INDArray rowVector) { - validateNumericalArray("putiRowVector", false); - return doRowWise(rowVector, 'p'); - } - - @Override - public INDArray addRowVector(INDArray rowVector) { - validateNumericalArray("addRowVector", false); - return dup().addiRowVector(rowVector); - } - - @Override - public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { - return mMulTranspose.exec(this, other, result); - } - - @Override - public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) { - return mMulTranspose.exec(this, other, null); - } - - @Override - public INDArray mmul(INDArray other, char resultOrder) { - Preconditions.checkArgument(resultOrder == 'c' || resultOrder == 'f', "Order must be either 'c' or 'f', but [" + resultOrder + "] was given"); - Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType()); - // FIXME: add support for 3D+ here? - long[] shape = other.rank() == 1 ? new long[]{rows()} : new long[]{rows(), other.columns()}; - INDArray result = createUninitialized(this.dataType(), shape, resultOrder); - if (result.isScalar()) - return Nd4j.scalar(this.dataType(), Nd4j.getBlasWrapper().dot(this, other)).reshape(1, 1); - return mmuli(other, result); - } - - @Override - public INDArray mmul(INDArray other) { - return mmul(other, (this.ordering() == 'f' && other.ordering() == 'f' && other.rank() != 1) ? 'f' : 'c'); - } - - protected INDArray create(int[] shape, char ordering) { - return Nd4j.create(shape, ordering); - } - - @Override - public double[][] toDoubleMatrix() { - if(!isMatrix()) { - throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); + @Override + public int getLeadingOnes() { + int numLeadingOnes = 0; + for (int i = 0; i < rank(); i++) { + if (size(i) == 1) { + numLeadingOnes++; } - - if (this.size(0) > Integer.MAX_VALUE || this.size(1) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - double[][] ret = new double[rows()][columns()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = getRow(i).dup().data().asDouble(); - } - - return ret; } - @Override - public double[] toDoubleVector() { - if(!isVectorOrScalar()) { - throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); - } - return dup().data().asDouble(); - } + return numLeadingOnes; + } - @Override - public float[] toFloatVector() { - if(!isVectorOrScalar()) { - throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); - } - return dup().data().asFloat(); - } + @Override + public INDArray slice(long slice, int dimension) { + Nd4j.getCompressor().autoDecompress(this); - @Override - public float[][] toFloatMatrix() { - if(!isMatrix()) { - throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); - } + long slices = size(dimension); + if (slice >= slices) { + throw new IllegalArgumentException("Illegal slice " + slice); + } - if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - float[][] ret = new float[(int) rows()][ (int) columns()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = getRow(i).dup().data().asFloat(); - } - - return ret; - } - - @Override - public int[] toIntVector() { - if (isEmpty()) - return new int[0]; - - if(!isVectorOrScalar()) { - throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); - } - if(isView() || elementWiseStride() != 1){ - return dup().data().asInt(); - } - return data().asInt(); - } - - @Override - public long[] toLongVector() { - if(!isVectorOrScalar()) { - throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); - } - if(isView() || elementWiseStride() != 1){ - return dup().data().asLong(); - } - return data().asLong(); - } - - @Override - public long[][] toLongMatrix() { - if(!isMatrix()) { - throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); - } - - if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - long[][] ret = new long[(int) rows()][(int) columns()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = getRow(i).dup().data().asLong(); - } - - return ret; - } - - @Override - public int[][] toIntMatrix() { - if(!isMatrix()) { - throw new ND4JIllegalStateException("Unable to create a 2d array from a non matrix! Shape: " + Shape.shapeToStringShort(this)); - } - - if (this.rows() > Integer.MAX_VALUE || this.columns() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - int[][] ret = new int[(int) rows()][(int) columns()]; - for(int i = 0; i < ret.length; i++) { - ret[i] = getRow(i).dup().data().asInt(); - } - - return ret; - } - - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @return the result of the matrix multiplication - */ - @Override - public INDArray mmul(INDArray other, INDArray result) { - return mmuli(other, result); - } - - @Override - public INDArray div(INDArray other) { - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); + if (jvmShapeInfo.rank == 0) { + if (slice == 0) { + return createScalarForIndex(slice, true); } else { - return divi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); - } - } - - @Override - public INDArray div(INDArray other, INDArray result) { - validateNumericalArray("div", true); - return divi(other, result); - } - - @Override - public INDArray mul(INDArray other) { - validateNumericalArray("mul", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return muli(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - val z = Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering()); - return muli(other, z); - } - } - - @Override - public INDArray mul(INDArray other, INDArray result) { - return muli(other, result); - } - - @Override - public INDArray sub(INDArray other) { - validateNumericalArray("sub", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - return subi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); - } - } - - @Override - public INDArray sub(INDArray other, INDArray result) { - return subi(other, result); - } - - @Override - public INDArray add(INDArray other) { - validateNumericalArray("add", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - return addi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), this.shape(), this.ordering())); - } - } - - @Override - public INDArray add(INDArray other, INDArray result) { - validateNumericalArray("add", false); - return addi(other, result); - } - - @Override - public INDArray mmuli(INDArray other, MMulTranspose transpose) { - validateNumericalArray("mmuli", false); - return dup().mmuli(other, this,transpose); - } - - @Override - public INDArray mmuli(INDArray other) { - validateNumericalArray("mmuli", false); - return dup().mmuli(other, this); - } - - @Override - public INDArray mmuli(INDArray other, INDArray result, MMulTranspose transpose) { - return transpose.exec(this, other, result); - } - - @Override - public INDArray mmuli(INDArray other, INDArray result) { - validateNumericalArray("mmuli", false); - LinAlgExceptions.assertMultiplies(this, other); - if(other.rank() == 1){ - //GEMV edge case - Preconditions.checkState(result.length() == this.size(0) && this.size(1) == other.size(0), - "Invalid matrix multiplication: %ndShape x %ndShape with result shape %ndShape", this, other, result); - } else { - //Standard case - Preconditions.checkState( - result.rank() == 2 && result.size(0) == this.size(0) && result.size(1) == other.size(1), - "Invalid result array shape: expected shape [%s,%s], got shape %ndShape result array for %ndShape x %ndShape", this.size(0), other.size(1), result, - this, other); - } - - if (other.isScalar()) { - return muli(other.getDouble(0), result); - } - if (isScalar()) { - return other.muli(getDouble(0), result); - } - - /* check sizes and resize if necessary */ - - - if (result == this || result == other) { - /* actually, blas cannot do multiplications in-place. Therefore, we will fake by - * allocating a temporary object on the side and copy the result later. - */ - INDArray temp = Nd4j.create(result.dataType(), result.shape(), Nd4j.getStrides(result.shape(), 'f'), 'f'); - - if (other.columns() == 1 || other.rank() == 1) { - Nd4j.getBlasWrapper().level2().gemv(BlasBufferUtil.getCharForTranspose(result), - BlasBufferUtil.getCharForTranspose(this), 1.0, this, other, 0.0, temp); - } - - else { - Nd4j.getBlasWrapper().level3().gemm(BlasBufferUtil.getCharForTranspose(result), - BlasBufferUtil.getCharForTranspose(this), BlasBufferUtil.getCharForTranspose(temp), 1.0, - this, other, 0.0, temp); - } - - result.assign(temp); - - - } else { - - //We require that the result array is 'f' (fortran) order - // However, user might have called mmuli with a c order array for the result - // In which case, we need to allocate a temporary f order array, and later do an assign to the real result array - - boolean requiresTemp = result.ordering() != 'f' || result.isView() || !Shape.hasDefaultStridesForShape(result); - INDArray gemmResultArr; - if (requiresTemp) { - //Can use createUninitialized due to beta==0.0 parameter in gemm - gemmResultArr = Nd4j.createUninitialized(result.dataType(), result.shape(), 'f'); - } else { - gemmResultArr = result; - } - - if (other.columns() == 1 || other.rank() == 1) { - Nd4j.getBlasWrapper().level2().gemv( - ordering(), - BlasBufferUtil.getCharForTranspose(other), - 1.0, - this, - other, - 0.0, - gemmResultArr); - } else { - //gemm doesn't support strides so vectors and views - //don't work - Nd4j.getBlasWrapper().level3().gemm(ordering(), - BlasBufferUtil.getCharForTranspose(other), - BlasBufferUtil.getCharForTranspose(gemmResultArr), - 1.0, - this, - other, - 0.0, - gemmResultArr); - } - - if (requiresTemp) { - result.assign(gemmResultArr); - } - } - - // 1D edge case: reshape back to vector - if (other.rank() == 1) - result = result.reshape(result.length()); - return result; - } - - private INDArray create(int[] shape, int[] stride) { - return Nd4j.create(shape, stride); - } - - @Override - public INDArray divi(INDArray other) { - return divi(other, this); - } - - @Override - public INDArray divi(INDArray other, INDArray result) { - validateNumericalArray("divi", false); - Shape.assertBroadcastable("divi", this, other, result); - Nd4j.exec(new DivOp(this, other, result)); - return result; - } - - @Override - public INDArray muli(INDArray other) { - return muli(other, this); - } - - @Override - public INDArray muli(INDArray other, INDArray result) { - validateNumericalArray("muli", false); - Shape.assertBroadcastable("muli", this, other, result); - Nd4j.exec(new MulOp(this, other, result)); - return result; - } - - @Override - public INDArray subi(INDArray other) { - return subi(other, this); - } - - /** - * in place subtraction of two matrices - * - * @param other the second ndarray to subtract - * @param result the result ndarray - * @return the result of the subtraction - */ - @Override - public INDArray subi(INDArray other, INDArray result) { - validateNumericalArray("subi", false); - Shape.assertBroadcastable("subi", this, other, result); - Nd4j.exec(new SubOp(this, other, result)); - return result; - } - - @Override - public INDArray addi(INDArray other) { - return addi(other, this); - } - - @Override - public INDArray addi(INDArray other, INDArray result) { - validateNumericalArray("addi", false); - Shape.assertBroadcastable("addi", this, other, result); - Nd4j.exec(new AddOp(this, other, result)); - return result; - } - - @Override - public INDArray normmax(boolean keepDims, int... dimension) { - validateNumericalArray("normmax", false); - return Nd4j.getExecutioner().exec(new NormMax(this, keepDims, dimension)); - } - - @Override - public INDArray normmax(int... dimension) { - return normmax(false, dimension); - } - - @Override - public INDArray rdiv(INDArray other) { - validateNumericalArray("rdiv", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return rdivi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - return rdivi(other, this.ulike()); - } - } - - @Override - public INDArray rdivi(INDArray other) { - return rdivi(other, this); - } - - @Override - public INDArray rdiv(INDArray other, INDArray result) { - validateNumericalArray("rdiv", false); - return dup().rdivi(other, result); - } - - @Override - public INDArray rdivi(INDArray other, INDArray result) { - validateNumericalArray("rdivi", false); - Shape.assertBroadcastable("rdivi", this, other, result); - Nd4j.exec(new RDivOp(this, other, result)); - return result; - } - - @Override - public INDArray rsub(INDArray other, INDArray result) { - validateNumericalArray("rsub", false); - return rsubi(other, result); - } - - @Override - public INDArray rsub(INDArray other) { - validateNumericalArray("rsub", false); - if (Shape.areShapesBroadcastable(this.shape(), other.shape())) { - return rsubi(other, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), other.dataType()), Shape.broadcastOutputShape(this.shape(), other.shape()), this.ordering())); - } else { - return rsubi(other, this.ulike()); - } - } - - @Override - public INDArray rsubi(INDArray other) { - return rsubi(other, this); - } - - @Override - public INDArray rsubi(INDArray other, INDArray result) { - validateNumericalArray("rsubi", false); - Shape.assertBroadcastable("rsubi", this, other, result); - Nd4j.exec(new RSubOp(this, other, result)); - return result; - } - - @Override - public INDArray assign(Number value) { - Preconditions.checkState(dataType() != DataType.BOOL || value.doubleValue() == 0.0 || value.doubleValue() == 1.0, "Only values 0 or 1 are allowed for scalar " + - "assign on boolean arrays: got value %s on to assign to boolean array with shape %ndShape", value, this); - Nd4j.getExecutioner().exec(new ScalarSet(this, value)); - return this; - } - - @Override - public INDArray assign(boolean value) { - return assign(value ? 1 : 0); - } - - @Override - public INDArray assignIf(INDArray arr, Condition condition) { - BooleanIndexing.assignIf(this, arr, condition); - return this; - } - - @Override - public INDArray replaceWhere(INDArray arr, Condition condition) { - Nd4j.getCompressor().autoDecompress(this); - BooleanIndexing.replaceWhere(this, arr, condition); - return this; - } - - @Override - @Deprecated //TODO: Investigate. Not deprecated in the base interface. - public long linearIndex(long i) { - long idx = i; - for (int j = 0; j < jvmShapeInfo.rank - 1; j++) { - if (size((int) i) == 1) - continue; - idx += i * stride(j); - } - return Shape.offset(jvmShapeInfo.javaShapeInformation) + (idx); - } - - @Override - public INDArray slice(long slice) { - Nd4j.getCompressor().autoDecompress(this); - - - long slices = slices(); - if (slice >= slices) - throw new IllegalArgumentException("Illegal slice " + slice); - - if (jvmShapeInfo.rank == 0 ) { throw new IllegalArgumentException("Can't slice a 0-d NDArray"); } + } - if (slice < 0) - slice += rank(); - INDArrayIndex[] indexes = new INDArrayIndex[rank()]; - indexes[0] = NDArrayIndex.point(slice); - for (int i = 1; i < rank(); i++) { + if (slice < 0) { + slice += rank(); + } + INDArrayIndex[] indexes = new INDArrayIndex[rank()]; + indexes[dimension] = NDArrayIndex.point(slice); + for (int i = 0; i < rank(); i++) { + if (i != dimension) { indexes[i] = NDArrayIndex.all(); } - return get(indexes); } + return get(indexes); + } + @Override + public INDArray getScalar(int[] indexes) { + if (indexes.length > rank()) { + throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); + } - protected INDArray createScalarForIndex(long i, boolean applyOffset) { - if(isVector()) - return getScalar(i); - return Nd4j.create(data(), new long[] {1, 1}, new long[] {1, 1}, i); - } - - protected INDArray createScalar(double d) { - return Nd4j.scalar(d); - } - - @Override - public int getTrailingOnes() { - int numLeadingOnes = 0; - for (int i = rank() - 1; i > 0; i--) { - if (size(i) == 1) - numLeadingOnes++; + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] < 0) { + indexes[i] += this.size(i); } - - return numLeadingOnes; } + long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); + val buffer = Nd4j.createBuffer(this.data(), idx, 1); + val shape = Nd4j.getShapeInfoProvider() + .createShapeInformation(new long[0], new long[0], 1, 'c', this.dataType(), false); + return Nd4j.createArrayFromShapeBuffer(buffer, shape); + } - @Override - public int getLeadingOnes() { - int numLeadingOnes = 0; - for (int i = 0; i < rank(); i++) { - if (size(i) == 1) - numLeadingOnes++; - } + @Override + public INDArray getScalar(long... indexes) { + if (indexes.length > rank()) { + throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); + } - return numLeadingOnes; - } - - @Override - public INDArray slice(long slice, int dimension) { - Nd4j.getCompressor().autoDecompress(this); - - long slices = size(dimension); - if (slice >= slices) - throw new IllegalArgumentException("Illegal slice " + slice); - - if (jvmShapeInfo.rank == 0) { - if (slice == 0) - return createScalarForIndex(slice, true); - else - throw new IllegalArgumentException("Can't slice a 0-d NDArray"); - - } - - - if (slice < 0) - slice += rank(); - INDArrayIndex[] indexes = new INDArrayIndex[rank()]; - indexes[dimension] = NDArrayIndex.point(slice); - for (int i = 0; i < rank(); i++) { - if (i != dimension) - indexes[i] = NDArrayIndex.all(); - } - return get(indexes); - - } - - @Override - public INDArray getScalar(int[] indexes) { - if (indexes.length > rank()) - throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); - - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] < 0) - indexes[i] += this.size(i); - } - long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); - val buffer = Nd4j.createBuffer(this.data(), idx, 1); - val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1, 'c', this.dataType(), false); - return Nd4j.createArrayFromShapeBuffer(buffer, shape); - } - - @Override - public INDArray getScalar(long... indexes) { - if (indexes.length > rank()) - throw new ND4JIllegalStateException("Indexes can't be longer then array rank"); - - for (int i = 0; i < indexes.length; i++) { - if (indexes[i] < 0) - indexes[i] += this.size(i); - } - - long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); - val buffer = Nd4j.createBuffer(this.data(), idx, 1); - val shape = Nd4j.getShapeInfoProvider().createShapeInformation(new long[0], new long[0],1,'c', this.dataType(), false); - return Nd4j.createArrayFromShapeBuffer(buffer, shape); - } - - @Override - public INDArray rdiv(Number n) { - //return dup().rdivi(n); - return rdivi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), this.ordering())); - } - - @Override - public INDArray rdivi(Number n) { - return rdivi(n, this); - } - - @Override - public INDArray rsub(Number n) { - validateNumericalArray("rsub", false); - return rsubi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); - } - - @Override - public INDArray rsubi(Number n) { - return rsubi(n, this); - } - - @Override - public INDArray div(Number n) { - validateNumericalArray("div", false); - return divi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); - } - - @Override - public INDArray divi(Number n) { - return divi(n, this); - } - - @Override - public INDArray mul(Number n) { - validateNumericalArray("mul", false); - return muli(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), this.ordering())); - } - - @Override - public INDArray muli(Number n) { - return muli(n, this); - } - - @Override - public INDArray sub(Number n) { - validateNumericalArray("sub", false); - return subi(n, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering())); - } - - @Override - public INDArray subi(Number n) { - return subi(n, this); - } - - @Override - public INDArray add(Number n) { - validateNumericalArray("add", false); - return addi(n, Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n),this.shape(), this.ordering())); - } - - @Override - public INDArray addi(Number n) { - return addi(n, this); - } - - @Override - public INDArray repmat(long[] shape) { - Nd4j.getCompressor().autoDecompress(this); - long rows = rows() * shape[0]; - long cols = columns() * shape[1]; - INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()).repeat(0, shape[1]); - return ret.reshape(rows, cols); - } - - @Deprecated - @Override - public INDArray repmat(int[] shape) { - long[] longShape = ArrayUtil.toLongArray(shape); - return repmat(longShape); - } - - @Override - public INDArray repeat(int dimension, long... repeats) { - Nd4j.getCompressor().autoDecompress(this); - CustomOp op = DynamicCustomOp.builder("repeat") - .addInputs(this) - .addIntegerArguments(ArrayUtil.toInts(repeats)) //TODO int cast - .build(); - op.addIArgument(dimension); //Native op: last iarg is dimension - - LongShapeDescriptor l = op.calculateOutputShape().get(0); - INDArray out = Nd4j.create(l); - op.addOutputArgument(out); - Nd4j.exec(op); - return out; - } - - @Override - public INDArray putRow(long row, INDArray toPut) { - if (isRowVector() && toPut.isVector()) { - return assign(toPut); - } - return put(new INDArrayIndex[] {NDArrayIndex.point(row), NDArrayIndex.all()}, toPut); - } - - @Override - public INDArray putColumn(int column, INDArray toPut) { - Nd4j.getCompressor().autoDecompress(this); - - if (isColumnVector() && toPut.isVector()) { - return assign(toPut); - } - return put(new INDArrayIndex[] {NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut); - } - - @Override - public Number getNumber(long i){ - switch (dataType()){ - case DOUBLE: - case FLOAT: - case HALF: - case BFLOAT16: - return getDouble(i); - case LONG: - case INT: - case SHORT: - case UBYTE: - case BYTE: - case BOOL: - case UINT64: - case UINT32: - case UINT16: - return getLong(i); - case UTF8: - case COMPRESSED: - case UNKNOWN: - default: - throw new UnsupportedOperationException("Cannot get number from array of datatype: " + dataType()); + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] < 0) { + indexes[i] += this.size(i); } } - @Override - public Number getNumber(long... idx){ - switch (dataType()){ - case DOUBLE: - case FLOAT: - case HALF: - return getDouble(idx); - case LONG: - case INT: - case SHORT: - case UBYTE: - case BYTE: - case BOOL: - return getLong(idx); - case UTF8: - case COMPRESSED: - case UNKNOWN: - default: - throw new UnsupportedOperationException("Cannot get number from array of datatype: " + dataType()); - } + long idx = Shape.getOffset(jvmShapeInfo.javaShapeInformation, indexes); + val buffer = Nd4j.createBuffer(this.data(), idx, 1); + val shape = Nd4j.getShapeInfoProvider() + .createShapeInformation(new long[0], new long[0], 1, 'c', this.dataType(), false); + return Nd4j.createArrayFromShapeBuffer(buffer, shape); + } + + @Override + public INDArray rdiv(Number n) { + //return dup().rdivi(n); + return rdivi(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray rdivi(Number n) { + return rdivi(n, this); + } + + @Override + public INDArray rsub(Number n) { + validateNumericalArray("rsub", false); + return rsubi(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray rsubi(Number n) { + return rsubi(n, this); + } + + @Override + public INDArray div(Number n) { + validateNumericalArray("div", false); + return divi(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray divi(Number n) { + return divi(n, this); + } + + @Override + public INDArray mul(Number n) { + validateNumericalArray("mul", false); + return muli(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray muli(Number n) { + return muli(n, this); + } + + @Override + public INDArray sub(Number n) { + validateNumericalArray("sub", false); + return subi(n, Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering())); + } + + @Override + public INDArray subi(Number n) { + return subi(n, this); + } + + @Override + public INDArray add(Number n) { + validateNumericalArray("add", false); + return addi(n, + Nd4j.createUninitialized(Shape.pickPairwiseDataType(this.dataType(), n), this.shape(), + this.ordering())); + } + + @Override + public INDArray addi(Number n) { + return addi(n, this); + } + + @Override + public INDArray repmat(long[] shape) { + Nd4j.getCompressor().autoDecompress(this); + long rows = rows() * shape[0]; + long cols = columns() * shape[1]; + INDArray ret = reshape(1, length()).repeat(0, shape[0]).reshape(rows, columns()) + .repeat(0, shape[1]); + return ret.reshape(rows, cols); + } + + @Deprecated + @Override + public INDArray repmat(int[] shape) { + long[] longShape = ArrayUtil.toLongArray(shape); + return repmat(longShape); + } + + @Override + public INDArray repeat(int dimension, long... repeats) { + Nd4j.getCompressor().autoDecompress(this); + CustomOp op = DynamicCustomOp.builder("repeat") + .addInputs(this) + .addIntegerArguments(ArrayUtil.toInts(repeats)) //TODO int cast + .build(); + op.addIArgument(dimension); //Native op: last iarg is dimension + + LongShapeDescriptor l = op.calculateOutputShape().get(0); + INDArray out = Nd4j.create(l); + op.addOutputArgument(out); + Nd4j.exec(op); + return out; + } + + @Override + public INDArray putRow(long row, INDArray toPut) { + if (isRowVector() && toPut.isVector()) { + return assign(toPut); + } + return put(new INDArrayIndex[]{NDArrayIndex.point(row), NDArrayIndex.all()}, toPut); + } + + @Override + public INDArray putColumn(int column, INDArray toPut) { + Nd4j.getCompressor().autoDecompress(this); + + if (isColumnVector() && toPut.isVector()) { + return assign(toPut); + } + return put(new INDArrayIndex[]{NDArrayIndex.all(), NDArrayIndex.point(column)}, toPut); + } + + @Override + public Number getNumber(long i) { + switch (dataType()) { + case DOUBLE: + case FLOAT: + case HALF: + case BFLOAT16: + return getDouble(i); + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case BOOL: + case UINT64: + case UINT32: + case UINT16: + return getLong(i); + case UTF8: + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException( + "Cannot get number from array of datatype: " + dataType()); + } + } + + @Override + public Number getNumber(long... idx) { + switch (dataType()) { + case DOUBLE: + case FLOAT: + case HALF: + return getDouble(idx); + case LONG: + case INT: + case SHORT: + case UBYTE: + case BYTE: + case BOOL: + return getLong(idx); + case UTF8: + case COMPRESSED: + case UNKNOWN: + default: + throw new UnsupportedOperationException( + "Cannot get number from array of datatype: " + dataType()); + } + } + + @Override + public double getDouble(long i) { + Nd4j.getCompressor().autoDecompress(this); + Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + + if (i >= length()) { + throw new IllegalArgumentException( + "Unable to get linear index " + i + ": values is greater than length (" + length() + ")"); } - @Override - public double getDouble(long i) { - Nd4j.getCompressor().autoDecompress(this); - Preconditions.checkState(!isEmpty(), "Unable to get value from empty array"); + autoProcessScalarCall(); - if (i >= length()) { - throw new IllegalArgumentException("Unable to get linear index " + i + ": values is greater than length (" + length() + ")"); - } + if (i == 0) { + return data().getDouble(i); + } - autoProcessScalarCall(); + long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); + Shape.assertShapeLessThan(dimensions, shape()); + return getDouble(dimensions); - if (i == 0) - return data().getDouble(i); + } - long[] dimensions = ordering() == 'c' ? Shape.ind2subC(this, i) : Shape.ind2sub(this, i); - Shape.assertShapeLessThan(dimensions, shape()); - return getDouble(dimensions); + @Override + public double getDouble(long i, long j) { + return getDouble(new long[]{i, j}); + } + @Override + public float getFloat(long i) { + return (float) getDouble(i); + } + + @Override + public float getFloat(long i, long j) { + return (float) getDouble(i, j); + } + + @Override + public INDArray transpose() { + Preconditions.checkState(rank() >= 2, + "Can't transpose array with rank < 2: array shape %ndShape", this); + + return permute(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); + } + + /** + * Return transposed version of this matrix. + *

      + * PLEASE NOTE: This method is NOT in place, it will return transposed copy instead. + */ + @Override + public INDArray transposei() { + Preconditions.checkState(rank() >= 2, + "Can't transpose array with rank < 2: array shape %ndShape", this); + + return permutei(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); + } + + protected INDArray create(DataBuffer data, int[] shape, int[] strides) { + return Nd4j.create(data, shape, strides, 0, ordering()); + } + + @Deprecated + @Override + public INDArray reshape(char order, int... newShape) { + return reshape(order, ArrayUtil.toLongArray(newShape)); + } + + @Override + public INDArray reshape(char order, long... newShape) { + return reshape(order, false, newShape); + } + + @Override + public INDArray reshape(char order, boolean enforceView, long... newShape) { + Nd4j.getCompressor().autoDecompress(this); + + // special case for empty reshape + if (this.length() == 1 && (newShape == null || newShape.length == 0) + && this.elementWiseStride() == 1) { + return Nd4j.create(this.data(), new int[0], new int[0], 0); } - @Override - public double getDouble(long i, long j) { - return getDouble(new long[] {i, j}); - } - - @Override - public float getFloat(long i) { - return (float) getDouble(i); - } - - @Override - public float getFloat(long i, long j) { - return (float) getDouble(i, j); - } - - @Override - public INDArray transpose() { - Preconditions.checkState(rank() >= 2, "Can't transpose array with rank < 2: array shape %ndShape", this); - - return permute(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); - } - - /** - * - * Return transposed version of this matrix. - * - * PLEASE NOTE: This method is NOT in place, it will return transposed copy instead. - */ - @Override - public INDArray transposei() { - Preconditions.checkState(rank() >= 2, "Can't transpose array with rank < 2: array shape %ndShape", this); - - return permutei(ArrayUtil.reverseCopy(ArrayUtil.range(0, rank()))); - } - - protected INDArray create(DataBuffer data, int[] shape, int[] strides) { - return Nd4j.create(data, shape, strides, 0, ordering()); - } - - @Deprecated - @Override - public INDArray reshape(char order, int... newShape) { - return reshape(order, ArrayUtil.toLongArray(newShape)); - } - - @Override - public INDArray reshape(char order, long... newShape) { - return reshape(order, false, newShape); - } - - @Override - public INDArray reshape(char order, boolean enforceView, long... newShape){ - Nd4j.getCompressor().autoDecompress(this); - - // special case for empty reshape - if (this.length() == 1 && (newShape == null || newShape.length == 0) && this.elementWiseStride() == 1) { - return Nd4j.create(this.data(), new int[0], new int[0], 0); - } - - if (newShape == null || newShape.length < 1) - throw new ND4JIllegalStateException( - "Can't reshape(long...) without shape arguments. Got empty shape instead."); - - // TODO: maybe toFlatten() makes more sense here? - // reshape(-1) special case - if (newShape.length == 1 && newShape[0] == -1) - newShape[0] = this.length(); - - int numberNegativesOnes = 0; - long[] shape = ArrayUtil.copy(newShape); - - - for (int i = 0; i < shape.length; i++) { - if (shape[i] < 0) { - if (numberNegativesOnes >= 1) - throw new IllegalArgumentException("Only one dimension can be negative ones. Got shape " - + Arrays.toString(newShape)); - - numberNegativesOnes++; - - int shapeLength = 1; - for (int j = 0; j < shape.length; j++) - if (shape[j] >= 1) - shapeLength *= shape[j]; - long realShape = Math.abs(length() / shapeLength); - long[] thisNewShape = new long[shape.length]; - for (int j = 0; j < shape.length; j++) { - if (i != j) { - thisNewShape[j] = shape[j]; - } else - thisNewShape[j] = realShape; - } - - shape = thisNewShape; - break; - - } - } - - long prod = ArrayUtil.prodLong(shape); - - if (prod != this.length()) { - throw new ND4JIllegalStateException("New shape length doesn't match original length: [" + prod + "] vs [" + this.length() + "]. Original shape: "+Arrays.toString(this.shape())+" New Shape: "+Arrays.toString(newShape)); - } - - - - - - INDArray reshapeAttempt = Shape.newShapeNoCopy(this, shape, order == 'f'); - if (reshapeAttempt != null) { - // kinda strange get/set usage - // reshapeAttempt.setOrder(Shape.getOrder(reshapeAttempt)); - return reshapeAttempt; - } - - if(enforceView){ - throw new ND4JIllegalStateException("Unable to reshape array as view, called with enforceView=true. " + - "Use enforceView=false to return a copy instead, or call reshape on a non-strided array. Array shape info: " + this.shapeInfoToString().replaceAll("\n","")); - } - - - if (order != ordering()) { - INDArray ret = Nd4j.createUninitialized(this.dataType(), shape, order); - ret.setData(dup(order).data()); - return ret; - } else if (this.isEmpty()) { - return Nd4j.create(this.dataType(), shape); - } else { - INDArray ret = this.dup(order); - return Nd4j.create(ret.data(), shape); - } - } - - @Override - public double getDoubleUnsafe(long offset) { - return data().getDouble(offset); - } - - @Override - public INDArray putScalarUnsafe(long offset, double value) { - autoProcessScalarCall(); - data().put(offset, value); - return this; - } - - @Override - public INDArray reshape(char order, int rows, int columns) { - return reshape(order, new long[] {rows, columns}); - } - - /** - * Reshape the ndarray in to the specified dimensions, - * possible errors being thrown for invalid shapes - * - * Note here that one dimension can be -1. - * The dimension that is -1 will be inferred from the shape and - * the length of the ndarray - * - * @param shape the shape of the ndarray. - * @return the new reshaped nd array - */ - - @Override - public INDArray reshape(int[] shape) { - return reshape(Nd4j.order(), shape); - } - - @Override - public INDArray reshape(long... shape) { - return reshape(Nd4j.order(), shape); - } - - @Override - public INDArray prod(boolean keepDims, int... dimension) { - validateNumericalArray("prod", false); - return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); - } - - @Override - public INDArray prod(int... dimension) { - return prod(false, dimension); - } - - @Override - public INDArray mean(boolean keepDims, int... dimension) { - validateNumericalArray("mean", false); - return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); - } - - @Override - public INDArray mean(int... dimension) { - return mean(false, dimension); - } - - @Override - public INDArray amean(int... dimension) { - validateNumericalArray("amean", false); - return Nd4j.getExecutioner().exec(new AMean(this, dimension)); - } - - @Override - public INDArray mean(@NonNull INDArray result, boolean keepDims, int... dimension) { - validateNumericalArray("mean", false); - return Nd4j.getExecutioner().exec(new Mean(this, result, keepDims, dimension)); - } - - @Override - public INDArray mean(@NonNull INDArray result, int... dimension) { - return mean(result, false, dimension); - } - - @Override - public INDArray var(int... dimension) { - validateNumericalArray("var", false); - return Nd4j.getExecutioner().exec(new Variance(this, dimension)); - } - - @Override - public INDArray var(boolean biasCorrected, int... dimension) { - validateNumericalArray("var", false); - return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); - } - - @Override - public INDArray max(boolean keepDims, int... dimension) { - validateNumericalArray("max", false); - return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); - } - - @Override - public INDArray max(int... dimension) { - return max(false, dimension); - } - - @Override - public INDArray amax(int... dimension) { - validateNumericalArray("amax", false); - return Nd4j.getExecutioner().exec(new AMax(this, dimension)); - } - - @Override - public INDArray min(boolean keepDims, int... dimension) { - validateNumericalArray("min", false); - return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); - } - - @Override - public INDArray min(int... dimension) { - return min(false, dimension); - } - - @Override - public INDArray amin(int... dimension) { - validateNumericalArray("amin", false); - return Nd4j.getExecutioner().exec(new AMin(this, dimension)); - } - - @Override - public INDArray sum(int... dimension) { - validateNumericalArray("sum", true); - return Nd4j.getExecutioner().exec(new Sum(this, dimension)); - } - - @Override - public INDArray sum(boolean keepDim, int... dimension) { - validateNumericalArray("sum", true); - return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); - } - - @Override - public INDArray entropy(int... dimension) { - validateNumericalArray("entropy", false); - return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); - } - - @Override - public INDArray shannonEntropy(int... dimension) { - validateNumericalArray("shannonEntropy", false); - return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); - } - - @Override - public INDArray logEntropy(int... dimension) { - validateNumericalArray("logEntropy", false); - return Nd4j.getExecutioner().exec(new LogEntropy(this, dimension)); - } - - @Override - public INDArray sum(@NonNull INDArray result, boolean keepDims, int... dimension) { - validateNumericalArray("sum", true); - return Nd4j.getExecutioner().exec(new Sum(this, result, keepDims, dimension)); - } - - @Override - public INDArray sum(@NonNull INDArray result, int... dimension) { - return sum(result, false, dimension); - } - - @Override - public INDArray norm1(int... dimension) { - return norm1(false, dimension); - } - - @Override - public INDArray norm1(boolean keepDims, int... dimension) { - validateNumericalArray("norm1", false); - return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); - } - - @Override - public INDArray std(int... dimension) { - return std(true, dimension); - } - - @Override - public INDArray std(boolean biasCorrected, int... dimension) { - return std(biasCorrected, false, dimension); - } - - @Override - public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) { - validateNumericalArray("std", false); - return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected, keepDims, dimension)); - } - - @Override - public Number stdNumber(boolean biasCorrected) { - validateNumericalArray("stdNumber", false); - return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); - } - - @Override - public INDArray norm2(boolean keepDims, int... dimension) { - validateNumericalArray("norm2", false); - return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); - } - - @Override - public INDArray norm2(int... dimension) { - return norm2(false, dimension); - } - - @Override - public int columns() { - if (isMatrix()) - return (int) size(1); - else if (Shape.isColumnVectorShape(shape())) { - return 1; - } else if (Shape.isRowVectorShape(shape())) { - return (int) length(); - } - throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid"); - - - } - - @Override - public int rows() { - if (isMatrix()) - return (int) size(0); - else if (Shape.isRowVectorShape(shape())) { - return 1; - } else if (Shape.isColumnVectorShape(shape())) { - return (int) length(); - } - - throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid"); - } - - @Override - public INDArray ravel(char ordering) { - Nd4j.getCompressor().autoDecompress(this); - if(ordering == this.ordering() && Shape.hasDefaultStridesForShape(this)){ - return reshape(ordering, length()); - } - return dup(ordering).reshape(ordering, length()); - } - - @Override - public INDArray ravel() { - return reshape(length()); - } - - @Override - public void sliceVectors(List list) { - if (isVector()) - list.add(this); - else { - for (int i = 0; i < slices(); i++) { - slice(i).sliceVectors(list); - } - } - } - - @Override - public INDArray reshape(long newRows, long newColumns) { - return reshape(new long[] {newRows, newColumns}); - } - - @Override - public INDArray getColumn(long c) { - Nd4j.getCompressor().autoDecompress(this); - - if (isColumnVector() && c == 0) - return this; - else if (isColumnVector() && c > 0) - throw new IllegalArgumentException("Illegal index for column"); - Preconditions.checkArgument(this.rank() == 2, "getColumn() can be called on 2D arrays only"); - return tensorAlongDimension(c, 0); - } - - @Override - public INDArray getColumn(long c, boolean keepDim) { - INDArray col = getColumn(c); - if(!keepDim) - return col; - return col.reshape(col.length(), 1); - } - - @Override - public INDArray getRows(int[] rindices) { - Nd4j.getCompressor().autoDecompress(this); - - if (!isMatrix() && !isVector()) - throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); - if (isVector()) - return Nd4j.pullRows(this, 1, rindices); - else { - INDArray ret = Nd4j.createUninitialized(this.dataType(), rindices.length, columns()); - for (int i = 0; i < rindices.length; i++) - ret.putRow(i, getRow(rindices[i])); - return ret; - } - } - - @Override - public INDArray get(INDArrayIndex... indexes) { - Nd4j.getCompressor().autoDecompress(this); - - int numPoint = 0; - int numInterval = 0; - int numAll = 0; - int numNewAxis = 0; - int numSpecified = 0; - for(INDArrayIndex i : indexes){ - if(i instanceof PointIndex){ - numPoint++; - } else if(i instanceof NDArrayIndexAll){ - numAll++; - } else if(i instanceof IntervalIndex){ - numInterval++; - } else if(i instanceof NewAxis){ - numNewAxis++; - } else if(i instanceof SpecifiedIndex){ - numSpecified++; + if (newShape == null || newShape.length < 1) { + throw new ND4JIllegalStateException( + "Can't reshape(long...) without shape arguments. Got empty shape instead."); + } + + // TODO: maybe toFlatten() makes more sense here? + // reshape(-1) special case + if (newShape.length == 1 && newShape[0] == -1) { + newShape[0] = this.length(); + } + + int numberNegativesOnes = 0; + long[] shape = ArrayUtil.copy(newShape); + + for (int i = 0; i < shape.length; i++) { + if (shape[i] < 0) { + if (numberNegativesOnes >= 1) { + throw new IllegalArgumentException( + "Only one dimension can be negative ones. Got shape " + + Arrays.toString(newShape)); + } + + numberNegativesOnes++; + + int shapeLength = 1; + for (int j = 0; j < shape.length; j++) { + if (shape[j] >= 1) { + shapeLength *= shape[j]; + } + } + long realShape = Math.abs(length() / shapeLength); + long[] thisNewShape = new long[shape.length]; + for (int j = 0; j < shape.length; j++) { + if (i != j) { + thisNewShape[j] = shape[j]; } else { - throw new IllegalStateException("Unknown index: " + i); + thisNewShape[j] = realShape; } } - // Padding remaining dimensions with all() index if too few indices provided - if (indexes.length - numNewAxis < this.rank()) { - val newIndexes = new INDArrayIndex[this.rank() + numNewAxis]; - System.arraycopy(indexes, 0, newIndexes, 0, indexes.length); + shape = thisNewShape; + break; - for (int e = indexes.length; e < newIndexes.length; e++) { - numAll++; - newIndexes[e] = NDArrayIndex.all(); - } - - indexes = newIndexes; - } - - Preconditions.checkState((numPoint + numInterval + numAll + numSpecified) == rank(), "Illegal set of indices for array: need at least" + - " %s point/interval/all/specified indices for rank %s array (%ndShape), got indices %s", rank(), rank(), this, indexes); - - int outRank = rank() + numNewAxis - numPoint; - Preconditions.checkState(outRank >= 0, "Illegal set of indices for array: %ndShape, %s", this, indexes); - - - //To work out sub-array, we need to work out 3 things: offset, shape and strides. We calculate all of these - long[] outShape = new long[outRank]; - long[] outStrides = new long[outRank]; - long offset = offset(); //Start with existing offset if view - - int outIdx = 0; //Axis number counter for output array - int inIdx = 0; //Axis number counter for input array - for( int i=0; i= size(inIdx)) { - throw new IllegalStateException("Indices are out of range: Cannot get interval index " + indexes[i] + - " on array with size(" + inIdx + ")=" + size(inIdx) + ". Array shape: " + Arrays.toString(shape()) + - ", indices: " + Arrays.toString(indexes)); - } - long stride = ii.stride(); - long length = (endInc - start)/stride + 1; - - offset += ii.offset() * stride(inIdx); - outShape[outIdx] = length; - outStrides[outIdx] = ii.stride() * stride(inIdx); - inIdx++; - outIdx++; - } else if(indexes[i] instanceof NewAxis) { - //New axis: appends a 1 in shape. Axis not present in input, but is present in output - outShape[outIdx] = 1; - if (outIdx > 0) { //Stride doesn't matter for 1 size axis anyway... - outStrides[outIdx] = outStrides[outIdx - 1]; - } else { - outStrides[outIdx] = 1; - } - outIdx++; - } else if(indexes[i] instanceof SpecifiedIndex){ - //Specified index: axis present in both input and output - SpecifiedIndex si = (SpecifiedIndex)indexes[i]; - outShape[outIdx++] = si.length(); - inIdx++; - //Don't care about strides for specified index, as result won't be a view - } else { - throw new IllegalStateException("Unknown index type: " + i); //Should never happen - } - } - - - //Note: If we have specified indices, we can't return a view. Instead, we copy the specified sub-arrays from - // the input array to the output array. - //How? Create the output array, then do loop over the specified indices only, and copy sub-arrays for all other axes - if (numSpecified > 0) { - INDArray out = Nd4j.create(dataType(), outShape); - - //Need to copy subsets here - long[] specifiedSizes = new long[numSpecified]; - SpecifiedIndex[] si = new SpecifiedIndex[numSpecified]; - int j=0; - for( int i=0; i replace with loop + point - // ii. new axis indices -> ignore/exclude (don't appear in input) - // iii. interval indices -> replace with all - //(2) Get from output: requested indices, except for: - // i. point indices -> ignore/exclude (don't appear in output) - // ii. new axis indices -> replace with point(0) - - - INDArrayIndex[] pointIdxsIn = new INDArrayIndex[indexes.length - numNewAxis]; //Indices for source (this array) - int[] specifiedAxisIn = new int[numSpecified]; - int specCount = 0; - j = 0; - for( int i=0; i + * Note here that one dimension can be -1. The dimension that is -1 will be inferred from the + * shape and the length of the ndarray + * + * @param shape the shape of the ndarray. + * @return the new reshaped nd array + */ + + @Override + public INDArray reshape(int[] shape) { + return reshape(Nd4j.order(), shape); + } + + @Override + public INDArray reshape(long... shape) { + return reshape(Nd4j.order(), shape); + } + + @Override + public INDArray prod(boolean keepDims, int... dimension) { + validateNumericalArray("prod", false); + return Nd4j.getExecutioner().exec(new Prod(this, keepDims, dimension)); + } + + @Override + public INDArray prod(int... dimension) { + return prod(false, dimension); + } + + @Override + public INDArray mean(boolean keepDims, int... dimension) { + validateNumericalArray("mean", false); + return Nd4j.getExecutioner().exec(new Mean(this, keepDims, dimension)); + } + + @Override + public INDArray mean(int... dimension) { + return mean(false, dimension); + } + + @Override + public INDArray amean(int... dimension) { + validateNumericalArray("amean", false); + return Nd4j.getExecutioner().exec(new AMean(this, dimension)); + } + + @Override + public INDArray mean(@NonNull INDArray result, boolean keepDims, int... dimension) { + validateNumericalArray("mean", false); + return Nd4j.getExecutioner().exec(new Mean(this, result, keepDims, dimension)); + } + + @Override + public INDArray mean(@NonNull INDArray result, int... dimension) { + return mean(result, false, dimension); + } + + @Override + public INDArray var(int... dimension) { + validateNumericalArray("var", false); + return Nd4j.getExecutioner().exec(new Variance(this, dimension)); + } + + @Override + public INDArray var(boolean biasCorrected, int... dimension) { + validateNumericalArray("var", false); + return Nd4j.getExecutioner().exec(new Variance(this, biasCorrected, dimension)); + } + + @Override + public INDArray max(boolean keepDims, int... dimension) { + validateNumericalArray("max", false); + return Nd4j.getExecutioner().exec(new Max(this, keepDims, dimension)); + } + + @Override + public INDArray max(int... dimension) { + return max(false, dimension); + } + + @Override + public INDArray amax(int... dimension) { + validateNumericalArray("amax", false); + return Nd4j.getExecutioner().exec(new AMax(this, dimension)); + } + + @Override + public INDArray min(boolean keepDims, int... dimension) { + validateNumericalArray("min", false); + return Nd4j.getExecutioner().exec(new Min(this, keepDims, dimension)); + } + + @Override + public INDArray min(int... dimension) { + return min(false, dimension); + } + + @Override + public INDArray amin(int... dimension) { + validateNumericalArray("amin", false); + return Nd4j.getExecutioner().exec(new AMin(this, dimension)); + } + + @Override + public INDArray sum(int... dimension) { + validateNumericalArray("sum", true); + return Nd4j.getExecutioner().exec(new Sum(this, dimension)); + } + + @Override + public INDArray sum(boolean keepDim, int... dimension) { + validateNumericalArray("sum", true); + return Nd4j.getExecutioner().exec(new Sum(this, null, keepDim, dimension)); + } + + @Override + public INDArray entropy(int... dimension) { + validateNumericalArray("entropy", false); + return Nd4j.getExecutioner().exec(new Entropy(this, dimension)); + } + + @Override + public INDArray shannonEntropy(int... dimension) { + validateNumericalArray("shannonEntropy", false); + return Nd4j.getExecutioner().exec(new ShannonEntropy(this, dimension)); + } + + @Override + public INDArray logEntropy(int... dimension) { + validateNumericalArray("logEntropy", false); + return Nd4j.getExecutioner().exec(new LogEntropy(this, dimension)); + } + + @Override + public INDArray sum(@NonNull INDArray result, boolean keepDims, int... dimension) { + validateNumericalArray("sum", true); + return Nd4j.getExecutioner().exec(new Sum(this, result, keepDims, dimension)); + } + + @Override + public INDArray sum(@NonNull INDArray result, int... dimension) { + return sum(result, false, dimension); + } + + @Override + public INDArray norm1(int... dimension) { + return norm1(false, dimension); + } + + @Override + public INDArray norm1(boolean keepDims, int... dimension) { + validateNumericalArray("norm1", false); + return Nd4j.getExecutioner().exec(new Norm1(this, keepDims, dimension)); + } + + @Override + public INDArray std(int... dimension) { + return std(true, dimension); + } + + @Override + public INDArray std(boolean biasCorrected, int... dimension) { + return std(biasCorrected, false, dimension); + } + + @Override + public INDArray std(boolean biasCorrected, boolean keepDims, int... dimension) { + validateNumericalArray("std", false); + return Nd4j.getExecutioner() + .exec(new StandardDeviation(this, biasCorrected, keepDims, dimension)); + } + + @Override + public Number stdNumber(boolean biasCorrected) { + validateNumericalArray("stdNumber", false); + return Nd4j.getExecutioner().exec(new StandardDeviation(this, biasCorrected)).getDouble(0); + } + + @Override + public INDArray norm2(boolean keepDims, int... dimension) { + validateNumericalArray("norm2", false); + return Nd4j.getExecutioner().exec(new Norm2(this, keepDims, dimension)); + } + + @Override + public INDArray norm2(int... dimension) { + return norm2(false, dimension); + } + + @Override + public int columns() { + if (isMatrix()) { + return (int) size(1); + } else if (Shape.isColumnVectorShape(shape())) { + return 1; + } else if (Shape.isRowVectorShape(shape())) { + return (int) length(); + } + throw new IllegalStateException("Rank is [" + rank() + "]; columns() call is not valid"); + + + } + + @Override + public int rows() { + if (isMatrix()) { + return (int) size(0); + } else if (Shape.isRowVectorShape(shape())) { + return 1; + } else if (Shape.isColumnVectorShape(shape())) { + return (int) length(); + } + + throw new IllegalStateException("Rank is " + rank() + " rows() call is not valid"); + } + + @Override + public INDArray ravel(char ordering) { + Nd4j.getCompressor().autoDecompress(this); + if (ordering == this.ordering() && Shape.hasDefaultStridesForShape(this)) { + return reshape(ordering, length()); + } + return dup(ordering).reshape(ordering, length()); + } + + @Override + public INDArray ravel() { + return reshape(length()); + } + + @Override + public void sliceVectors(List list) { + if (isVector()) { + list.add(this); + } else { + for (int i = 0; i < slices(); i++) { + slice(i).sliceVectors(list); + } + } + } + + @Override + public INDArray reshape(long newRows, long newColumns) { + return reshape(new long[]{newRows, newColumns}); + } + + @Override + public INDArray getColumn(long c) { + Nd4j.getCompressor().autoDecompress(this); + + if (isColumnVector() && c == 0) { + return this; + } else if (isColumnVector() && c > 0) { + throw new IllegalArgumentException("Illegal index for column"); + } + Preconditions.checkArgument(this.rank() == 2, "getColumn() can be called on 2D arrays only"); + return tensorAlongDimension(c, 0); + } + + @Override + public INDArray getColumn(long c, boolean keepDim) { + INDArray col = getColumn(c); + if (!keepDim) { + return col; + } + return col.reshape(col.length(), 1); + } + + @Override + public INDArray getRows(int[] rindices) { + Nd4j.getCompressor().autoDecompress(this); + + if (!isMatrix() && !isVector()) { + throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); + } + if (isVector()) { + return Nd4j.pullRows(this, 1, rindices); + } else { + INDArray ret = Nd4j.createUninitialized(this.dataType(), rindices.length, columns()); + for (int i = 0; i < rindices.length; i++) { + ret.putRow(i, getRow(rindices[i])); + } + return ret; + } + } + + @Override + public INDArray get(INDArrayIndex... indexes) { + Nd4j.getCompressor().autoDecompress(this); + + int numPoint = 0; + int numInterval = 0; + int numAll = 0; + int numNewAxis = 0; + int numSpecified = 0; + for (INDArrayIndex i : indexes) { + if (i instanceof PointIndex) { + numPoint++; + } else if (i instanceof NDArrayIndexAll) { + numAll++; + } else if (i instanceof IntervalIndex) { + numInterval++; + } else if (i instanceof NewAxis) { + numNewAxis++; + } else if (i instanceof SpecifiedIndex) { + numSpecified++; + } else { + throw new IllegalStateException("Unknown index: " + i); + } + } + + // Padding remaining dimensions with all() index if too few indices provided + if (indexes.length - numNewAxis < this.rank()) { + val newIndexes = new INDArrayIndex[this.rank() + numNewAxis]; + System.arraycopy(indexes, 0, newIndexes, 0, indexes.length); + + for (int e = indexes.length; e < newIndexes.length; e++) { + numAll++; + newIndexes[e] = NDArrayIndex.all(); + } + + indexes = newIndexes; + } + + Preconditions.checkState((numPoint + numInterval + numAll + numSpecified) == rank(), + "Illegal set of indices for array: need at least" + + " %s point/interval/all/specified indices for rank %s array (%ndShape), got indices %s", + rank(), rank(), this, indexes); + + int outRank = rank() + numNewAxis - numPoint; + Preconditions.checkState(outRank >= 0, "Illegal set of indices for array: %ndShape, %s", this, + indexes); + + //To work out sub-array, we need to work out 3 things: offset, shape and strides. We calculate all of these + long[] outShape = new long[outRank]; + long[] outStrides = new long[outRank]; + long offset = offset(); //Start with existing offset if view + + int outIdx = 0; //Axis number counter for output array + int inIdx = 0; //Axis number counter for input array + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] instanceof PointIndex) { + //Point indexes don't appear in output + PointIndex pi = (PointIndex) indexes[i]; + offset += pi.offset() * stride(inIdx); + inIdx++; + } else if (indexes[i] instanceof NDArrayIndexAll) { + //All index: doesn't change offset. Axis is in both in and output arrays + outShape[outIdx] = size(inIdx); + outStrides[outIdx] = stride(inIdx); + inIdx++; + outIdx++; + } else if (indexes[i] instanceof IntervalIndex) { + //Interval index: Axis is in both in and output arrays, but output might be smaller + IntervalIndex ii = (IntervalIndex) indexes[i]; + long start = ii.offset(); + long endInc = ii.end() - (ii.isInclusive() ? 0 : 1); + if (endInc >= size(inIdx)) { + throw new IllegalStateException( + "Indices are out of range: Cannot get interval index " + indexes[i] + + " on array with size(" + inIdx + ")=" + size(inIdx) + ". Array shape: " + + Arrays.toString(shape()) + + ", indices: " + Arrays.toString(indexes)); + } + long stride = ii.stride(); + long length = (endInc - start) / stride + 1; + + offset += ii.offset() * stride(inIdx); + outShape[outIdx] = length; + outStrides[outIdx] = ii.stride() * stride(inIdx); + inIdx++; + outIdx++; + } else if (indexes[i] instanceof NewAxis) { + //New axis: appends a 1 in shape. Axis not present in input, but is present in output + outShape[outIdx] = 1; + if (outIdx > 0) { //Stride doesn't matter for 1 size axis anyway... + outStrides[outIdx] = outStrides[outIdx - 1]; } else { - INDArray ret = Nd4j.createUninitialized(this.dataType(), rows(), cindices.length); - for (int i = 0; i < cindices.length; i++) - ret.putColumn(i, getColumn(cindices[i])); - return ret; + outStrides[outIdx] = 1; } - + outIdx++; + } else if (indexes[i] instanceof SpecifiedIndex) { + //Specified index: axis present in both input and output + SpecifiedIndex si = (SpecifiedIndex) indexes[i]; + outShape[outIdx++] = si.length(); + inIdx++; + //Don't care about strides for specified index, as result won't be a view + } else { + throw new IllegalStateException("Unknown index type: " + i); //Should never happen + } } - protected INDArray create(int rows, int length) { - return create(new int[] {rows, length}); - } + //Note: If we have specified indices, we can't return a view. Instead, we copy the specified sub-arrays from + // the input array to the output array. + //How? Create the output array, then do loop over the specified indices only, and copy sub-arrays for all other axes + if (numSpecified > 0) { + INDArray out = Nd4j.create(dataType(), outShape); - @Override - public INDArray getRow(long r) { - if (isRowVector() && r == 0) - return this; - else if (isRowVector() && r > 0) - throw new IllegalArgumentException("Illegal index for row: requested row " + r + " but this.size(0)=" + this.size(0)); + //Need to copy subsets here + long[] specifiedSizes = new long[numSpecified]; + SpecifiedIndex[] si = new SpecifiedIndex[numSpecified]; + int j = 0; + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] instanceof SpecifiedIndex) { + specifiedSizes[j] = indexes[i].length(); + si[j] = (SpecifiedIndex) indexes[i]; + j++; + } + } + NdIndexIterator iter = new NdIndexIterator(specifiedSizes); - Preconditions.checkArgument(rank() == 2, "getRow() can be called on 2D arrays only"); - Preconditions.checkArgument(r < rows(), "Row index must be smaller than total number of rows"); + //What we need to do here: Iterate over sub-arrays for both input and output + //(1) Get from input: requested indices, except for: + // i. specified indices -> replace with loop + point + // ii. new axis indices -> ignore/exclude (don't appear in input) + // iii. interval indices -> replace with all + //(2) Get from output: requested indices, except for: + // i. point indices -> ignore/exclude (don't appear in output) + // ii. new axis indices -> replace with point(0) - return tensorAlongDimension(r, 1); - } + INDArrayIndex[] pointIdxsIn = new INDArrayIndex[indexes.length + - numNewAxis]; //Indices for source (this array) + int[] specifiedAxisIn = new int[numSpecified]; + int specCount = 0; + j = 0; + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] instanceof NewAxis) { + continue; //Skip new axis in source dims + } + if (indexes[i] instanceof SpecifiedIndex) { + specifiedAxisIn[specCount++] = j; + } + pointIdxsIn[j++] = indexes[i]; + } - @Override - public INDArray getRow(long r, boolean keepDim) { - INDArray row = getRow(r); - if(!keepDim) - return row; - return row.reshape(1, row.length()); - } + INDArrayIndex[] pointIdxsOut = new INDArrayIndex[indexes.length + - numPoint]; //Indices for destination (output array) + j = 0; + specCount = 0; + int[] specifiedAxisOut = new int[numSpecified]; + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] instanceof NewAxis) { + pointIdxsOut[j++] = NDArrayIndex.point(0); + continue; + } else if (indexes[i] instanceof PointIndex) { + continue; + } else if (indexes[i] instanceof SpecifiedIndex) { + specifiedAxisOut[specCount++] = j; + } else if (indexes[i] instanceof IntervalIndex) { + pointIdxsOut[j++] = NDArrayIndex.all(); + continue; + } + pointIdxsOut[j++] = indexes[i]; + } - public boolean equalsWithEps(Object o, double eps) { - Nd4j.getCompressor().autoDecompress(this); - - - if (o == null) - return false; - - if (!(o instanceof INDArray)) - return false; - - INDArray n = (INDArray) o; - Nd4j.getCompressor().autoDecompress(n); - - if (n == this) - return true; - - if (this.rank() != n.rank()) - return false; - - if (this.length() != n.length()) - return false; - - if (this.isEmpty() != n.isEmpty()) - return false; - - if (this.isEmpty() && n.isEmpty()) - return Shape.shapeEquals(this.shape(), n.shape()); - - if (this.dataType() != n.dataType()) - return false; - - // meh - if (this.dataType() == DataType.UTF8 && n.dataType() == DataType.UTF8) { - for (long e = 0; e < this.length(); e++) { - val str1 = this.getString(e); - val str2 = n.getString(e); - - if (!str1.equals(str2)) - return false; - } - - return true; + //Iterate over sub-arrays; copy from source to destination + while (iter.hasNext()) { + long[] specifiedIdxs = iter.next(); + for (int i = 0; i < specifiedIdxs.length; i++) { + long sourceIdx = si[i].getIndexes()[(int) specifiedIdxs[i]]; + pointIdxsIn[specifiedAxisIn[i]] = NDArrayIndex.point(sourceIdx); + int outI = (int) specifiedIdxs[i]; + pointIdxsOut[specifiedAxisOut[i]] = NDArrayIndex.point(outI); } - //epsilon equals - if (isScalar() && n.isScalar()) { - if (isZ()) { - val val = getLong(0); - val val2 = n.getLong(0); + out.get(pointIdxsOut).assign(get(pointIdxsIn)); + } - return val == val2; - } else if (isR()) { - val val = getDouble(0); - val val2 = n.getDouble(0); + return out; + } - if (Double.isNaN(val) != Double.isNaN(val2)) - return false; + char order = Shape.getOrder(outShape, outStrides, -1); + INDArray out = create(data, outShape, outStrides, offset, order); + return out; + } - return Math.abs(val - val2) < eps; - } else if (isB()) { - val val = getInt(0); - val val2 = n.getInt(0); - - return val == val2; - } - - } else if (isVector() && n.isVector()) { - val op = new EqualsWithEps(this, n, eps); - Nd4j.exec(op); - val diff = op.z().getDouble(0); - - return diff < 0.5; + @Override + public INDArray getColumns(int... cindices) { + if (!isMatrix() && !isVector()) { + throw new IllegalArgumentException("Unable to get columns from a non matrix or vector"); + } + if (isVector()) { + return Nd4j.pullRows(this, 0, cindices, this.ordering()); + } else { + INDArray ret = Nd4j.createUninitialized(this.dataType(), rows(), cindices.length); + for (int i = 0; i < cindices.length; i++) { + ret.putColumn(i, getColumn(cindices[i])); } + return ret; + } - if (!Arrays.equals(this.shape(), n.shape())) - return false; + } + + protected INDArray create(int rows, int length) { + return create(new int[]{rows, length}); + } + + @Override + public INDArray getRow(long r) { + if (isRowVector() && r == 0) { + return this; + } else if (isRowVector() && r > 0) { + throw new IllegalArgumentException( + "Illegal index for row: requested row " + r + " but this.size(0)=" + this.size(0)); + } + + Preconditions.checkArgument(rank() == 2, "getRow() can be called on 2D arrays only"); + Preconditions.checkArgument(r < rows(), "Row index must be smaller than total number of rows"); + + return tensorAlongDimension(r, 1); + } + + @Override + public INDArray getRow(long r, boolean keepDim) { + INDArray row = getRow(r); + if (!keepDim) { + return row; + } + return row.reshape(1, row.length()); + } + + public boolean equalsWithEps(Object o, double eps) { + Nd4j.getCompressor().autoDecompress(this); + + if (o == null) { + return false; + } + + if (!(o instanceof INDArray)) { + return false; + } + + INDArray n = (INDArray) o; + Nd4j.getCompressor().autoDecompress(n); + + if (n == this) { + return true; + } + + if (this.rank() != n.rank()) { + return false; + } + + if (this.length() != n.length()) { + return false; + } + + if (this.isEmpty() != n.isEmpty()) { + return false; + } + + if (this.isEmpty() && n.isEmpty()) { + return Shape.shapeEquals(this.shape(), n.shape()); + } + + if (this.dataType() != n.dataType()) { + return false; + } + + // meh + if (this.dataType() == DataType.UTF8 && n.dataType() == DataType.UTF8) { + for (long e = 0; e < this.length(); e++) { + val str1 = this.getString(e); + val str2 = n.getString(e); + + if (!str1.equals(str2)) { + return false; + } + } + + return true; + } + + //epsilon equals + if (isScalar() && n.isScalar()) { + if (isZ()) { + val val = getLong(0); + val val2 = n.getLong(0); + + return val == val2; + } else if (isR()) { + val val = getDouble(0); + val val2 = n.getDouble(0); + + if (Double.isNaN(val) != Double.isNaN(val2)) { + return false; + } + + return Math.abs(val - val2) < eps; + } else if (isB()) { + val val = getInt(0); + val val2 = n.getInt(0); + + return val == val2; + } + + } else if (isVector() && n.isVector()) { + val op = new EqualsWithEps(this, n, eps); + Nd4j.exec(op); + val diff = op.z().getDouble(0); + + return diff < 0.5; + } + + if (!Arrays.equals(this.shape(), n.shape())) { + return false; + } + + if (!Shape.shapeEquals(shape(), n.shape())) { + return false; + } + + if (slices() != n.slices()) { + return false; + } + + if (n.ordering() == ordering()) { + EqualsWithEps op = new EqualsWithEps(this, n, eps); + Nd4j.getExecutioner().exec(op); + double diff = op.z().getDouble(0); + + return diff < 0.5; + } else { + EqualsWithEps op = new EqualsWithEps(this, n, eps); + Nd4j.getExecutioner().exec(op); + double diff = op.z().getDouble(0); + + return diff < 0.5; + } + } + + @Override + public boolean equalShapes(@NonNull INDArray other) { + if (isEmpty() != other.isEmpty()) { + return false; + } + if (rank() != other.rank()) { + return false; + } + for (int i = 0; i < rank(); i++) { + if (size(i) != other.size(i)) { + return false; + } + } + return true; + } + + /** + * Compare two matrices. Returns true if and only if other is also a DoubleMatrix which has the + * same size and the maximal absolute difference in matrix elements is smaller than 1e-5. + * + * @param o + */ + @Override + public boolean equals(Object o) { + return equalsWithEps(o, Nd4j.EPS_THRESHOLD); + } + + @Override + public int hashCode() { + val longHash = Nd4j.exec(new HashCode(this))[0].getLong(0); + return Math.abs(longHash) <= Integer.MAX_VALUE ? (int) longHash + : (int) (longHash % Integer.MAX_VALUE); + } + + @Override + public DataBuffer shapeInfoDataBuffer() { + return shapeInformation; + } + + @Override + public LongBuffer shapeInfo() { + return shapeInformation.asNioLong(); + } + + public long[] shape() { + return jvmShapeInfo.shape; + } + + @Override + public String shapeInfoToString() { + return Shape.shapeToString(this); + } + + @Override + public long[] stride() { + return jvmShapeInfo.stride; + } - if (!Shape.shapeEquals(shape(), n.shape())) { - return false; - } + @Override + public long offset() { + return data().offset(); + } + @Override + public char ordering() { + return jvmShapeInfo.order; + } - if (slices() != n.slices()) - return false; + @Override + public long size(int dimension) { + if (dimension < 0) { + dimension += jvmShapeInfo.rank; + } - if (n.ordering() == ordering()) { - EqualsWithEps op = new EqualsWithEps(this, n, eps); - Nd4j.getExecutioner().exec(op); - double diff = op.z().getDouble(0); - - return diff < 0.5; + if (isScalar()) { + if (dimension == 0 || dimension == 1 || dimension < 0) { + return length(); } else { - EqualsWithEps op = new EqualsWithEps(this, n, eps); - Nd4j.getExecutioner().exec(op); - double diff = op.z().getDouble(0); - - return diff < 0.5; + throw new IllegalArgumentException("Illegal dimension for scalar " + dimension); } } - @Override - public boolean equalShapes(@NonNull INDArray other){ - if(isEmpty() != other.isEmpty()) - return false; - if(rank() != other.rank()) - return false; - for( int i=0; i= rank()) { + throw new IllegalArgumentException( + "Invalid size: cannot get size of dimension " + dimension + " for rank " + + rank() + " NDArray (array shape: " + Arrays.toString(this.shape()) + ")"); + } + + return jvmShapeInfo.shape[dimension]; + } + + @Override + public int rank() { + return jvmShapeInfo.rank; + } + + @Override + public long length() { + if (isEmpty()) { + return 0; + } + return jvmShapeInfo.length; + } + + @Override + public INDArray broadcast(INDArray result) { + Nd4j.getCompressor().autoDecompress(this); + + val shape = result.shape(); + + if (Shape.shapeEquals(shape, shape())) { + return this; + } + + // if we're on scalar, we can just create new array + if (this.isScalar()) { + return Nd4j.createUninitialized(this.dataType(), shape).assign(this.getDouble(0)); + } + + boolean compatible = true; + int count = shape.length - 1; + int thisCount = jvmShapeInfo.rank - 1; + for (int i = shape.length - 1; i > 0; i--) { + if (count < 0 || thisCount < 0) { + break; } - return true; + if (shape[count] != shape()[thisCount] && shape[count] != 1 && shape()[thisCount] != 1) { + compatible = false; + break; + } + + count--; + thisCount--; } - /** - * Compare two matrices. Returns true if and only if other is also a - * DoubleMatrix which has the same size and the maximal absolute - * difference in matrix elements is smaller than 1e-5. - * - * @param o - */ - @Override - public boolean equals(Object o) { - return equalsWithEps(o, Nd4j.EPS_THRESHOLD); - } + if (!compatible) { + throw new IllegalArgumentException( + "Incompatible broadcast from " + Arrays.toString(shape()) + " to " + + Arrays.toString(shape)); + } - @Override - public int hashCode() { - val longHash = Nd4j.exec(new HashCode(this))[0].getLong(0); - return Math.abs(longHash) <= Integer.MAX_VALUE ? (int) longHash : (int) (longHash % Integer.MAX_VALUE); - } - - @Override - public DataBuffer shapeInfoDataBuffer() { - return shapeInformation; - } - - @Override - public LongBuffer shapeInfo() { - return shapeInformation.asNioLong(); - } - - public long[] shape() { - return jvmShapeInfo.shape; - } - - @Override - public String shapeInfoToString() { - return Shape.shapeToString(this); - } - - @Override - public long[] stride() { - return jvmShapeInfo.stride; - } - - - @Override - public long offset() { - return data().offset(); - } - - @Override - public char ordering() { - return jvmShapeInfo.order; - } - - @Override - public long size(int dimension) { - if (dimension < 0) - dimension += jvmShapeInfo.rank; - - if (isScalar()) { - if (dimension == 0 || dimension == 1 || dimension < 0) - return length(); - else - throw new IllegalArgumentException("Illegal dimension for scalar " + dimension); - } - - if (dimension >= rank()) - throw new IllegalArgumentException("Invalid size: cannot get size of dimension " + dimension + " for rank " - + rank() + " NDArray (array shape: " + Arrays.toString(this.shape()) + ")"); - - return jvmShapeInfo.shape[dimension]; - } - - @Override - public int rank() { - return jvmShapeInfo.rank; - } - - @Override - public long length() { - if (isEmpty()) - return 0; - return jvmShapeInfo.length; - } - - @Override - public INDArray broadcast(INDArray result) { - Nd4j.getCompressor().autoDecompress(this); - - val shape = result.shape(); - - if (Shape.shapeEquals(shape, shape())) - return this; - - // if we're on scalar, we can just create new array - if (this.isScalar()) - return Nd4j.createUninitialized(this.dataType(), shape).assign(this.getDouble(0)); - - - - - boolean compatible = true; - int count = shape.length - 1; - int thisCount = jvmShapeInfo.rank - 1; - for (int i = shape.length - 1; i > 0; i--) { - if (count < 0 || thisCount < 0) - break; - if (shape[count] != shape()[thisCount] && shape[count] != 1 && shape()[thisCount] != 1) { - compatible = false; - break; - } - - count--; - thisCount--; - } - - if (!compatible) - throw new IllegalArgumentException("Incompatible broadcast from " + Arrays.toString(shape()) + " to " - + Arrays.toString(shape)); - - - - long[] retShape = new long[shape.length]; - List broadCastDimensions = new ArrayList<>(); - List nonBroadCastDimensions = new ArrayList<>(); - for (int i = 0; i < retShape.length; i++) { - if (shape().length == 1) { - if (i == 0) { - if (i < shape().length) - retShape[i] = Math.max(1, shape[i]); - else - retShape[i] = shape[i]; - } else { - if (i < shape().length) - retShape[i] = Math.max(shape[i], size(i)); - else - retShape[i] = shape[i]; - } + long[] retShape = new long[shape.length]; + List broadCastDimensions = new ArrayList<>(); + List nonBroadCastDimensions = new ArrayList<>(); + for (int i = 0; i < retShape.length; i++) { + if (shape().length == 1) { + if (i == 0) { + if (i < shape().length) { + retShape[i] = Math.max(1, shape[i]); } else { - if (i < rank() && size(i) == 1) - broadCastDimensions.add(i); - else - nonBroadCastDimensions.add(i); - if (i < shape().length) - retShape[i] = Math.max(shape[i], size(i)); - else - retShape[i] = shape[i]; + retShape[i] = shape[i]; } - - } - - - if (isRowVector()) { - //number of times to repeat each value - for (int i = 0; i < result.slices(); i++) { - result.putSlice(i, this); - } - } else if (isColumnVector()) { - for (int i = 0; i < result.columns(); i++) { - result.putColumn(i, this); + } else { + if (i < shape().length) { + retShape[i] = Math.max(shape[i], size(i)); + } else { + retShape[i] = shape[i]; } } - - else { - int[] repeat = new int[shape.length]; - for(int i = 0; i < shape.length; i++) { - if(i < rank()) { - if(size(i) == 1) - repeat[i] = (int) shape[i]; - else { - repeat[i] = 1; - } - } - - else { - repeat[i] = (int) shape[i]; - } - } - - if (this.isView()) { - Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this.dup(this.ordering())},new INDArray[]{result},repeat)); - } else - Nd4j.getExecutioner().execAndReturn(new Tile(new INDArray[]{this},new INDArray[]{result},repeat)); - } - return result; + } else { + if (i < rank() && size(i) == 1) { + broadCastDimensions.add(i); + } else { + nonBroadCastDimensions.add(i); + } + if (i < shape().length) { + retShape[i] = Math.max(shape[i], size(i)); + } else { + retShape[i] = shape[i]; + } + } } - @Override - public INDArray broadcast(long... shape) { - return broadcast(Nd4j.createUninitialized(this.dataType(), shape, this.ordering())); + if (isRowVector()) { + //number of times to repeat each value + for (int i = 0; i < result.slices(); i++) { + result.putSlice(i, this); + } + } else if (isColumnVector()) { + for (int i = 0; i < result.columns(); i++) { + result.putColumn(i, this); + } + } else { + int[] repeat = new int[shape.length]; + for (int i = 0; i < shape.length; i++) { + if (i < rank()) { + if (size(i) == 1) { + repeat[i] = (int) shape[i]; + } else { + repeat[i] = 1; + } + } else { + repeat[i] = (int) shape[i]; + } + } + + if (this.isView()) { + Nd4j.getExecutioner().execAndReturn( + new Tile(new INDArray[]{this.dup(this.ordering())}, new INDArray[]{result}, + repeat)); + } else { + Nd4j.getExecutioner() + .execAndReturn(new Tile(new INDArray[]{this}, new INDArray[]{result}, repeat)); + } + } + return result; + + } + + @Override + public INDArray broadcast(long... shape) { + return broadcast(Nd4j.createUninitialized(this.dataType(), shape, this.ordering())); + } + + @Deprecated + @Override + public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable) { + return dimShuffle(rearrange, ArrayUtil.toLongArray(newOrder), broadCastable); + } + + /** + * Dimshuffle: an extension of permute that adds the ability to broadcast various dimensions. + *

      + * See theano for more examples. This will only accept integers and xs. + *

      + * An x indicates a dimension should be broadcasted rather than permuted. + * + * @param rearrange the dimensions to swap to + * @return the newly permuted array + */ + @Override + public INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable) { + Nd4j.getCompressor().autoDecompress(this); + + if (broadCastable.length != jvmShapeInfo.rank) { + throw new IllegalArgumentException( + "The broadcastable dimensions must be the same length as the current shape"); + } + + boolean broadcast = false; + Set set = new HashSet<>(); + for (int i = 0; i < rearrange.length; i++) { + set.add(rearrange[i]); + if (rearrange[i] instanceof Integer) { + Integer j = (Integer) rearrange[i]; + if (j >= broadCastable.length) { + throw new IllegalArgumentException( + "Illegal dimension, dimension must be < broadcastable.length (aka the real dimensions"); + } + } else if (rearrange[i] instanceof Character) { + Character c = (Character) rearrange[i]; + if (c != 'x') { + throw new IllegalArgumentException("Illegal input: Must be x"); + } + broadcast = true; + + } else { + throw new IllegalArgumentException("Only characters and integers allowed"); + } } - @Deprecated - @Override - public INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable) { - return dimShuffle(rearrange, ArrayUtil.toLongArray(newOrder), broadCastable); + //just do permute + if (!broadcast) { + int[] ret = new int[rearrange.length]; + for (int i = 0; i < ret.length; i++) { + ret[i] = (Integer) rearrange[i]; + } + return permute(ret); + } else { + List drop = new ArrayList<>(); + for (int i = 0; i < broadCastable.length; i++) { + if (!set.contains(i)) { + if (broadCastable[i]) { + drop.add(i); + } else { + throw new IllegalArgumentException( + "We can't drop the given dimension because its not broadcastable"); + } + } + + } + + //list of dimensions to keep + int[] shuffle = new int[broadCastable.length]; + int count = 0; + for (int i = 0; i < rearrange.length; i++) { + if (rearrange[i] instanceof Integer) { + shuffle[count++] = (Integer) rearrange[i]; + } + } + + List augment = new ArrayList<>(); + for (int i = 0; i < rearrange.length; i++) { + if (rearrange[i] instanceof Character) { + augment.add(i); + } + } + + Integer[] augmentDims = augment.toArray(new Integer[1]); + + count = 0; + + int dropIdx = 0; + int[] newShape = new int[shuffle.length + drop.size()]; + for (int i = 0; i < newShape.length; i++) { + if (i < shuffle.length) { + newShape[count++] = shuffle[i]; + } else { + newShape[count++] = drop.get(dropIdx++); + } + } + + INDArray ret; //TODO is this correct? This was old behaviour before adding permute input check + if (newShape.length == this.rank()) { + ret = permute(newShape); + } else { + ret = dup(); + } + List newDims = new ArrayList<>(); + long[] shape = Arrays.copyOfRange(ret.shape(), 0, shuffle.length); + for (int i = 0; i < shape.length; i++) { + newDims.add(shape[i]); + } + + for (int i = 0; i < augmentDims.length; i++) { + newDims.add(augmentDims[i], 1L); + } + + long[] toReshape = ArrayUtil.toArrayLong(newDims); + + ret = ret.reshape(toReshape); + return ret; + } - /** - * Dimshuffle: an extension of permute that adds the ability - * to broadcast various dimensions. - *

      - * See theano for more examples. - * This will only accept integers and xs. - *

      - * An x indicates a dimension should be broadcasted rather than permuted. - * - * @param rearrange the dimensions to swap to - * @return the newly permuted array - */ - @Override - public INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable) { - Nd4j.getCompressor().autoDecompress(this); - if (broadCastable.length != jvmShapeInfo.rank) + } + + @Override + public INDArray permute(int... rearrange) { + Preconditions.checkArgument(rearrange.length == rank(), + "Incorrect number of arguments for permute function:" + + " got arguments %s for rank %s array. Number of arguments must equal array rank", + rearrange, rank()); + Nd4j.getCompressor().autoDecompress(this); + boolean alreadyInOrder = true; + //IntBuffer shapeInfo = shapeInfo(); + int rank = jvmShapeInfo.rank; + for (int i = 0; i < rank; i++) { + if (rearrange[i] != i) { + alreadyInOrder = false; + break; + } + } + + if (alreadyInOrder) { + return this; + } + + checkArrangeArray(rearrange); + val newShape = doPermuteSwap(shape(), rearrange); + val newStride = doPermuteSwap(stride(), rearrange); + + char newOrder = Shape.getOrder(newShape, newStride, 1); + + INDArray value = create(data(), newShape, newStride, offset(), newOrder); + return value; + } + + @Override + public INDArray permutei(int... rearrange) { + Preconditions.checkArgument(rearrange.length == rank(), + "Incorrect number of arguments for permute function:" + + " got arguments %s for rank %s array. Number of arguments must equal array rank", + rearrange, rank()); + boolean alreadyInOrder = true; + val shapeInfo = shapeInfo(); + int rank = jvmShapeInfo.rank; + for (int i = 0; i < rank; i++) { + if (rearrange[i] != i) { + alreadyInOrder = false; + break; + } + } + + if (alreadyInOrder) { + return this; + } + + checkArrangeArray(rearrange); + val newShape = doPermuteSwap(shape(), rearrange); + val newStride = doPermuteSwap(stride(), rearrange); + char newOrder = Shape.getOrder(newShape, newStride, 1); + + val ews = shapeInfo.get(2 * rank + 2); + + val si = Nd4j.getShapeInfoProvider() + .createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty()); + setShapeInformation(si); + + if (shapeInfo.get(2 * rank + 2) > 0) { + //for the backend to work - no ews for permutei + //^^ not true anymore? Not sure here. Marking this for raver + setShapeInformation(Nd4j.getShapeInfoProvider() + .createShapeInformation(newShape, newStride, 0, newOrder, dataType(), isEmpty())); + } + + //this.shape = null; + //this.stride = null; + + return this; + } + + + @Deprecated + protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) { + val ret = new long[rearrange.length]; + for (int i = 0; i < rearrange.length; i++) { + ret[i] = shape.get(rearrange[i]); + } + return ret; + } + + @Deprecated + protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) { + int[] ret = new int[rearrange.length]; + for (int i = 0; i < rearrange.length; i++) { + ret[i] = shape.get(rearrange[i]); + } + return ret; + } + + @Deprecated + protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) { + int[] ret = new int[rearrange.length]; + for (int i = 0; i < rearrange.length; i++) { + ret[i] = shape.getInt(rearrange[i]); + } + return ret; + } + + protected long[] doPermuteSwap(long[] shape, int[] rearrange) { + val ret = new long[rearrange.length]; + for (int i = 0; i < rearrange.length; i++) { + ret[i] = shape[rearrange[i]]; + } + + return ret; + } + + + protected void checkArrangeArray(int[] arr) { + Preconditions.checkArgument(arr.length == jvmShapeInfo.rank, + "Invalid rearrangement: number of arrangement (%s) != rank (%s)", + arr.length, jvmShapeInfo.rank); + for (int i = 0; i < arr.length; i++) { + if (arr[i] >= arr.length) { throw new IllegalArgumentException( - "The broadcastable dimensions must be the same length as the current shape"); - - boolean broadcast = false; - Set set = new HashSet<>(); - for (int i = 0; i < rearrange.length; i++) { - set.add(rearrange[i]); - if (rearrange[i] instanceof Integer) { - Integer j = (Integer) rearrange[i]; - if (j >= broadCastable.length) - throw new IllegalArgumentException( - "Illegal dimension, dimension must be < broadcastable.length (aka the real dimensions"); - } else if (rearrange[i] instanceof Character) { - Character c = (Character) rearrange[i]; - if (c != 'x') - throw new IllegalArgumentException("Illegal input: Must be x"); - broadcast = true; - - } else - throw new IllegalArgumentException("Only characters and integers allowed"); + "The specified dimensions can't be swapped. Given element " + i + + " was >= number of dimensions"); } - - //just do permute - if (!broadcast) { - int[] ret = new int[rearrange.length]; - for (int i = 0; i < ret.length; i++) - ret[i] = (Integer) rearrange[i]; - return permute(ret); - } else { - List drop = new ArrayList<>(); - for (int i = 0; i < broadCastable.length; i++) { - if (!set.contains(i)) { - if (broadCastable[i]) - drop.add(i); - else - throw new IllegalArgumentException( - "We can't drop the given dimension because its not broadcastable"); - } - - } - - - //list of dimensions to keep - int[] shuffle = new int[broadCastable.length]; - int count = 0; - for (int i = 0; i < rearrange.length; i++) { - if (rearrange[i] instanceof Integer) { - shuffle[count++] = (Integer) rearrange[i]; - } - } - - - List augment = new ArrayList<>(); - for (int i = 0; i < rearrange.length; i++) { - if (rearrange[i] instanceof Character) - augment.add(i); - } - - Integer[] augmentDims = augment.toArray(new Integer[1]); - - count = 0; - - int dropIdx = 0; - int[] newShape = new int[shuffle.length + drop.size()]; - for (int i = 0; i < newShape.length; i++) { - if (i < shuffle.length) { - newShape[count++] = shuffle[i]; - } else - newShape[count++] = drop.get(dropIdx++); - } - - INDArray ret; //TODO is this correct? This was old behaviour before adding permute input check - if(newShape.length == this.rank()){ - ret = permute(newShape); - } else { - ret = dup(); - } - List newDims = new ArrayList<>(); - long[] shape = Arrays.copyOfRange(ret.shape(), 0, shuffle.length); - for (int i = 0; i < shape.length; i++) { - newDims.add(shape[i]); - } - - for (int i = 0; i < augmentDims.length; i++) { - newDims.add(augmentDims[i], 1L); - } - - long[] toReshape = ArrayUtil.toArrayLong(newDims); - - - ret = ret.reshape(toReshape); - return ret; - + if (arr[i] < 0) { + throw new IllegalArgumentException("Invalid dimension: " + i + " : negative value"); } } - @Override - public INDArray permute(int... rearrange) { - Preconditions.checkArgument(rearrange.length == rank(), "Incorrect number of arguments for permute function:" + - " got arguments %s for rank %s array. Number of arguments must equal array rank", rearrange, rank()); - Nd4j.getCompressor().autoDecompress(this); - boolean alreadyInOrder = true; - //IntBuffer shapeInfo = shapeInfo(); - int rank = jvmShapeInfo.rank; - for (int i = 0; i < rank; i++) { - if (rearrange[i] != i) { - alreadyInOrder = false; - break; - } - } - - if (alreadyInOrder) - return this; - - checkArrangeArray(rearrange); - val newShape = doPermuteSwap(shape(), rearrange); - val newStride = doPermuteSwap(stride(), rearrange); - - char newOrder = Shape.getOrder(newShape, newStride, 1); - - INDArray value = create(data(), newShape, newStride, offset(), newOrder); - return value; + for (int i = 0; i < arr.length; i++) { + for (int j = 0; j < arr.length; j++) { + if (i != j && arr[i] == arr[j]) { + throw new IllegalArgumentException("Permute array must have unique elements"); + } + } } - @Override - public INDArray permutei(int... rearrange) { - Preconditions.checkArgument(rearrange.length == rank(), "Incorrect number of arguments for permute function:" + - " got arguments %s for rank %s array. Number of arguments must equal array rank", rearrange, rank()); - boolean alreadyInOrder = true; - val shapeInfo = shapeInfo(); - int rank = jvmShapeInfo.rank; - for (int i = 0; i < rank; i++) { - if (rearrange[i] != i) { - alreadyInOrder = false; - break; - } - } + } - if (alreadyInOrder) - return this; - - checkArrangeArray(rearrange); - val newShape = doPermuteSwap(shape(), rearrange); - val newStride = doPermuteSwap(stride(), rearrange); - char newOrder = Shape.getOrder(newShape, newStride, 1); - - val ews = shapeInfo.get(2 * rank + 2); - - val si = Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, ews, newOrder, dataType(), isEmpty()); - setShapeInformation(si); - - - if (shapeInfo.get(2 * rank + 2) > 0) { - //for the backend to work - no ews for permutei - //^^ not true anymore? Not sure here. Marking this for raver - setShapeInformation(Nd4j.getShapeInfoProvider().createShapeInformation(newShape, newStride, 0, newOrder, dataType(), isEmpty())); - } - - //this.shape = null; - //this.stride = null; - - - return this; - } - - - @Deprecated - protected long[] doPermuteSwap(LongBuffer shape, int[] rearrange) { - val ret = new long[rearrange.length]; - for (int i = 0; i < rearrange.length; i++) { - ret[i] = shape.get(rearrange[i]); - } - return ret; - } - - @Deprecated - protected int[] doPermuteSwap(IntBuffer shape, int[] rearrange) { - int[] ret = new int[rearrange.length]; - for (int i = 0; i < rearrange.length; i++) { - ret[i] = shape.get(rearrange[i]); - } - return ret; - } - - @Deprecated - protected int[] doPermuteSwap(DataBuffer shape, int[] rearrange) { - int[] ret = new int[rearrange.length]; - for (int i = 0; i < rearrange.length; i++) { - ret[i] = shape.getInt(rearrange[i]); - } - return ret; - } - - protected long[] doPermuteSwap(long[] shape, int[] rearrange) { - val ret = new long[rearrange.length]; - for (int i = 0; i < rearrange.length; i++) { - ret[i] = shape[rearrange[i]]; - } - - return ret; - } - - - protected void checkArrangeArray(int[] arr) { - Preconditions.checkArgument(arr.length == jvmShapeInfo.rank, "Invalid rearrangement: number of arrangement (%s) != rank (%s)", - arr.length, jvmShapeInfo.rank); - for (int i = 0; i < arr.length; i++) { - if (arr[i] >= arr.length) - throw new IllegalArgumentException("The specified dimensions can't be swapped. Given element " + i - + " was >= number of dimensions"); - if (arr[i] < 0) - throw new IllegalArgumentException("Invalid dimension: " + i + " : negative value"); - - - } - - for (int i = 0; i < arr.length; i++) { - for (int j = 0; j < arr.length; j++) { - if (i != j && arr[i] == arr[j]) - throw new IllegalArgumentException("Permute array must have unique elements"); - } - } - - } - - protected void autoProcessScalarCall() { + protected void autoProcessScalarCall() { /* if (Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.DISABLED && Nd4j.getExecutioner().getProfilingMode() != OpExecutioner.ProfilingMode.SCOPE_PANIC) OpProfiler.getInstance().processScalarCall();*/ + } + + /** + * Checks whether the matrix is a vector. + */ + @Override + public boolean isVector() { + if (jvmShapeInfo.rank == 1) { + return true; + } + + return isRowVector() || isColumnVector(); + } + + @Override + public boolean isVectorOrScalar() { + return isVector() || isScalar(); + } + + @Override + public boolean isSquare() { + return isMatrix() && rows() == columns(); + } + + @Override + public boolean isRowVector() { + return (rank() == 2 && rows() == 1) && length() > 1 || rank() == 1 && length() > 1; + } + + @Override + public boolean isColumnVector() { + return rank() == 2 && columns() == 1 && length() > 1; + } + + @Override + public boolean isColumnVectorOrScalar() { + return isColumnVector() || isScalar(); + } + + @Override + public boolean isRowVectorOrScalar() { + return isRowVector() || isScalar(); + } + + /** + * Generate string representation of the matrix. Printing will switch to scientific notation on a + * per element basis - when abs value is greater than or equal to 10000 - when abs value is less + * than or equal to 0.0001 and not zero + *

      + * If the number of elements in the array is greater than 1000 (by default) only the first and + * last three elements in a dimension are included. This can be changed globally using + * {@link NDArrayStrings#setMaxPrintElements(long)} + */ + @Override + public String toString() { + return toString(new NDArrayStrings()); + } + + + @Override + public String toString(@NonNull NDArrayStrings options) { + if (wasClosed()) { + return ""; + } + if (!isCompressed() && !preventUnpack) { + return options.format(this); + } else if (isCompressed() && compressDebug) { + return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered."; + } else if (preventUnpack) { + return "Array string unpacking is disabled."; + } + return options.format(this); + } + + @Override + public String toString(long maxElements, boolean forceSummarize, int precision) { + return toString(new NDArrayStrings(maxElements, forceSummarize, precision)); + } + + + @Override + public String toStringFull() { + return toString(Long.MAX_VALUE, false, -1 * dataType().precision()); + } + + @Override + public Object element() { + + if (!isScalar()) { + throw new IllegalStateException("Unable to retrieve element from non scalar matrix"); + } + if (data.dataType() == DataType.FLOAT) { + return data.getFloat(0); + } + return data.getDouble(0); + } + + @Override + public INDArray remainder(INDArray denominator) { + if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { + return remainder(denominator, Nd4j.createUninitialized(this.dataType(), + Shape.broadcastOutputShape(this.shape(), denominator.shape()))); + } else { + return remainder(denominator, this.ulike()); + } + } + + @Override + public INDArray remainder(INDArray denominator, INDArray result) { + validateNumericalArray("remainder", false); + Preconditions.checkArgument(Shape.areShapesBroadcastable(this.shape(), denominator.shape()), + "Shapes must be broadcastable"); + + val op = new RemainderOp(this, denominator, result); + Nd4j.getExecutioner().exec(op); + return result; + } + + @Override + public INDArray remainder(Number denominator) { + return remainder(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); + } + + @Override + public INDArray remainder(Number denominator, INDArray result) { + validateNumericalArray("remainder", false); + + ScalarRemainder op = new ScalarRemainder(this, null, result, denominator); + Nd4j.getExecutioner().exec(op); + return result; + } + + @Override + public INDArray remainderi(INDArray denominator) { + validateNumericalArray("remainderi", false); + RemainderOp op = new RemainderOp(this, denominator, this); + Nd4j.getExecutioner().exec(op); + return this; + } + + @Override + public INDArray remainderi(Number denominator) { + validateNumericalArray("remainderi", false); + ScalarRemainder op = new ScalarRemainder(this, null, this, denominator); + Nd4j.getExecutioner().exec(op); + return this; + } + + @Override + public INDArray fmod(INDArray denominator) { + validateNumericalArray("fmod", false); + if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { + return fmod(denominator, Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), + Shape.broadcastOutputShape(this.shape(), denominator.shape()))); + } else { + return fmod(denominator, this.ulike()); + } + } + + @Override + public INDArray fmod(INDArray denominator, INDArray result) { + validateNumericalArray("fmod", false); + if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { + val outShape = Shape.broadcastOutputShape(this.shape(), denominator.shape()); + Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), + "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); + + Nd4j.exec(new FloorModOp(new INDArray[]{this, denominator}, new INDArray[]{result})); + + return result; + } else { + FModOp op = new FModOp(this, denominator, result); + Nd4j.getExecutioner().exec(op); + return result; + } + } + + @Override + public INDArray fmod(Number denominator) { + return fmod(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); + } + + @Override + public INDArray fmod(Number denominator, INDArray result) { + validateNumericalArray("fmod", false); + ScalarFMod op = new ScalarFMod(this, null, result, denominator); + Nd4j.getExecutioner().exec(op); + return result; + } + + @Override + public INDArray fmodi(INDArray denominator) { + validateNumericalArray("fmodi", false); + FModOp op = new FModOp(this, denominator, this); + Nd4j.getExecutioner().exec(op); + return this; + } + + @Override + public INDArray fmodi(Number denominator) { + validateNumericalArray("fmodi", false); + ScalarFMod op = new ScalarFMod(this, null, this, denominator); + Nd4j.getExecutioner().exec(op); + return this; + } + + @Override + public Iterator iterator() { + return new FirstAxisIterator(this); + } + + @Override + public long originalOffset() { + if (data().originalOffset() >= Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "Original offset of buffer can not be >= Integer.MAX_VALUE"); + } + + return data().originalOffset(); + } + + private void readObject(ObjectInputStream s) { + try { + s.defaultReadObject(); + read(s); + } catch (Exception e) { + throw new RuntimeException(e); } - /** - * Checks whether the matrix is a vector. - */ - @Override - public boolean isVector() { - if (jvmShapeInfo.rank == 1) - return true; + } - return isRowVector() || isColumnVector(); + private void writeObject(ObjectOutputStream out) throws IOException { + out.defaultWriteObject(); + write(out); + } + + //Custom serialization for Java serialization + protected void write(ObjectOutputStream out) throws IOException { + if (this.isView()) { + //As per Nd4j.write, duplicate before writing to the output stream + //BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here + //Furthermore, because we only want to save the *actual* data for a view (not the full data), the shape info + // (mainly strides, offset, element-wise stride) may be different in the duped array vs. the view array + INDArray copy = this.dup(); + copy.shapeInfoDataBuffer().write(out); + copy.data().write(out); + } else { + shapeInformation.write(out); + data().write(out); } + } - @Override - public boolean isVectorOrScalar() { - return isVector() || isScalar(); - } + //Custom deserialization for Java serialization + protected void read(ObjectInputStream s) { + val headerShape = BaseDataBuffer.readHeader(s); - @Override - public boolean isSquare() { - return isMatrix() && rows() == columns(); - } + shapeInformation = Nd4j.createBuffer(new int[Shape.shapeInfoLength(rank())]); + shapeInformation.read(s, headerShape.getLeft(), headerShape.getMiddle(), + headerShape.getRight()); - @Override - public boolean isRowVector() { - return (rank() == 2 && rows() == 1) && length() > 1 || rank() == 1 && length() > 1; - } + setShapeInformation(Pair.create(shapeInformation, shapeInformation.asLong())); - @Override - public boolean isColumnVector() { - return rank() == 2 && columns() == 1 && length() > 1; - } + val headerData = BaseDataBuffer.readHeader(s); + data = Nd4j.createBuffer(headerData.getRight(), headerData.getMiddle(), false); + data().read(s, headerData.getLeft(), headerData.getMiddle(), headerData.getRight()); + } - @Override - public boolean isColumnVectorOrScalar() { - return isColumnVector() || isScalar(); - } + @Override + public INDArray argMax(int... dimension) { + return Nd4j.argMax(this, dimension); + } - @Override - public boolean isRowVectorOrScalar() { - return isRowVector() || isScalar(); - } + @Override + public boolean isAttached() { + if (isEmpty()) { + return false; + } - /** - * Generate string representation of the matrix. - * Printing will switch to scientific notation on a per element basis - * - when abs value is greater than or equal to 10000 - * - when abs value is less than or equal to 0.0001 and not zero - * - * If the number of elements in the array is greater than 1000 (by default) only the first and last three elements - * in a dimension are included. This can be changed globally using {@link NDArrayStrings#setMaxPrintElements(long)} - * - * - */ - @Override - public String toString() { - return toString(new NDArrayStrings()); - } + Preconditions.checkArgument(!(data == null && !isEmpty()), "Array has no buffer!"); + return data.isAttached() || + (data.underlyingDataBuffer() != null && data.underlyingDataBuffer().isAttached()) || + (data.originalDataBuffer() != null && data.originalDataBuffer().isAttached()); + } - @Override - public String toString(@NonNull NDArrayStrings options) { - if(wasClosed()) - return ""; - if (!isCompressed() && !preventUnpack) - return options.format(this); - else if (isCompressed() && compressDebug) - return "COMPRESSED ARRAY. SYSTEM PROPERTY compressdebug is true. This is to prevent auto decompression from being triggered."; - else if (preventUnpack) - return "Array string unpacking is disabled."; - return options.format(this); - } + @Override + public boolean isInScope() { + if (!isAttached()) { + return true; + } - @Override - public String toString(long maxElements, boolean forceSummarize, int precision){ - return toString(new NDArrayStrings(maxElements, forceSummarize, precision)); - } + return data.isInScope(); + } + @Override + public INDArray detach() { + if (!isAttached()) { + return this; + } - @Override - public String toStringFull(){ - return toString(Long.MAX_VALUE, false, -1 * dataType().precision()); - } + WorkspaceUtils.assertValidArray(this, "Cannot detach INDArray"); - @Override - public Object element() { - - if (!isScalar()) - throw new IllegalStateException("Unable to retrieve element from non scalar matrix"); - if (data.dataType() == DataType.FLOAT) - return data.getFloat(0); - return data.getDouble(0); - } - - @Override - public INDArray remainder(INDArray denominator) { - if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { - return remainder(denominator, Nd4j.createUninitialized(this.dataType(), Shape.broadcastOutputShape(this.shape(), denominator.shape()))); - } else - return remainder(denominator, this.ulike()); - } - - @Override - public INDArray remainder(INDArray denominator, INDArray result) { - validateNumericalArray("remainder", false); - Preconditions.checkArgument(Shape.areShapesBroadcastable(this.shape(), denominator.shape()),"Shapes must be broadcastable"); - - val op = new RemainderOp(this, denominator, result); - Nd4j.getExecutioner().exec(op); - return result; - } - - @Override - public INDArray remainder(Number denominator) { - return remainder(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); - } - - @Override - public INDArray remainder(Number denominator, INDArray result) { - validateNumericalArray("remainder", false); - - ScalarRemainder op = new ScalarRemainder(this, null, result, denominator); - Nd4j.getExecutioner().exec(op); - return result; - } - - @Override - public INDArray remainderi(INDArray denominator) { - validateNumericalArray("remainderi", false); - RemainderOp op = new RemainderOp(this, denominator, this); - Nd4j.getExecutioner().exec(op); - return this; - } - - @Override - public INDArray remainderi(Number denominator) { - validateNumericalArray("remainderi", false); - ScalarRemainder op = new ScalarRemainder(this, null, this, denominator); - Nd4j.getExecutioner().exec(op); - return this; - } - - @Override - public INDArray fmod(INDArray denominator) { - validateNumericalArray("fmod", false); - if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { - return fmod(denominator, Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), Shape.broadcastOutputShape(this.shape(), denominator.shape()))); - } else - return fmod(denominator, this.ulike()); - } - - @Override - public INDArray fmod(INDArray denominator, INDArray result) { - validateNumericalArray("fmod", false); - if (Shape.areShapesBroadcastable(this.shape(), denominator.shape())) { - val outShape = Shape.broadcastOutputShape(this.shape(), denominator.shape()); - Preconditions.checkArgument(Shape.shapeEquals(outShape, result.shape()), "Result shape doesn't match expectations: " + Arrays.toString(result.shape())); - - Nd4j.exec(new FloorModOp(new INDArray[]{this, denominator}, new INDArray[]{result})); - - return result; - } else { - FModOp op = new FModOp(this, denominator, result); - Nd4j.getExecutioner().exec(op); - return result; - } - } - - @Override - public INDArray fmod(Number denominator) { - return fmod(denominator, Nd4j.createUninitialized(this.dataType(), this.shape())); - } - - @Override - public INDArray fmod(Number denominator, INDArray result) { - validateNumericalArray("fmod", false); - ScalarFMod op = new ScalarFMod(this, null, result, denominator); - Nd4j.getExecutioner().exec(op); - return result; - } - - @Override - public INDArray fmodi(INDArray denominator) { - validateNumericalArray("fmodi", false); - FModOp op = new FModOp(this, denominator, this); - Nd4j.getExecutioner().exec(op); - return this; - } - - @Override - public INDArray fmodi(Number denominator) { - validateNumericalArray("fmodi", false); - ScalarFMod op = new ScalarFMod(this, null, this, denominator); - Nd4j.getExecutioner().exec(op); - return this; - } - - @Override - public Iterator iterator() { - return new FirstAxisIterator(this); - } - - @Override - public long originalOffset() { - if (data().originalOffset() >= Integer.MAX_VALUE) - throw new IllegalArgumentException("Original offset of buffer can not be >= Integer.MAX_VALUE"); - - return data().originalOffset(); - } - - private void readObject(ObjectInputStream s) { - try { - s.defaultReadObject(); - read(s); - } catch (Exception e) { - throw new RuntimeException(e); - } - - } - - private void writeObject(ObjectOutputStream out) throws IOException { - out.defaultWriteObject(); - write(out); - } - - //Custom serialization for Java serialization - protected void write(ObjectOutputStream out) throws IOException { - if (this.isView()) { - //As per Nd4j.write, duplicate before writing to the output stream - //BaseDataBuffer.write(...) doesn't know about strides etc, so dup (or equiv. strategy) is necessary here - //Furthermore, because we only want to save the *actual* data for a view (not the full data), the shape info - // (mainly strides, offset, element-wise stride) may be different in the duped array vs. the view array - INDArray copy = this.dup(); - copy.shapeInfoDataBuffer().write(out); - copy.data().write(out); - } else { - shapeInformation.write(out); - data().write(out); - } - } - - //Custom deserialization for Java serialization - protected void read(ObjectInputStream s) { - val headerShape = BaseDataBuffer.readHeader(s); - - shapeInformation = Nd4j.createBuffer(new int[Shape.shapeInfoLength(rank())]); - shapeInformation.read(s, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getRight()); - - setShapeInformation(Pair.create(shapeInformation, shapeInformation.asLong())); - - val headerData = BaseDataBuffer.readHeader(s); - data = Nd4j.createBuffer(headerData.getRight(), headerData.getMiddle(), false); - data().read(s, headerData.getLeft(), headerData.getMiddle(), headerData.getRight()); - } - - @Override - public INDArray argMax(int... dimension) { - return Nd4j.argMax(this, dimension); - } - - @Override - public boolean isAttached() { - if (isEmpty()) - return false; - - Preconditions.checkArgument(!(data == null && !isEmpty()), "Array has no buffer!"); - - return data.isAttached() || - (data.underlyingDataBuffer() != null && data.underlyingDataBuffer().isAttached()) || - (data.originalDataBuffer() != null && data.originalDataBuffer().isAttached()); - } - - @Override - public boolean isInScope() { - if (!isAttached()) - return true; - - return data.isInScope(); - } - - @Override - public INDArray detach() { - if (!isAttached()) - return this; - - WorkspaceUtils.assertValidArray(this, "Cannot detach INDArray"); - - Nd4j.getExecutioner().commit(); + Nd4j.getExecutioner().commit(); /* two options here 1) we're within some workspace 2) we're out of any workspace */ - if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) { - if (!isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); + if (Nd4j.getMemoryManager().getCurrentWorkspace() == null) { + if (!isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - return Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); - } else { - INDArray copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); - copy.assign(this); - Nd4j.getExecutioner().commit(); - - return copy; - } - } else { - MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); - Nd4j.getMemoryManager().setCurrentWorkspace(null); - INDArray copy = null; - - if (!isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - - //Pointer.memcpy(buffer.pointer(), this.data.pointer(), this.lengthLong() * Nd4j.sizeOfDataType(this.data.dataType())); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); //this.dup(this.ordering()); - - - } else { - copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); - copy.assign(this); - Nd4j.getExecutioner().commit(); - } - - Nd4j.getMemoryManager().setCurrentWorkspace(workspace); - - return copy; - } - } - - @Override - public INDArray leverage() { - WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); - if (!isAttached()) - return this; - - MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); - if (workspace == null) { - return this.detach(); - } - - MemoryWorkspace parentWorkspace = workspace.getParentWorkspace(); - - if (this.data.getParentWorkspace() == parentWorkspace) - return this; - - // if there's no parent ws - just detach - if (parentWorkspace == null) - return this.detach(); - else { - Nd4j.getExecutioner().commit(); - - // temporary set parent ws as current ws - Nd4j.getMemoryManager().setCurrentWorkspace(parentWorkspace); - - INDArray copy = null; - if (!this.isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.length(), false); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); - } else { - copy = this.dup(this.ordering()); - Nd4j.getExecutioner().commit(); - } - - // restore current ws - Nd4j.getMemoryManager().setCurrentWorkspace(workspace); - return copy; - } - } - - @Override - public INDArray leverageTo(String id) { - return leverageTo(id, false); - } - - @Override - public INDArray leverageTo(String id, boolean enforceExistence) throws Nd4jNoSuchWorkspaceException { - WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); - if (!isAttached()) - return this; - - if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) { - if(enforceExistence){ - throw new Nd4jNoSuchWorkspaceException(id); - } else { - return this; - } - } - - MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); - MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id); - - if (this.data.getParentWorkspace() == target) - return this; - - Nd4j.getMemoryManager().setCurrentWorkspace(target); - INDArray copy = null; - if (!this.isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); - } else { - copy = this.dup(this.ordering()); - Nd4j.getExecutioner().commit(); - } - - Nd4j.getMemoryManager().setCurrentWorkspace(current); - - return copy; - } - - public INDArray leverageOrDetach(String id){ - if(!isAttached()){ - return this; - } - - if(!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(id)){ - return detach(); - } - return leverageTo(id); - } - - @Override - public INDArray migrate() { - return migrate(false); - } - - @Override - public INDArray migrate(boolean detachOnNoWs){ - WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); - - MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); - - if (current == null) { - if(detachOnNoWs){ - return detach(); - } else { - return this; - } - } - - INDArray copy = null; - - if (!this.isView()) { - Nd4j.getExecutioner().commit(); - DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - Nd4j.getMemoryManager().memcpy(buffer, this.data()); - - copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); - } else { - copy = this.dup(this.ordering()); - Nd4j.getExecutioner().commit(); - } - - return copy; - } - - @Override - public Number percentileNumber(Number quantile) { - validateNumericalArray("percentileNumber", false); - if (quantile.intValue() < 0 || quantile.intValue() > 100) - throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); - - if (isScalar()) - return this.getDouble(0); - - INDArray sorted = Nd4j.sort(this.dup(this.ordering()), true); - - return getPercentile(quantile, sorted); - } - - @Override - public Number medianNumber() { - validateNumericalArray("medianNumber", false); - if(isScalar()) - return getNumber(0); - return percentileNumber(50); - } - - @Override - public INDArray median(int... dimension) { - validateNumericalArray("median", false); - //Check edge case: size 1 element. No dimension == full array - if(dimension.length == 0){ - return Nd4j.scalar(dataType(), medianNumber().doubleValue()); - } - long shapeProd = 1; - for (int d : dimension) { - shapeProd *= size(d); - } - if (shapeProd == 1) { - long[] newShape = ArrayUtil.removeIndex(shape(), dimension); - return dup('c').reshape('c', newShape); - } - return percentile(50, dimension); - } - - protected double getPercentile(Number quantile, INDArray sorted) { - validateNumericalArray("getPercentile", false); - if (quantile.intValue() == 0) - return sorted.getDouble(0); - else if (quantile.intValue() == 100) - return sorted.getDouble(sorted.length() - 1); - - double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); - if (pos < 1) - return sorted.getDouble(0); - else if (pos >= sorted.length()) - return sorted.getDouble(sorted.length() - 1); - - double fposition = FastMath.floor(pos); - int position = (int)fposition; - - double diff = pos - fposition; - - double lower = sorted.getDouble(position-1); - double upper = sorted.getDouble(position); - - return lower + diff * (upper - lower); - } - - @Override - public INDArray percentile(Number quantile, int... dimension) { - validateNumericalArray("percentile", false); - if (quantile.doubleValue() < 0 || quantile.doubleValue() > 100) - throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); - - if (isScalar()) - return Nd4j.scalar(this.getDouble(0)); - - INDArray sorted = Nd4j.getNDArrayFactory().sort(this.dup(this.ordering()), false, dimension); - - // there's no practical sense doing this on GPU, stride will be just size of TAD. - INDArray ret = Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), sorted.tensorsAlongDimension(dimension)); - for (int i = 0; i < ret.length(); i++) { - ret.putScalar(i, getPercentile(quantile, sorted.tensorAlongDimension(i, dimension))); - } - - return ret; - - } - - protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer); - - @Override - public int toFlatArray(FlatBufferBuilder builder) { - if(isView()){ - return dup(this.ordering()).toFlatArray(builder); - } - int shape = FlatArray.createShapeVector(builder, this.shapeInfoDataBuffer().asLong()); - int buffer = this.isEmpty() ? 0 : this.dataType() == DataType.UTF8 ? stringBuffer(builder, this.data()) : FlatArray.createBufferVector(builder, this.data().asBytes()); - val type = this.isEmpty() ? FlatBuffersMapper.getDataTypeAsByte(this.dataType()) : FlatBuffersMapper.getDataTypeAsByte(this.data().dataType()); - int array = FlatArray.createFlatArray(builder, shape, buffer, type, ByteOrder.BE); - - return array; - } - - protected static DataTypeEx convertType(DataType type) { - if (type == DataType.HALF) { - return DataTypeEx.FLOAT16; - } else if (type == DataType.FLOAT) { - return DataTypeEx.FLOAT; - } else if (type == DataType.DOUBLE) { - return DataTypeEx.DOUBLE; - - } else if(type == DataType.INT) { - return DataTypeEx.INT8; - } else if(type == DataType.LONG) { - return DataTypeEx.INT16; - - } else - throw new IllegalStateException("Unknown dataType: [" + type + "]"); - } - - @Override - public boolean isEmpty() { - return Shape.isEmpty(jvmShapeInfo.javaShapeInformation); - } - - @Override - public long[] shapeInfoJava() { - return jvmShapeInfo.javaShapeInformation; - } - - @Override - public DataType dataType() { - if (data != null) - return data.dataType(); - - val e = Shape.extras(jvmShapeInfo.javaShapeInformation); - - if (e != 0) { - val t = ArrayOptionsHelper.dataType(jvmShapeInfo.javaShapeInformation); - return t; - } - - return DataType.UNKNOWN; - } - - @Override - public boolean isR() { - val dtype = dataType(); - return dtype == DataType.FLOAT || dtype == DataType.DOUBLE || dtype == DataType.HALF || dtype == DataType.BFLOAT16; - } - - @Override - public boolean isZ() { - return !isR() && !isB() && !isS(); - } - - @Override - public boolean isB() { - return dataType() == DataType.BOOL; - } - - @Override - public boolean isS() { - return dataType() == DataType.UTF8; - } - - @Override - public INDArray castTo(DataType dataType) { - if(dataType == dataType()) //No-op if correct datatype - return this; - if(isEmpty() && rank() == 0){ - return Nd4j.empty(dataType); - } - val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); - result.assign(this); - return result; - } - - @Override - public boolean all() { - val r = Nd4j.getExecutioner().exec(new All(this)); - return r.getDouble(0) != 0.0; - } - - @Override - public boolean any() { - val r = Nd4j.getExecutioner().exec(new Any(this)); - return r.getDouble(0) != 0.0; - } - - @Override - public boolean none() { - return !any(); - } - - - /** - * Validate that the operation is being applied on a numerical array (not boolean or utf8). - * Some operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 arrays - * @param opName Operation name to print in the exception - */ - protected void validateNumericalArray(String opName, boolean allowEmpty){ - if(dataType() == DataType.BOOL || dataType() == DataType.UTF8) - throw new IllegalStateException("Cannot apply operation " + opName + " to array with " + dataType() + " datatype. Array shape: " + Arrays.toString(shape())); - if(!allowEmpty && isEmpty()) - throw new IllegalStateException("Cannot perform operation " + opName + " on empty array with datatype " + dataType()); - } - - @Override - public boolean closeable() { - if (released || isAttached()) - return false; - - // empty arrays have no buffer at all - if (isEmpty()) - return true; - - if (isView()) - return false; - - return data.closeable(); - } - - @Override - public void close() { - // empty arrays have no buffer at all - if (released || isEmpty()) - return; + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + return Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); + } else { + INDArray copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); + copy.assign(this); Nd4j.getExecutioner().commit(); - if (!closeable()) - throw new ND4JIllegalStateException("Can't release this INDArray"); + return copy; + } + } else { + MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); + Nd4j.getMemoryManager().setCurrentWorkspace(null); + INDArray copy = null; - data.close(); + if (!isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); - released = true; + //Pointer.memcpy(buffer.pointer(), this.data.pointer(), this.lengthLong() * Nd4j.sizeOfDataType(this.data.dataType())); + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + + copy = Nd4j.createArrayFromShapeBuffer(buffer, + this.shapeInfoDataBuffer()); //this.dup(this.ordering()); + + + } else { + copy = Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); + copy.assign(this); + Nd4j.getExecutioner().commit(); + } + + Nd4j.getMemoryManager().setCurrentWorkspace(workspace); + + return copy; + } + } + + @Override + public INDArray leverage() { + WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); + if (!isAttached()) { + return this; + } + + MemoryWorkspace workspace = Nd4j.getMemoryManager().getCurrentWorkspace(); + if (workspace == null) { + return this.detach(); } - @Override - public INDArray like() { - return Nd4j.create(this.dataType(), this.shape(), Nd4j.getStrides(this.shape(), this.ordering()), this.ordering()); + MemoryWorkspace parentWorkspace = workspace.getParentWorkspace(); + + if (this.data.getParentWorkspace() == parentWorkspace) { + return this; + } + + // if there's no parent ws - just detach + if (parentWorkspace == null) { + return this.detach(); + } else { + Nd4j.getExecutioner().commit(); + + // temporary set parent ws as current ws + Nd4j.getMemoryManager().setCurrentWorkspace(parentWorkspace); + + INDArray copy = null; + if (!this.isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.length(), false); + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + + copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); + } else { + copy = this.dup(this.ordering()); + Nd4j.getExecutioner().commit(); + } + + // restore current ws + Nd4j.getMemoryManager().setCurrentWorkspace(workspace); + return copy; + } + } + + @Override + public INDArray leverageTo(String id) { + return leverageTo(id, false); + } + + @Override + public INDArray leverageTo(String id, boolean enforceExistence) + throws Nd4jNoSuchWorkspaceException { + WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); + if (!isAttached()) { + return this; + } + + if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExists(id)) { + if (enforceExistence) { + throw new Nd4jNoSuchWorkspaceException(id); + } else { + return this; + } } - @Override - public INDArray ulike() { - return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); + MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); + MemoryWorkspace target = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(id); + + if (this.data.getParentWorkspace() == target) { + return this; + } + + Nd4j.getMemoryManager().setCurrentWorkspace(target); + INDArray copy = null; + if (!this.isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + + copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); + } else { + copy = this.dup(this.ordering()); + Nd4j.getExecutioner().commit(); } - @Override - public boolean wasClosed() { - // data can be null if that's empty array - return released || (data() != null && data().wasClosed()); + Nd4j.getMemoryManager().setCurrentWorkspace(current); + + return copy; + } + + public INDArray leverageOrDetach(String id) { + if (!isAttached()) { + return this; } - @Override - public long getId(){ - return arrayId; + if (!Nd4j.getWorkspaceManager().checkIfWorkspaceExistsAndActive(id)) { + return detach(); } + return leverageTo(id); + } + + @Override + public INDArray migrate() { + return migrate(false); + } + + @Override + public INDArray migrate(boolean detachOnNoWs) { + WorkspaceUtils.assertValidArray(this, "Cannot leverage INDArray to new workspace"); + + MemoryWorkspace current = Nd4j.getMemoryManager().getCurrentWorkspace(); + + if (current == null) { + if (detachOnNoWs) { + return detach(); + } else { + return this; + } + } + + INDArray copy = null; + + if (!this.isView()) { + Nd4j.getExecutioner().commit(); + DataBuffer buffer = Nd4j.createBuffer(this.dataType(), this.length(), false); + Nd4j.getMemoryManager().memcpy(buffer, this.data()); + + copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer()); + } else { + copy = this.dup(this.ordering()); + Nd4j.getExecutioner().commit(); + } + + return copy; + } + + @Override + public Number percentileNumber(Number quantile) { + validateNumericalArray("percentileNumber", false); + if (quantile.intValue() < 0 || quantile.intValue() > 100) { + throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); + } + + if (isScalar()) { + return this.getDouble(0); + } + + INDArray sorted = Nd4j.sort(this.dup(this.ordering()), true); + + return getPercentile(quantile, sorted); + } + + @Override + public Number medianNumber() { + validateNumericalArray("medianNumber", false); + if (isScalar()) { + return getNumber(0); + } + return percentileNumber(50); + } + + @Override + public INDArray median(int... dimension) { + validateNumericalArray("median", false); + //Check edge case: size 1 element. No dimension == full array + if (dimension.length == 0) { + return Nd4j.scalar(dataType(), medianNumber().doubleValue()); + } + long shapeProd = 1; + for (int d : dimension) { + shapeProd *= size(d); + } + if (shapeProd == 1) { + long[] newShape = ArrayUtil.removeIndex(shape(), dimension); + return dup('c').reshape('c', newShape); + } + return percentile(50, dimension); + } + + protected double getPercentile(Number quantile, INDArray sorted) { + validateNumericalArray("getPercentile", false); + if (quantile.intValue() == 0) { + return sorted.getDouble(0); + } else if (quantile.intValue() == 100) { + return sorted.getDouble(sorted.length() - 1); + } + + double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); + if (pos < 1) { + return sorted.getDouble(0); + } else if (pos >= sorted.length()) { + return sorted.getDouble(sorted.length() - 1); + } + + double fposition = FastMath.floor(pos); + int position = (int) fposition; + + double diff = pos - fposition; + + double lower = sorted.getDouble(position - 1); + double upper = sorted.getDouble(position); + + return lower + diff * (upper - lower); + } + + @Override + public INDArray percentile(Number quantile, int... dimension) { + validateNumericalArray("percentile", false); + if (quantile.doubleValue() < 0 || quantile.doubleValue() > 100) { + throw new ND4JIllegalStateException("Percentile value should be in 0...100 range"); + } + + if (isScalar()) { + return Nd4j.scalar(this.getDouble(0)); + } + + INDArray sorted = Nd4j.getNDArrayFactory().sort(this.dup(this.ordering()), false, dimension); + + // there's no practical sense doing this on GPU, stride will be just size of TAD. + INDArray ret = Nd4j.createUninitialized(Nd4j.defaultFloatingPointType(), + sorted.tensorsAlongDimension(dimension)); + for (int i = 0; i < ret.length(); i++) { + ret.putScalar(i, getPercentile(quantile, sorted.tensorAlongDimension(i, dimension))); + } + + return ret; + + } + + protected abstract int stringBuffer(FlatBufferBuilder builder, DataBuffer buffer); + + @Override + public int toFlatArray(FlatBufferBuilder builder) { + if (isView()) { + return dup(this.ordering()).toFlatArray(builder); + } + int shape = FlatArray.createShapeVector(builder, this.shapeInfoDataBuffer().asLong()); + int buffer = this.isEmpty() ? 0 + : this.dataType() == DataType.UTF8 ? stringBuffer(builder, this.data()) + : FlatArray.createBufferVector(builder, this.data().asBytes()); + val type = this.isEmpty() ? FlatBuffersMapper.getDataTypeAsByte(this.dataType()) + : FlatBuffersMapper.getDataTypeAsByte(this.data().dataType()); + int array = FlatArray.createFlatArray(builder, shape, buffer, type, ByteOrder.BE); + + return array; + } + + protected static DataTypeEx convertType(DataType type) { + if (type == DataType.HALF) { + return DataTypeEx.FLOAT16; + } else if (type == DataType.FLOAT) { + return DataTypeEx.FLOAT; + } else if (type == DataType.DOUBLE) { + return DataTypeEx.DOUBLE; + + } else if (type == DataType.INT) { + return DataTypeEx.INT8; + } else if (type == DataType.LONG) { + return DataTypeEx.INT16; + + } else { + throw new IllegalStateException("Unknown dataType: [" + type + "]"); + } + } + + @Override + public boolean isEmpty() { + return Shape.isEmpty(jvmShapeInfo.javaShapeInformation); + } + + @Override + public long[] shapeInfoJava() { + return jvmShapeInfo.javaShapeInformation; + } + + @Override + public DataType dataType() { + if (data != null) { + return data.dataType(); + } + + val e = Shape.extras(jvmShapeInfo.javaShapeInformation); + + if (e != 0) { + val t = ArrayOptionsHelper.dataType(jvmShapeInfo.javaShapeInformation); + return t; + } + + return DataType.UNKNOWN; + } + + @Override + public boolean isR() { + val dtype = dataType(); + return dtype == DataType.FLOAT || dtype == DataType.DOUBLE || dtype == DataType.HALF + || dtype == DataType.BFLOAT16; + } + + @Override + public boolean isZ() { + return !isR() && !isB() && !isS(); + } + + @Override + public boolean isB() { + return dataType() == DataType.BOOL; + } + + @Override + public boolean isS() { + return dataType() == DataType.UTF8; + } + + @Override + public INDArray castTo(DataType dataType) { + if (dataType == dataType()) //No-op if correct datatype + { + return this; + } + if (isEmpty() && rank() == 0) { + return Nd4j.empty(dataType); + } + val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); + result.assign(this); + return result; + } + + @Override + public boolean all() { + val r = Nd4j.getExecutioner().exec(new All(this)); + return r.getDouble(0) != 0.0; + } + + @Override + public boolean any() { + val r = Nd4j.getExecutioner().exec(new Any(this)); + return r.getDouble(0) != 0.0; + } + + @Override + public boolean none() { + return !any(); + } + + + /** + * Validate that the operation is being applied on a numerical array (not boolean or utf8). Some + * operations (such as sum, norm2, add(Number) etc don't make sense when applied to boolean/utf8 + * arrays + * + * @param opName Operation name to print in the exception + */ + protected void validateNumericalArray(String opName, boolean allowEmpty) { + if (dataType() == DataType.BOOL || dataType() == DataType.UTF8) { + throw new IllegalStateException( + "Cannot apply operation " + opName + " to array with " + dataType() + + " datatype. Array shape: " + Arrays.toString(shape())); + } + if (!allowEmpty && isEmpty()) { + throw new IllegalStateException( + "Cannot perform operation " + opName + " on empty array with datatype " + dataType()); + } + } + + @Override + public boolean closeable() { + if (released || isAttached()) { + return false; + } + + // empty arrays have no buffer at all + if (isEmpty()) { + return true; + } + + if (isView()) { + return false; + } + + return data.closeable(); + } + + @Override + public void close() { + // empty arrays have no buffer at all + if (released || isEmpty()) { + return; + } + + Nd4j.getExecutioner().commit(); + + if (!closeable()) { + throw new ND4JIllegalStateException("Can't release this INDArray"); + } + + data.close(); + + released = true; + } + + @Override + public INDArray like() { + return Nd4j.create(this.dataType(), this.shape(), + Nd4j.getStrides(this.shape(), this.ordering()), this.ordering()); + } + + @Override + public INDArray ulike() { + return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); + } + + @Override + public boolean wasClosed() { + // data can be null if that's empty array + return released || (data() != null && data().wasClosed()); + } + + @Override + public long getId() { + return arrayId; + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseShapeInfoProvider.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseShapeInfoProvider.java index 98a8f95ab..aeb91f0b7 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseShapeInfoProvider.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseShapeInfoProvider.java @@ -47,7 +47,7 @@ public abstract class BaseShapeInfoProvider implements ShapeInfoProvider { } /** - * This method creates shapeInformation buffer, based on shape & order being passed in + * This method creates shapeInformation buffer, based on shape and order being passed in * * @param shape * @param order diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 61b53e23d..f4d4b200e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -2216,7 +2216,7 @@ public interface INDArray extends Serializable, AutoCloseable { * Dimshuffle: an extension of permute that adds the ability * to broadcast various dimensions. * This will only accept integers and xs. - *

      + *

      * An x indicates a dimension should be broadcasted rather than permuted. * * Examples originally from the theano docs: @@ -2226,15 +2226,15 @@ public interface INDArray extends Serializable, AutoCloseable { A few examples of patterns and their effect: - ('x') -> make a 0d (scalar) into a 1d vector - (0, 1) -> identity for 2d vectors - (1, 0) -> inverts the first and second dimensions - ('x', 0) -> make a row out of a 1d vector (N to 1xN) - (0, 'x') -> make a column out of a 1d vector (N to Nx1) - (2, 0, 1) -> AxBxC to CxAxB - (0, 'x', 1) -> AxB to Ax1xB - (1, 'x', 0) -> AxB to Bx1xA - (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) + ('x') -> make a 0d (scalar) into a 1d vector + (0, 1) -> identity for 2d vectors + (1, 0) -> inverts the first and second dimensions + ('x', 0) -> make a row out of a 1d vector (N to 1xN) + (0, 'x') -> make a column out of a 1d vector (N to Nx1) + (2, 0, 1) -> AxBxC to CxAxB + (0, 'x', 1) -> AxB to Ax1xB + (1, 'x', 0) -> AxB to Bx1xA + (1,) -> This remove dimensions 0. It must be a broadcastable dimension (1xA to A) * @param rearrange the dimensions to swap to * @param newOrder the new order (think permute) @@ -2244,7 +2244,7 @@ public interface INDArray extends Serializable, AutoCloseable { INDArray dimShuffle(Object[] rearrange, int[] newOrder, boolean[] broadCastable); /** - * See {@link #dimShuffle(Object[], int[], boolean[]) + * See {@link #dimShuffle(Object[], int[], boolean[])} */ INDArray dimShuffle(Object[] rearrange, long[] newOrder, boolean[] broadCastable); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/ShapeInfoProvider.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/ShapeInfoProvider.java index 1f3768038..33f8378f1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/ShapeInfoProvider.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ndarray/ShapeInfoProvider.java @@ -33,7 +33,7 @@ public interface ShapeInfoProvider { Pair createShapeInformation(long[] shape, DataType dataType); /** - * This method creates long shapeInformation buffer, based on shape & order being passed in + * This method creates long shapeInformation buffer, based on shape and order being passed in * @param shape * @return */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index d258e4b3a..a54f4700a 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -65,7 +65,8 @@ public interface OpContext extends AutoCloseable { /** * This method sets root-level seed for rng - * @param seed + * @param rootState + * @param nodeState */ void setRngStates(long rootState, long nodeState); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/Random.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/Random.java index d3740509c..30e349e94 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/Random.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/Random.java @@ -251,7 +251,7 @@ public interface Random extends AutoCloseable { * The reason for this is due to ints * having the same space usage as floats. * This also plays nice with blas. - *

      + *

      * If the data opType is set to double, * then these will be whole doubles. * @@ -272,7 +272,7 @@ public interface Random extends AutoCloseable { * The reason for this is due to ints * having the same space usage as floats. * This also plays nice with blas. - *

      + *

      * If the data opType is set to double, * then these will be whole doubles. * diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java index 9015301ea..801eb8d89 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/BaseDistribution.java @@ -35,233 +35,236 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.Iterator; public abstract class BaseDistribution implements Distribution { - protected Random random; - protected double solverAbsoluteAccuracy; + + protected Random random; + protected double solverAbsoluteAccuracy; - public BaseDistribution(Random rng) { - this.random = rng; + public BaseDistribution(Random rng) { + this.random = rng; + } + + + public BaseDistribution() { + this(Nd4j.getRandom()); + } + + /** + * For a random variable {@code X} whose values are distributed according to this distribution, + * this method returns {@code P(x0 < X <= x1)}. + * + * @param x0 Lower bound (excluded). + * @param x1 Upper bound (included). + * @return the probability that a random variable with this distribution takes a value between + * {@code x0} and {@code x1}, excluding the lower and including the upper endpoint. + * @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}. + *

      + * The default implementation + * uses the identity + * {@code P(x0 < X <= x1) = + * P(X <= x1) - P(X <= x0)} + * @since 3.1 + */ + + public double probability(double x0, double x1) { + if (x0 > x1) { + throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0, + x1, true); } + return cumulativeProbability(x1) - cumulativeProbability(x0); + } - - public BaseDistribution() { - this(Nd4j.getRandom()); - } - - /** - * For a random variable {@code X} whose values are distributed according - * to this distribution, this method returns {@code P(x0 < X <= x1)}. + /** + * {@inheritDoc} + *

      + * The default implementation returns + *

        + *
      • {@link #getSupportLowerBound()} for {@code p = 0},
      • + *
      • {@link #getSupportUpperBound()} for {@code p = 1}.
      • + *
      + */ + @Override + public double inverseCumulativeProbability(final double p) throws OutOfRangeException { + /* + * IMPLEMENTATION NOTES + * -------------------- + * Where applicable, use is made of the one-sided Chebyshev inequality + * to bracket the root. This inequality states that + * P(X - mu >= k * sig) <= 1 / (1 + k^2), + * mu: mean, sig: standard deviation. Equivalently + * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2), + * F(mu + k * sig) >= k^2 / (1 + k^2). * - * @param x0 Lower bound (excluded). - * @param x1 Upper bound (included). - * @return the probability that a random variable with this distribution - * takes a value between {@code x0} and {@code x1}, excluding the lower - * and including the upper endpoint. - * @throws org.apache.commons.math3.exception.NumberIsTooLargeException if {@code x0 > x1}. - *

      - * The default implementation uses the identity - * {@code P(x0 < X <= x1) = P(X <= x1) - P(X <= x0)} - * @since 3.1 + * For k = sqrt(p / (1 - p)), we find + * F(mu + k * sig) >= p, + * and (mu + k * sig) is an upper-bound for the root. + * + * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and + * P(Y >= -mu + k * sig) <= 1 / (1 + k^2), + * P(-X >= -mu + k * sig) <= 1 / (1 + k^2), + * P(X <= mu - k * sig) <= 1 / (1 + k^2), + * F(mu - k * sig) <= 1 / (1 + k^2). + * + * For k = sqrt((1 - p) / p), we find + * F(mu - k * sig) <= p, + * and (mu - k * sig) is a lower-bound for the root. + * + * In cases where the Chebyshev inequality does not apply, geometric + * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket + * the root. */ - - public double probability(double x0, double x1) { - if (x0 > x1) { - throw new NumberIsTooLargeException(LocalizedFormats.LOWER_ENDPOINT_ABOVE_UPPER_ENDPOINT, x0, x1, true); - } - return cumulativeProbability(x1) - cumulativeProbability(x0); + if (p < 0.0 || p > 1.0) { + throw new OutOfRangeException(p, 0, 1); } - /** - * {@inheritDoc} - *

      - * The default implementation returns - *

        - *
      • {@link #getSupportLowerBound()} for {@code p = 0},
      • - *
      • {@link #getSupportUpperBound()} for {@code p = 1}.
      • - *
      - */ - @Override - public double inverseCumulativeProbability(final double p) throws OutOfRangeException { - /* - * IMPLEMENTATION NOTES - * -------------------- - * Where applicable, use is made of the one-sided Chebyshev inequality - * to bracket the root. This inequality states that - * P(X - mu >= k * sig) <= 1 / (1 + k^2), - * mu: mean, sig: standard deviation. Equivalently - * 1 - P(X < mu + k * sig) <= 1 / (1 + k^2), - * F(mu + k * sig) >= k^2 / (1 + k^2). - * - * For k = sqrt(p / (1 - p)), we find - * F(mu + k * sig) >= p, - * and (mu + k * sig) is an upper-bound for the root. - * - * Then, introducing Y = -X, mean(Y) = -mu, sd(Y) = sig, and - * P(Y >= -mu + k * sig) <= 1 / (1 + k^2), - * P(-X >= -mu + k * sig) <= 1 / (1 + k^2), - * P(X <= mu - k * sig) <= 1 / (1 + k^2), - * F(mu - k * sig) <= 1 / (1 + k^2). - * - * For k = sqrt((1 - p) / p), we find - * F(mu - k * sig) <= p, - * and (mu - k * sig) is a lower-bound for the root. - * - * In cases where the Chebyshev inequality does not apply, geometric - * progressions 1, 2, 4, ... and -1, -2, -4, ... are used to bracket - * the root. - */ - if (p < 0.0 || p > 1.0) { - throw new OutOfRangeException(p, 0, 1); + double lowerBound = getSupportLowerBound(); + if (p == 0.0) { + return lowerBound; + } + + double upperBound = getSupportUpperBound(); + if (p == 1.0) { + return upperBound; + } + + final double mu = getNumericalMean(); + final double sig = FastMath.sqrt(getNumericalVariance()); + final boolean chebyshevApplies; + chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig) + || Double.isNaN(sig)); + + if (lowerBound == Double.NEGATIVE_INFINITY) { + if (chebyshevApplies) { + lowerBound = mu - sig * FastMath.sqrt((1. - p) / p); + } else { + lowerBound = -1.0; + while (cumulativeProbability(lowerBound) >= p) { + lowerBound *= 2.0; } + } + } - double lowerBound = getSupportLowerBound(); - if (p == 0.0) { - return lowerBound; + if (upperBound == Double.POSITIVE_INFINITY) { + if (chebyshevApplies) { + upperBound = mu + sig * FastMath.sqrt(p / (1. - p)); + } else { + upperBound = 1.0; + while (cumulativeProbability(upperBound) < p) { + upperBound *= 2.0; } + } + } - double upperBound = getSupportUpperBound(); - if (p == 1.0) { - return upperBound; - } + final UnivariateFunction toSolve = new UnivariateFunction() { - final double mu = getNumericalMean(); - final double sig = FastMath.sqrt(getNumericalVariance()); - final boolean chebyshevApplies; - chebyshevApplies = !(Double.isInfinite(mu) || Double.isNaN(mu) || Double.isInfinite(sig) || Double.isNaN(sig)); + public double value(final double x) { + return cumulativeProbability(x) - p; + } + }; - if (lowerBound == Double.NEGATIVE_INFINITY) { - if (chebyshevApplies) { - lowerBound = mu - sig * FastMath.sqrt((1. - p) / p); + double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound, + getSolverAbsoluteAccuracy()); + + if (!isSupportConnected()) { + /* Test for plateau. */ + final double dx = getSolverAbsoluteAccuracy(); + if (x - dx >= getSupportLowerBound()) { + double px = cumulativeProbability(x); + if (cumulativeProbability(x - dx) == px) { + upperBound = x; + while (upperBound - lowerBound > dx) { + final double midPoint = 0.5 * (lowerBound + upperBound); + if (cumulativeProbability(midPoint) < px) { + lowerBound = midPoint; } else { - lowerBound = -1.0; - while (cumulativeProbability(lowerBound) >= p) { - lowerBound *= 2.0; - } + upperBound = midPoint; } + } + return upperBound; } - - if (upperBound == Double.POSITIVE_INFINITY) { - if (chebyshevApplies) { - upperBound = mu + sig * FastMath.sqrt(p / (1. - p)); - } else { - upperBound = 1.0; - while (cumulativeProbability(upperBound) < p) { - upperBound *= 2.0; - } - } - } - - final UnivariateFunction toSolve = new UnivariateFunction() { - - public double value(final double x) { - return cumulativeProbability(x) - p; - } - }; - - double x = UnivariateSolverUtils.solve(toSolve, lowerBound, upperBound, getSolverAbsoluteAccuracy()); - - if (!isSupportConnected()) { - /* Test for plateau. */ - final double dx = getSolverAbsoluteAccuracy(); - if (x - dx >= getSupportLowerBound()) { - double px = cumulativeProbability(x); - if (cumulativeProbability(x - dx) == px) { - upperBound = x; - while (upperBound - lowerBound > dx) { - final double midPoint = 0.5 * (lowerBound + upperBound); - if (cumulativeProbability(midPoint) < px) { - lowerBound = midPoint; - } else { - upperBound = midPoint; - } - } - return upperBound; - } - } - } - return x; + } } + return x; + } - /** - * Returns the solver absolute accuracy for inverse cumulative computation. - * You can override this method in order to use a Brent solver with an - * absolute accuracy different from the default. - * - * @return the maximum absolute error in inverse cumulative probability estimates - */ - protected double getSolverAbsoluteAccuracy() { - return solverAbsoluteAccuracy; - } + /** + * Returns the solver absolute accuracy for inverse cumulative computation. You can override this + * method in order to use a Brent solver with an absolute accuracy different from the default. + * + * @return the maximum absolute error in inverse cumulative probability estimates + */ + protected double getSolverAbsoluteAccuracy() { + return solverAbsoluteAccuracy; + } - /** - * {@inheritDoc} - */ - @Override - public void reseedRandomGenerator(long seed) { - random.setSeed(seed); - } + /** + * {@inheritDoc} + */ + @Override + public void reseedRandomGenerator(long seed) { + random.setSeed(seed); + } - /** - * {@inheritDoc} - *

      - * The default implementation uses the - * - * inversion method. - * - */ - @Override - public double sample() { - return inverseCumulativeProbability(random.nextDouble()); - } + /** + * {@inheritDoc} + * The default implementation uses the + * + * inversion method. + * + */ + @Override + public double sample() { + return inverseCumulativeProbability(random.nextDouble()); + } - /** - * {@inheritDoc} - *

      - * The default implementation generates the sample by calling - * {@link #sample()} in a loop. - */ - @Override - public double[] sample(long sampleSize) { - if (sampleSize <= 0) { - throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); - } - double[] out = new double[(int) sampleSize]; - for (int i = 0; i < sampleSize; i++) { - out[i] = sample(); - } - return out; + /** + * {@inheritDoc} + *

      + * The default implementation generates the sample by calling {@link #sample()} in a loop. + */ + @Override + public double[] sample(long sampleSize) { + if (sampleSize <= 0) { + throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize); } + double[] out = new double[(int) sampleSize]; + for (int i = 0; i < sampleSize; i++) { + out[i] = sample(); + } + return out; + } - /** - * {@inheritDoc} - * - * @return zero. - * @since 3.1 - */ - @Override - public double probability(double x) { - return 0d; - } + /** + * {@inheritDoc} + * + * @return zero. + * @since 3.1 + */ + @Override + public double probability(double x) { + return 0d; + } - @Override - public INDArray sample(int[] shape) { - INDArray ret = Nd4j.create(shape); - return sample(ret); - } + @Override + public INDArray sample(int[] shape) { + INDArray ret = Nd4j.create(shape); + return sample(ret); + } - @Override - public INDArray sample(long[] shape) { - INDArray ret = Nd4j.create(shape); - return sample(ret); - } + @Override + public INDArray sample(long[] shape) { + INDArray ret = Nd4j.create(shape); + return sample(ret); + } - @Override - public INDArray sample(INDArray target) { - Iterator idxIter = new NdIndexIterator(target.shape()); //For consistent values irrespective of c vs. fortran ordering - long len = target.length(); - for (long i = 0; i < len; i++) { - target.putScalar(idxIter.next(), sample()); - } - return target; + @Override + public INDArray sample(INDArray target) { + Iterator idxIter = new NdIndexIterator( + target.shape()); //For consistent values irrespective of c vs. fortran ordering + long len = target.length(); + for (long i = 0; i < len; i++) { + target.putScalar(idxIter.next(), sample()); } + return target; + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/Distribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/Distribution.java index e224f5866..3375d57df 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/Distribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/Distribution.java @@ -89,8 +89,8 @@ public interface Distribution { * variable {@code X} distributed according to this distribution, the * returned value is *

        - *
      • inf{x in R | P(X<=x) >= p} for {@code 0 < p <= 1},
      • - *
      • inf{x in R | P(X<=x) > 0} for {@code p = 0}.
      • + *
      • {@code inf{x in R | P(X<=x) >= p}} for {@code 0 < p <= 1},
      • + *
      • {@code inf{x in R | P(X<=x) > 0}} for {@code p = 0}.
      • *
      * * @param p the cumulative probability @@ -122,7 +122,7 @@ public interface Distribution { * Access the lower bound of the support. This method must return the same * value as {@code inverseCumulativeProbability(0)}. In other words, this * method must return - *

      inf {x in R | P(X <= x) > 0}.

      + *

      {@code inf {x in R | P(X <= x) > 0}}.

      * * @return lower bound of the support (might be * {@code Double.NEGATIVE_INFINITY}) @@ -133,7 +133,7 @@ public interface Distribution { * Access the upper bound of the support. This method must return the same * value as {@code inverseCumulativeProbability(1)}. In other words, this * method must return - *

      inf {x in R | P(X <= x) = 1}.

      + *

      {@code inf {x in R | P(X <= x) = 1}}.

      * * @return upper bound of the support (might be * {@code Double.POSITIVE_INFINITY}) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.java index 2d295d53f..2d50a779d 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/BinomialDistribution.java @@ -166,7 +166,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For {@code n} trials and probability parameter {@code p}, the mean is * {@code n * p}. */ @@ -177,7 +177,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For {@code n} trials and probability parameter {@code p}, the variance is * {@code n * p * (1 - p)}. */ @@ -189,7 +189,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The lower bound of the support is always 0 except for the probability * parameter {@code p = 1}. * @@ -203,7 +203,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is the number of trials except for the * probability parameter {@code p = 0}. * @@ -227,7 +227,7 @@ public class BinomialDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java index b56722c30..cb4136829 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/ConstantDistribution.java @@ -83,7 +83,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * is returned, as in these cases the actual value is within * {@code Double.MIN_VALUE} of 0 or 1. @@ -131,7 +131,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For mean parameter {@code mu}, the mean is {@code mu}. */ public double getNumericalMean() { @@ -140,7 +140,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For standard deviation parameter {@code s}, the variance is {@code s^2}. */ public double getNumericalVariance() { @@ -150,7 +150,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The lower bound of the support is always negative infinity * no matter the parameters. * @@ -163,7 +163,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is always positive infinity * no matter the parameters. * @@ -190,7 +190,7 @@ public class ConstantDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/LogNormalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/LogNormalDistribution.java index f0c9aa396..8788539fd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/LogNormalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/LogNormalDistribution.java @@ -172,7 +172,6 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * is returned, as in these cases the actual value is within * {@code Double.MIN_VALUE} of 0 or 1. @@ -238,7 +237,6 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * For mean parameter {@code mu}, the mean is {@code mu}. */ public double getNumericalMean() { @@ -247,7 +245,6 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * For standard deviation parameter {@code s}, the variance is {@code s^2}. */ public double getNumericalVariance() { @@ -257,7 +254,6 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * The lower bound of the support is always negative infinity * no matter the parameters. * @@ -270,7 +266,7 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is always positive infinity * no matter the parameters. * @@ -297,7 +293,7 @@ public class LogNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/NormalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/NormalDistribution.java index a7ccc5caf..6cb7b5995 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/NormalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/NormalDistribution.java @@ -176,7 +176,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * is returned, as in these cases the actual value is within * {@code Double.MIN_VALUE} of 0 or 1. @@ -242,7 +241,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * For mean parameter {@code mu}, the mean is {@code mu}. */ public double getNumericalMean() { @@ -251,7 +249,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * For standard deviation parameter {@code s}, the variance is {@code s^2}. */ public double getNumericalVariance() { @@ -261,7 +258,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * The lower bound of the support is always negative infinity * no matter the parameters. * @@ -274,7 +270,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * The upper bound of the support is always positive infinity * no matter the parameters. * @@ -301,7 +296,6 @@ public class NormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java index 3b1faaf71..455388aa8 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/OrthogonalDistribution.java @@ -34,27 +34,28 @@ import org.nd4j.common.util.ArrayUtil; @Slf4j public class OrthogonalDistribution extends BaseDistribution { - /** - * Default inverse cumulative probability accuracy. - * - * @since 2.1 - */ - public static final double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9; - /** - * Serializable version identifier. - */ - private static final long serialVersionUID = 8589540077390120676L; - /** - * Mean of this distribution. - */ - private final double gain; - private INDArray gains; + /** + * Default inverse cumulative probability accuracy. + * + * @since 2.1 + */ + public static final double DEFAULT_INVERSE_ABSOLUTE_ACCURACY = 1e-9; + /** + * Serializable version identifier. + */ + private static final long serialVersionUID = 8589540077390120676L; - public OrthogonalDistribution(double gain) { - this.gain = gain; - this.random = Nd4j.getRandom(); - } + /** + * Mean of this distribution. + */ + private final double gain; + private INDArray gains; + + public OrthogonalDistribution(double gain) { + this.gain = gain; + this.random = Nd4j.getRandom(); + } /* max doesn't want this distripution public OrthogonalDistribution(@NonNull INDArray gains) { @@ -62,196 +63,192 @@ public class OrthogonalDistribution extends BaseDistribution { this.random = Nd4j.getRandom(); } */ - /** - * Access the mean. - * - * @return the mean for this distribution. - */ - public double getMean() { - throw new UnsupportedOperationException(); + + /** + * Access the mean. + * + * @return the mean for this distribution. + */ + public double getMean() { + throw new UnsupportedOperationException(); + } + + /** + * Access the standard deviation. + * + * @return the standard deviation for this distribution. + */ + public double getStandardDeviation() { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + */ + public double density(double x) { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} If {@code x} is more than 40 standard deviations from the mean, 0 or 1 is + * returned, as in these cases the actual value is within {@code Double.MIN_VALUE} of 0 or 1. + */ + public double cumulativeProbability(double x) { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + * + * @since 3.2 + */ + @Override + public double inverseCumulativeProbability(final double p) throws OutOfRangeException { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + * + * @deprecated See + * {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double, + * double)} + */ + @Override + @Deprecated + public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + */ + @Override + public double probability(double x0, double x1) throws NumberIsTooLargeException { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + */ + @Override + protected double getSolverAbsoluteAccuracy() { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} For mean parameter {@code mu}, the mean is {@code mu}. + */ + public double getNumericalMean() { + return getMean(); + } + + /** + * {@inheritDoc} For standard deviation parameter {@code s}, the variance is {@code s^2}. + */ + public double getNumericalVariance() { + final double s = getStandardDeviation(); + return s * s; + } + + /** + * {@inheritDoc} The lower bound of the support is always negative infinity no matter the + * parameters. + * + * @return lower bound of the support (always {@code Double.NEGATIVE_INFINITY}) + */ + public double getSupportLowerBound() { + return Double.NEGATIVE_INFINITY; + } + + /** + * {@inheritDoc} + *

      + * The upper bound of the support is always positive infinity no matter the parameters. + * + * @return upper bound of the support (always {@code Double.POSITIVE_INFINITY}) + */ + public double getSupportUpperBound() { + return Double.POSITIVE_INFINITY; + } + + /** + * {@inheritDoc} + */ + public boolean isSupportLowerBoundInclusive() { + return false; + } + + /** + * {@inheritDoc} + */ + public boolean isSupportUpperBoundInclusive() { + return false; + } + + /** + * {@inheritDoc} + *

      + * The support of this distribution is connected. + * + * @return {@code true} + */ + public boolean isSupportConnected() { + return true; + } + + /** + * {@inheritDoc} + */ + @Override + public double sample() { + throw new UnsupportedOperationException(); + } + + @Override + public INDArray sample(int[] shape) { + return sample(ArrayUtil.toLongArray(shape)); + } + + @Override + public INDArray sample(long[] shape) { + long numRows = 1; + for (int i = 0; i < shape.length - 1; i++) { + numRows *= shape[i]; } + long numCols = shape[shape.length - 1]; - /** - * Access the standard deviation. - * - * @return the standard deviation for this distribution. - */ - public double getStandardDeviation() { - throw new UnsupportedOperationException(); + val dtype = Nd4j.defaultFloatingPointType(); + + val flatShape = new long[]{numRows, numCols}; + val flatRng = Nd4j.getExecutioner().exec( + new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, + 1.0), random); + + val m = flatRng.rows(); + val n = flatRng.columns(); + + val s = Nd4j.create(dtype, m < n ? m : n); + val u = Nd4j.create(dtype, m, m); + val v = Nd4j.create(dtype, new long[]{n, n}, 'f'); + + Nd4j.exec(new Svd(flatRng, true, s, u, v)); + + if (gains == null) { + if (u.rows() >= numRows && u.columns() >= numCols) { + return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain) + .reshape(shape); + } else { + return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain) + .reshape(shape); + } + } else { + throw new UnsupportedOperationException(); } + } - /** - * {@inheritDoc} - */ - public double density(double x) { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - *

      - * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 - * is returned, as in these cases the actual value is within - * {@code Double.MIN_VALUE} of 0 or 1. - */ - public double cumulativeProbability(double x) { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - * - * @since 3.2 - */ - @Override - public double inverseCumulativeProbability(final double p) throws OutOfRangeException { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - * - * @deprecated See {@link org.apache.commons.math3.distribution.RealDistribution#cumulativeProbability(double, double)} - */ - @Override - @Deprecated - public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - */ - @Override - public double probability(double x0, double x1) throws NumberIsTooLargeException { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - */ - @Override - protected double getSolverAbsoluteAccuracy() { - throw new UnsupportedOperationException(); - } - - /** - * {@inheritDoc} - *

      - * For mean parameter {@code mu}, the mean is {@code mu}. - */ - public double getNumericalMean() { - return getMean(); - } - - /** - * {@inheritDoc} - *

      - * For standard deviation parameter {@code s}, the variance is {@code s^2}. - */ - public double getNumericalVariance() { - final double s = getStandardDeviation(); - return s * s; - } - - /** - * {@inheritDoc} - *

      - * The lower bound of the support is always negative infinity - * no matter the parameters. - * - * @return lower bound of the support (always - * {@code Double.NEGATIVE_INFINITY}) - */ - public double getSupportLowerBound() { - return Double.NEGATIVE_INFINITY; - } - - /** - * {@inheritDoc} - *

      - * The upper bound of the support is always positive infinity - * no matter the parameters. - * - * @return upper bound of the support (always - * {@code Double.POSITIVE_INFINITY}) - */ - public double getSupportUpperBound() { - return Double.POSITIVE_INFINITY; - } - - /** - * {@inheritDoc} - */ - public boolean isSupportLowerBoundInclusive() { - return false; - } - - /** - * {@inheritDoc} - */ - public boolean isSupportUpperBoundInclusive() { - return false; - } - - /** - * {@inheritDoc} - *

      - * The support of this distribution is connected. - * - * @return {@code true} - */ - public boolean isSupportConnected() { - return true; - } - - /** - * {@inheritDoc} - */ - @Override - public double sample() { - throw new UnsupportedOperationException(); - } - - @Override - public INDArray sample(int[] shape) { - return sample(ArrayUtil.toLongArray(shape)); - } - - @Override - public INDArray sample(long[] shape){ - long numRows = 1; - for (int i = 0; i < shape.length - 1; i++) - numRows *= shape[i]; - long numCols = shape[shape.length - 1]; - - val dtype = Nd4j.defaultFloatingPointType(); - - val flatShape = new long[]{numRows, numCols}; - val flatRng = Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.createUninitialized(dtype, flatShape, Nd4j.order()), 0.0, 1.0), random); - - val m = flatRng.rows(); - val n = flatRng.columns(); - - val s = Nd4j.create(dtype, m < n ? m : n); - val u = Nd4j.create(dtype, m, m); - val v = Nd4j.create(dtype, new long[] {n, n}, 'f'); - - Nd4j.exec(new Svd(flatRng, true, s, u, v)); - - if (gains == null) { - if (u.rows() >= numRows && u.columns() >= numCols) { - return u.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); - } else { - return v.get(NDArrayIndex.interval(0, numRows), NDArrayIndex.interval(0, numCols)).mul(gain).reshape(shape); - } - } else { - throw new UnsupportedOperationException(); - } - } - - @Override - public INDArray sample(INDArray target){ - return target.assign(sample(target.shape())); - } + @Override + public INDArray sample(INDArray target) { + return target.assign(sample(target.shape())); + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/SaddlePointExpansion.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/SaddlePointExpansion.java index 446c0c264..6b547e091 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/SaddlePointExpansion.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/SaddlePointExpansion.java @@ -84,7 +84,6 @@ public class SaddlePointExpansion { * href="http://mathworld.wolfram.com/StirlingsSeries.html"> * http://mathworld.wolfram.com/StirlingsSeries.html * - *

      * * @param z the value. * @return the Striling's series error. @@ -117,7 +116,6 @@ public class SaddlePointExpansion { * href="http://www.herine.net/stat/papers/dbinom.pdf"> * http://www.herine.net/stat/papers/dbinom.pdf * - *

      * * @param x the x value. * @param mu the average. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/TruncatedNormalDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/TruncatedNormalDistribution.java index 3043c9ebf..75cb216c4 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/TruncatedNormalDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/TruncatedNormalDistribution.java @@ -172,7 +172,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * If {@code x} is more than 40 standard deviations from the mean, 0 or 1 * is returned, as in these cases the actual value is within * {@code Double.MIN_VALUE} of 0 or 1. @@ -238,7 +238,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For mean parameter {@code mu}, the mean is {@code mu}. */ public double getNumericalMean() { @@ -247,7 +247,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For standard deviation parameter {@code s}, the variance is {@code s^2}. */ public double getNumericalVariance() { @@ -257,7 +257,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The lower bound of the support is always negative infinity * no matter the parameters. * @@ -270,7 +270,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is always positive infinity * no matter the parameters. * @@ -297,7 +297,7 @@ public class TruncatedNormalDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java index 07627f05c..bd5e1635e 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/api/rng/distribution/impl/UniformDistribution.java @@ -105,7 +105,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For lower bound {@code lower} and upper bound {@code upper}, the mean is * {@code 0.5 * (lower + upper)}. */ @@ -115,7 +115,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * For lower bound {@code lower} and upper bound {@code upper}, the * variance is {@code (upper - lower)^2 / 12}. */ @@ -126,7 +126,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The lower bound of the support is equal to the lower bound parameter * of the distribution. * @@ -138,7 +138,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The upper bound of the support is equal to the upper bound parameter * of the distribution. * @@ -164,7 +164,7 @@ public class UniformDistribution extends BaseDistribution { /** * {@inheritDoc} - *

      + *

      * The support of this distribution is connected. * * @return {@code true} diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java index 1665ca165..0e0d5c4b0 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/checkutil/NDArrayCreationUtil.java @@ -58,7 +58,7 @@ public class NDArrayCreationUtil { } - /** Get an array of INDArrays (2d) all with the specified shape. Pair returned to aid + /** Get an array of INDArrays (2d) all with the specified shape. {@code Pair} returned to aid * debugging: String contains information on how to reproduce the matrix (i.e., which function, and arguments) * Each NDArray in the returned array has been obtained by applying an operation such as transpose, tensorAlongDimension, * etc to an original array. @@ -88,7 +88,7 @@ public class NDArrayCreationUtil { * eg. rank 2: 1,1; 1,2; 2,1; 2,2; 3,4 * Motivated by TADs that often hit bugs when a "1" occurs as the size of a dimension * - * @param rank any rank including true scalars i.e rank >= 0 + * @param rank any rank including true scalars i.e rank >= 0 * @param order what order array to return i.e 'c' or 'f' order arrays * @return List of arrays and the shapes as strings */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java index e57f29072..b05689764 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncDataSetIterator.java @@ -355,7 +355,7 @@ public class AsyncDataSetIterator implements DataSetIterator { * yet been called, or the {@code remove} method has already * been called after the last call to the {@code next} * method - * @implSpec The default implementation throws an instance of + * The default implementation throws an instance of * {@link UnsupportedOperationException} and performs no other action. */ @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java index 822fa3ce2..5d372309b 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/AsyncMultiDataSetIterator.java @@ -299,7 +299,7 @@ public class AsyncMultiDataSetIterator implements MultiDataSetIterator { * yet been called, or the {@code remove} method has already * been called after the last call to the {@code next} * method - * @implSpec The default implementation throws an instance of + * The default implementation throws an instance of * {@link UnsupportedOperationException} and performs no other action. */ @Override diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java index f66afa29f..222990cc5 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/DataSet.java @@ -560,7 +560,6 @@ public class DataSet implements org.nd4j.linalg.dataset.api.DataSet { /** - * @Deprecated * Subtract by the column means and divide by the standard deviation */ @Deprecated diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java index 00e81c22f..6af1d6bcd 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/KFoldIterator.java @@ -117,7 +117,6 @@ public class KFoldIterator implements DataSetIterator { /** * Shuffles the dataset and resets to the first fold * - * @return void */ @Override public void reset() { @@ -129,7 +128,7 @@ public class KFoldIterator implements DataSetIterator { /** * The number of examples in every fold is (N / k), - * except when (N % k) > 0, when the first (N % k) folds contain (N / k) + 1 examples + * except when (N % k) > 0, when the first (N % k) folds contain (N / k) + 1 examples * * @return examples in a fold */ diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java index ff82d068c..1ebe44dca 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/iterator/TestDataSetIterator.java @@ -49,7 +49,6 @@ public class TestDataSetIterator implements DataSetIterator { * Initializes with a default batch of 5 * * @param dataset the dataset to make the iterator from - * @param batch the batchsize for the iterator */ public TestDataSetIterator(DataSet dataset) { this(dataset, 5); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java index 326b9d45f..6ec22e018 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dimensionalityreduction/RandomProjection.java @@ -65,9 +65,9 @@ public class RandomProjection { * The minimum number n' of components to guarantee the eps-embedding is * given by: * - * n' >= 4 log(n) / (eps² / 2 - eps³ / 3) + * {@code n' >= 4 log(n) / (eps² / 2 - eps³ / 3)} * - * see http://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf §2.1 + * http://cseweb.ucsd.edu/~dasgupta/papers/jl.pdf §2.1 * @param n Number of samples. If an array is given, it will compute * a safe number of components array-wise. * @param eps Maximum distortion rate as defined by the Johnson-Lindenstrauss lemma. diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/EnvironmentalAction.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/EnvironmentalAction.java index e369b61e2..ee9a94719 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/EnvironmentalAction.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/env/EnvironmentalAction.java @@ -30,7 +30,6 @@ public interface EnvironmentalAction { /** * This method will be executed with corresponding Env Var value * - * @param name * @param value */ void process(String value); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java index 3458ed06b..d4d11fa50 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BaseNDArrayFactory.java @@ -276,7 +276,6 @@ public abstract class BaseNDArrayFactory implements NDArrayFactory { * Rotate a matrix 90 degrees * * @param toRotate the matrix to rotate - * @return the rotated matrix */ @Override public void rot90(INDArray toRotate) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BlasWrapper.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BlasWrapper.java index b4655fd6f..1965883af 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BlasWrapper.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/BlasWrapper.java @@ -33,7 +33,7 @@ public interface BlasWrapper { */ /** - * Compute x <-> y (swap two matrices) + * Compute {@code x <-> y} (swap two matrices) */ INDArray swap(INDArray x, INDArray y); @@ -69,14 +69,14 @@ public interface BlasWrapper { INDArray scal(double alpha, INDArray x); /** - * Compute x <- alpha * x (scale a matrix) + * Compute {@code x <- alpha * x} (scale a matrix) */ @Deprecated INDArray scal(float alpha, INDArray x); /** - * Compute y <- x (copy a matrix) + * Compute {@code y <- x} (copy a matrix) */ INDArray copy(INDArray x, INDArray y); @@ -84,13 +84,13 @@ public interface BlasWrapper { INDArray axpy(double da, INDArray dx, INDArray dy); /** - * Compute y <- alpha * x + y (elementwise addition) + * Compute {@code y <- alpha * x + y }(elementwise addition) */ @Deprecated INDArray axpy(float da, INDArray dx, INDArray dy); /** - * Compute y <- y + x * alpha + * Compute {@code y <- y + x * alpha} * @param da the alpha to multiply by * @param dx * @param dy @@ -130,7 +130,7 @@ public interface BlasWrapper { INDArray gemv(double alpha, INDArray a, INDArray x, double beta, INDArray y); /** - * Compute y <- alpha*op(a)*x + beta * y (general matrix vector + * Compute {@code y <- alpha*op(a)*x + beta * y} (general matrix vector * multiplication) */ @Deprecated @@ -142,7 +142,7 @@ public interface BlasWrapper { INDArray ger(double alpha, INDArray x, INDArray y, INDArray a); /** - * Compute A <- alpha * x * y^T + A (general rank-1 update) + * Compute {@code A <- alpha * x * y^T + A} (general rank-1 update) */ INDArray ger(float alpha, INDArray x, INDArray y, INDArray a); @@ -193,14 +193,14 @@ public interface BlasWrapper { /** * Generalized Least Squares via *GELSD. - *

      + *

      * Note that B must be padded to contain the solution matrix. This occurs when A has fewer rows * than columns. - *

      - * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain + *

      + * For example: in A * X = B, A is (m,n), X is (n,k) and B is (m,k). Now if m < n, since B is overwritten to contain * the solution (in classical LAPACK style), B needs to be padded to be an (n,k) matrix. - *

      - * Likewise, if m > n, the solution consists only of the first n rows of B. + *

      + * Likewise, if m > n, the solution consists only of the first n rows of B. * * @param A an (m,n) matrix * @param B an (max(m,n), k) matrix (well, at least) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java index fe162c8e2..a2b91fb15 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/NDArrayFactory.java @@ -193,7 +193,6 @@ public interface NDArrayFactory { * Rotate a matrix 90 degrees * * @param toRotate the matrix to rotate - * @return the rotated matrix */ void rot90(INDArray toRotate); @@ -340,7 +339,6 @@ public interface NDArrayFactory { * * @param array the ndarray to shuffle * @param dimension the dimension to do the shuffle - * @return */ void shuffle(INDArray array, Random rnd, int... dimension); @@ -350,7 +348,6 @@ public interface NDArrayFactory { * * @param array the ndarray to shuffle * @param dimension the dimension to do the shuffle - * @return */ void shuffle(Collection array, Random rnd, int... dimension); @@ -360,7 +357,6 @@ public interface NDArrayFactory { * * @param array the ndarray to shuffle * @param dimensions the dimensions to do the shuffle - * @return */ void shuffle(List array, Random rnd, List dimensions); @@ -1370,9 +1366,9 @@ public interface NDArrayFactory { INDArray createFromNpyFile(File file); /** - * Create a Map from given npz file. + * Create a {@code Map} from given npz file. * @param file the file to create the map from - * @return Map + * @return {@code Map} */ Map createFromNpzFile(File file) throws Exception; @@ -1386,7 +1382,7 @@ public interface NDArrayFactory { * * * @param array the array to convert - * @returnthe created pointer representing + * @return the created pointer representing * a pointer to a numpy header */ Pointer convertToNumpy(INDArray array); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 2dfff5fbd..f542e3cce 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -1441,7 +1441,7 @@ public class Nd4j { } /** - * See {@link #createBuffer(DataType dataType, long length, boolean initialize) with default datatype. + * See {@link #createBuffer(DataType dataType, long length, boolean initialize)} with default datatype. */ public static DataBuffer createBuffer(long length, boolean initialize) { return createBuffer(Nd4j.dataType(), length, initialize); @@ -2828,7 +2828,7 @@ public class Nd4j { } /** - * @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)) + * @deprecated use {@link Nd4j#rand(org.nd4j.linalg.api.buffer.DataType, char, long...)} */ @Deprecated public static INDArray rand(@NonNull DataType dataType, int[] shape, char order) { @@ -3306,7 +3306,7 @@ public class Nd4j { * Generate an array with random values generated according to a binomial distribution with the specified * number of trials and probability * - * @param nTrials Number of trials. Must be >= 0 + * @param nTrials Number of trials. Must be >= 0 * @param p Probability. Must be in range 0 to 1 * @param shape Shape of the result array * @return Result array @@ -3319,7 +3319,7 @@ public class Nd4j { * Fill the target array with random values generated according to a binomial distribution with the specified * number of trials and probability * - * @param nTrials Number of trials. Must be >= 0 + * @param nTrials Number of trials. Must be >= 0 * @param p Probability. Must be in range 0 to 1 * @param target Result array * @return Result array @@ -3333,7 +3333,7 @@ public class Nd4j { /** * Exponential distribution: P(x) = lambda * exp(-lambda * x) * - * @param lambda Must be > 0 + * @param lambda Must be > 0 * @param shape Shape of the array to generate */ public static INDArray randomExponential(double lambda, long... shape) { @@ -3341,9 +3341,9 @@ public class Nd4j { } /** - * Exponential distribution: P(x) = lambda * exp(-lambda * x) + * Exponential distribution: {@code P(x) = lambda * exp(-lambda * x)} * - * @param lambda Must be > 0 + * @param lambda Must be > 0 * @param target Array to hold the result */ public static INDArray randomExponential(double lambda, INDArray target) { @@ -3925,7 +3925,7 @@ public class Nd4j { } /** - * See {@link @see #create(int, int, int[], char)} + * See {@link Nd4j#create(int, int, int[], char)} */ public static INDArray zeros(int rows, int columns, int[] stride) { return create(rows, columns, stride, order()); @@ -4630,7 +4630,7 @@ public class Nd4j { /** * Concatenates two matrices vertically. Matrices must have identical numbers of columns.
      - * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] + * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] * * @param arrs Arrays to vstack */ @@ -4646,7 +4646,7 @@ public class Nd4j { /** * Concatenates two matrices vertically. Matrices must have identical numbers of columns.
      - * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] + * Note that for vstack on rank 1 arrays, this is equivalent to {@link Nd4j#pile(INDArray...)}. Example: vstack([3],[3]) -> [2,3] * * @param arrs Arrays to vstack */ @@ -5462,7 +5462,7 @@ public class Nd4j { Examples -------- - >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) + {@code >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) array([[ 1, 2, 3], [ 4, 5, 6], [ 0, 8, 9], @@ -5473,6 +5473,7 @@ public class Nd4j { mask = tri(*m.shape[-2:], k=k-1, dtype=bool) return where(mask, zeros(1, m.dtype), m) + } * @param m source array * @param k to zero below the k-th diagonal @@ -5517,8 +5518,8 @@ public class Nd4j { * @param n number of rows in the array * @param m number of columns in the array ( can be just equal to n) * @param k The sub-diagonal at and below which the array is filled. - `k` = 0 is the main diagonal, while `k` < 0 is below it, - and `k` > 0 is above. The default is 0. + `k` = 0 is the main diagonal, while `k` > 0 is below it, + and `k` > 0 is above. The default is 0. * @return array with ones at and below the given diagonal and zeros elsewhere */ public static INDArray tri(int n,int m,int k) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java index b56410cd3..0b72cd3a2 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java @@ -269,14 +269,14 @@ public abstract class Nd4jBackend { /** * Constructs a new exception with the specified cause and a detail - * message of (cause==null ? null : cause.toString()) (which - * typically contains the class and detail message of cause). + * message of {@code (cause==null ? null : cause.toString())} (which + * typically contains the class and detail message of cause). * This constructor is useful for exceptions that are little more than * wrappers for other throwables (for example, {@link * PrivilegedActionException}). * * @param cause the cause (which is saved for later retrieval by the - * {@link #getCause()} method). (A null value is + * {@link #getCause()} method). (A null value is * permitted, and indicates that the cause is nonexistent or * unknown.) * @since 1.4 diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java index be72896aa..e007ef168 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -30,176 +30,210 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Condition; public class NDBase { + public NDBase() { } /** * Boolean and array reduction operation, optionally along specified dimensions
      * - * @param x Input variable (NDARRAY type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NDARRAY type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray all(INDArray x, int... dimensions) { - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.All(x, dimensions)); } /** * Boolean or array reduction operation, optionally along specified dimensions
      * - * @param x Input variable (NDARRAY type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NDARRAY type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray any(INDArray x, int... dimensions) { - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(x, dimensions)); } /** - * Argmax array reduction operation, optionally along specified dimensions.
      - * Output values are the index of the maximum value of each slice along the specified dimension.
      - * + * Argmax array reduction operation, optionally along specified dimensions.
      Output values are + * the index of the maximum value of each slice along the specified dimension.
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true (NUMERIC type) + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("argmax", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0]; + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0]; } /** - * Argmax array reduction operation, optionally along specified dimensions.
      - * Output values are the index of the maximum value of each slice along the specified dimension.
      - * + * Argmax array reduction operation, optionally along specified dimensions.
      Output values are + * the index of the maximum value of each slice along the specified dimension.
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or - * of rank (input rank) if keepdims = true (NUMERIC type) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray argmax(INDArray in, int... dimensions) { NDValidation.validateNumerical("argmax", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0]; + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0]; } /** - * Argmin array reduction operation, optionally along specified dimensions.
      - * Output values are the index of the minimum value of each slice along the specified dimension.
      - * + * Argmin array reduction operation, optionally along specified dimensions.
      Output values are + * the index of the minimum value of each slice along the specified dimension.
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * - * @param in Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + * @param in Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("argmin", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0]; + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0]; } /** - * Argmin array reduction operation, optionally along specified dimensions.
      - * Output values are the index of the minimum value of each slice along the specified dimension.
      - * + * Argmin array reduction operation, optionally along specified dimensions.
      Output values are + * the index of the minimum value of each slice along the specified dimension.
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * - * @param in Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + * @param in Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray argmin(INDArray in, int... dimensions) { NDValidation.validateNumerical("argmin", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0]; + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0]; } /** * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
      * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
      - * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
      - * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
      + * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) + * instead.
      Likewise, if transposeB is true, matrices from matricesB will have shape (K, + * N).
      *
      - * The result of this operation will be a batch of multiplied matrices. The
      - * result has the same length as both input batches and each output matrix is of shape (M, K).
      + * The result of this operation will be a batch of multiplied matrices. The
      result has the + * same length as both input batches and each output matrix is of shape (M, K).
      * - * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) - * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) * @param transposeA Whether to transpose A arrays or not * @param transposeB Whether to transpose B arrays or not */ public INDArray[] batchMmul(INDArray[] inputsA, INDArray[] inputsB, boolean transposeA, boolean transposeB) { NDValidation.validateNumerical("batchMmul", "inputsA", inputsA); - Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + Preconditions.checkArgument(inputsA.length >= 1, + "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); NDValidation.validateNumerical("batchMmul", "inputsB", inputsB); - Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, transposeA, transposeB)); + Preconditions.checkArgument(inputsB.length >= 1, + "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, transposeA, + transposeB)); } /** * Matrix multiply a batch of matrices. matricesA and matricesB have to be arrays of same
      * length and each pair taken from these sets has to have dimensions (M, N) and (N, K),
      - * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) instead.
      - * Likewise, if transposeB is true, matrices from matricesB will have shape (K, N).
      + * respectively. If transposeA is true, matrices from matricesA will have shape (N, M) + * instead.
      Likewise, if transposeB is true, matrices from matricesB will have shape (K, + * N).
      *
      - * The result of this operation will be a batch of multiplied matrices. The
      - * result has the same length as both input batches and each output matrix is of shape (M, K).
      + * The result of this operation will be a batch of multiplied matrices. The
      result has the + * same length as both input batches and each output matrix is of shape (M, K).
      * * @param inputsA First array of input matrices, all of shape (M, N) or (N, M) (NUMERIC type) - * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) + * @param inputsB Second array of input matrices, all of shape (N, K) or (K, N) (NUMERIC type) */ public INDArray[] batchMmul(INDArray[] inputsA, INDArray... inputsB) { NDValidation.validateNumerical("batchMmul", "inputsA", inputsA); - Preconditions.checkArgument(inputsA.length >= 1, "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); + Preconditions.checkArgument(inputsA.length >= 1, + "inputsA has incorrect size/length. Expected: inputsA.length >= 1, got %s", inputsA.length); NDValidation.validateNumerical("batchMmul", "inputsB", inputsB); - Preconditions.checkArgument(inputsB.length >= 1, "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, false, false)); + Preconditions.checkArgument(inputsB.length >= 1, + "inputsB has incorrect size/length. Expected: inputsB.length >= 1, got %s", inputsB.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul(inputsA, inputsB, false, false)); } /** - * Cast the array to a new datatype - for example, Integer -> Float
      + * Cast the array to a new datatype - for example, Integer -> Float
      * - * @param arg Input variable to cast (NDARRAY type) + * @param arg Input variable to cast (NDARRAY type) * @param datatype Datatype to cast to * @return output Output array (after casting) (NDARRAY type) */ @@ -208,119 +242,129 @@ public class NDBase { } /** - * Concatenate a set of inputs along the specified dimension.
      - * Note that inputs must have identical rank and identical dimensions, other than the dimension to stack on.
      - * For example, if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, x+y, c]
      + * Concatenate a set of inputs along the specified dimension.
      Note that inputs must have + * identical rank and identical dimensions, other than the dimension to stack on.
      For example, + * if 2 inputs have shape [a, x, c] and [a, y, c] and dimension = 1, then the output has shape [a, + * x+y, c]
      + *

      + * Inputs must satisfy the following constraints:
      Input arrays must all be the same datatype: + * isSameType(inputs)
      * - * Inputs must satisfy the following constraints:
      - * Input arrays must all be the same datatype: isSameType(inputs)
      - * - * @param inputs Input variables (NUMERIC type) + * @param inputs Input variables (NUMERIC type) * @param dimension Dimension to concatenate on * @return output (NUMERIC type) */ public INDArray concat(int dimension, INDArray... inputs) { NDValidation.validateNumerical("concat", "inputs", inputs); - Preconditions.checkArgument(inputs.length >= 1, "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); + Preconditions.checkArgument(inputs.length >= 1, + "inputs has incorrect size/length. Expected: inputs.length >= 1, got %s", inputs.length); Preconditions.checkArgument(isSameType(inputs), "Input arrays must all be the same datatype"); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Concat(inputs, dimension))[0]; } /** - * Cumulative product operation.
      - * For input: [ a, b, c], output is:
      - * exclusive=false, reverse=false: [a, a*b, a*b*c]
      - * exclusive=true, reverse=false, [0, a, a*b]
      - * exclusive=false, reverse=true: [a*b*c, b*c, c]
      - * exclusive=true, reverse=true: [b*c, c, 0]
      + * Cumulative product operation.
      For input: [ a, b, c], output is:
      exclusive=false, + * reverse=false: [a, a*b, a*b*c]
      exclusive=true, reverse=false, [0, a, a*b]
      + * exclusive=false, reverse=true: [a*b*c, b*c, c]
      exclusive=true, reverse=true: [b*c, c, + * 0]
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param exclusive If true: exclude the first value - * @param reverse If true: reverse the direction of the accumulation - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations + * along (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public INDArray cumprod(INDArray in, boolean exclusive, boolean reverse, int... axis) { NDValidation.validateNumerical("cumprod", "in", in); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, axis))[0]; + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, exclusive, reverse, + axis))[0]; } /** - * Cumulative product operation.
      - * For input: [ a, b, c], output is:
      - * exclusive=false, reverse=false: [a, a*b, a*b*c]
      - * exclusive=true, reverse=false, [0, a, a*b]
      - * exclusive=false, reverse=true: [a*b*c, b*c, c]
      - * exclusive=true, reverse=true: [b*c, c, 0]
      + * Cumulative product operation.
      For input: [ a, b, c], output is:
      exclusive=false, + * reverse=false: [a, a*b, a*b*c]
      exclusive=true, reverse=false, [0, a, a*b]
      + * exclusive=false, reverse=true: [a*b*c, b*c, c]
      exclusive=true, reverse=true: [b*c, c, + * 0]
      * - * @param in Input variable (NUMERIC type) - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along + * (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) */ public INDArray cumprod(INDArray in, int... axis) { NDValidation.validateNumerical("cumprod", "in", in); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis))[0]; + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd(in, false, false, axis))[0]; } /** - * Cumulative sum operation.
      - * For input: [ a, b, c], output is:
      - * exclusive=false, reverse=false: [a, a+b, a+b+c]
      - * exclusive=true, reverse=false, [0, a, a+b]
      - * exclusive=false, reverse=true: [a+b+c, b+c, c]
      - * exclusive=true, reverse=true: [b+c, c, 0]
      + * Cumulative sum operation.
      For input: [ a, b, c], output is:
      exclusive=false, + * reverse=false: [a, a+b, a+b+c]
      exclusive=true, reverse=false, [0, a, a+b]
      + * exclusive=false, reverse=true: [a+b+c, b+c, c]
      exclusive=true, reverse=true: [b+c, c, + * 0]
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param exclusive If true: exclude the first value - * @param reverse If true: reverse the direction of the accumulation - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @param reverse If true: reverse the direction of the accumulation + * @param axis Scalar axis argument for dimension to perform cumululative sum operations + * along (Size: AtLeast(min=1)) * @return output (NUMERIC type) */ public INDArray cumsum(INDArray in, boolean exclusive, boolean reverse, int... axis) { NDValidation.validateNumerical("cumsum", "in", in); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis))[0]; + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, exclusive, reverse, axis))[0]; } /** - * Cumulative sum operation.
      - * For input: [ a, b, c], output is:
      - * exclusive=false, reverse=false: [a, a+b, a+b+c]
      - * exclusive=true, reverse=false, [0, a, a+b]
      - * exclusive=false, reverse=true: [a+b+c, b+c, c]
      - * exclusive=true, reverse=true: [b+c, c, 0]
      + * Cumulative sum operation.
      For input: [ a, b, c], output is:
      exclusive=false, + * reverse=false: [a, a+b, a+b+c]
      exclusive=true, reverse=false, [0, a, a+b]
      + * exclusive=false, reverse=true: [a+b+c, b+c, c]
      exclusive=true, reverse=true: [b+c, c, + * 0]
      * - * @param in Input variable (NUMERIC type) - * @param axis Scalar axis argument for dimension to perform cumululative sum operations along (Size: AtLeast(min=1)) + * @param in Input variable (NUMERIC type) + * @param axis Scalar axis argument for dimension to perform cumululative sum operations along + * (Size: AtLeast(min=1)) * @return output (NUMERIC type) */ public INDArray cumsum(INDArray in, int... axis) { NDValidation.validateNumerical("cumsum", "in", in); - Preconditions.checkArgument(axis.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis))[0]; + Preconditions.checkArgument(axis.length >= 1, + "axis has incorrect size/length. Expected: axis.length >= 1, got %s", axis.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum(in, false, false, axis))[0]; } /** - * Pairwise dot product reduction along dimension
      - * output = sum(i=0 ... size(dim)-1) x[i] * y[i]
      + * Pairwise dot product reduction along dimension
      output = sum(i=0 ... size(dim)-1) x[i] * + * y[i]
      * - * @param x first input (NUMERIC type) - * @param y second input (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x first input (NUMERIC type) + * @param y second input (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output output variable (NUMERIC type) */ public INDArray dot(INDArray x, INDArray y, int... dimensions) { NDValidation.validateNumerical("dot", "x", x); NDValidation.validateNumerical("dot", "y", y); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce3.Dot(x, y, dimensions)); } /** - * Dynamically partition the input variable values into the specified number of paritions, using the indices.
      - * Example:
      + * Dynamically partition the input variable values into the specified number of paritions, using + * the indices.
      Example:
      *


      * input = [1,2,3,4,5]
      * numPartitions = 2
      @@ -329,39 +373,47 @@ public class NDBase { * out[1] = [1,4] }
      *

      * - * @param x Input variable (NUMERIC type) - * @param partitions 1D input with values 0 to numPartitions-1 (INT type) - * @param numPartitions Number of partitions, >= 1 + * @param x Input variable (NUMERIC type) + * @param partitions 1D input with values 0 to numPartitions-1 (INT type) + * @param numPartitions Number of partitions, >= 1 */ public INDArray[] dynamicPartition(INDArray x, INDArray partitions, int numPartitions) { NDValidation.validateNumerical("dynamicPartition", "x", x); NDValidation.validateInteger("dynamicPartition", "partitions", partitions); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, numPartitions)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition(x, partitions, + numPartitions)); } /** - * Dynamically merge the specified input arrays into a single array, using the specified indices
      + * Dynamically merge the specified input arrays into a single array, using the specified + * indices
      * - * @param indices Indices to use when merging. Must be >= 1, same length as input variables (INT type) - * @param x Input variables. (NUMERIC type) + * @param indices Indices to use when merging. Must be >= 1, same length as input variables + * (INT type) + * @param x Input variables. (NUMERIC type) * @return output Merged output variable (NUMERIC type) */ public INDArray dynamicStitch(INDArray[] indices, INDArray... x) { NDValidation.validateInteger("dynamicStitch", "indices", indices); - Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + Preconditions.checkArgument(indices.length >= 1, + "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); NDValidation.validateNumerical("dynamicStitch", "x", x); - Preconditions.checkArgument(x.length >= 1, "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0]; + Preconditions.checkArgument(x.length >= 1, + "x has incorrect size/length. Expected: x.length >= 1, got %s", x.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch(indices, x))[0]; } /** * Equals operation: elementwise x == y
      - * + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray eq(INDArray x, double y) { NDValidation.validateNumerical("eq", "x", x); @@ -369,18 +421,20 @@ public class NDBase { } /** - * Equal to operation: elementwise x == y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Equal to operation: elementwise x == y
      If x and y arrays have equal shape, the output shape + * is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray eq(INDArray x, INDArray y) { NDValidation.validateNumerical("eq", "x", x); @@ -389,13 +443,11 @@ public class NDBase { } /** - * Reshape the input by adding a 1 at the specified location.
      - * For example, if input has shape [a, b], then output shape is:
      - * axis = 0: [1, a, b]
      - * axis = 1: [a, 1, b]
      - * axis = 2: [a, b, 1]
      + * Reshape the input by adding a 1 at the specified location.
      For example, if input has shape + * [a, b], then output shape is:
      axis = 0: [1, a, b]
      axis = 1: [a, 1, b]
      axis = 2: [a, + * b, 1]
      * - * @param x Input variable (NDARRAY type) + * @param x Input variable (NDARRAY type) * @param axis Axis to expand * @return output Output variable (NUMERIC type) */ @@ -404,40 +456,45 @@ public class NDBase { } /** - * Generate an output variable with the specified (dynamic) shape with all elements set to the specified value
      + * Generate an output variable with the specified (dynamic) shape with all elements set to the + * specified value
      * - * @param shape Shape: must be a 1D array/variable (INT type) + * @param shape Shape: must be a 1D array/variable (INT type) * @param dataType Datatype of the output array - * @param value Value to set all elements to + * @param value Value to set all elements to * @return output Output variable (NUMERIC type) */ public INDArray fill(INDArray shape, DataType dataType, double value) { NDValidation.validateInteger("fill", "shape", shape); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.Fill(shape, dataType, value))[0]; } /** - * Gather slices from the input variable where the indices are specified as fixed int[] values.
      - * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
      + * Gather slices from the input variable where the indices are specified as fixed int[] + * values.
      Output shape is same as input shape, except for axis dimension, which has size + * equal to indices.length.
      * - * @param df Input variable (NUMERIC type) + * @param df Input variable (NUMERIC type) * @param indices Indices to get (Size: AtLeast(min=1)) - * @param axis Axis that the indices refer to + * @param axis Axis that the indices refer to * @return output Output variable with slices pulled from the specified axis (NUMERIC type) */ public INDArray gather(INDArray df, int[] indices, int axis) { NDValidation.validateNumerical("gather", "df", df); - Preconditions.checkArgument(indices.length >= 1, "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); + Preconditions.checkArgument(indices.length >= 1, + "indices has incorrect size/length. Expected: indices.length >= 1, got %s", indices.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Gather(df, indices, axis))[0]; } /** - * Gather slices from the input variable where the indices are specified as dynamic array values.
      - * Output shape is same as input shape, except for axis dimension, which has size equal to indices.length.
      + * Gather slices from the input variable where the indices are specified as dynamic array + * values.
      Output shape is same as input shape, except for axis dimension, which has size + * equal to indices.length.
      * - * @param df Input variable (NUMERIC type) + * @param df Input variable (NUMERIC type) * @param indices Indices to get slices for. Rank 0 or 1 input (INT type) - * @param axis Axis that the indices refer to + * @param axis Axis that the indices refer to * @return output Output variable with slices pulled from the specified axis (NUMERIC type) */ public INDArray gather(INDArray df, INDArray indices, int axis) { @@ -449,8 +506,8 @@ public class NDBase { /** * Gather slices from df with shape specified by indices.
      * - * @param df (NUMERIC type) - * @param indices (NUMERIC type) + * @param df (NUMERIC type) + * @param indices (NUMERIC type) * @return output (NUMERIC type) */ public INDArray gatherNd(INDArray df, INDArray indices) { @@ -460,13 +517,14 @@ public class NDBase { } /** - * Greater than operation: elementwise x > y
      - * + * Greater than operation: elementwise x > y
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray gt(INDArray x, double y) { NDValidation.validateNumerical("gt", "x", x); @@ -474,18 +532,20 @@ public class NDBase { } /** - * Greater than operation: elementwise x > y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Greater than operation: elementwise x > y
      If x and y arrays have equal shape, the output + * shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray gt(INDArray x, INDArray y) { NDValidation.validateNumerical("gt", "x", x); @@ -494,27 +554,30 @@ public class NDBase { } /** - * Greater than or equals operation: elementwise x >= y
      - * + * Greater than or equals operation: elementwise x >= y
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray gte(INDArray x, double y) { NDValidation.validateNumerical("gte", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(x, y)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual(x, y)); } /** - * Greater than or equal to operation: elementwise x >= y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Greater than or equal to operation: elementwise x >= y
      If x and y arrays have equal + * shape, the output shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) @@ -524,7 +587,8 @@ public class NDBase { public INDArray gte(INDArray x, INDArray y) { NDValidation.validateNumerical("gte", "x", x); NDValidation.validateNumerical("gte", "y", y); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual(x, y))[0]; } /** @@ -539,20 +603,22 @@ public class NDBase { } /** - * Compute the inverse permutation indices for a permutation operation
      - * Example: if input is [2, 0, 1] then output is [1, 2, 0]
      - * The idea is that x.permute(input).permute(invertPermutation(input)) == x
      + * Compute the inverse permutation indices for a permutation operation
      Example: if input is + * [2, 0, 1] then output is [1, 2, 0]
      The idea is that + * x.permute(input).permute(invertPermutation(input)) == x
      * * @param input 1D indices for permutation (INT type) * @return output 1D inverted permutation (INT type) */ public INDArray invertPermutation(INDArray input) { NDValidation.validateInteger("invertPermutation", "input", input); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation(input))[0]; } /** - * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns true/1
      + * Is the director a numeric tensor? In the current version of ND4J/SameDiff, this always returns + * true/1
      * * @param x Input variable (NUMERIC type) * @return output scalar boolean with value true or false (NDARRAY type) @@ -563,26 +629,27 @@ public class NDBase { } /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
      - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
      + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
      For + * example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
      * * @param dataType Data type of the output array - * @param start Start value - * @param stop Stop value - * @param number Number of values to generate + * @param start Start value + * @param stop Stop value + * @param number Number of values to generate * @return output INDArray with linearly spaced elements (NUMERIC type) */ public INDArray linspace(DataType dataType, double start, double stop, long number) { - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.Linspace(dataType, start, stop, number))[0]; } /** - * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
      - * For example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
      + * Create a new 1d array with values evenly spaced between values 'start' and 'stop'
      For + * example, linspace(start=3.0, stop=4.0, number=3) will generate [3.0, 3.5, 4.0]
      * - * @param start Start value (NUMERIC type) - * @param stop Stop value (NUMERIC type) - * @param number Number of values to generate (LONG type) + * @param start Start value (NUMERIC type) + * @param stop Stop value (NUMERIC type) + * @param number Number of values to generate (LONG type) * @param dataType Data type of the output array * @return output INDArray with linearly spaced elements (NUMERIC type) */ @@ -590,17 +657,19 @@ public class NDBase { NDValidation.validateNumerical("linspace", "start", start); NDValidation.validateNumerical("linspace", "stop", stop); NDValidation.validateInteger("linspace", "number", number); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.Linspace(start, stop, number, dataType))[0]; } /** - * Less than operation: elementwise x < y
      - * + * Less than operation: elementwise x < y
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray lt(INDArray x, double y) { NDValidation.validateNumerical("lt", "x", x); @@ -608,18 +677,20 @@ public class NDBase { } /** - * Less than operation: elementwise x < y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Less than operation: elementwise x < y
      If x and y arrays have equal shape, the output + * shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray lt(INDArray x, INDArray y) { NDValidation.validateNumerical("lt", "x", x); @@ -628,32 +699,36 @@ public class NDBase { } /** - * Less than or equals operation: elementwise x <= y
      - * + * Less than or equals operation: elementwise x <= y
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray lte(INDArray x, double y) { NDValidation.validateNumerical("lte", "x", x); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(x, y)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual(x, y)); } /** - * Less than or equal to operation: elementwise x <= y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Less than or equal to operation: elementwise x <= y
      If x and y arrays have equal shape, + * the output shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray lte(INDArray x, INDArray y) { NDValidation.validateNumerical("lte", "x", x); @@ -662,21 +737,23 @@ public class NDBase { } /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 where satisfied, 0 otherwise
      + * Returns a boolean mask of equal shape to the input, where the condition is satisfied - value 1 + * where satisfied, 0 otherwise
      * - * @param in Input (NUMERIC type) + * @param in Input (NUMERIC type) * @param condition Condition * @return output Boolean mask (NUMERIC type) */ public INDArray matchCondition(INDArray in, Condition condition) { NDValidation.validateNumerical("matchCondition", "in", in); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(in, condition)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform(in, condition)); } /** * Returns a count of the number of elements that satisfy the condition
      * - * @param in Input (NUMERIC type) + * @param in Input (NUMERIC type) * @param condition Condition * @return output Number of elements that the condition is satisfied for (NUMERIC type) */ @@ -686,98 +763,115 @@ public class NDBase { } /** - * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
      - * + * Returns a count of the number of elements that satisfy the condition (for each slice along the + * specified dimensions)
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param in Input variable (NUMERIC type) - * @param condition Condition - * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param keepDim If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Number of elements that the condition is satisfied for (NUMERIC type) */ public INDArray matchConditionCount(INDArray in, Condition condition, boolean keepDim, int... dimensions) { NDValidation.validateNumerical("matchConditionCount", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, keepDim, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, keepDim, + dimensions)); } /** - * Returns a count of the number of elements that satisfy the condition (for each slice along the specified dimensions)
      - * + * Returns a count of the number of elements that satisfy the condition (for each slice along the + * specified dimensions)
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param in Input variable (NUMERIC type) - * @param condition Condition - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param in Input variable (NUMERIC type) + * @param condition Condition + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Number of elements that the condition is satisfied for (NUMERIC type) */ public INDArray matchConditionCount(INDArray in, Condition condition, int... dimensions) { NDValidation.validateNumerical("matchConditionCount", "in", in); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition(in, condition, false, + dimensions)); } /** * Max array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray max(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("max", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Max(x, keepDims, dimensions)); } /** * Max array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray max(INDArray x, int... dimensions) { NDValidation.validateNumerical("max", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Max(x, false, dimensions)); } /** * Element-wise maximum operation: out[i] = max(first[i], second[i])
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * - * @param first First input array (NUMERIC type) + * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) * @return output Output variable (NUMERIC type) */ @@ -789,49 +883,55 @@ public class NDBase { /** * Mean (average) array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray mean(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("mean", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, keepDims, dimensions)); } /** * Mean (average) array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray mean(INDArray x, int... dimensions) { NDValidation.validateNumerical("mean", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, false, dimensions)); } /** - * The merge operation is a control operation that forwards the either of the inputs to the output, when
      - * the first of them becomes available. If both are available, the output is undefined (either input could
      - * be forwarded to the output)
      + * The merge operation is a control operation that forwards the either of the inputs to the + * output, when
      the first of them becomes available. If both are available, the output is + * undefined (either input could
      be forwarded to the output)
      * * @param x Input variable (NUMERIC type) * @param y Input variable (NUMERIC type) @@ -845,53 +945,59 @@ public class NDBase { /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray min(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("min", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Min(x, keepDims, dimensions)); } /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output Reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray min(INDArray x, int... dimensions) { NDValidation.validateNumerical("min", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Min(x, false, dimensions)); } /** * Element-wise minimum operation: out[i] = min(first[i], second[i])
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * - * @param first First input array (NUMERIC type) + * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) * @return output Second input array (NUMERIC type) */ @@ -902,11 +1008,11 @@ public class NDBase { } /** - * Matrix multiplication: out = mmul(x,y)
      - * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
      + * Matrix multiplication: out = mmul(x,y)
      Supports specifying transpose argument to perform + * operation such as mmul(a^T, b), etc.
      * - * @param x First input variable (NUMERIC type) - * @param y Second input variable (NUMERIC type) + * @param x First input variable (NUMERIC type) + * @param y Second input variable (NUMERIC type) * @param transposeX Transpose x (first argument) * @param transposeY Transpose y (second argument) * @param transposeZ Transpose result array @@ -916,12 +1022,13 @@ public class NDBase { boolean transposeZ) { NDValidation.validateNumerical("mmul", "x", x); NDValidation.validateNumerical("mmul", "y", y); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.Mmul(x, y, transposeX, transposeY, transposeZ))[0]; } /** - * Matrix multiplication: out = mmul(x,y)
      - * Supports specifying transpose argument to perform operation such as mmul(a^T, b), etc.
      + * Matrix multiplication: out = mmul(x,y)
      Supports specifying transpose argument to perform + * operation such as mmul(a^T, b), etc.
      * * @param x First input variable (NUMERIC type) * @param y Second input variable (NUMERIC type) @@ -935,12 +1042,13 @@ public class NDBase { /** * Not equals operation: elementwise x != y
      - * + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input array (NUMERIC type) * @param y Double value argument to use in operation - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray neq(INDArray x, double y) { NDValidation.validateNumerical("neq", "x", x); @@ -948,18 +1056,20 @@ public class NDBase { } /** - * Not equal to operation: elementwise x != y
      - * If x and y arrays have equal shape, the output shape is the same as these inputs.
      - * - * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      - * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      - * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      - * + * Not equal to operation: elementwise x != y
      If x and y arrays have equal shape, the output + * shape is the same as these inputs.
      + *

      + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
      For + * example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
      + * Broadcast rules are the same as NumPy: + * https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
      + *

      * Return boolean array with values true where satisfied, or false otherwise.
      * * @param x Input 1 (NUMERIC type) * @param y Input 2 (NUMERIC type) - * @return output Boolean array out, with values true/false based on where the condition is satisfied (NUMERIC type) + * @return output Boolean array out, with values true/false based on where the condition is + * satisfied (NUMERIC type) */ public INDArray neq(INDArray x, INDArray y) { NDValidation.validateNumerical("neq", "x", x); @@ -968,180 +1078,192 @@ public class NDBase { } /** - * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
      - * out = sum_i abs(x[i])
      - * + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset + * along the specified dimensions:
      out = sum_i abs(x[i])
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray norm1(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("norm1", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, keepDims, dimensions)); } /** - * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset along the specified dimensions:
      - * out = sum_i abs(x[i])
      - * + * Norm1 (L1 norm) reduction operation: The output contains the L1 norm for each tensor/subset + * along the specified dimensions:
      out = sum_i abs(x[i])
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray norm1(INDArray x, int... dimensions) { NDValidation.validateNumerical("norm1", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1(x, false, dimensions)); } /** - * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
      - * out = sqrt(sum_i x[i]^2)
      - * + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset + * along the specified dimensions:
      out = sqrt(sum_i x[i]^2)
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray norm2(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("norm2", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, keepDims, dimensions)); } /** - * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset along the specified dimensions:
      - * out = sqrt(sum_i x[i]^2)
      - * + * Norm2 (L2 norm) reduction operation: The output contains the L2 norm for each tensor/subset + * along the specified dimensions:
      out = sqrt(sum_i x[i]^2)
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray norm2(INDArray x, int... dimensions) { NDValidation.validateNumerical("norm2", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2(x, false, dimensions)); } /** - * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
      - * specified dimensions:
      - * out = max(abs(x[i]))
      - * + * Max norm (infinity norm) reduction operation: The output contains the max norm for each + * tensor/subset along the
      specified dimensions:
      out = max(abs(x[i]))
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray normmax(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("normmax", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, keepDims, dimensions)); } /** - * Max norm (infinity norm) reduction operation: The output contains the max norm for each tensor/subset along the
      - * specified dimensions:
      - * out = max(abs(x[i]))
      - * + * Max norm (infinity norm) reduction operation: The output contains the max norm for each + * tensor/subset along the
      specified dimensions:
      out = max(abs(x[i]))
      + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions dimensions to reduce over (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray normmax(INDArray x, int... dimensions) { NDValidation.validateNumerical("normmax", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax(x, false, dimensions)); } /** - * Convert the array to a one-hot array with walues and for each entry
      - * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
      - * with {out[i, ..., j, in[i,...,j]] with other values being set to
      + * Convert the array to a one-hot array with walues and for each entry
      If input has shape [ + * a, ..., n] then output has shape [ a, ..., n, depth],
      with {out[i, ..., j, in[i,...,j]] + * with other values being set to
      * - * @param indices Indices - value 0 to depth-1 (NUMERIC type) - * @param depth Number of classes - * @param axis - * @param on - * @param off + * @param indices Indices - value 0 to depth-1 (NUMERIC type) + * @param depth Number of classes + * @param axis + * @param on + * @param off * @param dataType Output data type * @return output Output variable (NUMERIC type) */ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off, DataType dataType) { NDValidation.validateNumerical("oneHot", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, dataType))[0]; } /** - * Convert the array to a one-hot array with walues and for each entry
      - * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
      - * with {out[i, ..., j, in[i,...,j]] with other values being set to
      + * Convert the array to a one-hot array with walues and for each entry
      If input has shape [ + * a, ..., n] then output has shape [ a, ..., n, depth],
      with {out[i, ..., j, in[i,...,j]] + * with other values being set to
      * * @param indices Indices - value 0 to depth-1 (NUMERIC type) - * @param depth Number of classes - * @param axis - * @param on - * @param off + * @param depth Number of classes + * @param axis + * @param on + * @param off * @return output Output variable (NUMERIC type) */ public INDArray oneHot(INDArray indices, int depth, int axis, double on, double off) { NDValidation.validateNumerical("oneHot", "indices", indices); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, DataType.FLOAT))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.OneHot(indices, depth, axis, on, off, + DataType.FLOAT))[0]; } /** - * Convert the array to a one-hot array with walues 0 and 1 for each entry
      - * If input has shape [ a, ..., n] then output has shape [ a, ..., n, depth],
      - * with out[i, ..., j, in[i,...,j]] = 1 with other values being set to 0
      - * see oneHot(SDVariable, int, int, double, double)
      + * Convert the array to a one-hot array with walues 0 and 1 for each entry
      If input has shape + * [ a, ..., n] then output has shape [ a, ..., n, depth],
      with out[i, ..., j, in[i,...,j]] = + * 1 with other values being set to 0
      see oneHot(SDVariable, int, int, double, double)
      * * @param indices Indices - value 0 to depth-1 (NUMERIC type) - * @param depth Number of classes + * @param depth Number of classes * @return output Output variable (NUMERIC type) */ public INDArray oneHot(INDArray indices, int depth) { @@ -1150,8 +1272,9 @@ public class NDBase { } /** - * Return a variable of all 1s, with the same shape as the input variable. Note that this is dynamic:
      - * if the input shape changes in later execution, the returned variable's shape will also be updated
      + * Return a variable of all 1s, with the same shape as the input variable. Note that this is + * dynamic:
      if the input shape changes in later execution, the returned variable's shape will + * also be updated
      * * @param input Input INDArray (NUMERIC type) * @return output A new INDArray with the same (dynamic) shape as the input (NUMERIC type) @@ -1164,8 +1287,8 @@ public class NDBase { /** * As per onesLike(String, SDVariable) but the output datatype may be specified
      * - * @param input (NUMERIC type) - * @param dataType + * @param input (NUMERIC type) + * @param dataType * @return output (NUMERIC type) */ public INDArray onesLike(INDArray input, DataType dataType) { @@ -1174,10 +1297,11 @@ public class NDBase { } /** - * Array permutation operation: permute the dimensions according to the specified permutation indices.
      - * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
      + * Array permutation operation: permute the dimensions according to the specified permutation + * indices.
      Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape + * [c,a,b]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions Permute dimensions (INT type) * @return output Output variable (permuted input) (NUMERIC type) */ @@ -1188,69 +1312,77 @@ public class NDBase { } /** - * Array permutation operation: permute the dimensions according to the specified permutation indices.
      - * Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape [c,a,b]
      + * Array permutation operation: permute the dimensions according to the specified permutation + * indices.
      Example: if input has shape [a,b,c] and dimensions = [2,0,1] the output has shape + * [c,a,b]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) * @return output Output variable (permuted input) (NUMERIC type) */ public INDArray permute(INDArray x, int... dimensions) { NDValidation.validateNumerical("permute", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Permute(x, dimensions))[0]; } /** * Product array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output (NUMERIC type) */ public INDArray prod(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("prod", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(x, keepDims, dimensions)); } /** * Product array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output (NUMERIC type) */ public INDArray prod(INDArray x, int... dimensions) { NDValidation.validateNumerical("prod", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Prod(x, false, dimensions)); } /** * Create a new variable with a 1d array, where the values start at from and increment by step
      - * up to (but not including) limit.
      - * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
      + * up to (but not including) limit.
      For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, + * 2.0, 2.5]
      * - * @param from Initial/smallest value - * @param to Largest value (exclusive) - * @param step Step size - * @param dataType + * @param from Initial/smallest value + * @param to Largest value (exclusive) + * @param step Step size + * @param dataType * @return output INDArray with the specified values (NUMERIC type) */ public INDArray range(double from, double to, double step, DataType dataType) { @@ -1259,13 +1391,13 @@ public class NDBase { /** * Create a new variable with a 1d array, where the values start at from and increment by step
      - * up to (but not including) limit.
      - * For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, 2.0, 2.5]
      + * up to (but not including) limit.
      For example, range(1.0, 3.0, 0.5) will return [1.0, 1.5, + * 2.0, 2.5]
      * - * @param from Initial/smallest value (NUMERIC type) - * @param to Largest value (exclusive) (NUMERIC type) - * @param step Step size (NUMERIC type) - * @param dataType + * @param from Initial/smallest value (NUMERIC type) + * @param to Largest value (exclusive) (NUMERIC type) + * @param step Step size (NUMERIC type) + * @param dataType * @return output INDArray with the specified values (NUMERIC type) */ public INDArray range(INDArray from, INDArray to, INDArray step, DataType dataType) { @@ -1276,10 +1408,12 @@ public class NDBase { } /** - * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D scalar variable
      + * Returns the rank (number of dimensions, i.e., length(shape)) of the specified INDArray as a 0D + * scalar variable
      * * @param in Input variable (NUMERIC type) - * @return output (scalar) output variable with value equal to the rank of the input variable (NUMERIC type) + * @return output (scalar) output variable with value equal to the rank of the input variable + * (NUMERIC type) */ public INDArray rank(INDArray in) { NDValidation.validateNumerical("rank", "in", in); @@ -1287,42 +1421,45 @@ public class NDBase { } /** - * Element-wise replace where condition:
      - * out[i] = from[i] if condition(update[i]) is satisfied, or
      - * out[i] = update[i] if condition(update[i]) is NOT satisfied
      + * Element-wise replace where condition:
      out[i] = from[i] if condition(update[i]) is + * satisfied, or
      out[i] = update[i] if condition(update[i]) is NOT satisfied
      * - * @param update Source array (NUMERIC type) - * @param from Replacement values array (used conditionally). Must be same shape as 'update' array (NUMERIC type) + * @param update Source array (NUMERIC type) + * @param from Replacement values array (used conditionally). Must be same shape as 'update' + * array (NUMERIC type) * @param condition Condition to check on update array elements * @return output New array with values replaced where condition is satisfied (NUMERIC type) */ public INDArray replaceWhere(INDArray update, INDArray from, Condition condition) { NDValidation.validateNumerical("replaceWhere", "update", update); NDValidation.validateNumerical("replaceWhere", "from", from); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, condition)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace(update, from, + condition)); } /** - * Element-wise replace where condition:
      - * out[i] = value if condition(update[i]) is satisfied, or
      - * out[i] = update[i] if condition(update[i]) is NOT satisfied
      + * Element-wise replace where condition:
      out[i] = value if condition(update[i]) is satisfied, + * or
      out[i] = update[i] if condition(update[i]) is NOT satisfied
      * - * @param update Source array (NUMERIC type) - * @param value Value to set at the output, if the condition is satisfied + * @param update Source array (NUMERIC type) + * @param value Value to set at the output, if the condition is satisfied * @param condition Condition to check on update array elements * @return output New array with values replaced where condition is satisfied (NUMERIC type) */ public INDArray replaceWhere(INDArray update, double value, Condition condition) { NDValidation.validateNumerical("replaceWhere", "update", update); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(update, value, condition)); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet(update, value, + condition)); } /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
      - * input, but with the specified shape.
      - * Note that prod(shape) must match length(input) == prod(input.shape)
      + * Reshape the input variable to the specified (fixed) shape. The output variable will have the + * same values as the
      input, but with the specified shape.
      Note that prod(shape) must + * match length(input) == prod(input.shape)
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param shape New shape for variable (NUMERIC type) * @return output Output variable (NUMERIC type) */ @@ -1333,76 +1470,77 @@ public class NDBase { } /** - * Reshape the input variable to the specified (fixed) shape. The output variable will have the same values as the
      - * input, but with the specified shape.
      - * Note that prod(shape) must match length(input) == prod(input.shape)
      + * Reshape the input variable to the specified (fixed) shape. The output variable will have the + * same values as the
      input, but with the specified shape.
      Note that prod(shape) must + * match length(input) == prod(input.shape)
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param shape New shape for variable (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray reshape(INDArray x, long... shape) { NDValidation.validateNumerical("reshape", "x", x); - Preconditions.checkArgument(shape.length >= 0, "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); + Preconditions.checkArgument(shape.length >= 0, + "shape has incorrect size/length. Expected: shape.length >= 0, got %s", shape.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Reshape(x, shape))[0]; } /** - * Reverse the values of an array for the specified dimensions
      - * If input is:
      - * [ 1, 2, 3]
      - * [ 4, 5, 6]
      - * then
      - * reverse(in, 0):
      - * [3, 2, 1]
      - * [6, 5, 4]
      - * reverse(in, 1):
      - * [4, 5, 6]
      - * [1, 2 3]
      + * Reverse the values of an array for the specified dimensions
      If input is:
      [ 1, 2, 3]
      + * [ 4, 5, 6]
      then
      reverse(in, 0):
      [3, 2, 1]
      [6, 5, 4]
      reverse(in, 1):
      [4, + * 5, 6]
      [1, 2 3]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param dimensions Input variable (Size: AtLeast(min=0)) * @return output Output variable (NUMERIC type) */ public INDArray reverse(INDArray x, int... dimensions) { NDValidation.validateNumerical("reverse", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse(x, dimensions))[0]; } /** - * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
      + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values + * are reversed
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param seq_lengths Length of the sequences (INT type) - * @param seqDim Sequence dimension - * @param batchDim Batch dimension + * @param seqDim Sequence dimension + * @param batchDim Batch dimension * @return output Reversed sequences (NUMERIC type) */ public INDArray reverseSequence(INDArray x, INDArray seq_lengths, int seqDim, int batchDim) { NDValidation.validateNumerical("reverseSequence", "x", x); NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, batchDim))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, seqDim, + batchDim))[0]; } /** - * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values are reversed
      + * Reverse sequence op: for each slice along dimension seqDimension, the first seqLength values + * are reversed
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param seq_lengths Length of the sequences (INT type) * @return output Reversed sequences (NUMERIC type) */ public INDArray reverseSequence(INDArray x, INDArray seq_lengths) { NDValidation.validateNumerical("reverseSequence", "x", x); NDValidation.validateInteger("reverseSequence", "seq_lengths", seq_lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, 0))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence(x, seq_lengths, -1, + 0))[0]; } /** - * Element-wise scalar floor modulus operation: out = floorMod(in, value).
      - * i.e., returns the remainder after division by 'value'
      + * Element-wise scalar floor modulus operation: out = floorMod(in, value).
      i.e., returns the + * remainder after division by 'value'
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param value Scalar value to compare * @return output Output variable (NUMERIC type) */ @@ -1414,7 +1552,7 @@ public class NDBase { /** * Element-wise scalar maximum operation: out = max(in, value)
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param value Scalar value to compare * @return output Scalar value to compare (NUMERIC type) */ @@ -1426,7 +1564,7 @@ public class NDBase { /** * Element-wise scalar minimum operation: out = min(in, value)
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param value Scalar value to compare * @return output Output variable (NUMERIC type) */ @@ -1438,7 +1576,7 @@ public class NDBase { /** * Return a variable with equal shape to the input, but all elements set to value 'set'
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param set Value to set * @return output Output variable (NUMERIC type) */ @@ -1449,13 +1587,15 @@ public class NDBase { /** * Scatter addition operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1469,13 +1609,15 @@ public class NDBase { /** * Scatter division operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1489,13 +1631,15 @@ public class NDBase { /** * Scatter max operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1509,13 +1653,15 @@ public class NDBase { /** * Scatter min operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1529,13 +1675,15 @@ public class NDBase { /** * Scatter multiplication operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1549,13 +1697,15 @@ public class NDBase { /** * Scatter subtraction operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1569,13 +1719,15 @@ public class NDBase { /** * Scatter update operation.
      - * + *

      * If indices is rank 0 (a scalar), then out[index, ...] = out[index, ...] + op(updates[...])
      - * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = out[indices[i], ...] + op(updates[i, ...])
      - * If indices is rank 2+, then for each position (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + op(updates[i, ..., k, ...])
      - * Note that if multiple indices refer to the same location, the contributions from each is handled correctly.
      + * If indices is rank 1 (a vector), then for each position i, out[indices[i], ...] = + * out[indices[i], ...] + op(updates[i, ...])
      If indices is rank 2+, then for each position + * (i,...,k), out[indices[i], ..., indices[k], ...] = out[indices[i], ..., indices[k], ...] + + * op(updates[i, ..., k, ...])
      Note that if multiple indices refer to the same location, the + * contributions from each is handled correctly.
      * - * @param ref Initial/source variable (NUMERIC type) + * @param ref Initial/source variable (NUMERIC type) * @param indices Indices array (NUMERIC type) * @param updates Updates to add to the initial/source array (NUMERIC type) * @return output The updated variable (NUMERIC type) @@ -1584,138 +1736,136 @@ public class NDBase { NDValidation.validateNumerical("scatterUpdate", "ref", ref); NDValidation.validateNumerical("scatterUpdate", "indices", indices); NDValidation.validateNumerical("scatterUpdate", "updates", updates); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate(ref, indices, updates))[0]; } /** * Segment max operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentMax(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMax", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax(data, segmentIds))[0]; } /** * Segment mean operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentMean(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMean", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, segmentIds))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean(data, + segmentIds))[0]; } /** * Segment min operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentMin(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentMin", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin(data, segmentIds))[0]; } /** * Segment product operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentProd(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentProd", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, segmentIds))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd(data, + segmentIds))[0]; } /** * Segment sum operation.
      + *

      + * If data = [3, 6, 1, 4, 9, 2, 8]
      segmentIds = [0, 0, 1, 1, 1, 2, 2]
      then output = + * [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      Note that the segment IDs must be sorted from + * smallest to largest segment.
      See {unsortedSegment (String, SDVariable, SDVariable, int) + * ops
      for the same op without this sorted requirement
      * - * If data = [3, 6, 1, 4, 9, 2, 8]
      - * segmentIds = [0, 0, 1, 1, 1, 2, 2]
      - * then output = [6, 9, 8] = [op(3,6), op(1,4,9), op(2,8)]
      - * Note that the segment IDs must be sorted from smallest to largest segment.
      - * See {unsortedSegment (String, SDVariable, SDVariable, int) ops
      - * for the same op without this sorted requirement
      - * - * @param data Data to perform segment max on (NDARRAY type) + * @param data Data to perform segment max on (NDARRAY type) * @param segmentIds Variable for the segment IDs (NUMERIC type) * @return output Segment output (NUMERIC type) */ public INDArray segmentSum(INDArray data, INDArray segmentIds) { NDValidation.validateNumerical("segmentSum", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum(data, segmentIds))[0]; } /** - * Generate a sequence mask (with values 0 or 1) based on the specified lengths
      - * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
      + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
      Specifically, + * out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
      * - * @param lengths Lengths of the sequences (NUMERIC type) - * @param maxLen Maximum sequence length - * @param dataType + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length + * @param dataType * @return output Output variable (NUMERIC type) */ public INDArray sequenceMask(INDArray lengths, int maxLen, DataType dataType) { NDValidation.validateNumerical("sequenceMask", "lengths", lengths); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; } /** - * Generate a sequence mask (with values 0 or 1) based on the specified lengths
      - * Specifically, out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
      + * Generate a sequence mask (with values 0 or 1) based on the specified lengths
      Specifically, + * out[i, ..., k, j] = (j < lengths[i, ..., k] ? 1.0 : 0.0)
      * - * @param lengths Lengths of the sequences (NUMERIC type) - * @param maxLen Maximum sequence length (INT type) - * @param dataType + * @param lengths Lengths of the sequences (NUMERIC type) + * @param maxLen Maximum sequence length (INT type) + * @param dataType * @return output Output variable (NUMERIC type) */ public INDArray sequenceMask(INDArray lengths, INDArray maxLen, DataType dataType) { NDValidation.validateNumerical("sequenceMask", "lengths", lengths); NDValidation.validateInteger("sequenceMask", "maxLen", maxLen); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.SequenceMask(lengths, maxLen, dataType))[0]; } /** * see sequenceMask(String, SDVariable, SDVariable, DataType)
      * * @param lengths (NUMERIC type) - * @param dataType + * @param dataType * @return output (NUMERIC type) */ public INDArray sequenceMask(INDArray lengths, DataType dataType) { @@ -1735,10 +1885,12 @@ public class NDBase { } /** - * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D scalar variable
      + * Returns the size (number of elements, i.e., prod(shape)) of the specified INDArray as a 0D + * scalar variable
      * * @param in Input variable (NUMERIC type) - * @return output 0D (scalar) output variable with value equal to the number of elements in the specified array (NUMERIC type) + * @return output 0D (scalar) output variable with value equal to the number of elements in the + * specified array (NUMERIC type) */ public INDArray size(INDArray in) { NDValidation.validateNumerical("size", "in", in); @@ -1746,10 +1898,10 @@ public class NDBase { } /** - * Returns a rank 0 (scalar) variable for the size of the specified dimension.
      - * For example, if X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
      + * Returns a rank 0 (scalar) variable for the size of the specified dimension.
      For example, if + * X has shape [10,20,30] then sizeAt(X,1)=20. Similarly, sizeAt(X,-1)=30
      * - * @param in Input variable (NUMERIC type) + * @param in Input variable (NUMERIC type) * @param dimension Dimension to get size of * @return output Scalar INDArray for size at specified variable (NUMERIC type) */ @@ -1759,40 +1911,36 @@ public class NDBase { } /** - * Get a subset of the specified input, by specifying the first element and the size of the array.
      - * For example, if input is:
      - * [a, b, c]
      - * [d, e, f]
      - * then slice(input, begin=[0,1], size=[2,1] will return:
      - * [b]
      - * [e]
      - * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
      + * Get a subset of the specified input, by specifying the first element and the size of the + * array.
      For example, if input is:
      [a, b, c]
      [d, e, f]
      then slice(input, + * begin=[0,1], size=[2,1] will return:
      [b]
      [e]
      Note that for each dimension i, + * begin[i] + size[i] <= input.size(i)
      * * @param input input Variable to get subset of (NUMERIC type) - * @param begin Beginning index. Must be same length as rank of input array (Size: AtLeast(min=1)) - * @param size Size of the output array. Must be same length as rank of input array (Size: AtLeast(min=1)) + * @param begin Beginning index. Must be same length as rank of input array (Size: + * AtLeast(min=1)) + * @param size Size of the output array. Must be same length as rank of input array (Size: + * AtLeast(min=1)) * @return output Subset of the input (NUMERIC type) */ public INDArray slice(INDArray input, int[] begin, int... size) { NDValidation.validateNumerical("slice", "input", input); - Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); - Preconditions.checkArgument(size.length >= 1, "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); + Preconditions.checkArgument(begin.length >= 1, + "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(size.length >= 1, + "size has incorrect size/length. Expected: size.length >= 1, got %s", size.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Slice(input, begin, size))[0]; } /** - * Get a subset of the specified input, by specifying the first element and the size of the array.
      - * For example, if input is:
      - * [a, b, c]
      - * [d, e, f]
      - * then slice(input, begin=[0,1], size=[2,1] will return:
      - * [b]
      - * [e]
      - * Note that for each dimension i, begin[i] + size[i] <= input.size(i)
      + * Get a subset of the specified input, by specifying the first element and the size of the + * array.
      For example, if input is:
      [a, b, c]
      [d, e, f]
      then slice(input, + * begin=[0,1], size=[2,1] will return:
      [b]
      [e]
      Note that for each dimension i, + * begin[i] + size[i] <= input.size(i)
      * * @param input input Variable to get subset of (NUMERIC type) * @param begin Beginning index. Must be same length as rank of input array (INT type) - * @param size Size of the output array. Must be same length as rank of input array (INT type) + * @param size Size of the output array. Must be same length as rank of input array (INT type) * @return output Subset of the input (NUMERIC type) */ public INDArray slice(INDArray input, INDArray begin, INDArray size) { @@ -1804,50 +1952,54 @@ public class NDBase { /** * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x (NUMERIC type) - * @param keepDims - * @param dimensions (Size: AtLeast(min=0)) + * @param x (NUMERIC type) + * @param keepDims + * @param dimensions (Size: AtLeast(min=0)) * @return output (NUMERIC type) */ public INDArray squaredNorm(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("squaredNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, keepDims, dimensions)); } /** * Squared L2 norm: see norm2(String, SDVariable, boolean, int...)
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x (NUMERIC type) - * @param dimensions (Size: AtLeast(min=0)) + * @param x (NUMERIC type) + * @param dimensions (Size: AtLeast(min=0)) * @return output (NUMERIC type) */ public INDArray squaredNorm(INDArray x, int... dimensions) { NDValidation.validateNumerical("squaredNorm", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm(x, false, dimensions)); } /** - * Remove a single dimension of size 1.
      - * For example, if input has shape [a,b,1,c] then squeeze(input, 2) returns an array of shape [a,b,c]
      + * Remove a single dimension of size 1.
      For example, if input has shape [a,b,1,c] then + * squeeze(input, 2) returns an array of shape [a,b,c]
      * - * @param x Input variable (NUMERIC type) + * @param x Input variable (NUMERIC type) * @param axis Size 1 dimension to remove * @return output Output variable (NUMERIC type) */ @@ -1857,168 +2009,198 @@ public class NDBase { } /** - * Stack a set of N INDArray of rank X into one rank X+1 variable.
      - * If inputs have shape [a,b,c] then output has shape:
      - * axis = 0: [N,a,b,c]
      - * axis = 1: [a,N,b,c]
      - * axis = 2: [a,b,N,c]
      - * axis = 3: [a,b,c,N]
      - * see unstack(String[], SDVariable, int, int)
      + * Stack a set of N INDArray of rank X into one rank X+1 variable.
      If inputs have shape + * [a,b,c] then output has shape:
      axis = 0: [N,a,b,c]
      axis = 1: [a,N,b,c]
      axis = 2: + * [a,b,N,c]
      axis = 3: [a,b,c,N]
      see unstack(String[], SDVariable, int, int)
      * * @param values Input variables to stack. Must have the same shape for all inputs (NDARRAY type) - * @param axis Axis to stack on + * @param axis Axis to stack on * @return output Output variable (NDARRAY type) */ public INDArray stack(int axis, INDArray... values) { - Preconditions.checkArgument(values.length >= 1, "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); + Preconditions.checkArgument(values.length >= 1, + "values has incorrect size/length. Expected: values.length >= 1, got %s", values.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Stack(values, axis))[0]; } /** * Stardard deviation array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N + * (population stdev) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: + * remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray standardDeviation(INDArray x, boolean biasCorrected, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("standardDeviation", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, keepDims, + dimensions)); } /** * Stardard deviation array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N (population stdev) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample stdev). If false: divide by N + * (population stdev) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray standardDeviation(INDArray x, boolean biasCorrected, int... dimensions) { NDValidation.validateNumerical("standardDeviation", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation(x, biasCorrected, false, + dimensions)); } /** - * Get a subset of the specified input, by specifying the first element, last element, and the strides.
      - * For example, if input is:
      - * [a, b, c]
      - * [d, e, f]
      - * [g, h, i]
      - * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
      - * [b, c]
      - * [h, i]
      + * Get a subset of the specified input, by specifying the first element, last element, and the + * strides.
      For example, if input is:
      [a, b, c]
      [d, e, f]
      [g, h, i]
      then + * stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
      [b, + * c]
      [h, i]
      * - * @param in Variable to get subset of (NUMERIC type) - * @param begin Beginning index (Size: AtLeast(min=1)) - * @param end End index (Size: AtLeast(min=1)) - * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) - * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] is ignored, and a value of 0 is used instead for the beginning index for that dimension - * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is ignored, and a value of size(i)-1 is used instead for the end index for that dimension - * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is set, then other dimensions are inserted as required at the specified position - * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is inserted at this point - * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values are ignored, and a size 1 dimension is removed at this point. Note that begin/end/stride values must result in a size 1 output for these dimensions + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means + * take every second element. (Size: AtLeast(min=1)) + * @param beginMask Bit mask: If the ith bit is set to 1, then the value in the begin long[] + * is ignored, and a value of 0 is used instead for the beginning index for + * that dimension + * @param endMask Bit mask: If the ith bit is set to 1, then the value in the end long[] is + * ignored, and a value of size(i)-1 is used instead for the end index for + * that dimension + * @param ellipsisMask Bit mask: only one non-zero value is allowed here. If a non-zero value is + * set, then other dimensions are inserted as required at the specified + * position + * @param newAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values + * are ignored, and a size 1 dimension is inserted at this point + * @param shrinkAxisMask Bit mask: if the ith bit is set to 1, then the begin/end/stride values + * are ignored, and a size 1 dimension is removed at this point. Note that + * begin/end/stride values must result in a size 1 output for these + * dimensions * @return output A subset of the input array (NUMERIC type) */ public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long[] strides, int beginMask, int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { NDValidation.validateNumerical("stridedSlice", "in", in); - Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); - Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); - Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0]; + Preconditions.checkArgument(begin.length >= 1, + "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, + "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, + "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, beginMask, + endMask, ellipsisMask, newAxisMask, shrinkAxisMask))[0]; } /** - * Get a subset of the specified input, by specifying the first element, last element, and the strides.
      - * For example, if input is:
      - * [a, b, c]
      - * [d, e, f]
      - * [g, h, i]
      - * then stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
      - * [b, c]
      - * [h, i]
      + * Get a subset of the specified input, by specifying the first element, last element, and the + * strides.
      For example, if input is:
      [a, b, c]
      [d, e, f]
      [g, h, i]
      then + * stridedSlice(input, begin=[0,1], end=[2,2], strides=[2,1], all masks = 0) will return:
      [b, + * c]
      [h, i]
      * - * @param in Variable to get subset of (NUMERIC type) - * @param begin Beginning index (Size: AtLeast(min=1)) - * @param end End index (Size: AtLeast(min=1)) - * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take every second element. (Size: AtLeast(min=1)) + * @param in Variable to get subset of (NUMERIC type) + * @param begin Beginning index (Size: AtLeast(min=1)) + * @param end End index (Size: AtLeast(min=1)) + * @param strides Stride ("step size") for each dimension. For example, stride of 2 means take + * every second element. (Size: AtLeast(min=1)) * @return output A subset of the input array (NUMERIC type) */ public INDArray stridedSlice(INDArray in, long[] begin, long[] end, long... strides) { NDValidation.validateNumerical("stridedSlice", "in", in); - Preconditions.checkArgument(begin.length >= 1, "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); - Preconditions.checkArgument(end.length >= 1, "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); - Preconditions.checkArgument(strides.length >= 1, "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, 0))[0]; + Preconditions.checkArgument(begin.length >= 1, + "begin has incorrect size/length. Expected: begin.length >= 1, got %s", begin.length); + Preconditions.checkArgument(end.length >= 1, + "end has incorrect size/length. Expected: end.length >= 1, got %s", end.length); + Preconditions.checkArgument(strides.length >= 1, + "strides has incorrect size/length. Expected: strides.length >= 1, got %s", strides.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.shape.StridedSlice(in, begin, end, strides, 0, 0, 0, 0, + 0))[0]; } /** * Sum array reduction operation, optionally along specified dimensions.
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param keepDims If true: keep the dimensions that are reduced on (as length 1). False: remove + * the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray sum(INDArray x, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("sum", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, keepDims, dimensions)); } /** * Sum array reduction operation, optionally along specified dimensions.
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) - * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of rank (input rank) if keepdims = true (NUMERIC type) + * @param x Input variable (NUMERIC type) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) + * @return output reduced array of rank (input rank - num dimensions) if keepDims = false, or of + * rank (input rank) if keepdims = true (NUMERIC type) */ public INDArray sum(INDArray x, int... dimensions) { NDValidation.validateNumerical("sum", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, false, dimensions)); } /** - * Switch operation
      - * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
      + * Switch operation
      Predictate - if false, values are output to left (first) branch/output; if + * true, to right (second) branch/output
      * - * @param x Input variable (NDARRAY type) - * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if + * true, to right (second) branch/output (BOOL type) */ public INDArray[] switchOp(INDArray x, INDArray predicate) { NDValidation.validateBool("switchOp", "predicate", predicate); @@ -2028,29 +2210,35 @@ public class NDBase { /** * //TODO: Ops must be documented.
      * - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) - * @param transposeX Transpose x (first argument) - * @param transposeY Transpose y (second argument) - * @param transposeZ Transpose result array + * @param transposeX Transpose x (first argument) + * @param transposeY Transpose y (second argument) + * @param transposeZ Transpose result array * @return output Output variable (NUMERIC type) */ public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int[] dimensionsY, boolean transposeX, boolean transposeY, boolean transposeZ) { NDValidation.validateNumerical("tensorMmul", "x", x); NDValidation.validateNumerical("tensorMmul", "y", y); - Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); - Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, transposeX, transposeY, transposeZ))[0]; + Preconditions.checkArgument(dimensionsX.length >= 1, + "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", + dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, + "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", + dimensionsY.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, + transposeX, transposeY, transposeZ))[0]; } /** * //TODO: Ops must be documented.
      * - * @param x Input variable x (NUMERIC type) - * @param y Input variable y (NUMERIC type) + * @param x Input variable x (NUMERIC type) + * @param y Input variable y (NUMERIC type) * @param dimensionsX dimensions for first input array (x) (Size: AtLeast(min=1)) * @param dimensionsY dimensions for second input array (y) (Size: AtLeast(min=1)) * @return output Output variable (NUMERIC type) @@ -2058,25 +2246,25 @@ public class NDBase { public INDArray tensorMmul(INDArray x, INDArray y, int[] dimensionsX, int... dimensionsY) { NDValidation.validateNumerical("tensorMmul", "x", x); NDValidation.validateNumerical("tensorMmul", "y", y); - Preconditions.checkArgument(dimensionsX.length >= 1, "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", dimensionsX.length); - Preconditions.checkArgument(dimensionsY.length >= 1, "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", dimensionsY.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, false, false))[0]; + Preconditions.checkArgument(dimensionsX.length >= 1, + "dimensionsX has incorrect size/length. Expected: dimensionsX.length >= 1, got %s", + dimensionsX.length); + Preconditions.checkArgument(dimensionsY.length >= 1, + "dimensionsY has incorrect size/length. Expected: dimensionsY.length >= 1, got %s", + dimensionsY.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.reduce.TensorMmul(x, y, dimensionsX, dimensionsY, false, + false, false))[0]; } /** - * Repeat (tile) the input tensor the specified number of times.
      - * For example, if input is
      - * [1, 2]
      - * [3, 4]
      - * and repeat is [2, 3]
      - * then output is
      - * [1, 2, 1, 2, 1, 2]
      - * [3, 4, 3, 4, 3, 4]
      - * [1, 2, 1, 2, 1, 2]
      - * [3, 4, 3, 4, 3, 4]
      + * Repeat (tile) the input tensor the specified number of times.
      For example, if input is
      + * [1, 2]
      [3, 4]
      and repeat is [2, 3]
      then output is
      [1, 2, 1, 2, 1, 2]
      [3, 4, + * 3, 4, 3, 4]
      [1, 2, 1, 2, 1, 2]
      [3, 4, 3, 4, 3, 4]
      * - * @param x Input variable (NDARRAY type) - * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the input array (INT type) + * @param x Input variable (NDARRAY type) + * @param repeat Number of times to repeat in each axis. Must have length equal to the rank of the + * input array (INT type) * @return output Output variable (NDARRAY type) */ public INDArray tile(INDArray x, INDArray repeat) { @@ -2087,12 +2275,13 @@ public class NDBase { /** * see tile(String, SDVariable, int...)
      * - * @param x (NDARRAY type) - * @param repeat (Size: AtLeast(min=1)) + * @param x (NDARRAY type) + * @param repeat (Size: AtLeast(min=1)) * @return output (NDARRAY type) */ public INDArray tile(INDArray x, int... repeat) { - Preconditions.checkArgument(repeat.length >= 1, "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); + Preconditions.checkArgument(repeat.length >= 1, + "repeat has incorrect size/length. Expected: repeat.length >= 1, got %s", repeat.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Tile(x, repeat))[0]; } @@ -2107,122 +2296,126 @@ public class NDBase { } /** - * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [6, 9, 8] = [max(3,6), max(1,4,9), max(2,8)]
      + * Unsorted segment max operation. As per segmentMax(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [6, 9, 8] = [max(3,6), max(1,4,9), + * max(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentMax(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMax", "data", data); NDValidation.validateNumerical("unsortedSegmentMax", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
      + * Unsorted segment mean operation. As per segmentMean(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [4.5, 4.666, 5] = [mean(3,6), + * mean(1,4,9), mean(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentMean(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMean", "data", data); NDValidation.validateNumerical("unsortedSegmentMean", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [3, 1, 2] = [min(3,6), min(1,4,9), min(2,8)]
      + * Unsorted segment min operation. As per segmentMin(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [3, 1, 2] = [min(3,6), min(1,4,9), + * min(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentMin(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentMin", "data", data); NDValidation.validateNumerical("unsortedSegmentMin", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [4.5, 4.666, 5] = [mean(3,6), mean(1,4,9), mean(2,8)]
      + * Unsorted segment product operation. As per segmentProd(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [4.5, 4.666, 5] = [mean(3,6), + * mean(1,4,9), mean(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentProd(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentProd", "data", data); NDValidation.validateNumerical("unsortedSegmentProd", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values in each segment
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
      + * Unsorted segment sqrtN operation. Simply returns the sqrt of the count of the number of values + * in each segment
      If data = [1, 3, 2, 6, 4, 9, 8]
      segmentIds = [1, 0, 2, 0, 1, 1, + * 2]
      then output = [1.414, 1.732, 1.414] = [sqrt(2), sqrtN(3), sqrtN(2)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentSqrtN(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentSqrtN", "data", data); NDValidation.validateNumerical("unsortedSegmentSqrtN", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN(data, segmentIds, + numSegments))[0]; } /** - * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but without
      - * the requirement for the indices to be sorted.
      - * If data = [1, 3, 2, 6, 4, 9, 8]
      - * segmentIds = [1, 0, 2, 0, 1, 1, 2]
      - * then output = [9, 14, 10] = [sum(3,6), sum(1,4,9), sum(2,8)]
      + * Unsorted segment sum operation. As per segmentSum(String, SDVariable, SDVariable) but + * without
      the requirement for the indices to be sorted.
      If data = [1, 3, 2, 6, 4, 9, + * 8]
      segmentIds = [1, 0, 2, 0, 1, 1, 2]
      then output = [9, 14, 10] = [sum(3,6), + * sum(1,4,9), sum(2,8)]
      * - * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) - * @param segmentIds Variable for the segment IDs (NUMERIC type) + * @param data Data (variable) to perform unsorted segment max on (NUMERIC type) + * @param segmentIds Variable for the segment IDs (NUMERIC type) * @param numSegments Number of segments * @return output Unsorted segment output (NUMERIC type) */ public INDArray unsortedSegmentSum(INDArray data, INDArray segmentIds, int numSegments) { NDValidation.validateNumerical("unsortedSegmentSum", "data", data); NDValidation.validateNumerical("unsortedSegmentSum", "segmentIds", segmentIds); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, numSegments))[0]; + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum(data, segmentIds, + numSegments))[0]; } /** - * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified axis.
      - * If input has shape [a,b,c] then output has shape:
      - * axis = 0: [b,c]
      - * axis = 1: [a,c]
      - * axis = 2: [a,b]
      + * Unstack a variable of rank X into N rank X-1 variables by taking slices along the specified + * axis.
      If input has shape [a,b,c] then output has shape:
      axis = 0: [b,c]
      axis = 1: + * [a,c]
      axis = 2: [a,b]
      * * @param value Input variable to unstack (NDARRAY type) - * @param axis Axis to unstack on - * @param num Number of output variables + * @param axis Axis to unstack on + * @param num Number of output variables */ public INDArray[] unstack(INDArray value, int axis, int num) { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.Unstack(value, axis, num)); @@ -2230,50 +2423,61 @@ public class NDBase { /** * Variance array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) - * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N + * (population variance) + * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: + * remove the reduction dimensions + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray variance(INDArray x, boolean biasCorrected, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("variance", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, keepDims, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec( + new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, keepDims, + dimensions)); } /** * Variance array reduction operation, optionally along specified dimensions
      - * + *

      * Note that if keepDims = true, the output variable has the same rank as the input variable,
      - * with the reduced dimensions having size 1. This can be useful for later broadcast operations (such as subtracting
      - * the mean along a dimension).
      - * Example: if input has shape [a,b,c] and dimensions=[1] then output has shape:
      - * keepDims = true: [a,1,c]
      - * keepDims = false: [a,c]
      + * with the reduced dimensions having size 1. This can be useful for later broadcast operations + * (such as subtracting
      the mean along a dimension).
      Example: if input has shape [a,b,c] + * and dimensions=[1] then output has shape:
      keepDims = true: [a,1,c]
      keepDims = false: + * [a,c]
      * - * @param x Input variable (NUMERIC type) - * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N (population variance) - * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) + * @param x Input variable (NUMERIC type) + * @param biasCorrected If true: divide by (N-1) (i.e., sample variable). If false: divide by N + * (population variance) + * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array + * reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (NUMERIC type) */ public INDArray variance(INDArray x, boolean biasCorrected, int... dimensions) { NDValidation.validateNumerical("variance", "x", x); - Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, false, dimensions)); + Preconditions.checkArgument(dimensions.length >= 0, + "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", + dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.summarystats.Variance(x, biasCorrected, false, + dimensions)); } /** - * Return a variable of all 0s, with the same shape as the input variable. Note that this is dynamic:
      - * if the input shape changes in later execution, the returned variable's shape will also be updated
      + * Return a variable of all 0s, with the same shape as the input variable. Note that this is + * dynamic:
      if the input shape changes in later execution, the returned variable's shape will + * also be updated
      * * @param input Input (NUMERIC type) * @return output A new Variable with the same (dynamic) shape as the input (NUMERIC type) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/Indices.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/Indices.java index fb505698a..e9c521eae 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/Indices.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/indexing/Indices.java @@ -173,13 +173,13 @@ public class Indices { /** * Fill in the missing indices to be the * same length as the original shape. - *

      + *

      * Think of this as what fills in the indices for numpy or matlab: * Given a which is (4,3,2) in numpy: - *

      + *

      * a[1:3] is filled in by the rest * to give back the full slice - *

      + *

      * This algorithm fills in that delta * * @param shape the original shape @@ -244,7 +244,7 @@ public class Indices { /** * Calculate the shape for the given set of indices. - *

      + *

      * The shape is defined as (for each dimension) * the difference between the end index + 1 and * the begin index @@ -344,12 +344,12 @@ public class Indices { /** * Calculate the shape for the given set of indices and offsets. - *

      + *

      * The shape is defined as (for each dimension) * the difference between the end index + 1 and * the begin index - *

      - * If specified, this will check for whether any of the indices are >= to end - 1 + *

      + * If specified, this will check for whether any of the indices are >= to end - 1 * and if so, prune it down * * @param shape the original shape diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java index 7d6dbb16c..857f9e467 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaBeliefUpdater.java @@ -90,7 +90,6 @@ public class AdaBeliefUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return the gradient */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java index 83a740be8..a9a2b44e5 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java @@ -32,78 +32,84 @@ import java.util.Map; @Data public class AdaDeltaUpdater implements GradientUpdater { - public static final String MSG_STATE = "msg"; - public static final String MSDX_STATE = "msdx"; - private final AdaDelta config; + public static final String MSG_STATE = "msg"; + public static final String MSDX_STATE = "msdx"; - private INDArray msg; //E[g^2]_t by arxiv paper, algorithm 1 - private INDArray msdx; //E[delta x^2]_t by arxiv paper, algorithm 1 + private final AdaDelta config; + + private INDArray msg; //E[g^2]_t by arxiv paper, algorithm 1 + private INDArray msdx; //E[delta x^2]_t by arxiv paper, algorithm 1 + public AdaDeltaUpdater(AdaDelta config) { + this.config = config; + } - public AdaDeltaUpdater(AdaDelta config) { - this.config = config; + @Override + public void setState(Map stateMap, boolean initialize) { + if (!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE) + || stateMap.size() != 2) { + throw new IllegalStateException( + "State map should contain only keys [" + MSG_STATE + "," + MSDX_STATE + "] but has keys " + + stateMap.keySet()); } + this.msg = stateMap.get(MSG_STATE); + this.msdx = stateMap.get(MSDX_STATE); + } - @Override - public void setState(Map stateMap, boolean initialize) { - if(!stateMap.containsKey(MSG_STATE) || !stateMap.containsKey(MSDX_STATE) || stateMap.size() != 2){ - throw new IllegalStateException("State map should contain only keys [" + MSG_STATE + "," + MSDX_STATE + "] but has keys " + stateMap.keySet()); - } - this.msg = stateMap.get(MSG_STATE); - this.msdx = stateMap.get(MSDX_STATE); - } + @Override + public Map getState() { + Map r = new HashMap<>(); + r.put(MSG_STATE, msg); + r.put(MSDX_STATE, msdx); + return r; + } - @Override - public Map getState() { - Map r = new HashMap<>(); - r.put(MSG_STATE, msg); - r.put(MSDX_STATE, msdx); - return r; - } + @Override + public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, + boolean initialize) { + if (!viewArray.isRowVector()) { + throw new IllegalArgumentException("Invalid input: expect row vector input"); + } + if (initialize) { + viewArray.assign(0); + } + long length = viewArray.length(); + this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); + this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); - @Override - public void setStateViewArray(INDArray viewArray, long[] gradientShape, char gradientOrder, boolean initialize) { - if (!viewArray.isRowVector()) - throw new IllegalArgumentException("Invalid input: expect row vector input"); - if (initialize) - viewArray.assign(0); - long length = viewArray.length(); - this.msg = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, length / 2)); - this.msdx = viewArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(length / 2, length)); + //Reshape to match the expected shape of the input gradient arrays + this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f'); + this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f'); + if (msg == null || msdx == null) { + throw new IllegalStateException("Could not correctly reshape gradient view arrays"); + } + } - //Reshape to match the expected shape of the input gradient arrays - this.msg = Shape.newShapeNoCopy(this.msg, gradientShape, gradientOrder == 'f'); - this.msdx = Shape.newShapeNoCopy(this.msdx, gradientShape, gradientOrder == 'f'); - if (msg == null || msdx == null) - throw new IllegalStateException("Could not correctly reshape gradient view arrays"); - } + /** + * Get the updated gradient for the given gradient and also update the state of ada delta. + * + * @param gradient the gradient to get the updated gradient for + * @param iteration + */ + @Override + public void applyUpdater(INDArray gradient, int iteration, int epoch) { + if (msg == null || msdx == null) { + throw new IllegalStateException("Updater has not been initialized with view state"); + } - /** - * Get the updated gradient for the given gradient - * and also update the state of ada delta. - * - * @param gradient the gradient to get the - * updated gradient for - * @param iteration - * @return the update gradient - */ - @Override - public void applyUpdater(INDArray gradient, int iteration, int epoch) { - if (msg == null || msdx == null) - throw new IllegalStateException("Updater has not been initialized with view state"); + double rho = config.getRho(); + double epsilon = config.getEpsilon(); - double rho = config.getRho(); - double epsilon = config.getEpsilon(); + //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf + //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t + //Calculate update: + //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t + //Note: negative is applied in the DL4J step function: params -= update rather than params += update + //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 - //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf - //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t - //Calculate update: - //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t - //Note: negative is applied in the DL4J step function: params -= update rather than params += update - //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 - - Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon)); - } + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, + epsilon)); + } } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java index 7f7d27593..704075e13 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java @@ -92,7 +92,6 @@ public class AdaMaxUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return the gradient */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java index 6ad7255af..996c97268 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java @@ -93,7 +93,6 @@ public class AdamUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return the gradient */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/GradientUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/GradientUpdater.java index 1cfca0e7d..e32b45f66 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/GradientUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/GradientUpdater.java @@ -48,7 +48,6 @@ public interface GradientUpdater { * * @param gradient the gradient to modify * @param iteration - * @return the modified gradient */ void applyUpdater(INDArray gradient, int iteration, int epoch); } diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java index 1bc3adcf5..cdb2e39ec 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java @@ -92,7 +92,6 @@ public class NadamUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return the gradient */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java index 891bec8a5..d580d8833 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java @@ -77,7 +77,6 @@ public class NesterovsUpdater implements GradientUpdater { * * @param gradient the gradient to get the update for * @param iteration - * @return */ @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java index 03ec92701..dfd953b4f 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java @@ -152,12 +152,12 @@ public class MultiDimensionalMap implements Serializable { /** * Returns the value to which the specified key is mapped, * or {@code null} if this map contains no mapping for the key. - *

      + *

      *

      More formally, if this map contains a mapping from a key * {@code k} to a value {@code v} such that {@code (key==null ? k==null : * key.equals(k))}, then this method returns {@code v}; otherwise * it returns {@code null}. (There can be at most one such mapping.) - *

      + *

      *

      If this map permits null values, then a return value of * {@code null} does not necessarily indicate that the map * contains no mapping for the key; it's also possible that the map @@ -214,15 +214,15 @@ public class MultiDimensionalMap implements Serializable { * from key k to value v such that * (key==null ? k==null : key.equals(k)), that mapping * is removed. (The map can contain at most one such mapping.) - *

      + *

      *

      Returns the value to which this map previously associated the key, * or null if the map contained no mapping for the key. - *

      + *

      *

      If this map permits null values, then a return value of * null does not necessarily indicate that the map * contained no mapping for the key; it's also possible that the map * explicitly mapped the key to null. - *

      + *

      *

      The map will not contain a mapping for the specified key once the * call returns. * diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java index d16c190cb..fdd2bff15 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java @@ -108,12 +108,12 @@ public class MultiDimensionalSet implements Set> { * If this applyTransformToDestination makes any guarantees as to what order its elements * are returned by its iterator, this method must return the * elements in the same order. - *

      + *

      *

      The returned array will be "safe" in that no references to it * are maintained by this applyTransformToDestination. (In other words, this method must * allocate a new array even if this applyTransformToDestination is backed by an array). * The caller is thus free to modify the returned array. - *

      + *

      *

      This method acts as bridge between array-based and collection-based * APIs. * @@ -130,27 +130,27 @@ public class MultiDimensionalSet implements Set> { * If the applyTransformToDestination fits in the specified array, it is returned therein. * Otherwise, a new array is allocated with the runtime type of the * specified array and the size of this applyTransformToDestination. - *

      + *

      *

      If this applyTransformToDestination fits in the specified array with room to spare * (i.e., the array has more elements than this applyTransformToDestination), the element in * the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to * null. (This is useful in determining the length of this * applyTransformToDestination only if the caller knows that this applyTransformToDestination does not contain * any null elements.) - *

      + *

      *

      If this applyTransformToDestination makes any guarantees as to what order its elements * are returned by its iterator, this method must return the elements * in the same order. - *

      + *

      *

      Like the {@link #toArray()} method, this method acts as bridge between * array-based and collection-based APIs. Further, this method allows * precise control over the runtime type of the output array, and may, * under certain circumstances, be used to save allocation costs. - *

      + *

      *

      Suppose x is a applyTransformToDestination known to contain only strings. * The following code can be used to dump the applyTransformToDestination into a newly allocated * array of String: - *

      + *

      *

            *     String[] y = x.toArray(new String[0]);
      * @@ -181,7 +181,7 @@ public class MultiDimensionalSet implements Set> { * unchanged and returns false. In combination with the * restriction on constructors, this ensures that sets never contain * duplicate elements. - *

      + *

      *

      The stipulation above does not imply that sets must accept all * elements; sets may refuse to add any particular element, including * null, and throw an exception, as described in the diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java index 13780f3a6..240031440 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/util/ArrayUtil.java @@ -204,9 +204,9 @@ public class ArrayUtil { /** * Credit to mikio braun from jblas - *

      + *

      * Create a random permutation of the numbers 0, ..., size - 1. - *

      + *

      * see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145 */ public static int[] randomPermutation(int size) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java index eb59a2c5f..dfff491e4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/HelperUtils.java @@ -64,7 +64,7 @@ public class HelperUtils { if("CUDA".equalsIgnoreCase(backend) && cudnnHelperClassName != null && !cudnnHelperClassName.isEmpty()) { if(DL4JClassLoading.loadClassByName(cudnnHelperClassName) != null) { log.debug("Attempting to initialize cudnn helper {}",cudnnHelperClassName); - helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( cudnnHelperClassName, (Class) layerHelperSuperClass, new Object[]{arguments}); @@ -76,7 +76,7 @@ public class HelperUtils { ClassLoader classLoader = DL4JClassLoading.getDl4jClassloader(); DL4JClassLoading.setDl4jClassloaderFromClass(layerHelperSuperClass); try { - helperRet = (LayerHelper) DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( cudnnHelperClassName, (Class) layerHelperSuperClass, arguments); @@ -99,7 +99,7 @@ public class HelperUtils { } } else if("CPU".equalsIgnoreCase(backend) && oneDnnClassName != null && !oneDnnClassName.isEmpty()) { - helperRet = DL4JClassLoading.createNewInstance( + helperRet = DL4JClassLoading.createNewInstance( oneDnnClassName, arguments); log.trace("Created oneDNN helper: {}, layer {}", oneDnnClassName,layerName); diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 2e587fa8e..b0e7e9b81 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -3,11 +3,14 @@ plugins { id 'maven-publish' } +/* configurations.archives.artifacts.with { archives -> + archives.each { println(it.name) } } +*/ dependencies { //Todo clean this @@ -19,7 +22,7 @@ dependencies { //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + // api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 0a638ff15..1d083f0ce 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -212,7 +212,7 @@ tasks.withType(org.bytedeco.gradle.javacpp.BuildTask) { // Disable the standard javacpp generated tasks and use own // versions below. This allows to build for each variant [javacppBuildParser, javacppBuildCommand, javacppCompileJava, javacppBuildCompiler].each { - it.enabled false; + it.enabled false } chipList.each { thisChip -> diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java index 03ec92701..dfd953b4f 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalMap.java @@ -152,12 +152,12 @@ public class MultiDimensionalMap implements Serializable { /** * Returns the value to which the specified key is mapped, * or {@code null} if this map contains no mapping for the key. - *

      + *

      *

      More formally, if this map contains a mapping from a key * {@code k} to a value {@code v} such that {@code (key==null ? k==null : * key.equals(k))}, then this method returns {@code v}; otherwise * it returns {@code null}. (There can be at most one such mapping.) - *

      + *

      *

      If this map permits null values, then a return value of * {@code null} does not necessarily indicate that the map * contains no mapping for the key; it's also possible that the map @@ -214,15 +214,15 @@ public class MultiDimensionalMap implements Serializable { * from key k to value v such that * (key==null ? k==null : key.equals(k)), that mapping * is removed. (The map can contain at most one such mapping.) - *

      + *

      *

      Returns the value to which this map previously associated the key, * or null if the map contained no mapping for the key. - *

      + *

      *

      If this map permits null values, then a return value of * null does not necessarily indicate that the map * contained no mapping for the key; it's also possible that the map * explicitly mapped the key to null. - *

      + *

      *

      The map will not contain a mapping for the specified key once the * call returns. * diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java index d16c190cb..fdd2bff15 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/collection/MultiDimensionalSet.java @@ -108,12 +108,12 @@ public class MultiDimensionalSet implements Set> { * If this applyTransformToDestination makes any guarantees as to what order its elements * are returned by its iterator, this method must return the * elements in the same order. - *

      + *

      *

      The returned array will be "safe" in that no references to it * are maintained by this applyTransformToDestination. (In other words, this method must * allocate a new array even if this applyTransformToDestination is backed by an array). * The caller is thus free to modify the returned array. - *

      + *

      *

      This method acts as bridge between array-based and collection-based * APIs. * @@ -130,27 +130,27 @@ public class MultiDimensionalSet implements Set> { * If the applyTransformToDestination fits in the specified array, it is returned therein. * Otherwise, a new array is allocated with the runtime type of the * specified array and the size of this applyTransformToDestination. - *

      + *

      *

      If this applyTransformToDestination fits in the specified array with room to spare * (i.e., the array has more elements than this applyTransformToDestination), the element in * the array immediately following the end of the applyTransformToDestination is applyTransformToDestination to * null. (This is useful in determining the length of this * applyTransformToDestination only if the caller knows that this applyTransformToDestination does not contain * any null elements.) - *

      + *

      *

      If this applyTransformToDestination makes any guarantees as to what order its elements * are returned by its iterator, this method must return the elements * in the same order. - *

      + *

      *

      Like the {@link #toArray()} method, this method acts as bridge between * array-based and collection-based APIs. Further, this method allows * precise control over the runtime type of the output array, and may, * under certain circumstances, be used to save allocation costs. - *

      + *

      *

      Suppose x is a applyTransformToDestination known to contain only strings. * The following code can be used to dump the applyTransformToDestination into a newly allocated * array of String: - *

      + *

      *

            *     String[] y = x.toArray(new String[0]);
      * @@ -181,7 +181,7 @@ public class MultiDimensionalSet implements Set> { * unchanged and returns false. In combination with the * restriction on constructors, this ensures that sets never contain * duplicate elements. - *

      + *

      *

      The stipulation above does not imply that sets must accept all * elements; sets may refuse to add any particular element, including * null, and throw an exception, as described in the diff --git a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java index 13780f3a6..240031440 100644 --- a/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java +++ b/cavis-nd4j/cavis-nd4j-common/src/main/java/org/nd4j/common/util/ArrayUtil.java @@ -204,9 +204,9 @@ public class ArrayUtil { /** * Credit to mikio braun from jblas - *

      + *

      * Create a random permutation of the numbers 0, ..., size - 1. - *

      + *

      * see Algorithm P, D.E. Knuth: The Art of Computer Programming, Vol. 2, p. 145 */ public static int[] randomPermutation(int size) { diff --git a/settings.gradle b/settings.gradle index 2e4e68cce..aaf58f336 100644 --- a/settings.gradle +++ b/settings.gradle @@ -148,7 +148,6 @@ include ':cavis-ui:cavis-ui-standalone' include ':cavis-ui:cavis-ui-vertx' include ':cavis-zoo' include ':cavis-zoo:cavis-zoo-models' - include ':brutex-extended-tests' include ':cavis-full' From 582cbdf67d98d819c05df415801cdf40735db92a Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 21 Oct 2022 22:03:01 +0200 Subject: [PATCH 058/126] Make Writable and Record first class citizen of the datavec.api Signed-off-by: brian --- .gitignore | 1 + .../test/java/net/brutex/spark/BrianTest.java | 3 +- .../java/net/brutex/spark/BrianTest2.java | 2 +- .../src/main/java/org/datavec/api/Record.java | 55 +++++++++++++++++++ .../datavec/api/{writable => }/Writable.java | 39 +++++++------ .../org/datavec/api/conf/Configuration.java | 2 +- .../api/formats/input/InputFormat.java | 2 +- .../datavec/api/io/WritableComparable.java | 2 +- .../datavec/api/io/WritableComparator.java | 2 +- .../org/datavec/api/io/WritableConverter.java | 2 +- .../org/datavec/api/io/WritableUtils.java | 2 +- .../converters/DoubleWritableConverter.java | 1 + .../io/converters/FloatWritableConverter.java | 1 + .../io/converters/LabelWriterConverter.java | 2 +- .../io/converters/SelfWritableConverter.java | 2 +- .../api/io/filters/BalancedPathFilter.java | 2 +- .../io/labels/ParentPathLabelGenerator.java | 2 +- .../api/io/labels/PathLabelGenerator.java | 2 +- .../io/labels/PathMultiLabelGenerator.java | 2 +- .../io/labels/PatternPathLabelGenerator.java | 2 +- .../java/org/datavec/api/records/Record.java | 55 ------------------- .../datavec/api/records/SequenceRecord.java | 4 +- .../org/datavec/api/records/impl/Record.java | 4 +- .../api/records/impl/SequenceRecord.java | 2 +- .../api/records/mapper/RecordMapper.java | 2 +- .../api/records/reader/BaseRecordReader.java | 2 +- .../api/records/reader/RecordReader.java | 4 +- .../records/reader/SequenceRecordReader.java | 4 +- .../reader/impl/ComposableRecordReader.java | 4 +- .../impl/ConcatenatingRecordReader.java | 4 +- .../records/reader/impl/FileRecordReader.java | 4 +- .../records/reader/impl/LineRecordReader.java | 4 +- .../collection/CollectionRecordReader.java | 4 +- .../CollectionSequenceRecordReader.java | 4 +- .../collection/ListStringRecordReader.java | 4 +- .../impl/csv/CSVLineSequenceRecordReader.java | 4 +- .../csv/CSVMultiSequenceRecordReader.java | 2 +- .../csv/CSVNLinesSequenceRecordReader.java | 4 +- .../reader/impl/csv/CSVRecordReader.java | 4 +- .../reader/impl/csv/CSVRegexRecordReader.java | 2 +- .../impl/csv/CSVSequenceRecordReader.java | 2 +- .../CSVVariableSlidingWindowRecordReader.java | 4 +- .../impl/filebatch/FileBatchRecordReader.java | 4 +- .../FileBatchSequenceRecordReader.java | 4 +- .../impl/inmemory/InMemoryRecordReader.java | 4 +- .../InMemorySequenceRecordReader.java | 4 +- .../reader/impl/jackson/FieldSelection.java | 2 +- .../impl/jackson/JacksonLineRecordReader.java | 2 +- .../JacksonLineSequenceRecordReader.java | 2 +- .../impl/jackson/JacksonReaderUtils.java | 2 +- .../impl/jackson/JacksonRecordReader.java | 4 +- .../reader/impl/misc/MatlabRecordReader.java | 2 +- .../impl/misc/SVMLightRecordReader.java | 4 +- .../impl/regex/RegexLineRecordReader.java | 4 +- .../impl/regex/RegexSequenceRecordReader.java | 2 +- .../TransformProcessRecordReader.java | 4 +- .../TransformProcessSequenceRecordReader.java | 4 +- .../api/records/writer/RecordWriter.java | 2 +- .../records/writer/SequenceRecordWriter.java | 2 +- .../records/writer/impl/LineRecordWriter.java | 2 +- .../writer/impl/csv/CSVRecordWriter.java | 2 +- .../writer/impl/misc/MatlabRecordWriter.java | 2 +- .../impl/misc/SVMLightRecordWriter.java | 2 +- .../util/TimeSeriesWritableUtils.java | 2 +- .../org/datavec/api/transform/Transform.java | 2 +- .../api/transform/TransformProcess.java | 1 + .../transform/analysis/AnalysisCounter.java | 2 +- .../counter/BytesAnalysisCounter.java | 2 +- .../counter/CategoricalAnalysisCounter.java | 2 +- .../counter/DoubleAnalysisCounter.java | 2 +- .../counter/IntegerAnalysisCounter.java | 2 +- .../analysis/counter/LongAnalysisCounter.java | 2 +- .../counter/NDArrayAnalysisCounter.java | 2 +- .../counter/StringAnalysisCounter.java | 2 +- .../CategoricalHistogramCounter.java | 2 +- .../histogram/DoubleHistogramCounter.java | 2 +- .../analysis/histogram/HistogramCounter.java | 2 +- .../histogram/NDArrayHistogramCounter.java | 2 +- .../histogram/StringHistogramCounter.java | 2 +- .../quality/QualityAnalysisAddFunction.java | 2 +- .../quality/QualityAnalysisState.java | 2 +- .../bytes/BytesQualityAnalysisState.java | 2 +- .../CategoricalQualityAddFunction.java | 2 +- .../CategoricalQualityAnalysisState.java | 2 +- .../integer/IntegerQualityAddFunction.java | 2 +- .../integer/IntegerQualityAnalysisState.java | 2 +- .../quality/longq/LongQualityAddFunction.java | 2 +- .../longq/LongQualityAnalysisState.java | 2 +- .../quality/real/RealQualityAddFunction.java | 2 +- .../real/RealQualityAnalysisState.java | 2 +- .../string/StringQualityAddFunction.java | 2 +- .../string/StringQualityAnalysisState.java | 2 +- .../quality/time/TimeQualityAddFunction.java | 2 +- .../time/TimeQualityAnalysisState.java | 2 +- .../transform/condition/BooleanCondition.java | 2 +- .../api/transform/condition/Condition.java | 2 +- .../condition/column/BaseColumnCondition.java | 2 +- .../column/BooleanColumnCondition.java | 2 +- .../column/CategoricalColumnCondition.java | 2 +- .../condition/column/ColumnCondition.java | 2 +- .../column/DoubleColumnCondition.java | 2 +- .../column/FloatColumnCondition.java | 2 +- .../column/InfiniteColumnCondition.java | 2 +- .../column/IntegerColumnCondition.java | 2 +- .../column/InvalidValueColumnCondition.java | 1 + .../condition/column/LongColumnCondition.java | 2 +- .../condition/column/NaNColumnCondition.java | 2 +- .../column/NullWritableColumnCondition.java | 2 +- .../column/StringColumnCondition.java | 2 +- .../condition/column/TimeColumnCondition.java | 2 +- .../column/TrivialColumnCondition.java | 2 +- .../sequence/SequenceLengthCondition.java | 2 +- .../string/StringRegexColumnCondition.java | 2 +- .../transform/filter/BaseColumnFilter.java | 2 +- .../api/transform/filter/ConditionFilter.java | 2 +- .../datavec/api/transform/filter/Filter.java | 2 +- .../transform/filter/FilterInvalidValues.java | 1 + .../transform/filter/InvalidNumColumns.java | 2 +- .../org/datavec/api/transform/join/Join.java | 2 +- .../transform/metadata/BinaryMetaData.java | 2 +- .../transform/metadata/BooleanMetaData.java | 2 +- .../metadata/CategoricalMetaData.java | 2 +- .../transform/metadata/ColumnMetaData.java | 2 +- .../transform/metadata/DoubleMetaData.java | 2 +- .../api/transform/metadata/FloatMetaData.java | 2 +- .../transform/metadata/IntegerMetaData.java | 2 +- .../api/transform/metadata/LongMetaData.java | 2 +- .../transform/metadata/NDArrayMetaData.java | 2 +- .../transform/metadata/StringMetaData.java | 2 +- .../api/transform/metadata/TimeMetaData.java | 2 +- .../NDArrayColumnsMathOpTransform.java | 2 +- .../ndarray/NDArrayDistanceTransform.java | 2 +- .../ndarray/NDArrayMathFunctionTransform.java | 2 +- .../ndarray/NDArrayScalarOpTransform.java | 2 +- .../transform/ops/AggregableCheckingOp.java | 2 +- .../api/transform/ops/AggregableMultiOp.java | 2 +- .../api/transform/ops/AggregatorImpls.java | 2 +- .../api/transform/ops/ByteWritableOp.java | 2 +- .../ops/DispatchWithConditionOp.java | 2 +- .../api/transform/ops/DoubleWritableOp.java | 2 +- .../api/transform/ops/FloatWritableOp.java | 2 +- .../api/transform/ops/IntWritableOp.java | 2 +- .../api/transform/ops/LongWritableOp.java | 2 +- .../transform/ops/StringAggregatorImpls.java | 2 +- .../api/transform/ops/StringWritableOp.java | 2 +- .../reduce/AggregableColumnReduction.java | 2 +- .../reduce/AggregableReductionUtils.java | 2 +- .../api/transform/reduce/ColumnReduction.java | 2 +- .../transform/reduce/IAssociativeReducer.java | 2 +- .../datavec/api/transform/reduce/Reducer.java | 2 +- .../impl/GeographicMidpointReduction.java | 2 +- .../datavec/api/transform/schema/Schema.java | 1 + .../api/transform/schema/SequenceSchema.java | 1 + .../schema/conversion/TypeConversion.java | 2 +- .../sequence/ReduceSequenceTransform.java | 2 +- .../sequence/SequenceComparator.java | 2 +- .../api/transform/sequence/SequenceSplit.java | 2 +- .../comparator/BaseColumnComparator.java | 2 +- .../comparator/NumericalColumnComparator.java | 2 +- .../sequence/comparator/StringComparator.java | 2 +- .../BaseSequenceExpansionTransform.java | 2 +- .../sequence/merge/SequenceMerge.java | 2 +- .../split/SequenceSplitTimeSeparation.java | 2 +- .../split/SplitMaxLengthSequence.java | 2 +- .../trim/SequenceTrimToLengthTransform.java | 2 +- .../sequence/trim/SequenceTrimTransform.java | 2 +- .../window/OverlappingTimeWindowFunction.java | 2 +- .../ReduceSequenceByWindowTransform.java | 2 +- .../sequence/window/TimeWindowFunction.java | 2 +- .../sequence/window/WindowFunction.java | 2 +- .../serde/legacy/LegacyJsonFormat.java | 1 + .../stringreduce/IStringReducer.java | 2 +- .../transform/stringreduce/StringReducer.java | 3 +- .../transform/BaseColumnTransform.java | 2 +- .../transform/BaseColumnsMathOpTransform.java | 5 +- .../transform/transform/BaseTransform.java | 2 +- .../CategoricalToIntegerTransform.java | 2 +- .../CategoricalToOneHotTransform.java | 2 +- .../categorical/FirstDigitTransform.java | 2 +- .../IntegerToCategoricalTransform.java | 2 +- .../transform/categorical/PivotTransform.java | 1 + .../StringToCategoricalTransform.java | 2 +- .../column/AddConstantColumnTransform.java | 2 +- .../column/DuplicateColumnsTransform.java | 2 +- .../RemoveAllColumnsExceptForTransform.java | 2 +- .../column/RemoveColumnsTransform.java | 2 +- .../column/RenameColumnsTransform.java | 2 +- .../column/ReorderColumnsTransform.java | 2 +- .../ConditionalCopyValueTransform.java | 2 +- .../ConditionalReplaceValueTransform.java | 2 +- ...ionalReplaceValueTransformWithDefault.java | 2 +- .../doubletransform/BaseDoubleTransform.java | 2 +- .../doubletransform/ConvertToDouble.java | 2 +- .../DoubleColumnsMathOpTransform.java | 2 +- .../DoubleMathFunctionTransform.java | 2 +- .../DoubleMathOpTransform.java | 2 +- .../doubletransform/Log2Normalizer.java | 2 +- .../doubletransform/MinMaxNormalizer.java | 2 +- .../StandardizeNormalizer.java | 2 +- .../SubtractMeanNormalizer.java | 2 +- .../floattransform/BaseFloatTransform.java | 2 +- .../floattransform/ConvertToFloat.java | 2 +- .../FloatColumnsMathOpTransform.java | 3 +- .../FloatMathFunctionTransform.java | 2 +- .../floattransform/FloatMathOpTransform.java | 3 +- .../integer/BaseIntegerTransform.java | 2 +- .../transform/integer/ConvertToInteger.java | 2 +- .../IntegerColumnsMathOpTransform.java | 3 +- .../integer/IntegerMathOpTransform.java | 2 +- .../integer/IntegerToOneHotTransform.java | 2 +- ...ReplaceEmptyIntegerWithValueTransform.java | 2 +- .../ReplaceInvalidWithIntegerTransform.java | 2 +- .../LongColumnsMathOpTransform.java | 3 +- .../longtransform/LongMathOpTransform.java | 2 +- .../nlp/TextToCharacterIndexTransform.java | 2 +- .../nlp/TextToTermIndexSequenceTransform.java | 2 +- .../transform/parse/ParseDoubleTransform.java | 2 +- .../sequence/SequenceDifferenceTransform.java | 1 + .../SequenceMovingWindowReduceTransform.java | 2 +- .../sequence/SequenceOffsetTransform.java | 2 +- .../string/AppendStringColumnTransform.java | 2 +- .../transform/string/BaseStringTransform.java | 2 +- .../string/ChangeCaseStringTransform.java | 2 +- .../string/ConcatenateStringColumns.java | 2 +- .../transform/string/ConvertToString.java | 2 +- .../MapAllStringsExceptListTransform.java | 2 +- .../string/RemoveWhiteSpaceTransform.java | 2 +- .../string/ReplaceEmptyStringTransform.java | 2 +- .../string/ReplaceStringTransform.java | 2 +- .../StringListToCategoricalSetTransform.java | 2 +- .../StringListToCountsNDArrayTransform.java | 2 +- .../transform/string/StringMapTransform.java | 2 +- .../time/DeriveColumnsFromTimeTransform.java | 2 +- .../transform/time/StringToTimeTransform.java | 2 +- .../transform/time/TimeMathOpTransform.java | 2 +- .../transform/ui/HtmlSequencePlotting.java | 2 +- .../org/datavec/api/util/RecordUtils.java | 2 +- .../api/util/ndarray/RecordConverter.java | 1 + .../org/datavec/api/vector/Vectorizer.java | 2 +- .../datavec/api/writable/ArrayWritable.java | 2 + .../datavec/api/writable/ByteWritable.java | 1 + .../datavec/api/writable/BytesWritable.java | 1 + .../datavec/api/writable/DoubleWritable.java | 1 + .../datavec/api/writable/FloatWritable.java | 1 + .../org/datavec/api/writable/IntWritable.java | 1 + .../datavec/api/writable/LongWritable.java | 1 + .../java/org/datavec/api/writable/Text.java | 1 + .../api/writable/UnsafeWritableInjector.java | 1 + .../datavec/api/writable/WritableFactory.java | 1 + .../datavec/api/writable/WritableType.java | 2 + ...AbstractTimeSeriesWritableRecordBatch.java | 2 +- .../batch/AbstractWritableRecordBatch.java | 2 +- .../writable/batch/NDArrayRecordBatch.java | 2 +- .../api/writable/comparator/Comparators.java | 2 +- .../comparator/DoubleWritableComparator.java | 2 +- .../comparator/FloatWritableComparator.java | 2 +- .../comparator/IntWritableComparator.java | 2 +- .../comparator/LongWritableComparator.java | 2 +- .../comparator/TextWritableComparator.java | 2 +- .../comparator/WritableComparator.java | 2 +- .../impl/CSVLineSequenceRecordReaderTest.java | 2 +- .../CSVMultiSequenceRecordReaderTest.java | 2 +- .../CSVNLinesSequenceRecordReaderTest.java | 2 +- .../reader/impl/CSVRecordReaderTest.java | 4 +- .../impl/CSVSequenceRecordReaderTest.java | 2 +- ...VariableSlidingWindowRecordReaderTest.java | 2 +- .../impl/FileBatchRecordReaderTest.java | 2 +- .../reader/impl/FileRecordReaderTest.java | 4 +- .../impl/JacksonLineRecordReaderTest.java | 2 +- .../reader/impl/JacksonRecordReaderTest.java | 4 +- .../reader/impl/LibSvmRecordReaderTest.java | 2 +- .../records/reader/impl/LineReaderTest.java | 4 +- .../reader/impl/RegexRecordReaderTest.java | 4 +- .../reader/impl/SVMLightRecordReaderTest.java | 4 +- .../impl/TestCollectionRecordReaders.java | 2 +- .../reader/impl/TestSerialization.java | 2 +- .../TransformProcessRecordReaderTests.java | 2 +- .../writer/impl/CSVRecordWriterTest.java | 2 +- .../writer/impl/LibSvmRecordWriterTest.java | 2 +- .../writer/impl/SVMLightRecordWriterTest.java | 1 + .../api/split/TestStreamInputSplit.java | 2 +- .../api/transform/TestTransformProcess.java | 2 +- .../transform/condition/TestConditions.java | 1 + .../api/transform/filter/TestFilters.java | 2 +- .../datavec/api/transform/join/TestJoin.java | 2 +- .../transform/ops/AggregableMultiOpTest.java | 2 +- .../api/transform/ops/DispatchOpTest.java | 2 +- .../transform/reduce/TestMultiOpReduce.java | 1 + .../api/transform/reduce/TestReductions.java | 2 +- .../TestReduceSequenceByWindowFunction.java | 2 +- .../transform/sequence/TestSequenceSplit.java | 2 +- .../sequence/TestWindowFunctions.java | 2 +- .../serde/testClasses/CustomCondition.java | 2 +- .../serde/testClasses/CustomFilter.java | 2 +- .../serde/testClasses/CustomTransform.java | 2 +- .../transform/stringreduce/TestReduce.java | 2 +- .../transform/transform/TestTransforms.java | 2 +- .../TestNDArrayWritableTransforms.java | 2 +- .../parse/ParseDoubleTransformTest.java | 3 +- .../org/datavec/api/transform/ui/TestUI.java | 2 +- .../datavec/api/util/TimeSeriesUtilsTest.java | 2 +- .../api/writable/RecordConverterTest.java | 1 + .../datavec/api/writable/WritableTest.java | 1 + .../org/datavec/arrow/ArrowConverter.java | 1 + .../arrow/recordreader/ArrowRecord.java | 4 +- .../arrow/recordreader/ArrowRecordReader.java | 4 +- .../arrow/recordreader/ArrowRecordWriter.java | 3 +- .../ArrowWritableRecordBatch.java | 2 +- .../ArrowWritableRecordTimeSeriesBatch.java | 2 +- .../org/datavec/arrow/ArrowConverterTest.java | 3 +- .../org/datavec/arrow/RecordMapperTest.java | 2 +- ...rowWritableRecordTimeSeriesBatchTests.java | 2 +- .../recordreader/BaseAudioRecordReader.java | 4 +- .../recordreader/NativeAudioRecordReader.java | 2 +- .../recordreader/WavFileRecordReader.java | 2 +- .../org/datavec/audio/AudioReaderTest.java | 2 +- .../codec/reader/BaseCodecRecordReader.java | 2 +- .../codec/reader/CodecRecordReader.java | 2 +- .../codec/reader/NativeCodecRecordReader.java | 2 +- .../datavec/codec/reader/CodecReaderTest.java | 2 +- .../datavec/poi/excel/ExcelRecordReader.java | 4 +- .../datavec/poi/excel/ExcelRecordWriter.java | 1 + .../poi/excel/ExcelRecordReaderTest.java | 2 +- .../poi/excel/ExcelRecordWriterTest.java | 2 +- .../reduce/geo/CoordinatesReduction.java | 2 +- .../geo/CoordinatesDistanceTransform.java | 2 +- .../geo/IPAddressToLocationTransform.java | 2 +- .../transform/reduce/TestGeoReduction.java | 2 +- .../transform/TestGeoTransforms.java | 3 +- .../reader/mapfile/MapFileRecordReader.java | 4 +- .../mapfile/MapFileSequenceRecordReader.java | 4 +- .../reader/mapfile/record/RecordWritable.java | 4 +- .../record/SequenceRecordWritable.java | 10 ++-- .../writer/mapfile/AbstractMapFileWriter.java | 1 + .../writer/mapfile/MapFileRecordWriter.java | 2 +- .../mapfile/MapFileSequenceRecordWriter.java | 2 +- .../reader/TestMapFileRecordReader.java | 8 +-- .../TestMapFileRecordReaderMultipleParts.java | 8 +-- ...ileRecordReaderMultiplePartsSomeEmpty.java | 8 +-- .../writer/TestMapFileRecordWriter.java | 2 +- .../org/datavec/image/data/ImageWritable.java | 3 +- .../recordreader/BaseImageRecordReader.java | 4 +- .../ObjectDetectionRecordReader.java | 4 +- .../transform/ImageTransformProcess.java | 2 +- .../FileBatchRecordReaderTest.java | 2 +- .../recordreader/TestImageRecordReader.java | 4 +- .../TestObjectDetectionRecordReader.java | 4 +- .../datavec/nlp/reader/TfidfRecordReader.java | 4 +- .../nlp/transforms/BagOfWordsTransform.java | 2 +- .../nlp/transforms/GazeteerTransform.java | 2 +- .../nlp/transforms/MultiNlpTransform.java | 2 +- ...rBagOfWordsTermSequenceIndexTransform.java | 2 +- .../vectorizer/AbstractTfidfVectorizer.java | 2 +- .../nlp/vectorizer/TextVectorizer.java | 4 +- .../nlp/vectorizer/TfidfVectorizer.java | 2 +- .../nlp/reader/TfidfRecordReaderTest.java | 4 +- .../nlp/transforms/TestGazeteerTransform.java | 2 +- .../nlp/transforms/TestMultiNLPTransform.java | 2 +- ...OfWordsTermSequenceIndexTransformTest.java | 2 +- .../local/transforms/AnalyzeLocal.java | 2 +- .../transforms/LocalTransformExecutor.java | 1 + ...lTransformProcessSequenceRecordReader.java | 3 +- .../SequenceEmptyRecordFunction.java | 2 +- .../aggregate/AnalysisAddFunction.java | 2 +- .../histogram/HistogramAddFunction.java | 2 +- .../functions/EmptyRecordFunction.java | 2 +- .../functions/LineRecordReaderFunction.java | 2 +- .../functions/RecordReaderFunction.java | 2 +- .../SequenceRecordReaderFunction.java | 2 +- .../data/RecordReaderBytesFunction.java | 2 +- .../SequenceRecordReaderBytesFunction.java | 2 +- ...ExecuteJoinFromCoGroupFlatMapFunction.java | 2 +- ...JoinFromCoGroupFlatMapFunctionAdapter.java | 2 +- .../transforms/join/ExtractKeysFunction.java | 2 +- .../join/FilterAndFlattenJoinedValues.java | 2 +- .../FilterAndFlattenJoinedValuesAdapter.java | 2 +- .../local/transforms/join/JoinedValue.java | 2 +- .../misc/ColumnAsKeyPairFunction.java | 2 +- .../misc/ColumnToKeyPairTransform.java | 2 +- .../misc/NDArrayToWritablesFunction.java | 2 +- .../misc/SequenceMergeFunction.java | 2 +- .../SequenceWritablesToStringFunction.java | 2 +- .../misc/StringToWritablesFunction.java | 2 +- .../misc/WritablesToNDArrayFunction.java | 2 +- .../misc/WritablesToStringFunction.java | 2 +- .../UnzipForCalculateSortedRankFunction.java | 2 +- .../reduce/MapToPairForReducerFunction.java | 2 +- .../transforms/reduce/ReducerFunction.java | 2 +- .../sequence/ConvertToSequenceLengthOne.java | 2 +- .../LocalGroupToSequenceFunction.java | 2 +- .../LocalMapToPairByColumnFunction.java | 2 +- ...calMapToPairByMultipleColumnsFunction.java | 2 +- .../sequence/LocalSequenceFilterFunction.java | 2 +- .../LocalSequenceTransformFunction.java | 2 +- .../transform/LocalTransformFunction.java | 2 +- .../LocalTransformProcessFunction.java | 2 +- .../LocalTransformProcessFunctionAdapter.java | 2 +- .../transform/SequenceSplitFunction.java | 2 +- .../SequenceSplitFunctionAdapter.java | 2 +- .../FilterWritablesBySchemaFunction.java | 2 +- .../transform/filter/LocalFilterFunction.java | 2 +- ...ocalTransformProcessRecordReaderTests.java | 4 +- .../transforms/analysis/TestAnalyzeLocal.java | 2 +- .../TestLineRecordReaderFunction.java | 2 +- .../TestNDArrayToWritablesFunction.java | 3 +- .../TestWritablesToNDArrayFunction.java | 2 +- .../TestWritablesToStringFunctions.java | 2 +- .../transforms/transform/ExecutionTest.java | 1 + .../transform/TestGeoTransforms.java | 2 +- .../transform/TestPythonTransformProcess.java | 1 + .../transforms/transform/join/TestJoin.java | 1 + .../rank/TestCalculateSortedRank.java | 2 +- .../sequence/TestConvertToSequence.java | 2 +- .../org/datavec/python/PythonCondition.java | 1 + .../org/datavec/python/PythonTransform.java | 2 +- .../spark/SequenceEmptyRecordFunction.java | 2 +- .../spark/functions/EmptyRecordFunction.java | 2 +- .../functions/LineRecordReaderFunction.java | 2 +- .../spark/functions/RecordReaderFunction.java | 2 +- .../SequenceRecordReaderFunction.java | 2 +- .../data/RecordReaderBytesFunction.java | 2 +- .../SequenceRecordReaderBytesFunction.java | 2 +- ...PairSequenceRecordReaderBytesFunction.java | 2 +- .../spark/storage/SparkStorageUtils.java | 2 +- .../functions/RecordLoadPairFunction.java | 2 +- .../functions/RecordSavePrepPairFunction.java | 2 +- .../SequenceRecordLoadPairFunction.java | 2 +- .../SequenceRecordSavePrepPairFunction.java | 2 +- .../datavec/spark/transform/AnalyzeSpark.java | 2 +- .../datavec/spark/transform/DataFrames.java | 1 + .../spark/transform/Normalization.java | 2 +- .../transform/SparkTransformExecutor.java | 2 +- .../analysis/CategoricalToPairFunction.java | 2 +- .../analysis/SelectColumnFunction.java | 2 +- .../analysis/SequenceFlatMapFunction.java | 2 +- .../analysis/SequenceLengthFunction.java | 2 +- .../analysis/StringLengthFunction.java | 2 +- .../analysis/WritableToDoubleFunction.java | 2 +- .../analysis/WritableToStringFunction.java | 2 +- .../aggregate/AnalysisAddFunction.java | 2 +- .../histogram/HistogramAddFunction.java | 2 +- .../SequenceLengthAnalysisCounter.java | 2 +- .../analysis/unique/UniqueAddFunction.java | 2 +- .../analysis/unique/UniqueMergeFunction.java | 2 +- .../FilterWritablesBySchemaFunction.java | 2 +- .../transform/filter/SparkFilterFunction.java | 2 +- ...ExecuteJoinFromCoGroupFlatMapFunction.java | 2 +- .../transform/join/ExtractKeysFunction.java | 2 +- .../join/FilterAndFlattenJoinedValues.java | 2 +- .../spark/transform/join/JoinedValue.java | 2 +- .../misc/ColumnAsKeyPairFunction.java | 2 +- .../misc/ColumnToKeyPairTransform.java | 4 +- .../misc/NDArrayToWritablesFunction.java | 2 +- .../transform/misc/SequenceMergeFunction.java | 2 +- .../SequenceWritablesToStringFunction.java | 2 +- .../misc/StringToWritablesFunction.java | 2 +- .../misc/WritablesToNDArrayFunction.java | 2 +- .../misc/WritablesToStringFunction.java | 2 +- .../UnzipForCalculateSortedRankFunction.java | 2 +- .../reduce/MapToPairForReducerFunction.java | 2 +- .../transform/reduce/ReducerFunction.java | 2 +- .../sequence/ConvertToSequenceLengthOne.java | 2 +- .../SparkGroupToSequenceFunction.java | 2 +- .../SparkMapToPairByColumnFunction.java | 2 +- ...arkMapToPairByMultipleColumnsFunction.java | 2 +- .../sequence/SparkSequenceFilterFunction.java | 2 +- .../SparkSequenceTransformFunction.java | 2 +- .../sparkfunction/SequenceToRows.java | 2 +- .../transform/sparkfunction/ToRecord.java | 1 + .../spark/transform/sparkfunction/ToRow.java | 2 +- .../DataFrameToSequenceCreateCombiner.java | 2 +- .../DataFrameToSequenceMergeCombiner.java | 7 +-- .../DataFrameToSequenceMergeValue.java | 3 +- .../transform/SequenceSplitFunction.java | 2 +- .../transform/SparkTransformFunction.java | 2 +- .../SparkTransformProcessFunction.java | 2 +- .../spark/transform/utils/SparkExport.java | 2 +- .../spark/transform/utils/SparkUtils.java | 1 + .../TestLineRecordReaderFunction.java | 2 +- .../TestNDArrayToWritablesFunction.java | 3 +- ...PairSequenceRecordReaderBytesFunction.java | 2 +- .../TestRecordReaderBytesFunction.java | 2 +- .../functions/TestRecordReaderFunction.java | 2 +- ...TestSequenceRecordReaderBytesFunction.java | 2 +- .../TestSequenceRecordReaderFunction.java | 2 +- .../TestWritablesToNDArrayFunction.java | 1 + .../TestWritablesToStringFunctions.java | 2 +- .../spark/storage/TestSparkStorageUtils.java | 1 + .../spark/transform/DataFramesTests.java | 2 +- .../spark/transform/ExecutionTest.java | 2 +- .../spark/transform/NormalizationTests.java | 2 +- .../transform/analysis/TestAnalysis.java | 1 + .../spark/transform/join/TestJoin.java | 1 + .../rank/TestCalculateSortedRank.java | 2 +- .../sequence/TestConvertToSequence.java | 2 +- .../org/datavec/spark/util/TestSparkUtil.java | 2 +- .../org/nd4j/linalg/dataset/api/DataSet.java | 2 +- .../RecordReaderDataSetiteratorTest.java | 4 +- .../RecordReaderMultiDataSetIteratorTest.java | 4 +- .../tools/SpecialImageRecordReader.java | 2 +- .../org/deeplearning4j/eval/EvalTest.java | 2 +- .../exceptions/TestRecordReaders.java | 3 +- .../datavec/RecordReaderDataSetIterator.java | 4 +- .../RecordReaderMultiDataSetIterator.java | 4 +- .../sequencevectors/SequenceVectorsTest.java | 2 +- .../spark/datavec/DataVecDataSetFunction.java | 2 +- .../DataVecSequenceDataSetFunction.java | 3 +- .../DataVecSequencePairDataSetFunction.java | 3 +- .../spark/datavec/RecordReaderFunction.java | 2 +- .../export/StringToDataSetExportFunction.java | 2 +- .../spark/datavec/iterator/DataVecRecord.java | 2 +- .../datavec/iterator/DataVecRecords.java | 2 +- .../spark/datavec/iterator/IteratorUtils.java | 2 +- .../datavec/iterator/RRMDSIFunction.java | 2 +- .../iterator/SparkSourceDummyReader.java | 4 +- .../iterator/SparkSourceDummySeqReader.java | 2 +- .../deeplearning4j/spark/util/MLLibUtil.java | 2 +- .../datavec/TestDataVecDataSetFunctions.java | 2 +- .../datavec/iterator/TestIteratorUtils.java | 2 +- 519 files changed, 658 insertions(+), 636 deletions(-) create mode 100644 cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/Record.java rename cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/{writable => }/Writable.java (63%) delete mode 100644 cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/Record.java diff --git a/.gitignore b/.gitignore index 09430be6d..fbe938d6a 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,4 @@ bruai4j-native-common/cmake* *.dll /bruai4j-native/bruai4j-native-common/blasbuild/ /bruai4j-native/bruai4j-native-common/build/ +/cavis-native/cavis-native-lib/blasbuild/ diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java index 7813efe6a..cc88a0914 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java @@ -21,7 +21,6 @@ package net.brutex.spark; import com.fasterxml.jackson.core.Version; -import lombok.extern.log4j.Log4j2; import lombok.extern.slf4j.Slf4j; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; @@ -34,7 +33,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.filter.FilterInvalidValues; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.misc.StringToWritablesFunction; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java index be62228c1..4e340c69a 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java @@ -34,7 +34,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.filter.FilterInvalidValues; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.misc.StringToWritablesFunction; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/Record.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/Record.java new file mode 100644 index 000000000..48e6339f8 --- /dev/null +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/Record.java @@ -0,0 +1,55 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package org.datavec.api; + +import org.datavec.api.records.metadata.RecordMetaData; + +import java.io.Serializable; +import java.util.List; + +public interface Record extends Serializable { + + /** + * Get the record values, as a {@code List} + * + * @return Record values + */ + List getRecord(); + + /** + * Get the record values for this Record + */ + void setRecord(List record); + + /** + * Get the RecordMetaData for this record + * + * @return Metadata for this record (or null, if none has been set) + */ + RecordMetaData getMetaData(); + + /** + * Set the Record metadata + */ + void setMetaData(RecordMetaData recordMetaData); + +} diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Writable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/Writable.java similarity index 63% rename from cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Writable.java rename to cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/Writable.java index 25466136d..22b68adfa 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Writable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/Writable.java @@ -1,24 +1,25 @@ /* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * */ -package org.datavec.api.writable; +package org.datavec.api; import com.fasterxml.jackson.annotation.JsonTypeInfo; @@ -26,6 +27,8 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; import java.io.Serializable; +import org.datavec.api.writable.WritableFactory; +import org.datavec.api.writable.WritableType; @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") public interface Writable extends Serializable { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/conf/Configuration.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/conf/Configuration.java index 922b31aed..eddbea9bb 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/conf/Configuration.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/conf/Configuration.java @@ -24,7 +24,7 @@ import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonGenerator; import org.apache.commons.lang3.StringUtils; import org.datavec.api.util.ReflectionUtils; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/formats/input/InputFormat.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/formats/input/InputFormat.java index 5814af57c..03743d2f8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/formats/input/InputFormat.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/formats/input/InputFormat.java @@ -24,7 +24,7 @@ package org.datavec.api.formats.input; import org.datavec.api.conf.Configuration; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparable.java index 955207efc..e9746b93e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparable.java @@ -21,7 +21,7 @@ package org.datavec.api.io; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public interface WritableComparable extends Writable, Comparable { } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparator.java index bcd2b8074..c6c8a7fe7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableComparator.java @@ -22,7 +22,7 @@ package org.datavec.api.io; import org.datavec.api.util.ReflectionUtils; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInput; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableConverter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableConverter.java index 9afc8fa56..2ad9ccdbc 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableConverter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableConverter.java @@ -21,7 +21,7 @@ package org.datavec.api.io; import org.datavec.api.io.converters.WritableConverterException; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public interface WritableConverter { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableUtils.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableUtils.java index 7070ce47b..3aa7b007a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableUtils.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/WritableUtils.java @@ -23,7 +23,7 @@ package org.datavec.api.io; import org.datavec.api.conf.Configuration; import org.datavec.api.util.ReflectionUtils; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.*; import java.nio.charset.StandardCharsets; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/DoubleWritableConverter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/DoubleWritableConverter.java index eb2a09f2b..c3722a509 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/DoubleWritableConverter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/DoubleWritableConverter.java @@ -20,6 +20,7 @@ package org.datavec.api.io.converters; +import org.datavec.api.Writable; import org.datavec.api.io.WritableConverter; import org.datavec.api.writable.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/FloatWritableConverter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/FloatWritableConverter.java index 16ee74e52..19bb1f3fd 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/FloatWritableConverter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/FloatWritableConverter.java @@ -20,6 +20,7 @@ package org.datavec.api.io.converters; +import org.datavec.api.Writable; import org.datavec.api.io.WritableConverter; import org.datavec.api.writable.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/LabelWriterConverter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/LabelWriterConverter.java index 470f88417..31d16728e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/LabelWriterConverter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/LabelWriterConverter.java @@ -22,7 +22,7 @@ package org.datavec.api.io.converters; import org.datavec.api.io.WritableConverter; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/SelfWritableConverter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/SelfWritableConverter.java index 1e07532b3..46107941d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/SelfWritableConverter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/converters/SelfWritableConverter.java @@ -21,7 +21,7 @@ package org.datavec.api.io.converters; import org.datavec.api.io.WritableConverter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class SelfWritableConverter implements WritableConverter { @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java index 348b4e0fd..ad6d49a83 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/filters/BalancedPathFilter.java @@ -22,7 +22,7 @@ package org.datavec.api.io.filters; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.PathLabelGenerator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.net.URI; import java.util.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/ParentPathLabelGenerator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/ParentPathLabelGenerator.java index a68a7e2c5..4156e7cd2 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/ParentPathLabelGenerator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/ParentPathLabelGenerator.java @@ -22,7 +22,7 @@ package org.datavec.api.io.labels; import org.apache.commons.io.FilenameUtils; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.File; import java.net.URI; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathLabelGenerator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathLabelGenerator.java index d5bb50d2a..76b1af21c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathLabelGenerator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathLabelGenerator.java @@ -20,7 +20,7 @@ package org.datavec.api.io.labels; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.net.URI; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathMultiLabelGenerator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathMultiLabelGenerator.java index 0332353de..63597ffc9 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathMultiLabelGenerator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PathMultiLabelGenerator.java @@ -20,7 +20,7 @@ package org.datavec.api.io.labels; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PatternPathLabelGenerator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PatternPathLabelGenerator.java index 96265b0c1..311bccaa5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PatternPathLabelGenerator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/io/labels/PatternPathLabelGenerator.java @@ -22,7 +22,7 @@ package org.datavec.api.io.labels; import org.apache.commons.io.FilenameUtils; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.File; import java.net.URI; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/Record.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/Record.java deleted file mode 100644 index 93c89685e..000000000 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/Record.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.records; - -import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.writable.Writable; - -import java.io.Serializable; -import java.util.List; - -public interface Record extends Serializable { - - /** - * Get the record values, as a {@code List} - * - * @return Record values - */ - List getRecord(); - - /** - * Get the record values for this Record - */ - void setRecord(List record); - - /** - * Get the RecordMetaData for this record - * - * @return Metadata for this record (or null, if none has been set) - */ - RecordMetaData getMetaData(); - - /** - * Set the Record metadata - */ - void setMetaData(RecordMetaData recordMetaData); - -} diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/SequenceRecord.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/SequenceRecord.java index 6fbb68913..cfd474279 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/SequenceRecord.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/SequenceRecord.java @@ -21,7 +21,7 @@ package org.datavec.api.records; import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.List; @@ -29,7 +29,7 @@ import java.util.List; public interface SequenceRecord extends Serializable { /** - * Get the sequence record values + * Get the sequence record values. The outer list is the sequence of Records * * @return Sequence record values */ diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/impl/Record.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/impl/Record.java index 7024a718d..f8aeb8236 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/impl/Record.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/impl/Record.java @@ -23,13 +23,13 @@ package org.datavec.api.records.impl; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; @AllArgsConstructor @Data -public class Record implements org.datavec.api.records.Record { +public class Record implements org.datavec.api.Record { private List record; private RecordMetaData metaData; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/impl/SequenceRecord.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/impl/SequenceRecord.java index 11b7ae5c0..9a2e30c4c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/impl/SequenceRecord.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/impl/SequenceRecord.java @@ -23,7 +23,7 @@ package org.datavec.api.records.impl; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/mapper/RecordMapper.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/mapper/RecordMapper.java index 29a1abccc..7159633fe 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/mapper/RecordMapper.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/mapper/RecordMapper.java @@ -28,7 +28,7 @@ import org.datavec.api.records.writer.RecordWriter; import org.datavec.api.split.InputSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.split.partition.Partitioner; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java index 9fd71d227..73ff0b265 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/BaseRecordReader.java @@ -24,7 +24,7 @@ import org.datavec.api.records.listener.RecordListener; import org.datavec.api.split.InputSplit; import org.datavec.api.split.StreamInputSplit; import org.datavec.api.split.streams.FileStreamCreatorFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.io.Closeable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java index a72793529..332358f74 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java @@ -22,11 +22,11 @@ package org.datavec.api.records.reader; import org.datavec.api.conf.Configurable; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.*; import java.net.URI; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/SequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/SequenceRecordReader.java index 69fe24d5a..c987f617d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/SequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/SequenceRecordReader.java @@ -20,10 +20,10 @@ package org.datavec.api.records.reader; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java index 905854f03..fcb394f95 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ComposableRecordReader.java @@ -21,12 +21,12 @@ package org.datavec.api.records.reader.impl; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java index ab436407a..73cc36cef 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/ConcatenatingRecordReader.java @@ -21,12 +21,12 @@ package org.datavec.api.records.reader.impl; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java index a9448b981..fb010f258 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java @@ -23,14 +23,14 @@ package org.datavec.api.records.reader.impl; import lombok.Getter; import lombok.Setter; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataURI; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.*; import java.net.URI; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java index b05b739df..0569f46f3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/LineRecordReader.java @@ -26,7 +26,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.apache.commons.io.LineIterator; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataLine; import org.datavec.api.records.reader.BaseRecordReader; @@ -34,7 +34,7 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.split.StringSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import org.nd4j.common.primitives.Triple; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java index a8e02e2c4..0fe9738dc 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionRecordReader.java @@ -22,12 +22,12 @@ package org.datavec.api.records.reader.impl.collection; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataIndex; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java index e33f0a9ec..8216ecea1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/CollectionSequenceRecordReader.java @@ -22,14 +22,14 @@ package org.datavec.api.records.reader.impl.collection; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataIndex; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java index 7c99ca300..0f634f189 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/collection/ListStringRecordReader.java @@ -21,13 +21,13 @@ package org.datavec.api.records.reader.impl.collection; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.split.ListStringSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVLineSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVLineSequenceRecordReader.java index d2f7e3ffd..79e660220 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVLineSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVLineSequenceRecordReader.java @@ -20,11 +20,11 @@ package org.datavec.api.records.reader.impl.csv; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVMultiSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVMultiSequenceRecordReader.java index 5e4571f81..be79558fa 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVMultiSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVMultiSequenceRecordReader.java @@ -24,7 +24,7 @@ import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataInterval; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import java.io.BufferedReader; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVNLinesSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVNLinesSequenceRecordReader.java index 86e9c3c64..43b05dd3f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVNLinesSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVNLinesSequenceRecordReader.java @@ -21,14 +21,14 @@ package org.datavec.api.records.reader.impl.csv; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataLineInterval; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.primitives.Triple; import java.io.DataInputStream; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRecordReader.java index e947ebafa..4ed000c1b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRecordReader.java @@ -22,13 +22,13 @@ package org.datavec.api.records.reader.impl.csv; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataLine; import org.datavec.api.records.reader.impl.LineRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import java.io.BufferedReader; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRegexRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRegexRecordReader.java index 6d0a31a06..e7ff005e4 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRegexRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVRegexRecordReader.java @@ -21,7 +21,7 @@ package org.datavec.api.records.reader.impl.csv; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java index 8398bf274..d4d69835b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVSequenceRecordReader.java @@ -28,7 +28,7 @@ import org.datavec.api.records.metadata.RecordMetaDataURI; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.FileRecordReader; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.*; import java.net.URI; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java index 02a94f8d8..951c506ee 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/csv/CSVVariableSlidingWindowRecordReader.java @@ -21,13 +21,13 @@ package org.datavec.api.records.reader.impl.csv; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataLineInterval; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchRecordReader.java index b827400d6..36053f187 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchRecordReader.java @@ -21,12 +21,12 @@ package org.datavec.api.records.reader.impl.filebatch; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.loader.FileBatch; import org.nd4j.common.base.Preconditions; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchSequenceRecordReader.java index 133089920..c56b8efa8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/filebatch/FileBatchSequenceRecordReader.java @@ -21,13 +21,13 @@ package org.datavec.api.records.reader.impl.filebatch; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.loader.FileBatch; import org.nd4j.common.base.Preconditions; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemoryRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemoryRecordReader.java index d9023b46a..ada33b680 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemoryRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemoryRecordReader.java @@ -22,12 +22,12 @@ package org.datavec.api.records.reader.impl.inmemory; import lombok.Data; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemorySequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemorySequenceRecordReader.java index 76be03200..b01ff71e5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemorySequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/inmemory/InMemorySequenceRecordReader.java @@ -22,13 +22,13 @@ package org.datavec.api.records.reader.impl.inmemory; import lombok.Data; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/FieldSelection.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/FieldSelection.java index e3c36bb53..f8b75da51 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/FieldSelection.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/FieldSelection.java @@ -21,7 +21,7 @@ package org.datavec.api.records.reader.impl.jackson; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineRecordReader.java index e759b6aa6..442a9b867 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineRecordReader.java @@ -24,7 +24,7 @@ import java.util.List; import org.datavec.api.records.reader.impl.LineRecordReader; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.databind.ObjectMapper; public class JacksonLineRecordReader extends LineRecordReader { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java index 3c2d81f69..89ec25d1e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonLineSequenceRecordReader.java @@ -27,7 +27,7 @@ import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataURI; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.FileRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonReaderUtils.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonReaderUtils.java index 8626188bd..d6d390152 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonReaderUtils.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonReaderUtils.java @@ -21,7 +21,7 @@ package org.datavec.api.records.reader.impl.jackson; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java index de0d41573..78a31aaf9 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/jackson/JacksonRecordReader.java @@ -26,13 +26,13 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.datavec.api.conf.Configuration; import org.datavec.api.io.labels.PathLabelGenerator; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataURI; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java index 419c82c4d..0b2e088d8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/MatlabRecordReader.java @@ -23,7 +23,7 @@ package org.datavec.api.records.reader.impl.misc; import org.datavec.api.records.reader.impl.FileRecordReader; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java index 6d5bc5ea1..009629aa8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/misc/SVMLightRecordReader.java @@ -21,14 +21,14 @@ package org.datavec.api.records.reader.impl.misc; import lombok.extern.slf4j.Slf4j; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataLine; import org.datavec.api.records.reader.impl.LineRecordReader; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.conf.Configuration; import java.io.DataInputStream; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java index 3a216d784..b8de6f081 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexLineRecordReader.java @@ -21,13 +21,13 @@ package org.datavec.api.records.reader.impl.regex; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataLine; import org.datavec.api.records.reader.impl.LineRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java index ebf685d50..662448191 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/regex/RegexSequenceRecordReader.java @@ -30,7 +30,7 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.FileRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import org.slf4j.Logger; import org.slf4j.LoggerFactory; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java index 2b8a38d58..6cb29af95 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReader.java @@ -21,13 +21,13 @@ package org.datavec.api.records.reader.impl.transform; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.transform.TransformProcess; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java index cb9213dea..940db6445 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/transform/TransformProcessSequenceRecordReader.java @@ -22,14 +22,14 @@ package org.datavec.api.records.reader.impl.transform; import lombok.AllArgsConstructor; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.transform.TransformProcess; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/RecordWriter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/RecordWriter.java index 822ebe53a..7443c81a4 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/RecordWriter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/RecordWriter.java @@ -26,7 +26,7 @@ import org.datavec.api.conf.Configuration; import org.datavec.api.split.InputSplit; import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.split.partition.Partitioner; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Closeable; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/SequenceRecordWriter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/SequenceRecordWriter.java index afb517a4b..2b20d1fc0 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/SequenceRecordWriter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/SequenceRecordWriter.java @@ -24,7 +24,7 @@ package org.datavec.api.records.writer; import org.datavec.api.conf.Configurable; import org.datavec.api.split.partition.PartitionMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Closeable; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/LineRecordWriter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/LineRecordWriter.java index 476eb7235..cd4b77b05 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/LineRecordWriter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/LineRecordWriter.java @@ -24,7 +24,7 @@ package org.datavec.api.records.writer.impl; import org.apache.commons.lang3.NotImplementedException; import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.IOException; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/csv/CSVRecordWriter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/csv/CSVRecordWriter.java index 1ea232c5c..d9704c4ce 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/csv/CSVRecordWriter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/csv/CSVRecordWriter.java @@ -23,7 +23,7 @@ package org.datavec.api.records.writer.impl.csv; import org.datavec.api.records.writer.impl.FileRecordWriter; import org.datavec.api.split.partition.PartitionMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.IOException; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/MatlabRecordWriter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/MatlabRecordWriter.java index 490c6e39d..e4e2ea45e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/MatlabRecordWriter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/MatlabRecordWriter.java @@ -24,7 +24,7 @@ package org.datavec.api.records.writer.impl.misc; import org.apache.commons.lang3.NotImplementedException; import org.datavec.api.records.writer.impl.FileRecordWriter; import org.datavec.api.split.partition.PartitionMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.IOException; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/SVMLightRecordWriter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/SVMLightRecordWriter.java index c15f9ede6..0b18661d3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/SVMLightRecordWriter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/writer/impl/misc/SVMLightRecordWriter.java @@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader; import org.datavec.api.records.writer.impl.FileRecordWriter; import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.writable.ArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.IOException; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/timeseries/util/TimeSeriesWritableUtils.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/timeseries/util/TimeSeriesWritableUtils.java index 916478f0e..1f9f4f6a5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/timeseries/util/TimeSeriesWritableUtils.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/timeseries/util/TimeSeriesWritableUtils.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/Transform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/Transform.java index 96c7fdc8a..391e73463 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/Transform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/Transform.java @@ -20,7 +20,7 @@ package org.datavec.api.transform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java index 9673c9a4f..2ccbebb95 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/TransformProcess.java @@ -23,6 +23,7 @@ package org.datavec.api.transform; import lombok.Data; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import org.datavec.api.Writable; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.transform.analysis.DataAnalysis; import org.datavec.api.transform.analysis.columns.ColumnAnalysis; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/AnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/AnalysisCounter.java index d45264ce6..8bb8aed35 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/AnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/AnalysisCounter.java @@ -20,7 +20,7 @@ package org.datavec.api.transform.analysis; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/BytesAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/BytesAnalysisCounter.java index 07d263db7..d00c68e13 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/BytesAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/BytesAnalysisCounter.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.analysis.counter; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.transform.analysis.AnalysisCounter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/CategoricalAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/CategoricalAnalysisCounter.java index 297fef35b..a2426289a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/CategoricalAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/CategoricalAnalysisCounter.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.analysis.counter; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.transform.analysis.AnalysisCounter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.HashMap; import java.util.HashSet; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/DoubleAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/DoubleAnalysisCounter.java index 7f16b8d7a..74729c33a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/DoubleAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/DoubleAnalysisCounter.java @@ -24,7 +24,7 @@ import com.tdunning.math.stats.TDigest; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.transform.analysis.AnalysisCounter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/IntegerAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/IntegerAnalysisCounter.java index 0a37ac1d4..9a12a14ac 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/IntegerAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/IntegerAnalysisCounter.java @@ -24,7 +24,7 @@ import com.tdunning.math.stats.TDigest; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.transform.analysis.AnalysisCounter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/LongAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/LongAnalysisCounter.java index 4243deeef..b7777ccdb 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/LongAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/LongAnalysisCounter.java @@ -24,7 +24,7 @@ import com.tdunning.math.stats.TDigest; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.transform.analysis.AnalysisCounter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/NDArrayAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/NDArrayAnalysisCounter.java index d4a7c2bc6..5f68925ba 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/NDArrayAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/NDArrayAnalysisCounter.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.analysis.counter; import org.datavec.api.transform.analysis.AnalysisCounter; import org.datavec.api.transform.analysis.columns.NDArrayAnalysis; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.HashMap; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisCounter.java index a18237513..7add60b54 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/counter/StringAnalysisCounter.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.analysis.counter; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.transform.analysis.AnalysisCounter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/CategoricalHistogramCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/CategoricalHistogramCounter.java index c4487b891..5bc64f07d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/CategoricalHistogramCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/CategoricalHistogramCounter.java @@ -20,7 +20,7 @@ package org.datavec.api.transform.analysis.histogram; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.HashMap; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/DoubleHistogramCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/DoubleHistogramCounter.java index cb2ea37e2..0b42048a9 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/DoubleHistogramCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/DoubleHistogramCounter.java @@ -20,7 +20,7 @@ package org.datavec.api.transform.analysis.histogram; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class DoubleHistogramCounter implements HistogramCounter { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/HistogramCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/HistogramCounter.java index 4f7503626..93c22c724 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/HistogramCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/HistogramCounter.java @@ -20,7 +20,7 @@ package org.datavec.api.transform.analysis.histogram; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/NDArrayHistogramCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/NDArrayHistogramCounter.java index 2ab063cff..d0f39b922 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/NDArrayHistogramCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/NDArrayHistogramCounter.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.analysis.histogram; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; public class NDArrayHistogramCounter implements HistogramCounter { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/StringHistogramCounter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/StringHistogramCounter.java index 179e0f8bb..9d6f858fd 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/StringHistogramCounter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/histogram/StringHistogramCounter.java @@ -20,7 +20,7 @@ package org.datavec.api.transform.analysis.histogram; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class StringHistogramCounter implements HistogramCounter { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/QualityAnalysisAddFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/QualityAnalysisAddFunction.java index 9fac13e63..d79aa7699 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/QualityAnalysisAddFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/QualityAnalysisAddFunction.java @@ -31,7 +31,7 @@ import org.datavec.api.transform.analysis.quality.string.StringQualityAnalysisSt import org.datavec.api.transform.analysis.quality.time.TimeQualityAnalysisState; import org.datavec.api.transform.metadata.*; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/QualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/QualityAnalysisState.java index a8c88a093..d149b833a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/QualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/QualityAnalysisState.java @@ -21,7 +21,7 @@ package org.datavec.api.transform.analysis.quality; import org.datavec.api.transform.quality.columns.ColumnQuality; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/bytes/BytesQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/bytes/BytesQualityAnalysisState.java index f6c6e8c3c..f553af2e3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/bytes/BytesQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/bytes/BytesQualityAnalysisState.java @@ -24,7 +24,7 @@ import lombok.Getter; import org.datavec.api.transform.analysis.quality.QualityAnalysisState; import org.datavec.api.transform.quality.columns.BytesQuality; import org.datavec.api.transform.quality.columns.ColumnQuality; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class BytesQualityAnalysisState implements QualityAnalysisState { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAddFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAddFunction.java index 4346d4c1e..22c4c774d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAddFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAddFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.quality.columns.CategoricalQuality; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAnalysisState.java index 44aaac563..d70da5cf6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/categorical/CategoricalQualityAnalysisState.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.analysis.quality.QualityAnalysisState; import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.quality.columns.CategoricalQuality; import org.datavec.api.transform.quality.columns.ColumnQuality; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class CategoricalQualityAnalysisState implements QualityAnalysisState { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/integer/IntegerQualityAddFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/integer/IntegerQualityAddFunction.java index d9db5b8e1..e25dd1255 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/integer/IntegerQualityAddFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/integer/IntegerQualityAddFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.quality.columns.IntegerQuality; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/integer/IntegerQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/integer/IntegerQualityAnalysisState.java index 5e4a3028f..46cfddb2f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/integer/IntegerQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/integer/IntegerQualityAnalysisState.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.analysis.quality.QualityAnalysisState; import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.quality.columns.ColumnQuality; import org.datavec.api.transform.quality.columns.IntegerQuality; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class IntegerQualityAnalysisState implements QualityAnalysisState { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/longq/LongQualityAddFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/longq/LongQualityAddFunction.java index 8e8aa386f..b28042c8c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/longq/LongQualityAddFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/longq/LongQualityAddFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.metadata.LongMetaData; import org.datavec.api.transform.quality.columns.LongQuality; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/longq/LongQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/longq/LongQualityAnalysisState.java index f60c71a80..86e254808 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/longq/LongQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/longq/LongQualityAnalysisState.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.analysis.quality.QualityAnalysisState; import org.datavec.api.transform.metadata.LongMetaData; import org.datavec.api.transform.quality.columns.ColumnQuality; import org.datavec.api.transform.quality.columns.LongQuality; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class LongQualityAnalysisState implements QualityAnalysisState { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/real/RealQualityAddFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/real/RealQualityAddFunction.java index de1b859e1..85931597e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/real/RealQualityAddFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/real/RealQualityAddFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.transform.quality.columns.DoubleQuality; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/real/RealQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/real/RealQualityAnalysisState.java index 817867151..e4e61e03c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/real/RealQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/real/RealQualityAnalysisState.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.analysis.quality.QualityAnalysisState; import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.transform.quality.columns.ColumnQuality; import org.datavec.api.transform.quality.columns.DoubleQuality; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class RealQualityAnalysisState implements QualityAnalysisState { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/string/StringQualityAddFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/string/StringQualityAddFunction.java index bd81fe577..bbe128800 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/string/StringQualityAddFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/string/StringQualityAddFunction.java @@ -25,7 +25,7 @@ import lombok.AllArgsConstructor; import org.datavec.api.transform.metadata.StringMetaData; import org.datavec.api.transform.quality.columns.StringQuality; import org.datavec.api.writable.NullWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/string/StringQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/string/StringQualityAnalysisState.java index f911d40c9..c4155f0b6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/string/StringQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/string/StringQualityAnalysisState.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.analysis.quality.QualityAnalysisState; import org.datavec.api.transform.metadata.StringMetaData; import org.datavec.api.transform.quality.columns.ColumnQuality; import org.datavec.api.transform.quality.columns.StringQuality; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class StringQualityAnalysisState implements QualityAnalysisState { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/time/TimeQualityAddFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/time/TimeQualityAddFunction.java index e96d4e9f5..c16a929bc 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/time/TimeQualityAddFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/time/TimeQualityAddFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.metadata.TimeMetaData; import org.datavec.api.transform.quality.columns.TimeQuality; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/time/TimeQualityAnalysisState.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/time/TimeQualityAnalysisState.java index aca9badfb..cee78e4af 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/time/TimeQualityAnalysisState.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/analysis/quality/time/TimeQualityAnalysisState.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.analysis.quality.QualityAnalysisState; import org.datavec.api.transform.metadata.TimeMetaData; import org.datavec.api.transform.quality.columns.ColumnQuality; import org.datavec.api.transform.quality.columns.TimeQuality; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class TimeQualityAnalysisState implements QualityAnalysisState { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/BooleanCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/BooleanCondition.java index 6e128bc66..e31f1785e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/BooleanCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/BooleanCondition.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.condition; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java index 5928881f7..5be0ac648 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/Condition.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.condition; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/BaseColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/BaseColumnCondition.java index dd7e66bdb..da77bb07a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/BaseColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/BaseColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.SequenceConditionMode; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/BooleanColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/BooleanColumnCondition.java index 2e59fe5de..cfc7eddc8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/BooleanColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/BooleanColumnCondition.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.condition.column; import lombok.Data; import org.datavec.api.transform.condition.SequenceConditionMode; import org.datavec.api.writable.BooleanWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @Data public class BooleanColumnCondition extends BaseColumnCondition { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/CategoricalColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/CategoricalColumnCondition.java index d10ee29f3..949d18318 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/CategoricalColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/CategoricalColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.ConditionOp; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Set; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/ColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/ColumnCondition.java index 98d416456..7e1ac4d50 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/ColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/ColumnCondition.java @@ -24,7 +24,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.condition.SequenceConditionMode; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/DoubleColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/DoubleColumnCondition.java index f19b92ab2..176c5d5ee 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/DoubleColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/DoubleColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.ConditionOp; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Set; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/FloatColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/FloatColumnCondition.java index f4d40b45e..393ff97ae 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/FloatColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/FloatColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.ConditionOp; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Set; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/InfiniteColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/InfiniteColumnCondition.java index dd7c77132..6ea477fdf 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/InfiniteColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/InfiniteColumnCondition.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.condition.column; import lombok.Data; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @Data public class InfiniteColumnCondition extends BaseColumnCondition { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java index 0029eb044..b8ff3c78e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/IntegerColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.ConditionOp; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Set; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/InvalidValueColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/InvalidValueColumnCondition.java index 4c69b5f45..d8df9fd9d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/InvalidValueColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/InvalidValueColumnCondition.java @@ -21,6 +21,7 @@ package org.datavec.api.transform.condition.column; import lombok.Data; +import org.datavec.api.Writable; import org.datavec.api.writable.*; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java index a83be4fcf..9f079ff80 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/LongColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.ConditionOp; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Set; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/NaNColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/NaNColumnCondition.java index 875d45a5a..448294278 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/NaNColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/NaNColumnCondition.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.condition.column; import lombok.Data; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @Data public class NaNColumnCondition extends BaseColumnCondition { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/NullWritableColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/NullWritableColumnCondition.java index 6c4819efb..5bed35c8b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/NullWritableColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/NullWritableColumnCondition.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.condition.column; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.writable.NullWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @EqualsAndHashCode(callSuper = true) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/StringColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/StringColumnCondition.java index c5bee1731..0c96c1fca 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/StringColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/StringColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.ConditionOp; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Set; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java index 00c2714ce..bde341a0d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TimeColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.ConditionOp; import org.datavec.api.transform.condition.SequenceConditionMode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Set; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TrivialColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TrivialColumnCondition.java index 52a9a6040..e2b94b12a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TrivialColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/column/TrivialColumnCondition.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.condition.column; import lombok.Data; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/sequence/SequenceLengthCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/sequence/SequenceLengthCondition.java index 15d60608f..70e03ac39 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/sequence/SequenceLengthCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/sequence/SequenceLengthCondition.java @@ -25,7 +25,7 @@ import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.condition.ConditionOp; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/string/StringRegexColumnCondition.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/string/StringRegexColumnCondition.java index 4c44c8356..348e01549 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/string/StringRegexColumnCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/condition/string/StringRegexColumnCondition.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.SequenceConditionMode; import org.datavec.api.transform.condition.column.BaseColumnCondition; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @EqualsAndHashCode(callSuper = true) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/BaseColumnFilter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/BaseColumnFilter.java index 73718fa7f..8aac0d2f1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/BaseColumnFilter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/BaseColumnFilter.java @@ -21,7 +21,7 @@ package org.datavec.api.transform.filter; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/ConditionFilter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/ConditionFilter.java index cc7d24e9e..6caefdb99 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/ConditionFilter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/ConditionFilter.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java index ccccc6656..66d4bf03c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/Filter.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.filter; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java index 3a5a35b68..1f4dcc25d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/FilterInvalidValues.java @@ -23,6 +23,7 @@ package org.datavec.api.transform.filter; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.ToString; +import org.datavec.api.Writable; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/InvalidNumColumns.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/InvalidNumColumns.java index 70782abe8..064f5178c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/InvalidNumColumns.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/filter/InvalidNumColumns.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.filter; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/join/Join.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/join/Join.java index d71b3c0c5..42d41e911 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/join/Join.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/join/Join.java @@ -25,7 +25,7 @@ import org.apache.commons.lang3.ArrayUtils; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.NullWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BinaryMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BinaryMetaData.java index 91a9238f1..07ea6d15b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BinaryMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BinaryMetaData.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.metadata; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BooleanMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BooleanMetaData.java index 66d8872b1..24a260aad 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BooleanMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/BooleanMetaData.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.metadata; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/CategoricalMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/CategoricalMetaData.java index 95004405d..3057e975e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/CategoricalMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/CategoricalMetaData.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.metadata; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java index f13bad69e..e7c471b53 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/ColumnMetaData.java @@ -21,7 +21,7 @@ package org.datavec.api.transform.metadata; import org.datavec.api.transform.ColumnType; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/DoubleMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/DoubleMetaData.java index aaa85a489..526a9c192 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/DoubleMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/DoubleMetaData.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.metadata; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/FloatMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/FloatMetaData.java index 7bcb7abe2..1c1bfe0c2 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/FloatMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/FloatMetaData.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.metadata; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/IntegerMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/IntegerMetaData.java index c856da307..02bcac4c1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/IntegerMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/IntegerMetaData.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.metadata; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/LongMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/LongMetaData.java index 01119430e..29ce85c47 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/LongMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/LongMetaData.java @@ -25,7 +25,7 @@ import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/NDArrayMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/NDArrayMetaData.java index 9449eb780..ee5eb13b4 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/NDArrayMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/NDArrayMetaData.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/StringMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/StringMetaData.java index bf78d97e8..3fba49b42 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/StringMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/StringMetaData.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.metadata; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @EqualsAndHashCode(callSuper = true) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/TimeMetaData.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/TimeMetaData.java index c339ffe21..e502baa38 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/TimeMetaData.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/metadata/TimeMetaData.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java index ddbf34503..6e4b99798 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayColumnsMathOpTransform.java @@ -29,7 +29,7 @@ import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayDistanceTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayDistanceTransform.java index 3bba47c92..f16aa0431 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayDistanceTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayDistanceTransform.java @@ -29,7 +29,7 @@ import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.ops.transforms.Transforms; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.java index 74a91332c..731289091 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayMathFunctionTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.java index 708625c8a..faf33e837 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ndarray/NDArrayScalarOpTransform.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableCheckingOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableCheckingOp.java index c6390269d..15c3909a8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableCheckingOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableCheckingOp.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter; import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableMultiOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableMultiOp.java index 9e6fef114..7f01e09c6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableMultiOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregableMultiOp.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter; import lombok.NonNull; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregatorImpls.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregatorImpls.java index ce1b2b94d..d360db4aa 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregatorImpls.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/AggregatorImpls.java @@ -27,7 +27,7 @@ import lombok.NoArgsConstructor; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.UnsafeWritableInjector; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class AggregatorImpls { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/ByteWritableOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/ByteWritableOp.java index a0b9c30ec..436c9c1a6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/ByteWritableOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/ByteWritableOp.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.ops; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java index eb25b7b56..a4947d8b0 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DispatchWithConditionOp.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.ops; import lombok.Getter; import lombok.NonNull; import org.datavec.api.transform.condition.Condition; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DoubleWritableOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DoubleWritableOp.java index 59ff912c7..a8dc1a6d5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DoubleWritableOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/DoubleWritableOp.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.ops; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/FloatWritableOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/FloatWritableOp.java index a194f62b0..ebd6ebd9a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/FloatWritableOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/FloatWritableOp.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.ops; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/IntWritableOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/IntWritableOp.java index 636601129..968a7fae6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/IntWritableOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/IntWritableOp.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.ops; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/LongWritableOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/LongWritableOp.java index d4467cf17..b736f365c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/LongWritableOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/LongWritableOp.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.ops; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/StringAggregatorImpls.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/StringAggregatorImpls.java index 6bf357603..9e5762243 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/StringAggregatorImpls.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/StringAggregatorImpls.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.ops; import lombok.Getter; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class StringAggregatorImpls { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/StringWritableOp.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/StringWritableOp.java index 8d1524718..846721aae 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/StringWritableOp.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ops/StringWritableOp.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.ops; import lombok.AllArgsConstructor; import lombok.Data; import lombok.Getter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableColumnReduction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableColumnReduction.java index b6db27c0f..364fbf8be 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableColumnReduction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableColumnReduction.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.reduce; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.ops.IAggregableReduceOp; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableReductionUtils.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableReductionUtils.java index 77db4cbcd..2f02600b0 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableReductionUtils.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/AggregableReductionUtils.java @@ -24,7 +24,7 @@ import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ReduceOp; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.ops.*; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/ColumnReduction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/ColumnReduction.java index 57a9fecf3..39c6e23ef 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/ColumnReduction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/ColumnReduction.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.reduce; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/IAssociativeReducer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/IAssociativeReducer.java index e36830f65..f67508747 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/IAssociativeReducer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/IAssociativeReducer.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.reduce; import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/Reducer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/Reducer.java index 0979773a3..c754d6c6d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/Reducer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/Reducer.java @@ -31,7 +31,7 @@ import org.datavec.api.transform.condition.column.TrivialColumnCondition; import org.datavec.api.transform.metadata.*; import org.datavec.api.transform.ops.*; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java index b3538c8a7..df2a3a1b8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/reduce/impl/GeographicMidpointReduction.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.reduce.AggregableColumnReduction; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java index 003b212b6..c2da91813 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/Schema.java @@ -22,6 +22,7 @@ package org.datavec.api.transform.schema; import lombok.Data; import lombok.EqualsAndHashCode; +import org.datavec.api.Writable; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.metadata.*; import org.datavec.api.transform.serde.JsonMappers; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/SequenceSchema.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/SequenceSchema.java index 1a2cfa245..fdf78ffe5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/SequenceSchema.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/SequenceSchema.java @@ -22,6 +22,7 @@ package org.datavec.api.transform.schema; import lombok.Data; import lombok.EqualsAndHashCode; +import org.datavec.api.Writable; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.writable.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java index bc7fa2a98..123f4ec17 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java @@ -20,7 +20,7 @@ package org.datavec.api.transform.schema.conversion; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class TypeConversion { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/ReduceSequenceTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/ReduceSequenceTransform.java index bb61f9ae9..a49f10a8c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/ReduceSequenceTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/ReduceSequenceTransform.java @@ -28,7 +28,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.reduce.IAssociativeReducer; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java index c8616ecda..48cc13e3e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceComparator.java @@ -21,7 +21,7 @@ package org.datavec.api.transform.sequence; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java index a1a4c4312..e63af85ef 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/SequenceSplit.java @@ -21,7 +21,7 @@ package org.datavec.api.transform.sequence; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/BaseColumnComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/BaseColumnComparator.java index 02f0209fd..10ec6736c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/BaseColumnComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/BaseColumnComparator.java @@ -25,7 +25,7 @@ import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.sequence.SequenceComparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/NumericalColumnComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/NumericalColumnComparator.java index 419f68e78..819bc8b70 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/NumericalColumnComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/NumericalColumnComparator.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/StringComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/StringComparator.java index 9c173b9b6..6c3308c2e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/StringComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/comparator/StringComparator.java @@ -21,7 +21,7 @@ package org.datavec.api.transform.sequence.comparator; import lombok.Data; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/expansion/BaseSequenceExpansionTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/expansion/BaseSequenceExpansionTransform.java index 276ff5dff..b81a2e0e2 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/expansion/BaseSequenceExpansionTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/expansion/BaseSequenceExpansionTransform.java @@ -24,7 +24,7 @@ import lombok.*; import org.datavec.api.transform.Transform; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/merge/SequenceMerge.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/merge/SequenceMerge.java index 1ac4145ee..835793c49 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/merge/SequenceMerge.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/merge/SequenceMerge.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.sequence.merge; import lombok.Data; import org.datavec.api.transform.sequence.SequenceComparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SequenceSplitTimeSeparation.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SequenceSplitTimeSeparation.java index 1ffb60477..da917392b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SequenceSplitTimeSeparation.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SequenceSplitTimeSeparation.java @@ -25,7 +25,7 @@ import lombok.EqualsAndHashCode; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SplitMaxLengthSequence.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SplitMaxLengthSequence.java index 1d80c1f5c..9a441766d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SplitMaxLengthSequence.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/split/SplitMaxLengthSequence.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java index df873a753..90e003f83 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimToLengthTransform.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.Transform; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimTransform.java index ca9cc060b..9f1fe36ce 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/trim/SequenceTrimTransform.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.transform.Transform; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/OverlappingTimeWindowFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/OverlappingTimeWindowFunction.java index 98dbc0c51..eae44d38a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/OverlappingTimeWindowFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/OverlappingTimeWindowFunction.java @@ -28,7 +28,7 @@ import org.datavec.api.transform.metadata.TimeMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/ReduceSequenceByWindowTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/ReduceSequenceByWindowTransform.java index 77474f0e6..7f8564106 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/ReduceSequenceByWindowTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/ReduceSequenceByWindowTransform.java @@ -28,7 +28,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.reduce.IAssociativeReducer; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/TimeWindowFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/TimeWindowFunction.java index 594b0dbc1..816c7ce9a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/TimeWindowFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/TimeWindowFunction.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.TimeMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java index 45e00109f..27f1d74d7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/sequence/window/WindowFunction.java @@ -21,7 +21,7 @@ package org.datavec.api.transform.sequence.window; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java index 299434430..11d9894bf 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/serde/legacy/LegacyJsonFormat.java @@ -22,6 +22,7 @@ package org.datavec.api.transform.serde.legacy; import lombok.AccessLevel; import lombok.NoArgsConstructor; +import org.datavec.api.Writable; import org.datavec.api.transform.Transform; import org.datavec.api.transform.analysis.columns.*; import org.datavec.api.transform.condition.BooleanCondition; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java index 3cebf70a1..cdfe45f63 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/IStringReducer.java @@ -21,7 +21,7 @@ package org.datavec.api.transform.stringreduce; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java index 17d3ef39b..df805feaa 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/stringreduce/StringReducer.java @@ -22,14 +22,13 @@ package org.datavec.api.transform.stringreduce; import lombok.Data; import lombok.EqualsAndHashCode; -import org.datavec.api.transform.ReduceOp; import org.datavec.api.transform.StringReduceOp; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.StringMetaData; import org.datavec.api.transform.reduce.ColumnReduction; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnTransform.java index 67ef0ea43..576a095c7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnTransform.java @@ -25,7 +25,7 @@ import lombok.NoArgsConstructor; import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnsMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnsMathOpTransform.java index d0bda6912..dcf6f491a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnsMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseColumnsMathOpTransform.java @@ -27,10 +27,7 @@ import org.datavec.api.transform.MathOp; import org.datavec.api.transform.Transform; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.transform.doubletransform.DoubleMathOpTransform; -import org.datavec.api.transform.transform.integer.IntegerMathOpTransform; -import org.datavec.api.transform.transform.longtransform.LongMathOpTransform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseTransform.java index 7c5735364..949bc1db1 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/BaseTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform; import lombok.Data; import org.datavec.api.transform.Transform; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java index 236e0cc8e..588a04aad 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToIntegerTransform.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java index 9a43b80fc..6e42d2f7c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/CategoricalToOneHotTransform.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/FirstDigitTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/FirstDigitTransform.java index 8c09737a4..53a2d4a9b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/FirstDigitTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/FirstDigitTransform.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java index 881b88013..eeeecb7b9 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/IntegerToCategoricalTransform.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java index 04b23f1e9..053a3c46b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/PivotTransform.java @@ -21,6 +21,7 @@ package org.datavec.api.transform.transform.categorical; import lombok.Data; +import org.datavec.api.Writable; import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/StringToCategoricalTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/StringToCategoricalTransform.java index 6e3dd7172..248c1a9b5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/StringToCategoricalTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/categorical/StringToCategoricalTransform.java @@ -24,7 +24,7 @@ import lombok.Data; import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/AddConstantColumnTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/AddConstantColumnTransform.java index 48a191f35..884c025cf 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/AddConstantColumnTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/AddConstantColumnTransform.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.Transform; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java index 62f419855..db655e058 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/DuplicateColumnsTransform.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.Transform; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java index 52e13cc8b..c3c10337e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveAllColumnsExceptForTransform.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java index 62de1b280..c1d60fbb5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RemoveColumnsTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RenameColumnsTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RenameColumnsTransform.java index d50e52a70..366e5be5d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RenameColumnsTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/RenameColumnsTransform.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.Transform; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/ReorderColumnsTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/ReorderColumnsTransform.java index 0d0deb76e..b7d26b78b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/ReorderColumnsTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/column/ReorderColumnsTransform.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.Transform; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalCopyValueTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalCopyValueTransform.java index 809a11457..3f28b2b8f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalCopyValueTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalCopyValueTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.Transform; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransform.java index 435eee6e7..ca2101f87 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.Transform; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransformWithDefault.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransformWithDefault.java index e64b84e89..ba22a7b85 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransformWithDefault.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/condition/ConditionalReplaceValueTransformWithDefault.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.ColumnOp; import org.datavec.api.transform.Transform; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/BaseDoubleTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/BaseDoubleTransform.java index 291ae3978..fe4d138be 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/BaseDoubleTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/BaseDoubleTransform.java @@ -26,7 +26,7 @@ import lombok.NoArgsConstructor; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; /** * diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/ConvertToDouble.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/ConvertToDouble.java index 901a8d5a2..4b349faed 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/ConvertToDouble.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/ConvertToDouble.java @@ -25,7 +25,7 @@ import lombok.NoArgsConstructor; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableType; @NoArgsConstructor diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleColumnsMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleColumnsMathOpTransform.java index 596eb737e..f0bf2c085 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleColumnsMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleColumnsMathOpTransform.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathFunctionTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathFunctionTransform.java index e06426643..c435e029a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathFunctionTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathFunctionTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.doubletransform; import lombok.Data; import org.datavec.api.transform.MathFunction; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathOpTransform.java index a2c6e1821..9596799a8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/DoubleMathOpTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/Log2Normalizer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/Log2Normalizer.java index c00a2be58..0d2974e73 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/Log2Normalizer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/Log2Normalizer.java @@ -24,7 +24,7 @@ import lombok.Data; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/MinMaxNormalizer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/MinMaxNormalizer.java index 79a3dc02c..37622a958 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/MinMaxNormalizer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/MinMaxNormalizer.java @@ -24,7 +24,7 @@ import lombok.Data; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/StandardizeNormalizer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/StandardizeNormalizer.java index bb12b6541..4ede89241 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/StandardizeNormalizer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/StandardizeNormalizer.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.transform.doubletransform; import lombok.Data; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/SubtractMeanNormalizer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/SubtractMeanNormalizer.java index 5f2deb32d..066747bd8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/SubtractMeanNormalizer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/doubletransform/SubtractMeanNormalizer.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.transform.doubletransform; import lombok.Data; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/BaseFloatTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/BaseFloatTransform.java index 1c3e3a3ae..b9380d7a5 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/BaseFloatTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/BaseFloatTransform.java @@ -26,7 +26,7 @@ import lombok.NoArgsConstructor; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.FloatMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; /** * diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/ConvertToFloat.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/ConvertToFloat.java index cbb72e4e5..e25d7944c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/ConvertToFloat.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/ConvertToFloat.java @@ -25,7 +25,7 @@ import lombok.NoArgsConstructor; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.FloatMetaData; import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableType; @NoArgsConstructor diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatColumnsMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatColumnsMathOpTransform.java index b45fa9f82..5055b6311 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatColumnsMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatColumnsMathOpTransform.java @@ -26,9 +26,8 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.FloatMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; -import org.datavec.api.transform.transform.floattransform.FloatMathOpTransform; import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatMathFunctionTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatMathFunctionTransform.java index 0054750f5..185e24b00 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatMathFunctionTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatMathFunctionTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.floattransform; import lombok.Data; import org.datavec.api.transform.MathFunction; import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatMathOpTransform.java index a980c289a..2ce2c5b9c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/floattransform/FloatMathOpTransform.java @@ -25,9 +25,8 @@ import org.datavec.api.transform.MathOp; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.FloatMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.transform.transform.floattransform.FloatColumnsMathOpTransform; import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/BaseIntegerTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/BaseIntegerTransform.java index ae1b4eaa4..dc45b5f4a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/BaseIntegerTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/BaseIntegerTransform.java @@ -25,7 +25,7 @@ import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @EqualsAndHashCode(callSuper = true) @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ConvertToInteger.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ConvertToInteger.java index c26e7e2be..14b18e84f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ConvertToInteger.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ConvertToInteger.java @@ -25,7 +25,7 @@ import lombok.NoArgsConstructor; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableType; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerColumnsMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerColumnsMathOpTransform.java index 878123df0..36249c00b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerColumnsMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerColumnsMathOpTransform.java @@ -26,9 +26,8 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; -import org.datavec.api.transform.transform.doubletransform.DoubleColumnsMathOpTransform; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerMathOpTransform.java index 1eac7f2db..1239363ac 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerMathOpTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java index 1bd907723..f8b3e0e99 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/IntegerToOneHotTransform.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceEmptyIntegerWithValueTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceEmptyIntegerWithValueTransform.java index 8e3e44412..9f6a93997 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceEmptyIntegerWithValueTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceEmptyIntegerWithValueTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.integer; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @EqualsAndHashCode(callSuper = true) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceInvalidWithIntegerTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceInvalidWithIntegerTransform.java index cde5b1182..a5e2d534f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceInvalidWithIntegerTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/integer/ReplaceInvalidWithIntegerTransform.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.transform.integer; import lombok.Data; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongColumnsMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongColumnsMathOpTransform.java index cf1211dc7..4a5f2a571 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongColumnsMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongColumnsMathOpTransform.java @@ -26,9 +26,8 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.LongMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; -import org.datavec.api.transform.transform.doubletransform.DoubleColumnsMathOpTransform; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongMathOpTransform.java index 365edefa6..f798c66eb 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/longtransform/LongMathOpTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.LongMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java index 20d1b1c2e..2f25f4a62 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToCharacterIndexTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.sequence.expansion.BaseSequenceExpansionTransform; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java index fa2990e78..725b19152 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/nlp/TextToTermIndexSequenceTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.IntegerMetaData; import org.datavec.api.transform.sequence.expansion.BaseSequenceExpansionTransform; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/parse/ParseDoubleTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/parse/ParseDoubleTransform.java index aa11c3b47..86504d51a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/parse/ParseDoubleTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/parse/ParseDoubleTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java index 4ba0e8968..f08ff5ea6 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceDifferenceTransform.java @@ -21,6 +21,7 @@ package org.datavec.api.transform.transform.sequence; import lombok.Data; +import org.datavec.api.Writable; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.Transform; import org.datavec.api.transform.metadata.*; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceMovingWindowReduceTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceMovingWindowReduceTransform.java index 11895d47f..ecc2677fa 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceMovingWindowReduceTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceMovingWindowReduceTransform.java @@ -31,7 +31,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.reduce.AggregableReductionUtils; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceOffsetTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceOffsetTransform.java index eeba657c5..cb1c537dd 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceOffsetTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/sequence/SequenceOffsetTransform.java @@ -26,7 +26,7 @@ import lombok.Getter; import org.datavec.api.transform.Transform; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/AppendStringColumnTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/AppendStringColumnTransform.java index 2b7aa6de2..5b993d64c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/AppendStringColumnTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/AppendStringColumnTransform.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.StringMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/BaseStringTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/BaseStringTransform.java index 0cacf36be..1bdd13490 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/BaseStringTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/BaseStringTransform.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.StringMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @EqualsAndHashCode(callSuper = true) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ChangeCaseStringTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ChangeCaseStringTransform.java index cf55bddfc..823152466 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ChangeCaseStringTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ChangeCaseStringTransform.java @@ -22,7 +22,7 @@ package org.datavec.api.transform.transform.string; import lombok.Data; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConcatenateStringColumns.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConcatenateStringColumns.java index 6e0ae78fa..f1f705e55 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConcatenateStringColumns.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConcatenateStringColumns.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConvertToString.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConvertToString.java index cd19b025e..aaefe8d78 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConvertToString.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ConvertToString.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.string; import lombok.Data; import lombok.NoArgsConstructor; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @NoArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/MapAllStringsExceptListTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/MapAllStringsExceptListTransform.java index 6bc6e0898..96784240e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/MapAllStringsExceptListTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/MapAllStringsExceptListTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.string; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.HashSet; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/RemoveWhiteSpaceTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/RemoveWhiteSpaceTransform.java index ae16a92a5..f50d3772e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/RemoveWhiteSpaceTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/RemoveWhiteSpaceTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.string; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @EqualsAndHashCode(callSuper = true) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceEmptyStringTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceEmptyStringTransform.java index 1023e7e22..b930abc34 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceEmptyStringTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceEmptyStringTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.string; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; @EqualsAndHashCode(callSuper = true) diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceStringTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceStringTransform.java index a10056749..d386f0686 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceStringTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/ReplaceStringTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.string; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Map; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java index 108da34e7..c6379631f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCategoricalSetTransform.java @@ -28,7 +28,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java index e682dc099..799c8034f 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringListToCountsNDArrayTransform.java @@ -29,7 +29,7 @@ import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseTransform; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringMapTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringMapTransform.java index 7a39aa577..6f78b4ed3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringMapTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/string/StringMapTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.transform.string; import lombok.Data; import lombok.EqualsAndHashCode; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.Map; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java index 425b4cc68..f581ea683 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/DeriveColumnsFromTimeTransform.java @@ -33,7 +33,7 @@ import org.datavec.api.util.jackson.DateTimeFieldTypeDeserializer; import org.datavec.api.util.jackson.DateTimeFieldTypeSerializer; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTime; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/StringToTimeTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/StringToTimeTransform.java index e5141f2c2..286881583 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/StringToTimeTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/StringToTimeTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.TimeMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/TimeMathOpTransform.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/TimeMathOpTransform.java index 1c5e5fb0c..22f4e5390 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/TimeMathOpTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/transform/time/TimeMathOpTransform.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.TimeMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/HtmlSequencePlotting.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/HtmlSequencePlotting.java index f7b40bdbb..dd3931700 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/HtmlSequencePlotting.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/transform/ui/HtmlSequencePlotting.java @@ -31,7 +31,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.ui.components.RenderableComponentLineChart; import org.datavec.api.transform.ui.components.RenderableComponentTable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import org.joda.time.format.DateTimeFormat; import org.joda.time.format.DateTimeFormatter; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/RecordUtils.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/RecordUtils.java index 098aac53f..95efe3e8a 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/RecordUtils.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/RecordUtils.java @@ -22,7 +22,7 @@ package org.datavec.api.util; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java index 360f2aa74..f4b10da10 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java @@ -23,6 +23,7 @@ package org.datavec.api.util.ndarray; import com.google.common.base.Preconditions; import it.unimi.dsi.fastutil.doubles.DoubleArrayList; import lombok.NonNull; +import org.datavec.api.Writable; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.schema.Schema; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/vector/Vectorizer.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/vector/Vectorizer.java index 55c0dba4c..a0687d124 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/vector/Vectorizer.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/vector/Vectorizer.java @@ -21,7 +21,7 @@ package org.datavec.api.vector; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.reader.RecordReader; public interface Vectorizer { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ArrayWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ArrayWritable.java index 59e8289ee..e4fede64d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ArrayWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ArrayWritable.java @@ -20,6 +20,8 @@ package org.datavec.api.writable; +import org.datavec.api.Writable; + public abstract class ArrayWritable implements Writable { public abstract long length(); diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java index 68bf9ebd6..0effeea11 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/ByteWritable.java @@ -22,6 +22,7 @@ package org.datavec.api.writable; import com.google.common.math.DoubleMath; +import org.datavec.api.Writable; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/BytesWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/BytesWritable.java index 1caa52031..44201573b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/BytesWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/BytesWritable.java @@ -23,6 +23,7 @@ package org.datavec.api.writable; import lombok.Getter; import lombok.NoArgsConstructor; import lombok.Setter; +import org.datavec.api.Writable; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java index 8a6ef79ed..5914451d2 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/DoubleWritable.java @@ -22,6 +22,7 @@ package org.datavec.api.writable; import com.google.common.math.DoubleMath; +import org.datavec.api.Writable; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java index 783e77b9a..52aaf6ce9 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/FloatWritable.java @@ -22,6 +22,7 @@ package org.datavec.api.writable; import com.google.common.math.DoubleMath; +import org.datavec.api.Writable; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java index 56739a8f6..94794fe90 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/IntWritable.java @@ -22,6 +22,7 @@ package org.datavec.api.writable; import com.google.common.math.DoubleMath; +import org.datavec.api.Writable; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java index 599bde104..8054816d3 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/LongWritable.java @@ -22,6 +22,7 @@ package org.datavec.api.writable; import com.google.common.math.DoubleMath; +import org.datavec.api.Writable; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java index b80d491d2..33df815dc 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/Text.java @@ -21,6 +21,7 @@ package org.datavec.api.writable; +import org.datavec.api.Writable; import org.datavec.api.io.BinaryComparable; import org.datavec.api.io.WritableComparable; import org.datavec.api.io.WritableComparator; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/UnsafeWritableInjector.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/UnsafeWritableInjector.java index 04ee44bcb..29ac6456d 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/UnsafeWritableInjector.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/UnsafeWritableInjector.java @@ -20,6 +20,7 @@ package org.datavec.api.writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; public class UnsafeWritableInjector { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableFactory.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableFactory.java index b1e542b87..2b7d419ea 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableFactory.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableFactory.java @@ -28,6 +28,7 @@ import java.io.IOException; import java.lang.reflect.Constructor; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import org.datavec.api.Writable; public class WritableFactory { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableType.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableType.java index 3de22f696..3afbe4088 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableType.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableType.java @@ -20,6 +20,8 @@ package org.datavec.api.writable; +import org.datavec.api.Writable; + public enum WritableType { Boolean, Byte, Double, Float, Int, Long, Null, Text, NDArray, Image,Arrow,Bytes; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractTimeSeriesWritableRecordBatch.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractTimeSeriesWritableRecordBatch.java index 25cc863f4..829ab6172 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractTimeSeriesWritableRecordBatch.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractTimeSeriesWritableRecordBatch.java @@ -20,7 +20,7 @@ package org.datavec.api.writable.batch; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.Collection; import java.util.Iterator; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractWritableRecordBatch.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractWritableRecordBatch.java index 715d2a674..920865586 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractWritableRecordBatch.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/AbstractWritableRecordBatch.java @@ -20,7 +20,7 @@ package org.datavec.api.writable.batch; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.Collection; import java.util.Iterator; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java index 0a8d58d0c..ed9485432 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/batch/NDArrayRecordBatch.java @@ -24,7 +24,7 @@ import com.google.common.base.Preconditions; import lombok.Data; import lombok.NonNull; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/Comparators.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/Comparators.java index 03b34ce60..86cd5b040 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/Comparators.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/Comparators.java @@ -20,7 +20,7 @@ package org.datavec.api.writable.comparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableType; import java.util.Comparator; diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/DoubleWritableComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/DoubleWritableComparator.java index c678b7142..d85bad54b 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/DoubleWritableComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/DoubleWritableComparator.java @@ -21,7 +21,7 @@ package org.datavec.api.writable.comparator; import lombok.EqualsAndHashCode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @EqualsAndHashCode public class DoubleWritableComparator implements WritableComparator { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/FloatWritableComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/FloatWritableComparator.java index 6b2eaee97..642af28f7 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/FloatWritableComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/FloatWritableComparator.java @@ -20,7 +20,7 @@ package org.datavec.api.writable.comparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class FloatWritableComparator implements WritableComparator { @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/IntWritableComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/IntWritableComparator.java index 5332507c4..41ccd5ccc 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/IntWritableComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/IntWritableComparator.java @@ -20,7 +20,7 @@ package org.datavec.api.writable.comparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class IntWritableComparator implements WritableComparator { @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/LongWritableComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/LongWritableComparator.java index 9f733337f..1bac5418e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/LongWritableComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/LongWritableComparator.java @@ -21,7 +21,7 @@ package org.datavec.api.writable.comparator; import lombok.EqualsAndHashCode; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @EqualsAndHashCode public class LongWritableComparator implements WritableComparator { diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/TextWritableComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/TextWritableComparator.java index 8fe66e67e..c6c198db4 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/TextWritableComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/TextWritableComparator.java @@ -20,7 +20,7 @@ package org.datavec.api.writable.comparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class TextWritableComparator implements WritableComparator { @Override diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java index 044a27c0e..8cdf2c7e8 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/comparator/WritableComparator.java @@ -20,7 +20,7 @@ package org.datavec.api.writable.comparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonTypeInfo; import java.io.Serializable; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java index bfc531e7c..0aca2243c 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java @@ -25,7 +25,7 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVLineSequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java index e52f71cc2..a52d589f1 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java @@ -25,7 +25,7 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVMultiSequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java index 8096c391c..6125f2b0b 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java @@ -26,7 +26,7 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVNLinesSequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java index 16ed450df..7e2320ed7 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java @@ -21,7 +21,7 @@ package org.datavec.api.records.reader.impl; import org.apache.commons.io.FileUtils; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRegexRecordReader; @@ -33,7 +33,7 @@ import org.datavec.api.split.StringSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java index 4032b912c..c8a411f7d 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java @@ -26,7 +26,7 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.split.NumberedFileInputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java index 315016932..8ca6d1c0a 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java @@ -24,7 +24,7 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java index 809295f6d..d8b97cfeb 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java @@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader; import org.datavec.api.records.reader.impl.filebatch.FileBatchSequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java index 680254125..6b685c084 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java @@ -20,12 +20,12 @@ package org.datavec.api.records.reader.impl; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.split.CollectionInputSplit; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java index 3fcb9e9f5..148580c00 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java @@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.jackson.JacksonLineSequenceRecordRead import org.datavec.api.split.CollectionInputSplit; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java index 08b94fdec..c335cd255 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java @@ -22,7 +22,7 @@ package org.datavec.api.records.reader.impl; import org.apache.commons.io.FileUtils; import org.datavec.api.io.labels.PathLabelGenerator; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.jackson.FieldSelection; @@ -31,7 +31,7 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java index 13560a907..6103eb958 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java @@ -26,7 +26,7 @@ import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java index 73c6053dc..a4e4df160 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java @@ -23,13 +23,13 @@ package org.datavec.api.records.reader.impl; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.apache.commons.io.IOUtils; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputStreamInputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java index 481e3a8ba..d0150e777 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java @@ -20,7 +20,7 @@ package org.datavec.api.records.reader.impl; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataLine; @@ -32,7 +32,7 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java index e091c6945..426961e89 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java @@ -21,12 +21,12 @@ package org.datavec.api.records.reader.impl; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java index c3287ffa6..9e7502d62 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java @@ -25,7 +25,7 @@ import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java index c6de5ebcb..732a8a333 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java @@ -36,7 +36,7 @@ import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java index 3645c034a..eccc07946 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java @@ -28,7 +28,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java index 9b90e9221..fe357aab3 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java @@ -25,7 +25,7 @@ import org.datavec.api.records.writer.impl.csv.CSVRecordWriter; import org.datavec.api.split.FileSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java index f0013e516..ee59836c7 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java @@ -29,7 +29,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java index 48ee43c47..04e0d7134 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java @@ -21,6 +21,7 @@ package org.datavec.api.records.writer.impl; import org.apache.commons.io.FileUtils; +import org.datavec.api.Writable; import org.datavec.api.conf.Configuration; import org.datavec.api.records.reader.impl.misc.SVMLightRecordReader; import org.datavec.api.records.writer.impl.misc.SVMLightRecordWriter; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java index c53099d0f..ba51f1aa7 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java @@ -24,7 +24,7 @@ import org.apache.commons.io.FileUtils; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java index d478c934b..96bfd29ac 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java @@ -28,7 +28,7 @@ import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java index 1f4af7292..46aa072ce 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java @@ -20,6 +20,7 @@ package org.datavec.api.transform.condition; +import org.datavec.api.Writable; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.condition.column.*; import org.datavec.api.transform.condition.sequence.SequenceLengthCondition; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java index 1f937609e..2bb3aefcf 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.condition.column.IntegerColumnCondition; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java index ad056ccef..9e290f756 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java index 6106e37a3..55e8770e8 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java @@ -20,7 +20,7 @@ package org.datavec.api.transform.ops; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java index a8c9aace5..eab573252 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java @@ -20,7 +20,7 @@ package org.datavec.api.transform.ops; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java index b42f75f6f..c447610a7 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java @@ -21,6 +21,7 @@ package org.datavec.api.transform.reduce; import lombok.Getter; +import org.datavec.api.Writable; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.ReduceOp; import org.datavec.api.transform.condition.Condition; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java index c5867e5dd..a833ebdea 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.reduce; import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java index 6add32050..c55765ea4 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java @@ -31,7 +31,7 @@ import org.datavec.api.transform.sequence.window.WindowFunction; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.NullWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java index cae0795c4..5e97fa2e5 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.transform.sequence.split.SequenceSplitTimeSeparation; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java index f1becef3c..d9dc84897 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.sequence.window.TimeWindowFunction; import org.datavec.api.transform.sequence.window.WindowFunction; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomCondition.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomCondition.java index 18b68e119..0156b1627 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomCondition.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomCondition.java @@ -25,7 +25,7 @@ import lombok.Data; import lombok.NoArgsConstructor; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomFilter.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomFilter.java index afe1a9980..476db00a0 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomFilter.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomFilter.java @@ -25,7 +25,7 @@ import lombok.Data; import lombok.NoArgsConstructor; import org.datavec.api.transform.filter.Filter; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomTransform.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomTransform.java index d9a157a06..62567b76a 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomTransform.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/serde/testClasses/CustomTransform.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.serde.testClasses; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; public class CustomTransform extends BaseColumnTransform { diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java index 13730f588..fccfac6b2 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java @@ -23,7 +23,7 @@ package org.datavec.api.transform.stringreduce; import org.datavec.api.transform.StringReduceOp; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index 9a6f8b1d3..f356a8505 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -20,6 +20,7 @@ package org.datavec.api.transform.transform; +import org.datavec.api.Writable; import org.datavec.api.transform.*; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.condition.ConditionOp; @@ -57,7 +58,6 @@ import org.datavec.api.transform.transform.time.TimeMathOpTransform; import org.datavec.api.writable.*; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java index 06c75574f..892a42094 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java @@ -28,7 +28,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java index f6f5ab4b0..8a682a858 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java @@ -22,12 +22,11 @@ package org.datavec.api.transform.transform.parse; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java index c774bc7cb..2d7c7c53e 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java @@ -33,7 +33,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java index 17ffa9ea9..fd1e36c2d 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java @@ -22,7 +22,7 @@ package org.datavec.api.util; import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index 40d74cfcf..96b6bc9ed 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -20,6 +20,7 @@ package org.datavec.api.writable; +import org.datavec.api.Writable; import org.nd4j.common.tests.BaseND4JTest; import com.google.common.collect.Lists; import org.datavec.api.transform.schema.Schema; diff --git a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java index 0d5acbbb4..7b43b94cf 100644 --- a/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java +++ b/cavis-datavec/cavis-datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java @@ -20,6 +20,7 @@ package org.datavec.api.writable; +import org.datavec.api.Writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.junit.jupiter.api.Test; import org.nd4j.common.tests.BaseND4JTest; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java index 9d88cdb1c..76e4d020a 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java @@ -38,6 +38,7 @@ import org.apache.arrow.vector.types.pojo.DictionaryEncoding; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; +import org.datavec.api.Writable; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.metadata.*; import org.datavec.api.transform.schema.Schema; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java index 2d9f5e091..38034a6d4 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecord.java @@ -21,10 +21,10 @@ package org.datavec.arrow.recordreader; import lombok.AllArgsConstructor; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataIndex; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.net.URI; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java index e559da60b..4f82bed2d 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordReader.java @@ -24,7 +24,7 @@ import lombok.Getter; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataIndex; @@ -32,7 +32,7 @@ import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.File; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java index 322de25a5..bb2594e59 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowRecordWriter.java @@ -26,11 +26,10 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.split.partition.Partitioner; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.arrow.ArrowConverter; import java.io.IOException; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java index 3d3eab195..ead790f2e 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordBatch.java @@ -28,7 +28,7 @@ import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.NullWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.batch.AbstractWritableRecordBatch; import org.datavec.arrow.ArrowConverter; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java index 9bc926858..f5ebe9be6 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/main/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatch.java @@ -27,7 +27,7 @@ import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.batch.AbstractTimeSeriesWritableRecordBatch; import java.io.Closeable; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java index 3c52adc80..3df670e1c 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java @@ -32,7 +32,8 @@ import org.apache.arrow.vector.ipc.ArrowFileWriter; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; -import org.datavec.api.records.Record; +import org.datavec.api.Record; +import org.datavec.api.Writable; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataIndex; import org.datavec.api.records.reader.RecordReader; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java index 9666c18d7..2cd949d5c 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java @@ -31,7 +31,7 @@ import org.datavec.api.split.InputSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordWriter; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java index e3c1471fe..26427b212 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.arrow.ArrowConverter; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java index a98b0d1d5..4bfd06607 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java @@ -18,14 +18,14 @@ package org.datavec.audio.recordreader; import org.apache.commons.io.FileUtils; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.split.BaseInputSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.split.InputStreamInputSplit; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.File; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/NativeAudioRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/NativeAudioRecordReader.java index c2a049b9e..f1a083838 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/NativeAudioRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/NativeAudioRecordReader.java @@ -19,7 +19,7 @@ package org.datavec.audio.recordreader; import org.bytedeco.javacv.FFmpegFrameGrabber; import org.bytedeco.javacv.Frame; import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.File; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/WavFileRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/WavFileRecordReader.java index e0fb22bbb..8200ca577 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/WavFileRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/main/java/org/datavec/audio/recordreader/WavFileRecordReader.java @@ -17,7 +17,7 @@ package org.datavec.audio.recordreader; import org.datavec.api.util.RecordUtils; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.audio.Wave; import java.io.File; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/test/java/org/datavec/audio/AudioReaderTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/test/java/org/datavec/audio/AudioReaderTest.java index 126f3566c..0714d33db 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/test/java/org/datavec/audio/AudioReaderTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-audio/src/test/java/org/datavec/audio/AudioReaderTest.java @@ -20,7 +20,7 @@ import org.bytedeco.javacv.FFmpegFrameRecorder; import org.bytedeco.javacv.Frame; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.audio.recordreader.NativeAudioRecordReader; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java index e2d136474..65e080aaa 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java @@ -23,7 +23,7 @@ import org.datavec.api.records.metadata.RecordMetaDataURI; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.FileRecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.File; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java index 475fe932b..d0c5a0767 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java @@ -19,7 +19,7 @@ package org.datavec.codec.reader; import org.apache.commons.compress.utils.IOUtils; import org.datavec.api.conf.Configuration; import org.datavec.api.util.ndarray.RecordConverter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.loader.ImageLoader; import org.jcodec.api.FrameGrab; import org.jcodec.api.JCodecException; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/NativeCodecRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/NativeCodecRecordReader.java index e6e7844ff..300183308 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/NativeCodecRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/main/java/org/datavec/codec/reader/NativeCodecRecordReader.java @@ -21,7 +21,7 @@ import org.bytedeco.javacv.Frame; import org.bytedeco.javacv.OpenCVFrameConverter; import org.datavec.api.conf.Configuration; import org.datavec.api.util.ndarray.RecordConverter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.loader.NativeImageLoader; import java.io.File; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/test/java/org/datavec/codec/reader/CodecReaderTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/test/java/org/datavec/codec/reader/CodecReaderTest.java index ca80949ef..9d3a81d28 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/test/java/org/datavec/codec/reader/CodecReaderTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/src/test/java/org/datavec/codec/reader/CodecReaderTest.java @@ -22,7 +22,7 @@ import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.ArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java index 2ec31f426..87ee8963a 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordReader.java @@ -22,14 +22,14 @@ package org.datavec.poi.excel; import org.apache.poi.ss.usermodel.*; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaDataIndex; import org.datavec.api.records.reader.impl.FileRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.BooleanWritable; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.IOException; import java.io.InputStream; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordWriter.java index 33e691c57..0f65c5c44 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordWriter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/main/java/org/datavec/poi/excel/ExcelRecordWriter.java @@ -26,6 +26,7 @@ import org.apache.poi.ss.usermodel.Row; import org.apache.poi.ss.usermodel.Sheet; import org.apache.poi.ss.usermodel.Workbook; import org.apache.poi.xssf.usermodel.XSSFWorkbook; +import org.datavec.api.Writable; import org.datavec.api.conf.Configuration; import org.datavec.api.records.writer.impl.FileRecordWriter; import org.datavec.api.split.InputSplit; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/test/java/org/datavec/poi/excel/ExcelRecordReaderTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/test/java/org/datavec/poi/excel/ExcelRecordReaderTest.java index 201d257f9..f54e7d378 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/test/java/org/datavec/poi/excel/ExcelRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/test/java/org/datavec/poi/excel/ExcelRecordReaderTest.java @@ -22,7 +22,7 @@ package org.datavec.poi.excel; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java index e977aa5d4..dc4ad0a03 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-excel/src/test/java/org/datavec/poi/excel/ExcelRecordWriterTest.java @@ -25,7 +25,7 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java index 44d6409c2..ff8e7ffd1 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.reduce.AggregableReductionUtils; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Supplier; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java index 9df491dcf..c9aa981d0 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java @@ -7,7 +7,7 @@ import org.datavec.api.transform.metadata.DoubleMetaData; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java index e878619fe..0211eda45 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.StringMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import com.fasterxml.jackson.annotation.JsonProperty; import java.io.File; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java index 14d89576e..d5a3718db 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java @@ -22,7 +22,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.reduce.geo.CoordinatesReduction; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java index d6249b756..a3ebeaef5 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java @@ -25,9 +25,8 @@ import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform; import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.AfterAll; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java index 23909e2bc..5e429dd85 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java @@ -17,13 +17,13 @@ package org.datavec.hadoop.records.reader.mapfile; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataIndex; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.hadoop.records.reader.mapfile.index.LongIndexToKey; import org.datavec.hadoop.records.reader.mapfile.record.RecordWritable; import org.nd4j.common.util.MathUtils; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java index 03f071eae..bf5b7b345 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java @@ -18,14 +18,14 @@ package org.datavec.hadoop.records.reader.mapfile; import lombok.NonNull; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataIndex; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.hadoop.records.reader.mapfile.index.LongIndexToKey; import org.datavec.hadoop.records.reader.mapfile.record.SequenceRecordWritable; import org.nd4j.common.util.MathUtils; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java index 139f28ce9..ad7b4b4a8 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java @@ -35,13 +35,13 @@ import java.util.List; @NoArgsConstructor @Data public class RecordWritable implements Writable { - private List record; + private List record; @Override public void write(DataOutput out) throws IOException { WritableFactory wf = WritableFactory.getInstance(); out.writeInt(record.size()); - for (org.datavec.api.writable.Writable w : record) { + for (org.datavec.api.Writable w : record) { wf.writeWithType(w, out); } } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java index 1511de990..74ef5068c 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java @@ -36,7 +36,7 @@ import java.util.List; @NoArgsConstructor @Data public class SequenceRecordWritable implements Writable { - private List> sequenceRecord; + private List> sequenceRecord; @Override public void write(DataOutput out) throws IOException { @@ -47,12 +47,12 @@ public class SequenceRecordWritable implements Writable { int valuesPerStep = sequenceRecord.get(0).size(); out.writeInt(valuesPerStep); - for (List step : sequenceRecord) { + for (List step : sequenceRecord) { if (step.size() != valuesPerStep) { throw new IllegalStateException( "Number of values per time step vary: " + valuesPerStep + " vs. " + step.size()); } - for (org.datavec.api.writable.Writable w : step) { + for (org.datavec.api.Writable w : step) { wf.writeWithType(w, out); } } @@ -65,10 +65,10 @@ public class SequenceRecordWritable implements Writable { int numSteps = in.readInt(); if (numSteps > 0) { int valuesPerStep = in.readInt(); - List> out = new ArrayList<>(numSteps); + List> out = new ArrayList<>(numSteps); for (int i = 0; i < numSteps; i++) { - List currStep = new ArrayList<>(valuesPerStep); + List currStep = new ArrayList<>(valuesPerStep); for (int j = 0; j < valuesPerStep; j++) { currStep.add(wf.readWithType(in)); } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java index b87db7101..b5dc8b1c5 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.MapFile; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.WritableComparable; +import org.datavec.api.Writable; import org.datavec.api.conf.Configuration; import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.writable.*; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java index bf0479805..2695b8692 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java @@ -22,7 +22,7 @@ import org.datavec.api.records.writer.RecordWriter; import org.datavec.api.split.InputSplit; import org.datavec.api.split.partition.PartitionMetaData; import org.datavec.api.split.partition.Partitioner; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableType; import org.datavec.hadoop.records.reader.mapfile.record.RecordWritable; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java index 878bd4348..0689d385b 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java @@ -18,7 +18,7 @@ package org.datavec.hadoop.records.writer.mapfile; import lombok.NonNull; import org.datavec.api.records.writer.SequenceRecordWriter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableType; import org.datavec.hadoop.records.reader.mapfile.record.SequenceRecordWritable; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java index fa159c36c..b0f5f45ed 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java @@ -174,7 +174,7 @@ public class TestMapFileRecordReader { assertTrue(seqRR.hasNext()); int count = 0; while (seqRR.hasNext()) { - List> l = seqRR.sequenceRecord(); + List> l = seqRR.sequenceRecord(); assertEquals(seqMap.get(new LongWritable(count)).getSequenceRecord(), l); @@ -198,7 +198,7 @@ public class TestMapFileRecordReader { count = 0; while (seqRR.hasNext()) { - List> l = seqRR.sequenceRecord(); + List> l = seqRR.sequenceRecord(); assertEquals(seqMap.get(new LongWritable(expOrder[count])).getSequenceRecord(), l); count++; } @@ -214,7 +214,7 @@ public class TestMapFileRecordReader { assertTrue(rr.hasNext()); int count = 0; while (rr.hasNext()) { - List l = rr.next(); + List l = rr.next(); assertEquals(recordMap.get(new LongWritable(count)).getRecord(), l); @@ -239,7 +239,7 @@ public class TestMapFileRecordReader { count = 0; while (rr.hasNext()) { - List l = rr.next(); + List l = rr.next(); assertEquals(recordMap.get(new LongWritable(expOrder[count])).getRecord(), l); count++; } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java index 81be8ce7c..2d8d5cae1 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java @@ -202,7 +202,7 @@ public class TestMapFileRecordReaderMultipleParts { assertTrue(seqRR.hasNext()); int count = 0; while (seqRR.hasNext()) { - List> l = seqRR.sequenceRecord(); + List> l = seqRR.sequenceRecord(); assertEquals(seqMap.get(new LongWritable(count)).getSequenceRecord(), l); @@ -233,7 +233,7 @@ public class TestMapFileRecordReaderMultipleParts { count = 0; while (seqRR.hasNext()) { - List> l = seqRR.sequenceRecord(); + List> l = seqRR.sequenceRecord(); assertEquals(seqMap.get(new LongWritable(expOrder[count])).getSequenceRecord(), l); count++; } @@ -265,7 +265,7 @@ public class TestMapFileRecordReaderMultipleParts { assertTrue(rr.hasNext()); int count = 0; while (rr.hasNext()) { - List l = rr.next(); + List l = rr.next(); assertEquals(recordMap.get(new LongWritable(count)).getRecord(), l); count++; } @@ -290,7 +290,7 @@ public class TestMapFileRecordReaderMultipleParts { count = 0; while (rr.hasNext()) { - List l = rr.next(); + List l = rr.next(); assertEquals(recordMap.get(new LongWritable(expOrder[count])).getRecord(), l); count++; } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java index 1a3999e05..b6cf24ce7 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java @@ -213,7 +213,7 @@ public class TestMapFileRecordReaderMultiplePartsSomeEmpty { assertTrue(seqRR.hasNext()); int count = 0; while (seqRR.hasNext()) { - List> l = seqRR.sequenceRecord(); + List> l = seqRR.sequenceRecord(); assertEquals(seqMap.get(new LongWritable(count)).getSequenceRecord(), l); @@ -244,7 +244,7 @@ public class TestMapFileRecordReaderMultiplePartsSomeEmpty { count = 0; while (seqRR.hasNext()) { - List> l = seqRR.sequenceRecord(); + List> l = seqRR.sequenceRecord(); assertEquals(seqMap.get(new LongWritable(expOrder[count])).getSequenceRecord(), l); count++; } @@ -276,7 +276,7 @@ public class TestMapFileRecordReaderMultiplePartsSomeEmpty { assertTrue(rr.hasNext()); int count = 0; while (rr.hasNext()) { - List l = rr.next(); + List l = rr.next(); assertEquals(recordMap.get(new LongWritable(count)).getRecord(), l); count++; } @@ -301,7 +301,7 @@ public class TestMapFileRecordReaderMultiplePartsSomeEmpty { count = 0; while (rr.hasNext()) { - List l = rr.next(); + List l = rr.next(); assertEquals(recordMap.get(new LongWritable(expOrder[count])).getRecord(), l); count++; } diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java index 71dd9d7a6..72deed5f4 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java @@ -26,7 +26,7 @@ import org.datavec.api.records.writer.RecordWriter; import org.datavec.api.records.writer.SequenceRecordWriter; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableType; import org.datavec.hadoop.records.reader.mapfile.MapFileRecordReader; import org.datavec.hadoop.records.reader.mapfile.MapFileSequenceRecordReader; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/data/ImageWritable.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/data/ImageWritable.java index 654d6ead3..40ea1b90a 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/data/ImageWritable.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/data/ImageWritable.java @@ -21,8 +21,7 @@ package org.datavec.image.data; import org.bytedeco.javacv.Frame; -import org.bytedeco.javacv.FrameConverter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.WritableFactory; import org.datavec.api.writable.WritableType; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java index 4a2426ac4..55705af6b 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/BaseImageRecordReader.java @@ -27,7 +27,7 @@ import lombok.extern.slf4j.Slf4j; import org.datavec.api.conf.Configuration; import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.io.labels.PathMultiLabelGenerator; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataURI; import org.datavec.api.records.reader.BaseRecordReader; @@ -39,7 +39,7 @@ import org.datavec.api.util.files.URIUtil; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.image.loader.BaseImageLoader; import org.datavec.image.loader.ImageLoader; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java index c7e8657cb..51bcf04a0 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/recordreader/objdetect/ObjectDetectionRecordReader.java @@ -24,7 +24,7 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.util.files.FileFromPathIterator; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.datavec.image.data.Image; import org.datavec.image.loader.NativeImageLoader; @@ -35,7 +35,7 @@ import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaDataImageURI; import org.datavec.api.util.files.URIUtil; import org.datavec.api.util.ndarray.RecordConverter; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransformProcess.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransformProcess.java index 43f90a502..8894b971b 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransformProcess.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/main/java/org/datavec/image/transform/ImageTransformProcess.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.datavec.api.transform.serde.JsonMappers; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.data.ImageWritable; import org.datavec.image.loader.NativeImageLoader; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/FileBatchRecordReaderTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/FileBatchRecordReaderTest.java index 8b740de99..1e26d3388 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/FileBatchRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/FileBatchRecordReaderTest.java @@ -26,7 +26,7 @@ import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.records.reader.impl.filebatch.FileBatchRecordReader; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.loader.NativeImageLoader; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java index 44f4a31ee..11146ef13 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestImageRecordReader.java @@ -23,7 +23,7 @@ package org.datavec.image.recordreader; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.io.labels.PathLabelGenerator; import org.datavec.api.io.labels.PathMultiLabelGenerator; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.listener.impl.LogRecordListener; import org.datavec.api.records.metadata.RecordMetaData; @@ -34,7 +34,7 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java index 483be1090..669eb83f1 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-image/src/test/java/org/datavec/image/recordreader/TestObjectDetectionRecordReader.java @@ -20,14 +20,14 @@ package org.datavec.image.recordreader; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataImageURI; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.CollectionInputSplit; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.recordreader.objdetect.ImageObject; import org.datavec.image.recordreader.objdetect.ImageObjectLabelProvider; import org.datavec.image.recordreader.objdetect.ObjectDetectionRecordReader; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/reader/TfidfRecordReader.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/reader/TfidfRecordReader.java index 8401b640d..1c2762d01 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/reader/TfidfRecordReader.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/reader/TfidfRecordReader.java @@ -17,14 +17,14 @@ package org.datavec.nlp.reader; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataURI; import org.datavec.api.records.reader.impl.FileRecordReader; import org.datavec.api.split.InputSplit; import org.datavec.api.vector.Vectorizer; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.nlp.vectorizer.TfidfVectorizer; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BagOfWordsTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BagOfWordsTransform.java index 1a1a5eaf2..5017a6f40 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BagOfWordsTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BagOfWordsTransform.java @@ -17,7 +17,7 @@ package org.datavec.nlp.transforms; import org.datavec.api.transform.Transform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import com.fasterxml.jackson.annotation.JsonTypeInfo; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java index d99f629dc..5df210e0f 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java @@ -22,7 +22,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java index 8f8c0deb0..9de2621c5 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java @@ -21,7 +21,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.metadata.NDArrayMetaData; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.list.NDArrayList; import com.fasterxml.jackson.annotation.JsonCreator; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransform.java index 29cdf9153..6d643b1ec 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransform.java @@ -24,7 +24,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.BaseColumnTransform; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; import org.datavec.nlp.tokenization.tokenizer.Tokenizer; import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/AbstractTfidfVectorizer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/AbstractTfidfVectorizer.java index 228774142..c8048f217 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/AbstractTfidfVectorizer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/AbstractTfidfVectorizer.java @@ -17,7 +17,7 @@ package org.datavec.nlp.vectorizer; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.reader.RecordReader; import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; import org.datavec.nlp.tokenization.tokenizer.Tokenizer; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TextVectorizer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TextVectorizer.java index dcf588ce9..cdcffb090 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TextVectorizer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TextVectorizer.java @@ -19,10 +19,10 @@ package org.datavec.nlp.vectorizer; import lombok.Getter; import org.nd4j.common.primitives.Counter; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.vector.Vectorizer; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.nlp.metadata.DefaultVocabCache; import org.datavec.nlp.metadata.VocabCache; import org.datavec.nlp.stopwords.StopWords; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TfidfVectorizer.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TfidfVectorizer.java index a730bc739..0b79eacc9 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TfidfVectorizer.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TfidfVectorizer.java @@ -19,7 +19,7 @@ package org.datavec.nlp.vectorizer; import org.datavec.api.conf.Configuration; import org.nd4j.common.primitives.Counter; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaDataURI; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.writable.NDArrayWritable; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/reader/TfidfRecordReaderTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/reader/TfidfRecordReaderTest.java index b3dba2b96..4b21f609c 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/reader/TfidfRecordReaderTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/reader/TfidfRecordReaderTest.java @@ -17,12 +17,12 @@ package org.datavec.nlp.reader; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.CollectionInputSplit; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.nlp.vectorizer.TfidfVectorizer; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java index 7bfbe4eb0..40d0fa5fe 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java @@ -20,7 +20,7 @@ import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.LocalTransformExecutor; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java index ebb5c52c7..803c2b0e5 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java @@ -20,7 +20,7 @@ import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.LocalTransformExecutor; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java index dfa2e228a..4a3ddc2ae 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java @@ -22,7 +22,7 @@ import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.LocalTransformExecutor; import org.datavec.nlp.metadata.VocabCache; import org.datavec.nlp.tokenization.tokenizer.preprocessor.LowerCasePreProcessor; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/AnalyzeLocal.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/AnalyzeLocal.java index 296c19373..857130a66 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/AnalyzeLocal.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/AnalyzeLocal.java @@ -33,7 +33,7 @@ import org.datavec.api.transform.analysis.quality.QualityAnalysisState; import org.datavec.api.transform.quality.DataQualityAnalysis; import org.datavec.api.transform.quality.columns.ColumnQuality; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.analysis.aggregate.AnalysisAddFunction; import org.datavec.local.transforms.analysis.histogram.HistogramAddFunction; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java index 2f7328954..266d0de2d 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformExecutor.java @@ -27,6 +27,7 @@ import lombok.val; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; +import org.datavec.api.Writable; import org.datavec.api.transform.DataAction; import org.datavec.api.transform.Transform; import org.datavec.api.transform.TransformProcess; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformProcessSequenceRecordReader.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformProcessSequenceRecordReader.java index 8fc5fccae..e991aed6f 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformProcessSequenceRecordReader.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/LocalTransformProcessSequenceRecordReader.java @@ -23,9 +23,8 @@ package org.datavec.local.transforms; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.transform.TransformProcessSequenceRecordReader; import org.datavec.api.transform.TransformProcess; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/SequenceEmptyRecordFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/SequenceEmptyRecordFunction.java index 3dea4cd4e..33a79aa52 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/SequenceEmptyRecordFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/SequenceEmptyRecordFunction.java @@ -20,7 +20,7 @@ package org.datavec.local.transforms; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/analysis/aggregate/AnalysisAddFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/analysis/aggregate/AnalysisAddFunction.java index 6e98f21b8..cbddad0e1 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/analysis/aggregate/AnalysisAddFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/analysis/aggregate/AnalysisAddFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.analysis.AnalysisCounter; import org.datavec.api.transform.analysis.counter.*; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/analysis/histogram/HistogramAddFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/analysis/histogram/HistogramAddFunction.java index 99e323c5a..8f4eda4a0 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/analysis/histogram/HistogramAddFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/analysis/histogram/HistogramAddFunction.java @@ -25,7 +25,7 @@ import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.analysis.histogram.*; import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.BiFunction; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/EmptyRecordFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/EmptyRecordFunction.java index df68dce5e..0b6a12d88 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/EmptyRecordFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/EmptyRecordFunction.java @@ -20,7 +20,7 @@ package org.datavec.local.transforms.functions; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/LineRecordReaderFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/LineRecordReaderFunction.java index 042ac0cd0..01a3ffae5 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/LineRecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/LineRecordReaderFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.functions; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.StringSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/RecordReaderFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/RecordReaderFunction.java index f3b1198f9..7519dda8e 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/RecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/RecordReaderFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.functions; import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/SequenceRecordReaderFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/SequenceRecordReaderFunction.java index 55df2c6b2..dbbde0ab0 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/SequenceRecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/SequenceRecordReaderFunction.java @@ -23,7 +23,7 @@ package org.datavec.local.transforms.functions; import lombok.extern.slf4j.Slf4j; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/data/RecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/data/RecordReaderBytesFunction.java index 669a4571f..99cd4f790 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/data/RecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/data/RecordReaderBytesFunction.java @@ -23,7 +23,7 @@ package org.datavec.local.transforms.functions.data; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.writable.BytesWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/data/SequenceRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/data/SequenceRecordReaderBytesFunction.java index 80b8d6233..86d7f9683 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/data/SequenceRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/functions/data/SequenceRecordReaderBytesFunction.java @@ -24,7 +24,7 @@ package org.datavec.local.transforms.functions.data; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.writable.BytesWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunction.java index 4d2740194..f50584b75 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.join; import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.BaseFlatMapFunctionAdaptee; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java index 58ed6d10e..682fd2e90 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExecuteJoinFromCoGroupFlatMapFunctionAdapter.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.join; import com.google.common.collect.Iterables; import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.functions.FlatMapFunctionAdapter; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExtractKeysFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExtractKeysFunction.java index 28072cb98..ae998b9e8 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExtractKeysFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/ExtractKeysFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.join; import lombok.AllArgsConstructor; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/FilterAndFlattenJoinedValues.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/FilterAndFlattenJoinedValues.java index e605c8f58..ee3329792 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/FilterAndFlattenJoinedValues.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/FilterAndFlattenJoinedValues.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.join; import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.BaseFlatMapFunctionAdaptee; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/FilterAndFlattenJoinedValuesAdapter.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/FilterAndFlattenJoinedValuesAdapter.java index db28f431a..dc1c6bcf2 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/FilterAndFlattenJoinedValuesAdapter.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/FilterAndFlattenJoinedValuesAdapter.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.join; import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.functions.FlatMapFunctionAdapter; import java.util.Collections; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/JoinedValue.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/JoinedValue.java index 4a1c248d8..9584d2e4a 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/JoinedValue.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/join/JoinedValue.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.join; import lombok.AllArgsConstructor; import lombok.Data; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/ColumnAsKeyPairFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/ColumnAsKeyPairFunction.java index 06ce578c5..d751dd479 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/ColumnAsKeyPairFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/ColumnAsKeyPairFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.misc; import lombok.AllArgsConstructor; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/ColumnToKeyPairTransform.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/ColumnToKeyPairTransform.java index e5ca17781..f626eb043 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/ColumnToKeyPairTransform.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/ColumnToKeyPairTransform.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.misc; import lombok.AllArgsConstructor; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/NDArrayToWritablesFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/NDArrayToWritablesFunction.java index eda700c50..a5d9773cb 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/NDArrayToWritablesFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/NDArrayToWritablesFunction.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.NoArgsConstructor; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.function.Function; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceMergeFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceMergeFunction.java index e0fdf697e..4924764f1 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceMergeFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceMergeFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.misc; import org.datavec.api.transform.sequence.merge.SequenceMerge; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceWritablesToStringFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceWritablesToStringFunction.java index 57f2b75dd..594b72813 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceWritablesToStringFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/SequenceWritablesToStringFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.misc; import lombok.AllArgsConstructor; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/StringToWritablesFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/StringToWritablesFunction.java index 0270be039..a5448d58c 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/StringToWritablesFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/StringToWritablesFunction.java @@ -23,7 +23,7 @@ package org.datavec.local.transforms.misc; import lombok.AllArgsConstructor; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.StringSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.io.IOException; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/WritablesToNDArrayFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/WritablesToNDArrayFunction.java index 896716807..88a729232 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/WritablesToNDArrayFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/WritablesToNDArrayFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.misc; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.function.Function; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/WritablesToStringFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/WritablesToStringFunction.java index 1ec8333fe..ff43502f7 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/WritablesToStringFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/misc/WritablesToStringFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.misc; import lombok.AllArgsConstructor; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/rank/UnzipForCalculateSortedRankFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/rank/UnzipForCalculateSortedRankFunction.java index d7024f40e..88c981c79 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/rank/UnzipForCalculateSortedRankFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/rank/UnzipForCalculateSortedRankFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.rank; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/reduce/MapToPairForReducerFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/reduce/MapToPairForReducerFunction.java index 1509f4d48..5cdfd5ca0 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/reduce/MapToPairForReducerFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/reduce/MapToPairForReducerFunction.java @@ -23,7 +23,7 @@ package org.datavec.local.transforms.reduce; import lombok.AllArgsConstructor; import org.datavec.api.transform.reduce.IAssociativeReducer; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/reduce/ReducerFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/reduce/ReducerFunction.java index a74406e1e..dde417e11 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/reduce/ReducerFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/reduce/ReducerFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.reduce; import lombok.AllArgsConstructor; import org.datavec.api.transform.reduce.IAssociativeReducer; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/ConvertToSequenceLengthOne.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/ConvertToSequenceLengthOne.java index 14c0cb14f..3a3d2656a 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/ConvertToSequenceLengthOne.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/ConvertToSequenceLengthOne.java @@ -20,7 +20,7 @@ package org.datavec.local.transforms.sequence; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.Collections; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalGroupToSequenceFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalGroupToSequenceFunction.java index 963a340ef..4519b8d6f 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalGroupToSequenceFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalGroupToSequenceFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.sequence; import lombok.AllArgsConstructor; import org.datavec.api.transform.sequence.SequenceComparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalMapToPairByColumnFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalMapToPairByColumnFunction.java index 2b803ce5e..bc0221253 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalMapToPairByColumnFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalMapToPairByColumnFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.sequence; import lombok.AllArgsConstructor; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalMapToPairByMultipleColumnsFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalMapToPairByMultipleColumnsFunction.java index 2d5c17e5f..6d41b0111 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalMapToPairByMultipleColumnsFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalMapToPairByMultipleColumnsFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.sequence; import lombok.AllArgsConstructor; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import org.nd4j.common.primitives.Pair; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalSequenceFilterFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalSequenceFilterFunction.java index 8e289d0cb..f785d5d82 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalSequenceFilterFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalSequenceFilterFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.sequence; import lombok.AllArgsConstructor; import org.datavec.api.transform.filter.Filter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalSequenceTransformFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalSequenceTransformFunction.java index c7e201d07..a9eab41e6 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalSequenceTransformFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/sequence/LocalSequenceTransformFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.sequence; import lombok.AllArgsConstructor; import org.datavec.api.transform.Transform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformFunction.java index b82cc593a..78cef5fae 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformFunction.java @@ -23,7 +23,7 @@ package org.datavec.local.transforms.transform; import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.datavec.api.transform.Transform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.LocalTransformExecutor; import org.nd4j.common.function.Function; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformProcessFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformProcessFunction.java index 0f994965c..d37b79087 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformProcessFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformProcessFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.transform; import org.datavec.api.transform.TransformProcess; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.BaseFlatMapFunctionAdaptee; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformProcessFunctionAdapter.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformProcessFunctionAdapter.java index 48ea84418..b3622cfa0 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformProcessFunctionAdapter.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/LocalTransformProcessFunctionAdapter.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.transform; import org.datavec.api.transform.TransformProcess; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.functions.FlatMapFunctionAdapter; import java.util.Collections; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/SequenceSplitFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/SequenceSplitFunction.java index e6c3c3dc6..3d86327dd 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/SequenceSplitFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/SequenceSplitFunction.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.transform; import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.BaseFlatMapFunctionAdaptee; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/SequenceSplitFunctionAdapter.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/SequenceSplitFunctionAdapter.java index 52f9c9fbf..003044b5a 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/SequenceSplitFunctionAdapter.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/SequenceSplitFunctionAdapter.java @@ -21,7 +21,7 @@ package org.datavec.local.transforms.transform; import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.functions.FlatMapFunctionAdapter; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/filter/FilterWritablesBySchemaFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/filter/FilterWritablesBySchemaFunction.java index 9eb828143..3074b5ec2 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/filter/FilterWritablesBySchemaFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/filter/FilterWritablesBySchemaFunction.java @@ -23,7 +23,7 @@ package org.datavec.local.transforms.transform.filter; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; public class FilterWritablesBySchemaFunction implements Function { diff --git a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/filter/LocalFilterFunction.java b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/filter/LocalFilterFunction.java index 9c722d24d..9483a2356 100644 --- a/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/filter/LocalFilterFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/main/java/org/datavec/local/transforms/transform/filter/LocalFilterFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.transform.filter; import lombok.AllArgsConstructor; import org.datavec.api.transform.filter.Filter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.function.Function; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java index 25dc7b738..973650bf6 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/LocalTransformProcessRecordReaderTests.java @@ -20,7 +20,7 @@ package org.datavec.local.transforms; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; @@ -34,7 +34,7 @@ import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.joda.time.DateTimeZone; import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java index fcab5efb5..3b2e7cb8a 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java @@ -27,7 +27,7 @@ import org.datavec.api.transform.analysis.DataAnalysis; import org.datavec.api.transform.analysis.columns.NumericalColumnAnalysis; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.AnalyzeLocal; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestLineRecordReaderFunction.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestLineRecordReaderFunction.java index 11d4672b1..ccf6e490c 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestLineRecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestLineRecordReaderFunction.java @@ -25,7 +25,7 @@ import org.apache.commons.io.FileUtils; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java index 95b6ebfab..c57d4f609 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestNDArrayToWritablesFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.functions; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.misc.NDArrayToWritablesFunction; import org.junit.jupiter.api.Test; @@ -30,7 +30,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToNDArrayFunction.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToNDArrayFunction.java index 1cc2943f8..cae352f62 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToNDArrayFunction.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToNDArrayFunction.java @@ -22,7 +22,7 @@ package org.datavec.local.transforms.functions; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.misc.WritablesToNDArrayFunction; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java index 1086866f2..a8c9d026a 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/functions/TestWritablesToStringFunctions.java @@ -25,7 +25,7 @@ package org.datavec.local.transforms.functions; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.local.transforms.misc.SequenceWritablesToStringFunction; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java index 94aa8eeaf..de390f945 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/ExecutionTest.java @@ -21,6 +21,7 @@ package org.datavec.local.transforms.transform; +import org.datavec.api.Writable; import org.datavec.api.transform.MathFunction; import org.datavec.api.transform.MathOp; import org.datavec.api.transform.ReduceOp; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java index 3cca330af..b7bdb1d3c 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestGeoTransforms.java @@ -29,7 +29,7 @@ import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform; import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java index 1dd62a88e..bc6e882d1 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java @@ -20,6 +20,7 @@ package org.datavec.local.transforms.transform; +import org.datavec.api.Writable; import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.filter.ConditionFilter; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java index b7fc564c7..946106c17 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/join/TestJoin.java @@ -21,6 +21,7 @@ package org.datavec.local.transforms.transform.join; +import org.datavec.api.Writable; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.join.Join; import org.datavec.api.transform.schema.Schema; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java index a14c5f468..ca0f7bea5 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/rank/TestCalculateSortedRank.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.comparator.DoubleWritableComparator; diff --git a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java index bd3ace8c8..8774d328e 100644 --- a/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java +++ b/cavis-datavec/cavis-datavec-local/src/test/java/org/datavec/local/transforms/transform/sequence/TestConvertToSequence.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonCondition.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonCondition.java index 62370246f..bec018f74 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonCondition.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonCondition.java @@ -16,6 +16,7 @@ package org.datavec.python; +import org.datavec.api.Writable; import org.datavec.api.transform.condition.Condition; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; diff --git a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonTransform.java b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonTransform.java index 4395078d3..1d76f827e 100644 --- a/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonTransform.java +++ b/cavis-datavec/cavis-datavec-python/src/main/java/org/datavec/python/PythonTransform.java @@ -19,6 +19,7 @@ package org.datavec.python; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import org.datavec.api.Writable; import org.datavec.api.transform.Transform; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; @@ -26,7 +27,6 @@ import org.nd4j.common.base.Preconditions; import org.nd4j.common.holder.ObjectMapperHolder; import org.nd4j.linalg.api.ndarray.INDArray; import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import java.util.ArrayList; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/SequenceEmptyRecordFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/SequenceEmptyRecordFunction.java index 762b9dc0c..9616785dd 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/SequenceEmptyRecordFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/SequenceEmptyRecordFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/EmptyRecordFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/EmptyRecordFunction.java index 01e39d677..7c016fa19 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/EmptyRecordFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/EmptyRecordFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark.functions; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/LineRecordReaderFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/LineRecordReaderFunction.java index 56b5f2b63..16b34a503 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/LineRecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/LineRecordReaderFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.functions; import org.apache.spark.api.java.function.Function; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.StringSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/RecordReaderFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/RecordReaderFunction.java index c4126d170..4dbafbf74 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/RecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/RecordReaderFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.functions; import org.apache.spark.api.java.function.Function; import org.apache.spark.input.PortableDataStream; import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.io.DataInputStream; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/SequenceRecordReaderFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/SequenceRecordReaderFunction.java index 5bf6f4501..8361015e9 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/SequenceRecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/SequenceRecordReaderFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.functions; import org.apache.spark.api.java.function.Function; import org.apache.spark.input.PortableDataStream; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.io.DataInputStream; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/data/RecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/data/RecordReaderBytesFunction.java index 367666610..cf42c81a7 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/data/RecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/data/RecordReaderBytesFunction.java @@ -24,7 +24,7 @@ import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Text; import org.apache.spark.api.java.function.Function; import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.io.ByteArrayInputStream; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/data/SequenceRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/data/SequenceRecordReaderBytesFunction.java index 1557cd3a9..ef99cc08b 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/data/SequenceRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/data/SequenceRecordReaderBytesFunction.java @@ -24,7 +24,7 @@ import org.apache.hadoop.io.BytesWritable; import org.apache.hadoop.io.Text; import org.apache.spark.api.java.function.Function; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.io.ByteArrayInputStream; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/pairdata/PairSequenceRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/pairdata/PairSequenceRecordReaderBytesFunction.java index 45cf7caad..ea4f2e7fc 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/pairdata/PairSequenceRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/functions/pairdata/PairSequenceRecordReaderBytesFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.functions.pairdata; import org.apache.hadoop.io.Text; import org.apache.spark.api.java.function.Function; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.io.ByteArrayInputStream; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/SparkStorageUtils.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/SparkStorageUtils.java index 323012432..b905892d0 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/SparkStorageUtils.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/SparkStorageUtils.java @@ -30,7 +30,7 @@ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.hadoop.records.reader.mapfile.record.RecordWritable; import org.datavec.hadoop.records.reader.mapfile.record.SequenceRecordWritable; import org.datavec.spark.storage.functions.RecordLoadPairFunction; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/RecordLoadPairFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/RecordLoadPairFunction.java index 192c0e7d0..87af51220 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/RecordLoadPairFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/RecordLoadPairFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.storage.functions; import org.apache.hadoop.io.LongWritable; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.hadoop.records.reader.mapfile.record.RecordWritable; import scala.Tuple2; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/RecordSavePrepPairFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/RecordSavePrepPairFunction.java index 048f6b191..d3ab47cba 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/RecordSavePrepPairFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/RecordSavePrepPairFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.storage.functions; import org.apache.hadoop.io.LongWritable; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.hadoop.records.reader.mapfile.record.RecordWritable; import scala.Tuple2; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/SequenceRecordLoadPairFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/SequenceRecordLoadPairFunction.java index a8296cd6e..4d149eabf 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/SequenceRecordLoadPairFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/SequenceRecordLoadPairFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.storage.functions; import org.apache.hadoop.io.LongWritable; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.hadoop.records.reader.mapfile.record.SequenceRecordWritable; import scala.Tuple2; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/SequenceRecordSavePrepPairFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/SequenceRecordSavePrepPairFunction.java index 072beb5de..bc895ee20 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/SequenceRecordSavePrepPairFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/storage/functions/SequenceRecordSavePrepPairFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.storage.functions; import org.apache.hadoop.io.LongWritable; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.hadoop.records.reader.mapfile.record.SequenceRecordWritable; import scala.Tuple2; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/AnalyzeSpark.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/AnalyzeSpark.java index 1af4b9d4b..26bc1d5da 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/AnalyzeSpark.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/AnalyzeSpark.java @@ -38,7 +38,7 @@ import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.transform.quality.DataQualityAnalysis; import org.datavec.api.transform.quality.columns.ColumnQuality; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.comparator.Comparators; import org.datavec.spark.transform.analysis.SelectColumnFunction; import org.datavec.spark.transform.analysis.SequenceFlatMapFunction; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/DataFrames.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/DataFrames.java index 2a9fdb9fb..2f4e842f3 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/DataFrames.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/DataFrames.java @@ -30,6 +30,7 @@ import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.datavec.api.Writable; import org.nd4j.common.primitives.Pair; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/Normalization.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/Normalization.java index de5511017..22d9eaa56 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/Normalization.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/Normalization.java @@ -26,7 +26,7 @@ import org.apache.spark.sql.Column; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.*; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/SparkTransformExecutor.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/SparkTransformExecutor.java index 9595c5a7b..c4c57ad37 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/SparkTransformExecutor.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/SparkTransformExecutor.java @@ -37,7 +37,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.schema.SequenceSchema; import org.datavec.api.transform.sequence.ConvertToSequence; import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.SequenceEmptyRecordFunction; import org.datavec.spark.functions.EmptyRecordFunction; import org.datavec.spark.transform.analysis.SequenceFlatMapFunction; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/CategoricalToPairFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/CategoricalToPairFunction.java index d8de15c26..40b14e235 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/CategoricalToPairFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/CategoricalToPairFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark.transform.analysis; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; public class CategoricalToPairFunction implements PairFunction { diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SelectColumnFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SelectColumnFunction.java index b73d50e11..0a9db59d6 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SelectColumnFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SelectColumnFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.analysis; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java index 2685bc7ad..21f89c1d4 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SequenceFlatMapFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark.transform.analysis; import org.apache.spark.api.java.function.FlatMapFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.Iterator; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SequenceLengthFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SequenceLengthFunction.java index 7bd0acfa7..6839b432b 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SequenceLengthFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/SequenceLengthFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark.transform.analysis; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/StringLengthFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/StringLengthFunction.java index 5f3a8a17d..2ccf4f1e2 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/StringLengthFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/StringLengthFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark.transform.analysis; import org.apache.spark.api.java.function.DoubleFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class StringLengthFunction implements DoubleFunction { @Override diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/WritableToDoubleFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/WritableToDoubleFunction.java index ebfad6948..e20805cc7 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/WritableToDoubleFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/WritableToDoubleFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark.transform.analysis; import org.apache.spark.api.java.function.DoubleFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class WritableToDoubleFunction implements DoubleFunction { diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/WritableToStringFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/WritableToStringFunction.java index 6c793b41b..3ca802659 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/WritableToStringFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/WritableToStringFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark.transform.analysis; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class WritableToStringFunction implements Function { @Override diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/aggregate/AnalysisAddFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/aggregate/AnalysisAddFunction.java index 7975b8009..12d5305cf 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/aggregate/AnalysisAddFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/aggregate/AnalysisAddFunction.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.analysis.AnalysisCounter; import org.datavec.api.transform.analysis.counter.*; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/histogram/HistogramAddFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/histogram/HistogramAddFunction.java index 46a67f138..d7b2fc3e5 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/histogram/HistogramAddFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/histogram/HistogramAddFunction.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.analysis.histogram.*; import org.datavec.api.transform.metadata.CategoricalMetaData; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/seqlength/SequenceLengthAnalysisCounter.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/seqlength/SequenceLengthAnalysisCounter.java index 6a54521ea..249df87a6 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/seqlength/SequenceLengthAnalysisCounter.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/seqlength/SequenceLengthAnalysisCounter.java @@ -23,7 +23,7 @@ package org.datavec.spark.transform.analysis.seqlength; import lombok.AllArgsConstructor; import lombok.Data; import org.datavec.api.transform.analysis.AnalysisCounter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; @AllArgsConstructor @Data diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/unique/UniqueAddFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/unique/UniqueAddFunction.java index 35521ba04..aa080db0a 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/unique/UniqueAddFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/unique/UniqueAddFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.transform.analysis.unique; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function2; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.*; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/unique/UniqueMergeFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/unique/UniqueMergeFunction.java index 28451c335..e8e8a1b79 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/unique/UniqueMergeFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/analysis/unique/UniqueMergeFunction.java @@ -21,7 +21,7 @@ package org.datavec.spark.transform.analysis.unique; import org.apache.spark.api.java.function.Function2; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.Map; import java.util.Set; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/filter/FilterWritablesBySchemaFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/filter/FilterWritablesBySchemaFunction.java index 358f4ccd0..799cc22c0 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/filter/FilterWritablesBySchemaFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/filter/FilterWritablesBySchemaFunction.java @@ -24,7 +24,7 @@ import org.apache.spark.api.java.function.Function; import org.datavec.api.transform.metadata.ColumnMetaData; import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; public class FilterWritablesBySchemaFunction implements Function { diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/filter/SparkFilterFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/filter/SparkFilterFunction.java index ae23a3811..6f5b7c739 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/filter/SparkFilterFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/filter/SparkFilterFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.transform.filter; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.datavec.api.transform.filter.Filter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java index 0edd54fb8..7aaf9a158 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/ExecuteJoinFromCoGroupFlatMapFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.transform.join; import com.google.common.collect.Iterables; import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/ExtractKeysFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/ExtractKeysFunction.java index 98bc8cab6..d916cdc2d 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/ExtractKeysFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/ExtractKeysFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.join; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java index 58707bcef..0983772a5 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/FilterAndFlattenJoinedValues.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.join; import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.join.Join; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.Collections; import java.util.Iterator; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/JoinedValue.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/JoinedValue.java index 13891360c..fda664e1b 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/JoinedValue.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/join/JoinedValue.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.join; import lombok.AllArgsConstructor; import lombok.Data; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/ColumnAsKeyPairFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/ColumnAsKeyPairFunction.java index 268c5bd37..e3e50420a 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/ColumnAsKeyPairFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/ColumnAsKeyPairFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.misc; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/ColumnToKeyPairTransform.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/ColumnToKeyPairTransform.java index 4f043ab0f..09346a9d5 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/ColumnToKeyPairTransform.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/ColumnToKeyPairTransform.java @@ -21,10 +21,8 @@ package org.datavec.spark.transform.misc; import lombok.AllArgsConstructor; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/NDArrayToWritablesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/NDArrayToWritablesFunction.java index 9382d1db5..736c5bfa0 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/NDArrayToWritablesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/NDArrayToWritablesFunction.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceMergeFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceMergeFunction.java index 933fb6d4a..1e47677ac 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceMergeFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceMergeFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.misc; import org.apache.spark.api.java.function.Function; import org.datavec.api.transform.sequence.merge.SequenceMerge; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceWritablesToStringFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceWritablesToStringFunction.java index c8d164ee1..f872bd684 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceWritablesToStringFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/SequenceWritablesToStringFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.misc; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/StringToWritablesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/StringToWritablesFunction.java index ed79a81f5..cca9f29b6 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/StringToWritablesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/StringToWritablesFunction.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.StringSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.Collection; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/WritablesToNDArrayFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/WritablesToNDArrayFunction.java index a55161f1a..adc2e09da 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/WritablesToNDArrayFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/WritablesToNDArrayFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.misc; import org.apache.spark.api.java.function.Function; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/WritablesToStringFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/WritablesToStringFunction.java index 4bed90ce7..4362c4837 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/WritablesToStringFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/misc/WritablesToStringFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.misc; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/rank/UnzipForCalculateSortedRankFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/rank/UnzipForCalculateSortedRankFunction.java index 662a42d7b..83dd8c7cc 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/rank/UnzipForCalculateSortedRankFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/rank/UnzipForCalculateSortedRankFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.rank; import org.apache.spark.api.java.function.Function; import org.datavec.api.writable.LongWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/reduce/MapToPairForReducerFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/reduce/MapToPairForReducerFunction.java index 9c69e6d62..814b9ffc4 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/reduce/MapToPairForReducerFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/reduce/MapToPairForReducerFunction.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.PairFunction; import org.datavec.api.transform.reduce.IAssociativeReducer; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/reduce/ReducerFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/reduce/ReducerFunction.java index 10c409f60..6ed458848 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/reduce/ReducerFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/reduce/ReducerFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.transform.reduce; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.datavec.api.transform.reduce.IAssociativeReducer; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/ConvertToSequenceLengthOne.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/ConvertToSequenceLengthOne.java index 97a7bca31..e5c9f0091 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/ConvertToSequenceLengthOne.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/ConvertToSequenceLengthOne.java @@ -21,7 +21,7 @@ package org.datavec.spark.transform.sequence; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.Collections; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkGroupToSequenceFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkGroupToSequenceFunction.java index 3df2558e4..0d3845e64 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkGroupToSequenceFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkGroupToSequenceFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.transform.sequence; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.datavec.api.transform.sequence.SequenceComparator; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.Collections; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkMapToPairByColumnFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkMapToPairByColumnFunction.java index 98357e121..bc92549bf 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkMapToPairByColumnFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkMapToPairByColumnFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.sequence; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkMapToPairByMultipleColumnsFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkMapToPairByMultipleColumnsFunction.java index aaa8d1049..b84b0f1f6 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkMapToPairByMultipleColumnsFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkMapToPairByMultipleColumnsFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.sequence; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.PairFunction; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import scala.Tuple2; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkSequenceFilterFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkSequenceFilterFunction.java index 6fb4480ed..9dc47bec2 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkSequenceFilterFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkSequenceFilterFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.transform.sequence; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.datavec.api.transform.filter.Filter; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkSequenceTransformFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkSequenceTransformFunction.java index 0e1542cc1..bb7004bba 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkSequenceTransformFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sequence/SparkSequenceTransformFunction.java @@ -23,7 +23,7 @@ package org.datavec.spark.transform.sequence; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.datavec.api.transform.Transform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java index faac3ac13..a67e5a52f 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/SequenceToRows.java @@ -25,7 +25,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; import org.apache.spark.sql.types.StructType; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.DataFrames; import java.util.*; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRecord.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRecord.java index 21cb459ae..64ffa20cb 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRecord.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRecord.java @@ -23,6 +23,7 @@ package org.datavec.spark.transform.sparkfunction; import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Row; +import org.datavec.api.Writable; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRow.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRow.java index 6128901ee..5fc57c011 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRow.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/ToRow.java @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema; import org.apache.spark.sql.types.StructType; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.DataFrames; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceCreateCombiner.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceCreateCombiner.java index 355709cf7..b86fd3d1a 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceCreateCombiner.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceCreateCombiner.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.DataFrames; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceMergeCombiner.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceMergeCombiner.java index 945d43a40..fabb3f5e1 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceMergeCombiner.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceMergeCombiner.java @@ -20,13 +20,8 @@ package org.datavec.spark.transform.sparkfunction.sequence; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function2; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; -import org.datavec.spark.transform.DataFrames; +import org.datavec.api.Writable; import java.util.ArrayList; import java.util.Collections; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceMergeValue.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceMergeValue.java index 807903a2f..fa7207792 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceMergeValue.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/sparkfunction/sequence/DataFrameToSequenceMergeValue.java @@ -21,11 +21,10 @@ package org.datavec.spark.transform.sparkfunction.sequence; import lombok.AllArgsConstructor; -import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function2; import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.DataFrames; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java index fbad46a33..ad2b2e9f4 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SequenceSplitFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.transform; import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.sequence.SequenceSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.Iterator; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SparkTransformFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SparkTransformFunction.java index 2a4e66eb0..b595427b8 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SparkTransformFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SparkTransformFunction.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.spark.api.java.function.Function; import org.datavec.api.transform.Transform; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.SparkTransformExecutor; import java.util.ArrayList; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java index be9e2d662..cc08df978 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/transform/SparkTransformProcessFunction.java @@ -22,7 +22,7 @@ package org.datavec.spark.transform.transform; import org.apache.spark.api.java.function.FlatMapFunction; import org.datavec.api.transform.TransformProcess; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.util.Collections; import java.util.Iterator; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkExport.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkExport.java index 549ff7e13..5fc8dfe94 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkExport.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkExport.java @@ -25,7 +25,7 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.misc.WritablesToStringFunction; import java.io.File; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkUtils.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkUtils.java index e1d038495..453c0e46c 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkUtils.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/main/java/org/datavec/spark/transform/utils/SparkUtils.java @@ -29,6 +29,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.SparkContext; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.datavec.api.Writable; import org.datavec.api.transform.analysis.DataAnalysis; import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.split.RandomSplit; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestLineRecordReaderFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestLineRecordReaderFunction.java index d7a906597..9ce6dbeca 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestLineRecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestLineRecordReaderFunction.java @@ -25,7 +25,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.BaseSparkTest; import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java index 5143b01eb..406229a73 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestNDArrayToWritablesFunction.java @@ -22,14 +22,13 @@ package org.datavec.spark.functions; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.misc.NDArrayToWritablesFunction; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java index 0e4df00cc..151ad4601 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestPairSequenceRecordReaderBytesFunction.java @@ -30,7 +30,7 @@ import org.datavec.api.conf.Configuration; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.codec.reader.CodecRecordReader; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.functions.pairdata.BytesPairWritable; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java index 9afc645ad..2e1ed3497 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderBytesFunction.java @@ -31,7 +31,7 @@ import org.apache.spark.input.PortableDataStream; import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.functions.data.FilesAsBytesFunction; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderFunction.java index a1695ee63..1d061ded9 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestRecordReaderFunction.java @@ -28,7 +28,7 @@ import org.datavec.api.io.labels.ParentPathLabelGenerator; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.ArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.spark.BaseSparkTest; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java index 212b9bb64..968edd3f8 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderBytesFunction.java @@ -31,7 +31,7 @@ import org.datavec.api.conf.Configuration; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.codec.reader.CodecRecordReader; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.functions.data.FilesAsBytesFunction; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderFunction.java index 208bc42d1..b8b94f4d8 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestSequenceRecordReaderFunction.java @@ -30,7 +30,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.ArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.codec.reader.CodecRecordReader; import org.datavec.spark.BaseSparkTest; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToNDArrayFunction.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToNDArrayFunction.java index 62021a252..ddbab3ce7 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToNDArrayFunction.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToNDArrayFunction.java @@ -20,6 +20,7 @@ package org.datavec.spark.functions; +import org.datavec.api.Writable; import org.datavec.api.writable.*; import org.datavec.spark.transform.misc.WritablesToNDArrayFunction; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java index 19847cec0..67ffbd1e8 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/functions/TestWritablesToStringFunctions.java @@ -25,7 +25,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.PairFunction; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.misc.SequenceWritablesToStringFunction; import org.datavec.spark.transform.misc.WritablesToStringFunction; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java index eaafa1d14..cbb49ba5b 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/storage/TestSparkStorageUtils.java @@ -24,6 +24,7 @@ import com.sun.jna.Platform; import com.google.common.io.Files; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.datavec.api.Writable; import org.datavec.api.writable.*; import org.datavec.spark.BaseSparkTest; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/DataFramesTests.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/DataFramesTests.java index 4a6bca8dc..250dd1394 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/DataFramesTests.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/DataFramesTests.java @@ -28,7 +28,7 @@ import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.BaseSparkTest; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.ndarray.INDArray; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/ExecutionTest.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/ExecutionTest.java index 8da6f146b..e6ce68c17 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/ExecutionTest.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/ExecutionTest.java @@ -31,7 +31,7 @@ import org.datavec.api.transform.transform.categorical.FirstDigitTransform; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.spark.BaseSparkTest; import org.datavec.python.PythonTransform; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/NormalizationTests.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/NormalizationTests.java index 7fd6d9d7f..aa48b44f1 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/NormalizationTests.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/NormalizationTests.java @@ -26,7 +26,7 @@ import org.apache.spark.sql.Row; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.BaseSparkTest; import org.junit.jupiter.api.Test; import org.nd4j.linalg.api.buffer.DataType; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java index 8ed68b55e..998f96248 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/analysis/TestAnalysis.java @@ -23,6 +23,7 @@ package org.datavec.spark.transform.analysis; import com.tdunning.math.stats.TDigest; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.util.StatCounter; +import org.datavec.api.Writable; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/join/TestJoin.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/join/TestJoin.java index 853800a03..09b70e505 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/join/TestJoin.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/join/TestJoin.java @@ -21,6 +21,7 @@ package org.datavec.spark.transform.join; import org.apache.spark.api.java.JavaRDD; +import org.datavec.api.Writable; import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.join.Join; import org.datavec.api.transform.schema.Schema; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java index daf2794f2..e6ac5718e 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/rank/TestCalculateSortedRank.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.SparkTransformExecutor; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java index ad545172c..512970a42 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/transform/sequence/TestConvertToSequence.java @@ -26,7 +26,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.sequence.comparator.NumericalColumnComparator; import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.SparkTransformExecutor; import org.junit.jupiter.api.Test; diff --git a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/util/TestSparkUtil.java b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/util/TestSparkUtil.java index 1ed67934b..d33b303e6 100644 --- a/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/util/TestSparkUtil.java +++ b/cavis-datavec/cavis-datavec-spark/cavis-datavec-spark-core/src/test/java/org/datavec/spark/util/TestSparkUtil.java @@ -25,7 +25,7 @@ import org.apache.commons.io.IOUtils; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.BaseSparkTest; import org.datavec.spark.transform.utils.SparkUtils; import org.junit.jupiter.api.Test; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java index 9613d9141..b9ce2a9ea 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/dataset/api/DataSet.java @@ -91,7 +91,7 @@ public interface DataSet extends Iterable, Seri * Calculate and return a count of each label, by index. * Assumes labels are a one-hot INDArray, for classification * - * @return Map of countsn + * @return Map of counts */ Map labelCounts(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java index c7f354937..b5a139779 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetiteratorTest.java @@ -24,7 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.datavec.api.io.labels.ParentPathLabelGenerator; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.SequenceRecordReader; @@ -38,7 +38,7 @@ import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java index 0341ac846..e3c4c38fa 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIteratorTest.java @@ -27,7 +27,7 @@ import org.apache.commons.io.FileUtils; import org.apache.commons.io.FilenameUtils; import org.datavec.api.conf.Configuration; import org.datavec.api.io.labels.ParentPathLabelGenerator; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.BaseRecordReader; import org.datavec.api.records.reader.RecordReader; @@ -42,7 +42,7 @@ import org.datavec.api.split.NumberedFileInputSplit; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/tools/SpecialImageRecordReader.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/tools/SpecialImageRecordReader.java index 545be93e3..949d21b84 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/tools/SpecialImageRecordReader.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/datavec/tools/SpecialImageRecordReader.java @@ -27,7 +27,7 @@ import org.bytedeco.javacpp.Pointer; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.recordreader.ImageRecordReader; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 024804c0c..30cb1e5ca 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -27,7 +27,7 @@ import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordRe import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java index 3ec50df59..dcc40e807 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestRecordReaders.java @@ -24,11 +24,10 @@ import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; -import org.deeplearning4j.exception.DL4JException; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.nd4j.linalg.dataset.api.DataSet; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java index 967c4249d..51711abb3 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderDataSetIterator.java @@ -26,14 +26,14 @@ import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.datavec.api.io.WritableConverter; import org.datavec.api.io.converters.SelfWritableConverter; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataComposableMap; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.records.reader.impl.ConcatenatingRecordReader; import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; diff --git a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java index 34731a8ac..282def2da 100644 --- a/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java +++ b/cavis-dnn/cavis-dnn-data/cavis-dnn-data-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java @@ -25,7 +25,7 @@ import lombok.Getter; import lombok.Setter; import lombok.val; import org.apache.commons.lang3.ArrayUtils; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.metadata.RecordMetaDataComposableMap; @@ -34,7 +34,7 @@ import org.datavec.api.records.reader.SequenceRecordReader; import org.datavec.api.util.ndarray.RecordConverter; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.deeplearning4j.datasets.datavec.exception.ZeroLengthSequenceException; import org.nd4j.common.base.Preconditions; diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java index 09c621c4d..bf7e36a33 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/sequencevectors/SequenceVectorsTest.java @@ -25,7 +25,7 @@ import lombok.Getter; import lombok.Setter; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java index bc72ced72..41c32112d 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java @@ -25,7 +25,7 @@ import org.apache.spark.api.java.function.Function; import org.datavec.api.io.WritableConverter; import org.datavec.api.io.converters.WritableConverterException; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java index 025a1f4c0..dfc2db315 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java @@ -23,8 +23,7 @@ package org.deeplearning4j.spark.datavec; import org.apache.spark.api.java.function.Function; import org.datavec.api.io.WritableConverter; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; -import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java index 1ad5a7cfd..30ff1808a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java @@ -23,8 +23,7 @@ package org.deeplearning4j.spark.datavec; import org.apache.spark.api.java.function.Function; import org.datavec.api.io.WritableConverter; import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; -import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.DataSetPreProcessor; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java index 48ff6d0b0..872da2db2 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java @@ -24,7 +24,7 @@ import org.apache.spark.api.java.function.Function; import org.datavec.api.io.WritableConverter; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.StringSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java index 9242ab798..cb9350561 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java @@ -29,7 +29,7 @@ import org.apache.spark.broadcast.Broadcast; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; import org.datavec.api.split.StringSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.util.DefaultHadoopConfig; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java index 7fb70736b..99360fbe6 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.datavec.iterator; import lombok.AllArgsConstructor; import lombok.Data; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.List; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java index a950527e1..1ef958592 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.datavec.iterator; import lombok.AllArgsConstructor; import lombok.Data; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.Serializable; import java.util.List; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java index b2a7592e8..bbe7c8e47 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java @@ -27,7 +27,7 @@ import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.PairFunction; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; import org.nd4j.linalg.dataset.api.MultiDataSet; import scala.Tuple2; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java index 1c3a47039..5763c0885 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import org.apache.spark.api.java.function.Function; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.factory.Nd4j; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java index 0f5519f35..aabd6be9f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java @@ -22,12 +22,12 @@ package org.deeplearning4j.spark.datavec.iterator; import lombok.Data; import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; +import org.datavec.api.Record; import org.datavec.api.records.listener.RecordListener; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java index 54671ec37..8dfa37293 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java @@ -24,7 +24,7 @@ import lombok.Data; import org.datavec.api.records.SequenceRecord; import org.datavec.api.records.metadata.RecordMetaData; import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import java.io.DataInputStream; import java.io.IOException; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java index 576cda013..0f90d61b0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java @@ -32,7 +32,7 @@ import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.InputStreamInputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.exception.ND4JArraySizeException; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java index bd2e0f389..59ee7e189 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java @@ -36,7 +36,7 @@ import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.split.NumberedFileInputSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.image.recordreader.ImageRecordReader; import org.datavec.spark.functions.SequenceRecordReaderFunction; import org.datavec.spark.functions.pairdata.*; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java index 30ce34c6b..6edb60bda 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java @@ -24,7 +24,7 @@ import org.apache.spark.api.java.JavaRDD; import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; +import org.datavec.api.Writable; import org.datavec.spark.transform.misc.StringToWritablesFunction; import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; import org.deeplearning4j.spark.BaseSparkTest; From 471657fef710d10785547eac3eabae5888b7d2f4 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 10:42:07 +0200 Subject: [PATCH 059/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index b0e7e9b81..d8a0de833 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -1,6 +1,7 @@ plugins { id 'java-library' id 'maven-publish' + id 'com.github.johnrengelman.shadow' version '7.1.2' } /* From 0d443b8edabcfc1f9344c9c485e46dfa33d41f0d Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 10:47:43 +0200 Subject: [PATCH 060/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index d8a0de833..8197c5111 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -86,6 +86,10 @@ artifacts { } */ +shadowJar { + zip64 true //need this to support jars with more than 65535 entries +} + publishing { publications { mavenJava(MavenPublication) { From 2e074768f38502cc62efb6ce5d68b39a054e7efe Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 11:04:23 +0200 Subject: [PATCH 061/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 8197c5111..bd28ed655 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -23,7 +23,7 @@ dependencies { //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - // api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") From cf5c9b53b9bda94f332f275ed9502e3a46c7495f Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 11:09:17 +0200 Subject: [PATCH 062/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index bd28ed655..50f728fea 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -23,7 +23,7 @@ dependencies { //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" + //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' rootProject.getAllprojects().each { Project sproj -> if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") From 7a105ed2075168a76984f672e97b3697e63306c8 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 11:30:50 +0200 Subject: [PATCH 063/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 6 ++++++ cavis-native/build.gradle | 2 +- cavis-native/cavis-native-lib/build.gradle | 1 - settings.gradle | 2 +- 4 files changed, 8 insertions(+), 3 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 50f728fea..78f276087 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -1,3 +1,5 @@ +import org.gradle.api.internal.artifacts.configurations.Configurations + plugins { id 'java-library' id 'maven-publish' @@ -36,6 +38,10 @@ dependencies { && !sproj.name.equals("cavis-zoo")) { //compileOnly project(""+sproj.path) api sproj + sproj.configurations.each { Configuration c -> + logger.quiet(sproj.name + ":" + c.name) + } + if(! sproj.configurations.empty) { //compileOnly project(sproj.getPath()) diff --git a/cavis-native/build.gradle b/cavis-native/build.gradle index 1519fe9d4..943c0c441 100644 --- a/cavis-native/build.gradle +++ b/cavis-native/build.gradle @@ -20,7 +20,7 @@ */ subprojects { - group = "net.brutex.cavis-native" + group = group + "cavis-native" apply plugin: "java-library" apply plugin: "maven-publish" apply plugin: "signing" diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 1d083f0ce..41e227faa 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -45,7 +45,6 @@ buildscript { return pf } - } diff --git a/settings.gradle b/settings.gradle index aaf58f336..17d2ee1b9 100644 --- a/settings.gradle +++ b/settings.gradle @@ -116,6 +116,7 @@ include ':cavis-dnn:cavis-dnn-spark:cavis-dnn-spark-parameterserver' include ':cavis-dnn:cavis-dnn-tsne' include ':cavis-datavec' include ':cavis-datavec:cavis-datavec-api' +include ':cavis-datavec:dvec-api' include ':cavis-datavec:cavis-datavec-data' include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-arrow' include ':cavis-datavec:cavis-datavec-data:cavis-datavec-data-image' @@ -150,4 +151,3 @@ include ':cavis-zoo' include ':cavis-zoo:cavis-zoo-models' include ':brutex-extended-tests' include ':cavis-full' - From c3bb9d44cd717098a24aa9810ea3a90ad255fe14 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 12:44:21 +0200 Subject: [PATCH 064/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-native/build.gradle b/cavis-native/build.gradle index 943c0c441..d460c718d 100644 --- a/cavis-native/build.gradle +++ b/cavis-native/build.gradle @@ -20,7 +20,7 @@ */ subprojects { - group = group + "cavis-native" + group = group + ".cavis-native" apply plugin: "java-library" apply plugin: "maven-publish" apply plugin: "signing" From 30041c8aa5d8db579709f0e4ccb1e66327e0e34e Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 12:51:01 +0200 Subject: [PATCH 065/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/cavis-native-cpu/build.gradle | 2 +- cavis-native/cavis-native-jcublas/build.gradle | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-native/cavis-native-cpu/build.gradle b/cavis-native/cavis-native-cpu/build.gradle index a435e16f8..39dfff5bf 100644 --- a/cavis-native/cavis-native-cpu/build.gradle +++ b/cavis-native/cavis-native-cpu/build.gradle @@ -16,7 +16,7 @@ dependencies { implementation (projects.cavisNative.cavisNativeLib) { capabilities { - it.requireCapability group: "net.brutex.cavis-native", name:"cavis-native-lib-cpu-support" + it.requireCapability group: "net.brutex.cavis.cavis-native", name:"cavis-native-lib-cpu-support", version: project.version } } diff --git a/cavis-native/cavis-native-jcublas/build.gradle b/cavis-native/cavis-native-jcublas/build.gradle index b9b3c37e4..0e0a9dd22 100644 --- a/cavis-native/cavis-native-jcublas/build.gradle +++ b/cavis-native/cavis-native-jcublas/build.gradle @@ -22,7 +22,7 @@ dependencies { implementation(project(path: ":cavis-native:cavis-native-lib")) { capabilities { - it.requireCapability("net.brutex.cavis-native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT") + it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version:project.version) } } implementation project(":cavis-native:cavis-native-common") From a9bcb7f0c88f309045bba6cf7ac87fa0e9d764ab Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 14:22:22 +0200 Subject: [PATCH 066/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 39 ++++------------------ cavis-native/cavis-native-lib/build.gradle | 2 -- 2 files changed, 7 insertions(+), 34 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 78f276087..4136e7069 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -1,20 +1,9 @@ -import org.gradle.api.internal.artifacts.configurations.Configurations - plugins { id 'java-library' id 'maven-publish' id 'com.github.johnrengelman.shadow' version '7.1.2' } -/* -configurations.archives.artifacts.with { archives -> - - archives.each { - println(it.name) - } -} -*/ - dependencies { //Todo clean this api platform(project(":cavis-common-platform")) @@ -24,40 +13,26 @@ dependencies { api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" - //api group: "org.bytedeco", name: "javacpp", version: "1.5.7" - //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT', classifier: "linux-x86_64-avx2-cpu" - //api group: 'net.brutex.cavis-native', name: 'cavis-native-lib', version: '1.0.0-SNAPSHOT' + rootProject.getAllprojects().each { Project sproj -> - if(!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") + if (!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") && !sproj.name.equals("Cavis") && !sproj.name.equals("cavis-datavec") && !sproj.name.equals("cavis-dnn") && !sproj.name.equals("cavis-native") + && !sproj.name.equals("cavis-native-lib") && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { - //compileOnly project(""+sproj.path) - api sproj - sproj.configurations.each { Configuration c -> - logger.quiet(sproj.name + ":" + c.name) - } - if(! sproj.configurations.empty) { - //compileOnly project(sproj.getPath()) - - /* - sproj.configurations.each {Configuration conf -> - conf.dependencies.each {Dependency dep -> - compileOnly dep - } + implementation(projects.cavisNative.cavisNativeLib) { + capabilities { + it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) + it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } - - */ } } } - - } /* diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 41e227faa..ac2aa987e 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -169,8 +169,6 @@ dependencies { implementation "org.apache.commons:commons-lang3" implementation "org.apache.commons:commons-math3" implementation "com.google.flatbuffers:flatbuffers-java" - - //javacppPlatform project(":cavis-native:cavis-native-blas") } From 460205101c1494f84915d9910467867118a1c9d4 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 14:24:21 +0200 Subject: [PATCH 067/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 4136e7069..edfa7137d 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -27,7 +27,7 @@ dependencies { implementation(projects.cavisNative.cavisNativeLib) { capabilities { - it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) + //it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } } From f1695eb8aed825f0d2a7ca36ba876e6a8e614786 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 14:29:10 +0200 Subject: [PATCH 068/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-common-platform/build.gradle | 3 ++- cavis-full/build.gradle | 17 ++++++++++------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/cavis-common-platform/build.gradle b/cavis-common-platform/build.gradle index 81851f931..aaf070d84 100644 --- a/cavis-common-platform/build.gradle +++ b/cavis-common-platform/build.gradle @@ -160,6 +160,7 @@ dependencies { } } +/* publishing { publications { myPlatform(MavenPublication) { @@ -167,7 +168,7 @@ publishing { } } } - +*/ tasks.withType(GenerateModuleMetadata).configureEach { // The value 'enforced-platform' is provided in the validation diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index edfa7137d..5805abe6a 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -24,17 +24,20 @@ dependencies { && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { - - implementation(projects.cavisNative.cavisNativeLib) { - capabilities { - //it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) - it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) - } - } + api sproj } } + + implementation(projects.cavisNative.cavisNativeLib) { + capabilities { + //it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) + it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) + } + } + } + /* tasks.getByName("jar") { From 1e681e7c054858dec39a6d0093038135f12bce0e Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 15:36:39 +0200 Subject: [PATCH 069/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 5805abe6a..2433755d4 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -4,6 +4,8 @@ plugins { id 'com.github.johnrengelman.shadow' version '7.1.2' } +apply from: rootProject.projectDir.path+"/chooseBackend.gradle" + dependencies { //Todo clean this api platform(project(":cavis-common-platform")) @@ -28,10 +30,10 @@ dependencies { } } - implementation(projects.cavisNative.cavisNativeLib) { + api(projects.cavisNative.cavisNativeLib) { capabilities { - //it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) - it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) + if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) + if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } } @@ -71,6 +73,7 @@ artifacts { */ shadowJar { + enabled false; zip64 true //need this to support jars with more than 65535 entries } From 6044c1c53a6cc3f46894fdd6f1be512877d0b5f5 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 17:32:35 +0200 Subject: [PATCH 070/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 19 +++++++++++++------ cavis-native/cavis-native-cpu/build.gradle | 2 +- cavis-native/cavis-native-lib/build.gradle | 1 + 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 2433755d4..e722d42bf 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -10,11 +10,11 @@ dependencies { //Todo clean this api platform(project(":cavis-common-platform")) //api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise - api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" - api 'org.slf4j:slf4j-simple:2.0.3' - api 'org.slf4j:slf4j-api:2.0.3' + //api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" + //api 'org.slf4j:slf4j-simple:2.0.3' + //api 'org.slf4j:slf4j-api:2.0.3' //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" + //api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" rootProject.getAllprojects().each { Project sproj -> if (!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") @@ -22,20 +22,27 @@ dependencies { && !sproj.name.equals("cavis-datavec") && !sproj.name.equals("cavis-dnn") && !sproj.name.equals("cavis-native") - && !sproj.name.equals("cavis-native-lib") && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { api sproj } } - +/* api(projects.cavisNative.cavisNativeLib) { capabilities { if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } } +*/ + + api (project(':cavis-native:cavis-native-lib')) { + capabilities { + if(withCpu()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cpu-support") + //if(withCuda()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cuda-support") + } + } } diff --git a/cavis-native/cavis-native-cpu/build.gradle b/cavis-native/cavis-native-cpu/build.gradle index 39dfff5bf..cdb610ed8 100644 --- a/cavis-native/cavis-native-cpu/build.gradle +++ b/cavis-native/cavis-native-cpu/build.gradle @@ -14,6 +14,7 @@ dependencies { implementation projects.cavisDnn.cavisDnnApi implementation projects.cavisDnn.cavisDnnCommon + implementation (projects.cavisNative.cavisNativeLib) { capabilities { it.requireCapability group: "net.brutex.cavis.cavis-native", name:"cavis-native-lib-cpu-support", version: project.version @@ -28,5 +29,4 @@ dependencies { implementation "com.google.flatbuffers:flatbuffers-java" implementation "org.slf4j:slf4j-api" implementation "org.apache.commons:commons-math3" - } diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index ac2aa987e..989bc7c25 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -99,6 +99,7 @@ java { chipList.each {thisChip -> registerFeature("${thisChip}Support") { usingSourceSet(sourceSets.findByName("${thisChip}Support")) + capability(project.group, "cavis-native-lib-${thisChip}-support", project.version) //withJavadocJar() //withSourcesJar() } From 638f13e681ca1ccfae1c002a635572b8602aca57 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 17:32:47 +0200 Subject: [PATCH 071/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .../java/net/brutex/cavis/dvec/api/Field.java | 66 +++++++++++++++++++ .../brutex/cavis/dvec/api/FieldMetadata.java | 31 +++++++++ .../dvec/api/exceptions/DVecException.java | 32 +++++++++ .../brutex/cavis/dvec/api/package-info.java | 35 ++++++++++ 4 files changed, 164 insertions(+) create mode 100644 cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java create mode 100644 cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldMetadata.java create mode 100644 cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/exceptions/DVecException.java create mode 100644 cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/package-info.java diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java new file mode 100644 index 000000000..a3be6313f --- /dev/null +++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java @@ -0,0 +1,66 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.cavis.dvec.api; + +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import net.brutex.cavis.dvec.api.exceptions.DVecException; + +/** + * Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage + * in memory and adds basic error handling. + * + * @author Brian Rosenberger + * @since 1.0 + */ +public abstract class Field implements FieldInterface { + + /** + * {@inheritDoc} + * + * @param start Index of starting position, zero based + * @param length how many fields to read + * @return the list of Buffer + */ + @Override + public T read(long start, long length) throws DVecException { + if (start<0 || start>internalStorage.capacity()-1 ) { + throw new DVecException("Read on Field start position is out of bounds."); + } + if (start+length> internalStorage.capacity()) { + throw new DVecException("Read on Field exceeds field length"); + } + return null; + } + + @Override + public void write(long pos, T buffer) { + + } + + private ByteBuffer internalStorage = null; + + + +} diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldMetadata.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldMetadata.java new file mode 100644 index 000000000..9dcd12cdf --- /dev/null +++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldMetadata.java @@ -0,0 +1,31 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.cavis.dvec.api; + +/** + * tbd. + * @author Brian Rosenberger + * + */ +public interface FieldMetadata { + +} diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/exceptions/DVecException.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/exceptions/DVecException.java new file mode 100644 index 000000000..e1eb4ec28 --- /dev/null +++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/exceptions/DVecException.java @@ -0,0 +1,32 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.cavis.dvec.api.exceptions; + +import lombok.Getter; + +public class DVecException extends Exception { + + @Getter private final String message; + public DVecException(String message) { + this.message = message; + } +} diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/package-info.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/package-info.java new file mode 100644 index 000000000..190b22052 --- /dev/null +++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/package-info.java @@ -0,0 +1,35 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +/** + *

      + * The data vectorization api (dvec-api) defines a data structure in analogy to Hadoop and is + * derived from the dl4j datavec library.
      The main concept is around + *

      + *
        + *
      • InputFormat
      • + *
      • InputSplit
      • + *
      • RecordReader, Records and Writable
      • + *
      + * + * @author Brian Rosenberger <bru@brutex.de> + */ +package net.brutex.cavis.dvec.api; \ No newline at end of file From 4857b71181bf5d8344f8bc11c95c08424b68bdb6 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 17:35:50 +0200 Subject: [PATCH 072/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .../cavis-datavec-data/cavis-datavec-data-codec/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/build.gradle b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/build.gradle index b5de498e2..3f8076717 100644 --- a/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/build.gradle +++ b/cavis-datavec/cavis-datavec-data/cavis-datavec-data-codec/build.gradle @@ -28,6 +28,7 @@ dependencies { implementation "org.bytedeco:javacv" implementation "org.apache.commons:commons-compress" implementation "org.jcodec:jcodec:0.1.5" + implementation "com.fasterxml.jackson.core:jackson-annotations" implementation "org.slf4j:slf4j-api" From 42a27480e6a214319fd73d213d1ae6b5050093b3 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 17:47:41 +0200 Subject: [PATCH 073/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .../brutex/cavis/dvec/api/FieldInterface.java | 79 +++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java new file mode 100644 index 000000000..92705bea8 --- /dev/null +++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java @@ -0,0 +1,79 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.cavis.dvec.api; + +import java.io.Serializable; +import java.nio.Buffer; +import java.nio.LongBuffer; +import java.util.List; +import net.brutex.cavis.dvec.api.exceptions.DVecException; + +/** + * A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple + * entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api, + * other (i.e. Image or Arrow) require dvec extensions accordingly. + * + * @author Brian Rosenberger + * @since 1.0 + */ +public interface FieldInterface extends Serializable { + + /** + * Get a reference to the metadata for this Field. + * + * @return the {@link FieldMetadata} + */ + FieldMetadata getFieldMetadata(); + + /** + * Get the 1st field as Buffer. This deserializes the data from the underlying storage. + * + * @return T underlying Buffer + */ + default T read() throws DVecException { + return read(0, 1); + } + + /** + * Get a range of fields as a {@code Buffer} + * + * @param start Index of starting position, zero based + * @param length how many fields to read + * @return the buffers + */ + T read(long start, long length) throws DVecException; + + /** + * Write the data into the underlying storage. + */ + default void write(T buffer) { + write(0, buffer); + } + + /** + * Write the data into the underyling storage starting at a position + * + * @param pos the position to start + */ + void write(long pos, T buffer); + +} From 6af05fc8a5be55e611067381036b4b8ba9b1c491 Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 18:04:38 +0200 Subject: [PATCH 074/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 5 ++++- cavis-native/cavis-native-lib/build.gradle | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index e722d42bf..c360cfe56 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -36,14 +36,17 @@ dependencies { } } */ + if(withCpu()) api project(path: ":cavis-native:cavi-native-lib", configuration: "cpuSupportCompileClasspath") + if(withCuda()) api project(path: ":cavis-native:cavi-native-lib", configuration: "cudaSupportCompileClasspath") + /* api (project(':cavis-native:cavis-native-lib')) { capabilities { if(withCpu()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cpu-support") //if(withCuda()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cuda-support") } } - +*/ } diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 989bc7c25..a83faf398 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -315,7 +315,6 @@ chipList.each { thisChip -> classOrPackageNames = ["org.nd4j.nativeblas.${thisChip}.Nd4j${thisChip.capitalize()}Presets"] outputDirectory = file("${buildDir}/generated/sources/javacpp/${thisChip}/${javacppPlatform}${javacppPlatformExtension}/") - classPath = sourceSets.getByName("${thisChip}Support").getRuntimeClasspath() classPath += ["${buildDir}/classes/java/${thisChip}Support/"] } From 0bd8f072c0ed176f3dc3983d796e43100530080e Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 24 Oct 2022 18:07:39 +0200 Subject: [PATCH 075/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index c360cfe56..b6f082a52 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -36,8 +36,8 @@ dependencies { } } */ - if(withCpu()) api project(path: ":cavis-native:cavi-native-lib", configuration: "cpuSupportCompileClasspath") - if(withCuda()) api project(path: ":cavis-native:cavi-native-lib", configuration: "cudaSupportCompileClasspath") + if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportCompileClasspath") + if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") /* api (project(':cavis-native:cavis-native-lib')) { From 0aea7d8e4ccdfcd046e073e4fb806e9805de6efb Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 09:03:46 +0200 Subject: [PATCH 076/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++-- cavis-native/cavis-native-lib/build.gradle | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index b6f082a52..c09fed04e 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -36,8 +36,8 @@ dependencies { } } */ - if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportCompileClasspath") - if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") + //if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportCompileClasspath") + //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") /* api (project(':cavis-native:cavis-native-lib')) { diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index a83faf398..8614fc9df 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -428,6 +428,11 @@ chipList.each { thisChip -> } } +chipList.each{ thisChip -> + configurations { + implementation.extendsFrom findByName("${thisChip}SupportImplementation") + } +} tasks.withType(JavaCompile) { From a310b6be958805ca58a1500084e8c3a12ba196c7 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 09:16:07 +0200 Subject: [PATCH 077/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index c09fed04e..7250b7861 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -36,7 +36,7 @@ dependencies { } } */ - //if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportCompileClasspath") + if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportCompileClasspath") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") /* From 7cd0cd12cc99781df2d0dd4065ff97d38b27b6df Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 09:17:23 +0200 Subject: [PATCH 078/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 7250b7861..e25089501 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -36,7 +36,7 @@ dependencies { } } */ - if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportCompileClasspath") + if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "implementation") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") /* From 46b04cc0b4ae8e1469b5d6118abcd4ea05c5a398 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 09:20:22 +0200 Subject: [PATCH 079/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- cavis-native/cavis-native-lib/build.gradle | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index e25089501..7250b7861 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -36,7 +36,7 @@ dependencies { } } */ - if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "implementation") + if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportCompileClasspath") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") /* diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 8614fc9df..e2a72cb9c 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -431,6 +431,7 @@ chipList.each { thisChip -> chipList.each{ thisChip -> configurations { implementation.extendsFrom findByName("${thisChip}SupportImplementation") + findByName("${thisChip}SupportImplementation").setCanBeConsumed(true) } } From efd106dc0ae1bd11bd802b1fb54a5e8017d2fe7b Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 09:21:58 +0200 Subject: [PATCH 080/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/cavis-native-lib/build.gradle | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index e2a72cb9c..d0d23f5af 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -69,6 +69,12 @@ plugins { id 'signing' } +chipList.each{ thisChip -> + configurations { + findByName("${thisChip}SupportImplementation").setCanBeConsumed(true) + } +} + chipList.each {thisChip -> sourceSets.register("${thisChip}Support") { java { From 400429a32f670771e4cc1e0c8fd43cd76634021f Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 13:20:14 +0200 Subject: [PATCH 081/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/cavis-native-lib/build.gradle | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index d0d23f5af..0e18c9428 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -70,7 +70,9 @@ plugins { } chipList.each{ thisChip -> + configurations { + create("${thisChip}SupportImplementation") findByName("${thisChip}SupportImplementation").setCanBeConsumed(true) } } @@ -434,14 +436,6 @@ chipList.each { thisChip -> } } -chipList.each{ thisChip -> - configurations { - implementation.extendsFrom findByName("${thisChip}SupportImplementation") - findByName("${thisChip}SupportImplementation").setCanBeConsumed(true) - } -} - - tasks.withType(JavaCompile) { // options.setCompilerArgs(Arrays.asList("-Xlint:unchecked")) } From 602220e07ee5ed20dd5543fc565ee13688ecce09 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 13:21:53 +0200 Subject: [PATCH 082/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 7250b7861..809ca5c28 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -36,7 +36,8 @@ dependencies { } } */ - if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportCompileClasspath") + if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation") + if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") /* From 6f8c14c0a3c3377df36a30c3a337aadd965b8398 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 13:37:36 +0200 Subject: [PATCH 083/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 +-- cavis-native/cavis-native-lib/build.gradle | 29 +++------------------- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 809ca5c28..7421de739 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -36,8 +36,8 @@ dependencies { } } */ - if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation") - if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation") + //if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation") + //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") /* diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 0e18c9428..7f81613e7 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -69,14 +69,6 @@ plugins { id 'signing' } -chipList.each{ thisChip -> - - configurations { - create("${thisChip}SupportImplementation") - findByName("${thisChip}SupportImplementation").setCanBeConsumed(true) - } -} - chipList.each {thisChip -> sourceSets.register("${thisChip}Support") { java { @@ -114,17 +106,9 @@ java { } } -/* -configurations.each(s -> { - println "Configurations: " + s.name + " " + s.artifacts.each( x -> - { println x.getFile().getName()}) -}) -*/ dependencies { api platform(project(':cavis-common-platform')) - - implementation "org.bytedeco:javacpp" implementation group: "org.bytedeco", name: "javacpp", classifier: "${javacppPlatform}" @@ -422,10 +406,6 @@ chipList.each { thisChip -> thisTask.with spec thisTask.archiveClassifier = "${javacppPlatform}${javacppPlatformExtension}-${thisChip}" } - - //tasks.getByName("${thisChip}SupportJar").dependsOn("javacpp${thisChip.capitalize()}SupportJar") - - } //Before we can compile the whole java part, we @@ -469,10 +449,7 @@ javadoc { -if(! osdetector.os.startsWith("windows")) { - //tasks.getByName("publish") { - // enabled = false - // } + tasks.getByName("generatePomFileForMavenJavaPublication") { enabled = true } @@ -481,10 +458,10 @@ if(! osdetector.os.startsWith("windows")) { } chipList.each { thisChip -> artifacts { - archives tasks.getByName("${thisChip}SupportJar") + artifact tasks.getByName("${thisChip}SupportJar") } } -} + chipList.each { thisChip -> From 41fd85aa6763fabd4fd0e4a6d36294423d115cd9 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 13:38:48 +0200 Subject: [PATCH 084/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/cavis-native-lib/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 7f81613e7..2a661314b 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -458,7 +458,7 @@ javadoc { } chipList.each { thisChip -> artifacts { - artifact tasks.getByName("${thisChip}SupportJar") + archive tasks.getByName("${thisChip}SupportJar") } } From 9229c9d0f8db125d0dfe71baca506cb81c8c34da Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 13:39:56 +0200 Subject: [PATCH 085/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-native/cavis-native-lib/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 2a661314b..5877b8db1 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -458,7 +458,7 @@ javadoc { } chipList.each { thisChip -> artifacts { - archive tasks.getByName("${thisChip}SupportJar") + archives tasks.getByName("${thisChip}SupportJar") } } From c912c4ece1e6ccfaeb81d31bb323e58d2780cbc5 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 13:51:11 +0200 Subject: [PATCH 086/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 7421de739..57cb823df 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -93,6 +93,15 @@ publishing { mavenJava(MavenPublication) { // artifact customFatJar // from components.java + pom.withXml { + def dependencyNode = asNode().appendNode('dependencies').appendNode('dependency') + dependencyNode.appendNode('groupId', 'net.brutex.cavis') + dependencyNode.appendNode('artifactId', 'cavis-native-lib') + dependencyNode.appendNode('version', project.version) + dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu') + dependencyNode.appendNode('scope', 'compile') + + } } } } From 731d7d510e2654fba8c73a148af18228636b3c4b Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 13:59:14 +0200 Subject: [PATCH 087/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 57cb823df..341d918ff 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -97,7 +97,7 @@ publishing { def dependencyNode = asNode().appendNode('dependencies').appendNode('dependency') dependencyNode.appendNode('groupId', 'net.brutex.cavis') dependencyNode.appendNode('artifactId', 'cavis-native-lib') - dependencyNode.appendNode('version', project.version) + dependencyNode.appendNode('version', '1.0.0-SNAPSHOT') dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu') dependencyNode.appendNode('scope', 'compile') From 474d0726977021a786a7ea24932c44e1371bfba3 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 14:03:33 +0200 Subject: [PATCH 088/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 341d918ff..8af560cd6 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -98,9 +98,8 @@ publishing { dependencyNode.appendNode('groupId', 'net.brutex.cavis') dependencyNode.appendNode('artifactId', 'cavis-native-lib') dependencyNode.appendNode('version', '1.0.0-SNAPSHOT') - dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu') - dependencyNode.appendNode('scope', 'compile') - + //dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu') + //dependencyNode.appendNode('scope', 'compile') } } } From 55165c9e2c49b9b0bfbf746b677e64982d2fdb03 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 14:06:35 +0200 Subject: [PATCH 089/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 8af560cd6..482504023 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -101,6 +101,7 @@ publishing { //dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu') //dependencyNode.appendNode('scope', 'compile') } + pom.println() } } } From b023808b1d3502a3abeea73ef2e129fafba00653 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 20:01:57 +0200 Subject: [PATCH 090/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 482504023..f046688ec 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -94,14 +94,14 @@ publishing { // artifact customFatJar // from components.java pom.withXml { - def dependencyNode = asNode().appendNode('dependencies').appendNode('dependency') + def dependenciesNode = asNode().appendNode('dependencies') + def dependencyNode = appendNode('dependency') dependencyNode.appendNode('groupId', 'net.brutex.cavis') dependencyNode.appendNode('artifactId', 'cavis-native-lib') dependencyNode.appendNode('version', '1.0.0-SNAPSHOT') //dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu') //dependencyNode.appendNode('scope', 'compile') } - pom.println() } } } From d31457b54554b4795953296058a3b850e53fd0b4 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 20:13:08 +0200 Subject: [PATCH 091/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index f046688ec..ccdb86c9e 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -95,7 +95,7 @@ publishing { // from components.java pom.withXml { def dependenciesNode = asNode().appendNode('dependencies') - def dependencyNode = appendNode('dependency') + def dependencyNode = appendNode('dependency','') dependencyNode.appendNode('groupId', 'net.brutex.cavis') dependencyNode.appendNode('artifactId', 'cavis-native-lib') dependencyNode.appendNode('version', '1.0.0-SNAPSHOT') From 258e8b448661fd91b9ab1eb08305f426641fd55d Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 20:16:39 +0200 Subject: [PATCH 092/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index ccdb86c9e..1a2c7ba4f 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -95,7 +95,7 @@ publishing { // from components.java pom.withXml { def dependenciesNode = asNode().appendNode('dependencies') - def dependencyNode = appendNode('dependency','') + def dependencyNode = dependenciesNode.appendNode('dependency','') dependencyNode.appendNode('groupId', 'net.brutex.cavis') dependencyNode.appendNode('artifactId', 'cavis-native-lib') dependencyNode.appendNode('version', '1.0.0-SNAPSHOT') From d4b53afe89c1d6f42841be820a0c5ec93cb44624 Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 20:21:36 +0200 Subject: [PATCH 093/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 1a2c7ba4f..231fa8e03 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -95,7 +95,7 @@ publishing { // from components.java pom.withXml { def dependenciesNode = asNode().appendNode('dependencies') - def dependencyNode = dependenciesNode.appendNode('dependency','') + def dependencyNode = dependenciesNode.appendNode('dependency') dependencyNode.appendNode('groupId', 'net.brutex.cavis') dependencyNode.appendNode('artifactId', 'cavis-native-lib') dependencyNode.appendNode('version', '1.0.0-SNAPSHOT') From ef4f3a98415f5a4cb1d5a93d80f0635d68b162cc Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 20:30:59 +0200 Subject: [PATCH 094/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 231fa8e03..d08b190b4 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -94,7 +94,7 @@ publishing { // artifact customFatJar // from components.java pom.withXml { - def dependenciesNode = asNode().appendNode('dependencies') + def dependenciesNode = asNode().dependencies def dependencyNode = dependenciesNode.appendNode('dependency') dependencyNode.appendNode('groupId', 'net.brutex.cavis') dependencyNode.appendNode('artifactId', 'cavis-native-lib') From a6d60a4cdbafc901709f3b7a389f882ad5cdb45e Mon Sep 17 00:00:00 2001 From: brian Date: Tue, 25 Oct 2022 20:47:28 +0200 Subject: [PATCH 095/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index d08b190b4..03cf9e177 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -93,15 +93,17 @@ publishing { mavenJava(MavenPublication) { // artifact customFatJar // from components.java - pom.withXml { + /* pom.withXml { def dependenciesNode = asNode().dependencies - def dependencyNode = dependenciesNode.appendNode('dependency') + def dependencyNode = dependenciesNode.appendNode() + dependencyNode.appendNode('groupId', 'net.brutex.cavis') dependencyNode.appendNode('artifactId', 'cavis-native-lib') dependencyNode.appendNode('version', '1.0.0-SNAPSHOT') //dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu') //dependencyNode.appendNode('scope', 'compile') } + */ } } } From 6abe96a1daf22778c91c27da1d426ae9e6178e15 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 10:19:22 +0200 Subject: [PATCH 096/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- ...inux-x86_64-docker-all-publish.jenkinsfile | 65 +++++++++++++++++++ cavis-native/cavis-native-lib/build.gradle | 20 +++--- 2 files changed, 75 insertions(+), 10 deletions(-) create mode 100644 .jenkins/linux-x86_64-docker-all-publish.jenkinsfile diff --git a/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile b/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile new file mode 100644 index 000000000..f79dfd59d --- /dev/null +++ b/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile @@ -0,0 +1,65 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +pipeline { + agent none + + stages { + stage { + parallel { + agent { + dockerfile { + filename 'Dockerfile' + dir '.docker' + label 'linux && docker && cuda' + //additionalBuildArgs '--build-arg version=1.0.2' + //args '--gpus all' --needed for test only, you can build without GPU + } + } + stage('prep-build-environment-linux-cuda') { + steps { + checkout scm + //sh 'nvidia-smi' + sh 'nvcc --version' + sh 'gcc --version' + sh 'cmake --version' + sh 'sh ./gradlew --version' + } + } + stage('build-linux-cuda') { + environment { + MAVEN = credentials('Internal Archiva') + OSSRH = credentials('OSSRH') + } + + steps { + withGradle { + sh 'sh ./gradlew build --stacktrace -x test -PCAVIS_CHIP=cuda \ + -Pmavenuser=$MAVEN_USR -Pmavenpass=$MAVEN_PSW \ + -PossrhUsername=$OSSRH_USR -PossrhPassword=$OSSRH_PSW' + } + //stash includes: '/cavis-native/cavis-native-lib/build/lib/*.jar', name: 'cuda-build' + } + } + } + } + } +} diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 5877b8db1..36fa1f765 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -456,19 +456,19 @@ javadoc { tasks.getByName("publishMavenJavaPublicationToLocalRemoteRepository") { enabled = true } - chipList.each { thisChip -> - artifacts { + +artifacts { + archives jar + chipList.each { thisChip -> archives tasks.getByName("${thisChip}SupportJar") } - } +} - - -chipList.each { thisChip -> - publishing { - publications { - mavenJava(MavenPublication) { - artifact jar +publishing { + publications { + mavenJava(MavenPublication) { + artifact jar + chipList.each { thisChip -> artifact tasks.getByName("${thisChip}SupportJar") } } From 6107c7efef89b1980dfef6c7dd9be3a6823acba8 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 10:20:53 +0200 Subject: [PATCH 097/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .jenkins/linux-x86_64-docker-all-publish.jenkinsfile | 2 -- 1 file changed, 2 deletions(-) diff --git a/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile b/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile index f79dfd59d..f853fdbd2 100644 --- a/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile @@ -23,7 +23,6 @@ pipeline { agent none stages { - stage { parallel { agent { dockerfile { @@ -60,6 +59,5 @@ pipeline { } } } - } } } From 286cf061ab9a8d127c8e75852e7aa3535c8e4a1c Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 10:25:27 +0200 Subject: [PATCH 098/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- ...inux-x86_64-docker-all-publish.jenkinsfile | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile b/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile index f853fdbd2..2e1708e57 100644 --- a/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile +++ b/.jenkins/linux-x86_64-docker-all-publish.jenkinsfile @@ -20,20 +20,22 @@ */ pipeline { - agent none + agent { + dockerfile { + filename 'Dockerfile' + dir '.docker' + label 'linux && docker && cuda' + //additionalBuildArgs '--build-arg version=1.0.2' + //args '--gpus all' --needed for test only, you can build without GPU + } + } stages { + stage("Build all chip") { parallel { - agent { - dockerfile { - filename 'Dockerfile' - dir '.docker' - label 'linux && docker && cuda' - //additionalBuildArgs '--build-arg version=1.0.2' - //args '--gpus all' --needed for test only, you can build without GPU - } - } + stage('prep-build-environment-linux-cuda') { + steps { checkout scm //sh 'nvidia-smi' @@ -59,5 +61,6 @@ pipeline { } } } + } } } From 75e1fb9005445f894634eadd41d5562ed7e7e802 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 12:59:25 +0200 Subject: [PATCH 099/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- .gitignore | 16 +- {.github => .old/.github}/ISSUE_TEMPLATE.md | 0 .../.github}/PULL_REQUEST_TEMPLATE.md | 0 .../action.yml | 0 .../action.yml | 0 .../install-arm-cross-compile/action.yml | 0 .../actions/install-cmake-linux/action.yml | 0 .../actions/install-protobuf-linux/action.yml | 0 .../actions/msys2-base-setup/action.yml | 0 .../actions/publish-gh-packages/action.yml | 0 .../actions/update-deps-linux/action.yml | 0 .../workflows/build-android-x86_64.yml | 0 .../workflows/build-deploy-android-arm32.yml | 0 .../workflows/build-deploy-android-arm64.yml | 0 .../workflows/build-deploy-linux-arm32.yml | 0 .../workflows/build-deploy-linux-arm64.yml | 0 .../build-deploy-linux-cuda-11.0.yml | 0 .../build-deploy-linux-cuda-11.2.yml | 0 .../workflows/build-deploy-linux-x86_64.yml | 0 .../.github}/workflows/build-deploy-mac.yml | 0 .../build-deploy-windows-cuda-11.0.yml | 0 .../build-deploy-windows-cuda-11.2.yml | 0 .../workflows/build-deploy-windows.yml | 0 .../workflows/cpu-integration-tests.yaml | 0 .../workflows/cpu-sanity-check-tests.yaml | 0 .../workflows/run-cpu-tests-sanity-checks.yml | 0 .../workflows/run-gpu-tests-sanity-checks.yml | 0 .../workflows/test_multiple_arch.yaml | 0 .../ADRs}/0001-SameDiff_File_Format.md | 0 {ADRs => .old/ADRs}/0002-ONNX_Runtime.md | 0 {ADRs => .old/ADRs}/0003-Import_IR.md | 0 .../ADRs}/0003-NdArray_Strides_ArmCompute.md | 0 {ADRs => .old/ADRs}/0004-Mapping_IR.md | 0 {ADRs => .old/ADRs}/0005-Interpreter.md | 0 .../change-cuda-versions.sh | 0 .../change-scala-versions.sh | 0 {contrib => .old/contrib}/README.md | 0 .../adr/0001-kotlin_dsl_as_source_of_truth.md | 0 ...separate_object_graph_for_serialization.md | 0 ...ing_with_inconsistencies_in_java_naming.md | 0 ...o_initialization_for_inplace_operations.md | 0 ...0005-optional_parameters_and_signatures.md | 0 .../codegen/adr/0006-op_specific_enums.md | 0 .../codegen/adr/0007-configuration_objects.md | 0 .../codegen/adr/0008-inheritance.md | 0 .../codegen/adr/0009-aliasing.md | 0 .../codegen/adr/0010-ir-codegen.md | 0 .../main/java/org/nd4j/codegen/Namespace.java | 0 .../main/java/org/nd4j/codegen/cli/CLI.java | 0 .../nd4j/codegen/impl/cpp/CppGenerator.java | 0 .../nd4j/codegen/impl/java/DocsGenerator.java | 0 .../codegen/impl/java/JavaPoetGenerator.java | 0 .../impl/java/Nd4jNamespaceGenerator.java | 0 .../codegen/impl/python/PythonGenerator.java | 0 .../nd4j/codegen/ir/SerializationTest.java | 0 .../java/org/nd4j/codegen/util/GenUtil.java | 0 .../org/nd4j/codegen/util/JsonMapper.java | 0 .../org/nd4j/codegen/api/CodeComponent.kt | 0 .../kotlin/org/nd4j/codegen/api/DataType.kt | 0 .../kotlin/org/nd4j/codegen/api/Language.kt | 0 .../kotlin/org/nd4j/codegen/api/LossReduce.kt | 0 .../kotlin/org/nd4j/codegen/api/Namespace.kt | 0 .../org/nd4j/codegen/api/NamespaceOps.kt | 0 .../main/kotlin/org/nd4j/codegen/api/Op.kt | 0 .../kotlin/org/nd4j/codegen/api/Registry.kt | 0 .../kotlin/org/nd4j/codegen/api/Variables.kt | 0 .../org/nd4j/codegen/api/doc/DocScope.kt | 0 .../org/nd4j/codegen/api/doc/DocSection.kt | 0 .../org/nd4j/codegen/api/doc/DocTokens.kt | 0 .../api/generator/ConstraintCodeGenerator.kt | 0 .../nd4j/codegen/api/generator/Generator.kt | 0 .../codegen/api/generator/GeneratorConfig.kt | 0 .../kotlin/org/nd4j/codegen/dsl/OpBuilder.kt | 0 .../impl/java/JavaConstraintCodeGenerator.kt | 0 .../python/KotlinExamplePythonGenerator.kt | 0 .../org/nd4j/codegen/util/GenerateOps.kt | 0 .../util/extract/ExtractFromExisting.kt | 0 .../util/extract/FindUsedParameterTypes.kt | 0 .../ops/org/nd4j/codegen/mixins/Mixins.kt | 0 .../main/ops/org/nd4j/codegen/ops/Bitwise.kt | 0 .../src/main/ops/org/nd4j/codegen/ops/CNN.kt | 0 .../main/ops/org/nd4j/codegen/ops/Image.kt | 0 .../main/ops/org/nd4j/codegen/ops/Linalg.kt | 0 .../src/main/ops/org/nd4j/codegen/ops/Math.kt | 0 .../ops/org/nd4j/codegen/ops/NeuralNetwork.kt | 0 .../src/main/ops/org/nd4j/codegen/ops/RNN.kt | 0 .../main/ops/org/nd4j/codegen/ops/Random.kt | 0 .../ops/org/nd4j/codegen/ops/SDBaseOps.kt | 0 .../main/ops/org/nd4j/codegen/ops/SDLoss.kt | 0 .../codegen/src/main/resources/logback.xml | 0 .../src/main/resources/namespaces/math.json | 0 .../src/main/resources/nd4j-op-defs-2.proto | 0 .../src/main/resources/onnx-op-defs.pb | Bin .../codegen/src/main/resources/onnx.pbtxt | 0 .../main/resources/tensorflowOpMappings.csv | 0 .../nd4j/codegen/dsl/DocsGeneratorTest.java | 0 .../org/nd4j/codegen/dsl/TestGeneration.java | 0 .../kotlin/org/nd4j/codegen/dsl/ConfigTest.kt | 0 .../org/nd4j/codegen/dsl/ConstraintTest.kt | 0 .../codegen/dsl/NamespaceInvariantTest.kt | 0 .../org/nd4j/codegen/dsl/OpBuilderTest.kt | 0 .../org/nd4j/codegen/dsl/OpInvariantTest.kt | 0 .../org/nd4j/codegen/ops/ConstructionTest.kt | 0 .../codegen/src/test/resources/lenet.onnx | Bin .../src/test/resources/lenet_frozen.pb | Bin .../codegen-tools/libnd4j-gen/README.md | 0 .../codegen-tools/libnd4j-gen/op-ir.proto | 0 .../codegen-tools/libnd4j-gen/pom.xml | 0 .../descriptor/OpDeclarationDescriptor.java | 0 .../java/org/nd4j/descriptor/ParseOpFile.java | 0 .../proposal/ArgDescriptorProposal.java | 0 .../proposal/ArgDescriptorSource.java | 0 .../impl/ArgDescriptorParserUtils.java | 0 .../impl/JavaSourceArgDescriptorSource.java | 0 .../impl/Libnd4jArgDescriptorSource.java | 0 .../codegen-tools/onnx-def-gen/README.md | 0 .../codegen-tools/onnx-def-gen/lenet.onnx | Bin .../onnx-def-gen/onnx-op-defs.pb | Bin .../codegen-tools/onnx-def-gen/onnx.pbtxt | 0 .../onnx-def-gen/onnx_def_gen.py | 0 .../codegen-tools/onnx-def-gen/save_test.py | 0 .../onnx-def-gen/test_onnx_lenet.py | 0 .../onnx-def-gen/test_op_def_gen.py | 0 {contrib => .old/contrib}/formatter.xml | 0 .../deeplearning4j}/.codeclimate.yml | 0 .../deeplearning4j}/CONTRIBUTORS.md | 0 .../deeplearning4j}/GITTER_GUIDELINES.md | 0 .../deeplearning4j}/README.md | 0 .../buildmultiplescalaversions.sh | 0 .../deeplearning4j-dataimport-solrj/pom.xml | 0 .../io/stream/TupleStreamDataSetIterator.java | 0 .../TupleStreamDataSetIteratorTest.java | 0 .../test/resources/solr/collection1/README | 0 .../solr/configsets/mini/conf/schema.xml | 0 .../solr/configsets/mini/conf/solrconfig.xml | 0 .../deeplearning4j-graph/pom.xml | 0 .../deeplearning4j/graph/api/BaseGraph.java | 0 .../org/deeplearning4j/graph/api/Edge.java | 0 .../org/deeplearning4j/graph/api/IGraph.java | 0 .../graph/api/IVertexSequence.java | 0 .../graph/api/NoEdgeHandling.java | 0 .../org/deeplearning4j/graph/api/Vertex.java | 0 .../graph/data/EdgeLineProcessor.java | 0 .../graph/data/GraphLoader.java | 0 .../graph/data/VertexLoader.java | 0 .../data/impl/DelimitedEdgeLineProcessor.java | 0 .../data/impl/DelimitedVertexLoader.java | 0 .../data/impl/WeightedEdgeLineProcessor.java | 0 .../graph/exception/NoEdgesException.java | 0 .../graph/exception/ParseException.java | 0 .../org/deeplearning4j/graph/graph/Graph.java | 0 .../graph/graph/VertexSequence.java | 0 .../graph/iterator/GraphWalkIterator.java | 0 .../graph/iterator/RandomWalkIterator.java | 0 .../iterator/WeightedRandomWalkIterator.java | 0 .../parallel/GraphWalkIteratorProvider.java | 0 .../RandomWalkGraphIteratorProvider.java | 0 ...ightedRandomWalkGraphIteratorProvider.java | 0 .../graph/models/BinaryTree.java | 0 .../graph/models/GraphVectors.java | 0 .../graph/models/deepwalk/DeepWalk.java | 0 .../graph/models/deepwalk/GraphHuffman.java | 0 .../embeddings/GraphVectorLookupTable.java | 0 .../models/embeddings/GraphVectorsImpl.java | 0 .../embeddings/InMemoryGraphLookupTable.java | 0 .../models/loader/GraphVectorSerializer.java | 0 .../vertexfactory/IntegerVertexFactory.java | 0 .../vertexfactory/StringVertexFactory.java | 0 .../graph/vertexfactory/VertexFactory.java | 0 .../vertexfactory/VoidVertexFactory.java | 0 .../graph/AssertTestsExtendedBaseClass.java | 0 .../graph/data/TestGraphLoading.java | 0 .../graph/data/TestGraphLoadingWeighted.java | 0 .../deeplearning4j/graph/graph/TestGraph.java | 0 .../deepwalk/DeepWalkGradientCheck.java | 0 .../graph/models/deepwalk/TestDeepWalk.java | 0 .../models/deepwalk/TestGraphHuffman.java | 0 .../deeplearning4j-tsne/pom.xml | 0 .../deeplearning4j/plot/BarnesHutTsne.java | 0 .../java/org/deeplearning4j/plot/Tsne.java | 0 .../org/deeplearning4j/plot/Test6058.java | 0 .../org/deeplearning4j/plot/TsneTest.java | 0 .../deeplearning4j-manifold/pom.xml | 0 .../deeplearning4j-modelexport-solr/pom.xml | 0 .../solr/handler/ModelTupleStream.java | 0 .../solr/ltr/model/ScoringModel.java | 0 .../ModelTupleStreamIntegrationTest.java | 0 .../solr/handler/ModelTupleStreamTest.java | 0 .../solr/ltr/model/ScoringModelTest.java | 0 .../test/resources/solr/collection1/README | 0 .../mini-expressible/conf/schema.xml | 0 .../mini-expressible/conf/solrconfig.xml | 0 .../deeplearning4j-json-server/pom.xml | 0 .../deeplearning4j/remote/DL4jServlet.java | 0 .../remote/JsonModelServer.java | 0 .../remote/AssertTestsExtendBaseClass.java | 0 .../remote/BinaryModelServerTest.java | 0 .../remote/JsonModelServerTest.java | 0 .../deeplearning4j/remote/ServletTest.java | 0 .../deeplearning4j/remote/helpers/House.java | 0 .../helpers/HouseToPredictedPriceAdapter.java | 0 .../remote/helpers/ImageConversionUtils.java | 0 .../remote/helpers/PredictedPrice.java | 0 .../src/test/resources/logback.xml | 0 .../deeplearning4j-remote/pom.xml | 0 .../deeplearning4j-scaleout/pom.xml | 0 .../spark/dl4j-spark-nlp-java8/pom.xml | 0 .../SparkParagraphVectors.java | 0 .../DocumentSequenceConvertFunction.java | 0 .../functions/KeySequenceConvertFunction.java | 0 .../sequencevectors/SparkSequenceVectors.java | 0 .../export/ExportContainer.java | 0 .../export/SparkModelExporter.java | 0 .../export/impl/HdfsModelExporter.java | 0 .../export/impl/VocabCacheExporter.java | 0 .../functions/BaseTokenizerFunction.java | 0 .../functions/CountFunction.java | 0 .../functions/DistributedFunction.java | 0 .../ElementsFrequenciesAccumulator.java | 0 .../functions/ExportFunction.java | 0 .../functions/ExtraCountFunction.java | 0 .../ExtraElementsFrequenciesAccumulator.java | 0 .../ListSequenceConvertFunction.java | 0 .../functions/PartitionTrainingFunction.java | 0 .../functions/TokenizerFunction.java | 0 .../functions/TrainingFunction.java | 0 .../functions/VocabRddFunctionFlat.java | 0 .../SparkElementsLearningAlgorithm.java | 0 .../SparkSequenceLearningAlgorithm.java | 0 .../elements/BaseSparkLearningAlgorithm.java | 0 .../learning/elements/SparkCBOW.java | 0 .../learning/elements/SparkSkipGram.java | 0 .../BaseSparkSequenceLearningAlgorithm.java | 0 .../learning/sequence/SparkDBOW.java | 0 .../learning/sequence/SparkDM.java | 0 .../primitives/ExtraCounter.java | 0 .../spark/models/word2vec/SparkWord2Vec.java | 0 .../SparkSequenceVectorsTest.java | 0 .../export/ExportContainerTest.java | 0 .../models/word2vec/SparkWord2VecTest.java | 0 .../src/test/resources/log4j.properties | 0 .../src/test/resources/logback.xml | 0 .../spark/dl4j-spark-nlp/.gitignore | 0 .../spark/dl4j-spark-nlp/pom.xml | 0 .../word2vec/FirstIterationFunction.java | 0 .../word2vec/MapToPairFunction.java | 0 .../embeddings/word2vec/NegativeHolder.java | 0 .../word2vec/SecondIterationFunction.java | 0 .../embeddings/word2vec/SentenceBatch.java | 0 .../embeddings/word2vec/VocabHolder.java | 0 .../models/embeddings/word2vec/Word2Vec.java | 0 .../embeddings/word2vec/Word2VecChange.java | 0 .../embeddings/word2vec/Word2VecFuncCall.java | 0 .../embeddings/word2vec/Word2VecParam.java | 0 .../word2vec/Word2VecPerformer.java | 0 .../word2vec/Word2VecPerformerVoid.java | 0 .../embeddings/word2vec/Word2VecSetup.java | 0 .../word2vec/Word2VecVariables.java | 0 .../MaxPerPartitionAccumulator.java | 0 .../accumulators/WordFreqAccumulator.java | 0 .../spark/text/functions/CountCumSum.java | 0 .../FoldBetweenPartitionFunction.java | 0 .../FoldWithinPartitionFunction.java | 0 .../functions/GetSentenceCountFunction.java | 0 .../MapPerPartitionVoidFunction.java | 0 .../text/functions/ReduceSentenceCount.java | 0 .../spark/text/functions/TextPipeline.java | 0 .../text/functions/TokenizerFunction.java | 0 .../UpdateWordFreqAccumulatorFunction.java | 0 .../WordsListToVocabWordsFunction.java | 0 .../embeddings/word2vec/Word2VecTest.java | 0 .../spark/text/BaseSparkTest.java | 0 .../spark/text/TestFunction.java | 0 .../spark/text/TextPipelineTest.java | 0 .../src/test/resources/log4j.properties | 0 .../src/test/resources/logback.xml | 0 .../spark/dl4j-spark-parameterserver/pom.xml | 0 .../ParameterServerSubscriber.java | 0 .../ParameterServerTrainingHook.java | 0 .../SharedTrainingAccumulationFunction.java | 0 .../SharedTrainingAccumulationTuple.java | 0 .../SharedTrainingAggregateFunction.java | 0 .../DataSetDeserializationCallback.java | 0 .../MultiDataSetDeserializationCallback.java | 0 .../callbacks/PortableDataStreamCallback.java | 0 .../PortableDataStreamMDSCallback.java | 0 .../conf/SharedTrainingConfiguration.java | 0 .../functions/SharedFlatMapDataSet.java | 0 .../functions/SharedFlatMapMultiDataSet.java | 0 .../functions/SharedFlatMapPaths.java | 0 .../functions/SharedFlatMapPathsMDS.java | 0 .../iterators/MultiPdsIterator.java | 0 .../iterators/PdsIterator.java | 0 .../iterators/VirtualDataSetIterator.java | 0 .../iterators/VirtualIterator.java | 0 .../VirtualMultiDataSetIterator.java | 0 .../elephas/ElephasModelImport.java | 0 .../networking/v1/SilentTrainingDriver.java | 0 .../networking/v1/WiredEncodingHandler.java | 0 .../SilentIntroductoryConfirmation.java | 0 .../messages/SilentIntroductoryMessage.java | 0 .../v1/messages/SilentUpdatesMessage.java | 0 .../networking/v2/ModelParamsConsumer.java | 0 .../networking/v2/UpdaterParamsConsumer.java | 0 .../networking/v2/UpdatesConsumer.java | 0 .../networking/v2/WiredEncodingHandler.java | 0 .../pw/SharedTrainingWrapper.java | 0 .../python/ArrayDescriptor.java | 0 .../python/DataSetDescriptor.java | 0 .../spark/parameterserver/python/Utils.java | 0 .../training/SharedTrainingMaster.java | 0 .../training/SharedTrainingResult.java | 0 .../training/SharedTrainingWorker.java | 0 .../util/BlockingObserver.java | 0 .../util/CountingIterator.java | 0 .../spark/parameterserver/BaseSparkTest.java | 0 ...haredTrainingAccumulationFunctionTest.java | 0 .../SharedTrainingAggregateFunctionTest.java | 0 .../iterators/VirtualDataSetIteratorTest.java | 0 .../iterators/VirtualIteratorTest.java | 0 .../elephas/TestElephasImport.java | 0 .../train/GradientSharingTrainingTest.java | 0 .../src/test/resources/log4j.properties | 0 .../src/test/resources/logback.xml | 0 .../spark/dl4j-spark/nd4j-native.properties | 0 .../spark/dl4j-spark/pom.xml | 0 .../spark/api/RDDTrainingApproach.java | 0 .../deeplearning4j/spark/api/Repartition.java | 0 .../spark/api/RepartitionStrategy.java | 0 .../spark/api/Repartitioner.java | 0 .../spark/api/TrainingHook.java | 0 .../spark/api/TrainingMaster.java | 0 .../spark/api/TrainingResult.java | 0 .../spark/api/TrainingWorker.java | 0 .../spark/api/WorkerConfiguration.java | 0 .../api/stats/CommonSparkTrainingStats.java | 0 .../spark/api/stats/SparkTrainingStats.java | 0 .../api/stats/StatsCalculationHelper.java | 0 .../api/worker/ExecuteWorkerFlatMap.java | 0 .../ExecuteWorkerMultiDataSetFlatMap.java | 0 .../api/worker/ExecuteWorkerPDSFlatMap.java | 0 .../worker/ExecuteWorkerPDSMDSFlatMap.java | 0 .../api/worker/ExecuteWorkerPathFlatMap.java | 0 .../worker/ExecuteWorkerPathMDSFlatMap.java | 0 .../spark/api/worker/NetBroadcastTuple.java | 0 .../data/BatchAndExportDataSetsFunction.java | 0 .../BatchAndExportMultiDataSetsFunction.java | 0 .../spark/data/BatchDataSetsFunction.java | 0 .../spark/data/DataSetExportFunction.java | 0 .../spark/data/DataSetProvider.java | 0 .../data/MultiDataSetExportFunction.java | 0 .../spark/data/MultiDataSetProvider.java | 0 .../spark/data/PathToDataSetFunction.java | 0 .../data/PathToMultiDataSetFunction.java | 0 .../spark/data/SplitDataSetsFunction.java | 0 .../spark/data/loader/RemoteFileSource.java | 0 .../data/loader/RemoteFileSourceFactory.java | 0 ...litDataSetExamplesPairFlatMapFunction.java | 0 .../datavec/DataVecByteDataSetFunction.java | 0 .../spark/datavec/DataVecDataSetFunction.java | 0 .../DataVecSequenceDataSetFunction.java | 0 .../DataVecSequencePairDataSetFunction.java | 0 .../spark/datavec/RDDMiniBatches.java | 0 .../spark/datavec/RecordReaderFunction.java | 0 .../export/StringToDataSetExportFunction.java | 0 .../spark/datavec/iterator/DataVecRecord.java | 0 .../datavec/iterator/DataVecRecords.java | 0 .../spark/datavec/iterator/IteratorUtils.java | 0 .../datavec/iterator/RRMDSIFunction.java | 0 .../iterator/SparkSourceDummyReader.java | 0 .../iterator/SparkSourceDummySeqReader.java | 0 .../BaseSparkEarlyStoppingTrainer.java | 0 .../SparkDataSetLossCalculator.java | 0 .../SparkEarlyStoppingGraphTrainer.java | 0 .../SparkEarlyStoppingTrainer.java | 0 .../SparkLossCalculatorComputationGraph.java | 0 .../spark/impl/SparkListenable.java | 0 .../deeplearning4j/spark/impl/common/Add.java | 0 .../impl/common/CountPartitionsFunction.java | 0 .../impl/common/LoadDataSetFunction.java | 0 .../impl/common/SplitPartitionsFunction.java | 0 .../impl/common/SplitPartitionsFunction2.java | 0 .../reduce/IntDoubleReduceFunction.java | 0 .../reduce/LongDoubleReduceFunction.java | 0 .../repartition/BalancedPartitioner.java | 0 .../common/repartition/EqualPartitioner.java | 0 .../HashingBalancedPartitioner.java | 0 .../repartition/MapTupleToPairFlatMap.java | 0 ...eVaeReconstructionProbWithKeyFunction.java | 0 .../score/BaseVaeScoreWithKeyFunction.java | 0 .../impl/evaluation/EvaluationRunner.java | 0 .../impl/graph/SparkComputationGraph.java | 0 .../dataset/DataSetToMultiDataSetFn.java | 0 .../dataset/PairDataSetToMultiDataSetFn.java | 0 .../IEvaluateMDSFlatMapFunction.java | 0 .../IEvaluateMDSPathsFlatMapFunction.java | 0 .../impl/graph/scoring/ArrayPairToPair.java | 0 ...VaeReconstructionErrorWithKeyFunction.java | 0 ...GVaeReconstructionProbWithKeyFunction.java | 0 .../GraphFeedForwardWithKeyFunction.java | 0 .../impl/graph/scoring/PairToArrayPair.java | 0 .../graph/scoring/ScoreExamplesFunction.java | 0 .../scoring/ScoreExamplesWithKeyFunction.java | 0 .../ScoreFlatMapFunctionCGDataSet.java | 0 .../ScoreFlatMapFunctionCGMultiDataSet.java | 0 .../listeners/VanillaStatsStorageRouter.java | 0 .../VanillaStatsStorageRouterProvider.java | 0 .../impl/multilayer/SparkDl4jMultiLayer.java | 0 .../IEvaluateAggregateFunction.java | 0 .../evaluation/IEvaluateFlatMapFunction.java | 0 .../evaluation/IEvaluationReduceFunction.java | 0 .../scoring/FeedForwardWithKeyFunction.java | 0 .../scoring/ScoreExamplesFunction.java | 0 .../scoring/ScoreExamplesWithKeyFunction.java | 0 .../scoring/ScoreFlatMapFunction.java | 0 .../scoring/SingleToPairFunction.java | 0 ...VaeReconstructionErrorWithKeyFunction.java | 0 .../VaeReconstructionProbWithKeyFunction.java | 0 .../impl/paramavg/BaseTrainingMaster.java | 0 .../impl/paramavg/BaseTrainingResult.java | 0 .../impl/paramavg/BaseTrainingWorker.java | 0 .../ParameterAveragingTrainingMaster.java | 0 .../ParameterAveragingTrainingResult.java | 0 .../ParameterAveragingTrainingWorker.java | 0 .../ParameterAveragingAggregationTuple.java | 0 .../ParameterAveragingElementAddFunction.java | 0 ...ameterAveragingElementCombineFunction.java | 0 ...ParameterAveragingTrainingMasterStats.java | 0 ...ParameterAveragingTrainingWorkerStats.java | 0 .../impl/paramavg/util/ExportSupport.java | 0 .../repartitioner/DefaultRepartitioner.java | 0 .../repartitioner/EqualRepartitioner.java | 0 .../impl/repartitioner/NoOpRepartitioner.java | 0 .../spark/iterator/BaseDataSetIterator.java | 0 .../iterator/PathSparkDataSetIterator.java | 0 .../PathSparkMultiDataSetIterator.java | 0 .../PortableDataStreamDataSetIterator.java | 0 ...ortableDataStreamMultiDataSetIterator.java | 0 .../spark/iterator/SparkADSI.java | 0 .../spark/iterator/SparkAMDSI.java | 0 .../spark/ordering/DataSetOrdering.java | 0 .../spark/stats/BaseEventStats.java | 0 .../spark/stats/EventStats.java | 0 .../spark/stats/ExampleCountEventStats.java | 0 .../spark/stats/PartitionCountEventStats.java | 0 .../spark/stats/StatsUtils.java | 0 .../spark/time/NTPTimeSource.java | 0 .../spark/time/SystemClockTimeSource.java | 0 .../deeplearning4j/spark/time/TimeSource.java | 0 .../spark/time/TimeSourceProvider.java | 0 .../deeplearning4j/spark/util/MLLibUtil.java | 0 .../spark/util/SparkDataUtils.java | 0 .../deeplearning4j/spark/util/SparkUtils.java | 0 .../spark/util/data/SparkDataValidation.java | 0 .../spark/util/data/ValidationResult.java | 0 .../data/validation/ValidateDataSetFn.java | 0 .../validation/ValidateMultiDataSetFn.java | 0 .../validation/ValidationResultReduceFn.java | 0 .../util/serde/StorageLevelDeserializer.java | 0 .../util/serde/StorageLevelSerializer.java | 0 .../org/apache/spark/TaskContextHelper.scala | 0 .../spark/BaseSparkKryoTest.java | 0 .../deeplearning4j/spark/BaseSparkTest.java | 0 .../spark/TestEarlyStoppingSpark.java | 0 .../TestEarlyStoppingSparkCompGraph.java | 0 .../org/deeplearning4j/spark/TestKryo.java | 0 .../deeplearning4j/spark/common/AddTest.java | 0 .../spark/data/TestShuffleExamples.java | 0 .../spark/data/TestSparkDataUtils.java | 0 .../spark/datavec/MiniBatchTests.java | 0 .../datavec/TestDataVecDataSetFunctions.java | 0 .../spark/datavec/TestExport.java | 0 .../spark/datavec/TestPreProcessedData.java | 0 .../datavec/iterator/TestIteratorUtils.java | 0 .../spark/impl/TestKryoWarning.java | 0 .../repartition/BalancedPartitionerTest.java | 0 .../HashingBalancedPartitionerTest.java | 0 .../impl/customlayer/TestCustomLayer.java | 0 .../impl/customlayer/layer/CustomLayer.java | 0 .../customlayer/layer/CustomLayerImpl.java | 0 .../impl/graph/TestSparkComputationGraph.java | 0 .../spark/impl/misc/TestFrozenLayers.java | 0 .../impl/multilayer/TestMiscFunctions.java | 0 .../multilayer/TestSparkDl4jMultiLayer.java | 0 ...arameterAveragingSparkVsSingleMachine.java | 0 .../spark/impl/paramavg/TestJsonYaml.java | 0 ...TestSparkMultiLayerParameterAveraging.java | 0 .../impl/paramavg/util/ExportSupportTest.java | 0 .../stats/TestTrainingStatsCollection.java | 0 .../spark/time/TestTimeSource.java | 0 .../spark/ui/TestListeners.java | 0 .../spark/util/MLLIbUtilTest.java | 0 .../spark/util/TestRepartitioning.java | 0 .../spark/util/TestValidation.java | 0 .../src/test/resources/log4j.properties | 0 .../dl4j-spark/src/test/resources/logback.xml | 0 .../deeplearning4j-scaleout/spark/pom.xml | 0 .../deeplearning4j}/pom.xml | 0 {nd4j => .old/nd4j}/README.md | 0 {nd4j => .old/nd4j}/RaspberryPi.md | 0 .../nd4j}/buildmultiplescalaversions.sh | 0 .../nd4j}/nd4j-jdbc/nd4j-jdbc-api/pom.xml | 0 .../nd4j/jdbc/driverfinder/DriverFinder.java | 0 .../nd4j/jdbc/loader/api/JDBCNDArrayIO.java | 0 .../org/nd4j/jdbc/loader/impl/BaseLoader.java | 0 .../nd4j}/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml | 0 .../java/org/nd4j/jdbc/hsql/HsqlLoader.java | 0 .../src/main/resources/nd4j.jdbc.properties | 0 .../org/nd4j/jdbc/hsql/HSqlLoaderTest.java | 0 .../nd4j}/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml | 0 .../java/org/nd4j/jdbc/mysql/MysqlLoader.java | 0 .../src/main/resources/nd4j.jdbc.properties | 0 .../org/nd4j/jdbc/mysql/MysqlLoaderTest.java | 0 {nd4j => .old/nd4j}/nd4j-jdbc/pom.xml | 0 {nd4j => .old/nd4j}/nd4j-onnxruntime/pom.xml | 0 .../onnxruntime/runner/OnnxRuntimeRunner.java | 0 .../org/nd4j/onnxruntime/util/ONNXUtils.java | 0 .../runner/OnnxRuntimeRunnerTests.java | 0 .../src/test/resources/add.onnx | 0 {nd4j => .old/nd4j}/nd4j-remote/README.md | 0 .../nd4j-remote/nd4j-grpc-client/pom.xml | 0 .../remote/grpc/GraphInferenceGrpcClient.java | 0 .../grpc/grpc/GraphInferenceServerGrpc.java | 0 .../graph/GraphInferenceGrpcClientTest.java | 0 .../nd4j-remote/nd4j-json-client/pom.xml | 0 .../remote/clients/JsonRemoteInference.java | 0 .../clients/serde/BinaryDeserializer.java | 0 .../clients/serde/BinarySerializer.java | 0 .../clients/serde/JsonDeserializer.java | 0 .../remote/clients/serde/JsonSerializer.java | 0 .../clients/serde/impl/AbstractSerDe.java | 0 .../clients/serde/impl/BooleanSerde.java | 0 .../clients/serde/impl/DoubleArraySerde.java | 0 .../clients/serde/impl/DoubleSerde.java | 0 .../clients/serde/impl/FloatArraySerde.java | 0 .../remote/clients/serde/impl/FloatSerde.java | 0 .../clients/serde/impl/IntegerSerde.java | 0 .../clients/serde/impl/StringSerde.java | 0 .../nd4j-remote/nd4j-json-server/README.md | 0 .../nd4j-remote/nd4j-json-server/pom.xml | 0 .../nd4j/remote/SameDiffJsonModelServer.java | 0 .../remote/serving/ModelServingServlet.java | 0 .../nd4j/remote/serving/SameDiffServlet.java | 0 .../nd4j/remote/serving/ServingProcessor.java | 0 .../remote/SameDiffJsonModelServerTest.java | 0 .../org/nd4j/remote/SameDiffServletTest.java | 0 .../java/org/nd4j/remote/helpers/House.java | 0 .../helpers/HouseToPredictedPriceAdapter.java | 0 .../nd4j/remote/helpers/PredictedPrice.java | 0 .../nd4j/remote/serde/BasicSerdeTests.java | 0 .../src/test/resources/logback.xml | 0 {nd4j => .old/nd4j}/nd4j-remote/pom.xml | 0 .../nd4j}/nd4j-serde/nd4j-arrow/pom.xml | 0 .../main/java/org/nd4j/arrow/ArrowSerde.java | 0 .../java/org/nd4j/arrow/DataBufferStruct.java | 0 .../java/org/nd4j/arrow/ArrowSerdeTest.java | 0 .../nd4j}/nd4j-serde/nd4j-kryo/pom.xml | 0 .../java/org/nd4j/kryo/Nd4jRegistrator.java | 0 .../java/org/nd4j/kryo/Nd4jSerializer.java | 0 .../primitives/AtomicDoubleSerializer.java | 0 .../org/nd4j/TestNd4jKryoSerialization.java | 0 {nd4j => .old/nd4j}/nd4j-serde/pom.xml | 0 {nd4j => .old/nd4j}/nd4j-tvm/pom.xml | 0 .../java/org/nd4j/tvm/runner/TvmRunner.java | 0 .../main/java/org/nd4j/tvm/util/TVMUtils.java | 0 .../org/nd4j/tvm/runner/TvmRunnerTests.java | 0 {nd4j => .old/nd4j}/pom.xml | 0 {nd4j => .old/nd4j}/samediff-import/pom.xml | 0 .../samediff-import-api/pom.xml | 0 .../frameworkimport/FrameworkImporter.kt | 0 .../frameworkimport/IRProtobufExtensions.kt | 0 .../samediff/frameworkimport/ImportGraph.kt | 0 .../frameworkimport/ImportGraphFactory.kt | 0 .../frameworkimport/ImportGraphHolder.kt | 0 .../context/AbstractMappingContext.kt | 0 .../frameworkimport/context/MappingContext.kt | 0 .../frameworkimport/hooks/PostImportHook.kt | 0 .../frameworkimport/hooks/PreImportHook.kt | 0 .../hooks/annotations/HookResult.kt | 0 .../hooks/annotations/PostHookRule.kt | 0 .../hooks/annotations/PreHookRule.kt | 0 .../samediff/frameworkimport/ir/IRArgDef.kt | 0 .../frameworkimport/ir/IRAttribute.kt | 0 .../samediff/frameworkimport/ir/IRDataType.kt | 0 .../frameworkimport/ir/IRDataTypeValue.kt | 0 .../frameworkimport/ir/IRFunctions.kt | 0 .../samediff/frameworkimport/ir/IRGraph.kt | 0 .../samediff/frameworkimport/ir/IRNode.kt | 0 .../samediff/frameworkimport/ir/IROpDef.kt | 0 .../samediff/frameworkimport/ir/IRTensor.kt | 0 .../mapper/MapperExtensions.kt | 0 .../opdefs/OpDescriptorLoader.kt | 0 .../opdefs/OpDescriptorLoaderHolder.kt | 0 .../process/AbstractMappingProcess.kt | 0 .../process/AbstractMappingProcessLoader.kt | 0 .../frameworkimport/process/MappingProcess.kt | 0 .../process/MappingProcessLoader.kt | 0 .../reflect/ImportReflectionCache.kt | 0 .../registry/ObjectRegistryHolder.kt | 0 .../registry/OpMappingRegistry.kt | 0 .../frameworkimport/rule/MappingRule.kt | 0 .../rule/attribute/ArgDescriptorConstant.kt | 0 .../rule/attribute/AttributeMappingRule.kt | 0 .../AttributeNDArrayToScalarAttribute.kt | 0 .../attribute/AttributeNumberListNDArray.kt | 0 .../AttributeScalarNDArrayAttribute.kt | 0 .../rule/attribute/AttributeValueType.kt | 0 .../attribute/BaseAttributeExtractionRule.kt | 0 .../ConditionalFieldValueIntIndexArrayRule.kt | 0 ...onditionalFieldValueIntIndexNDArrayRule.kt | 0 .../rule/attribute/DataTypeToInt.kt | 0 .../rule/attribute/FlattenDims.kt | 0 .../rule/attribute/IRMappingFunctions.kt | 0 .../rule/attribute/InvertBooleanNumber.kt | 0 .../ListAttributeValueLookupToIndex.kt | 0 .../rule/attribute/ListNumberToListNumber.kt | 0 .../rule/attribute/ListNumberToNDArray.kt | 0 .../rule/attribute/MapStringToInt.kt | 0 .../NDArrayAttributeToNDArrayInput.kt | 0 .../attribute/NDArrayExtractScalarValue.kt | 0 .../NDArrayInputToNumericalAttribute.kt | 0 .../rule/attribute/NDArraySizeAtRule.kt | 0 .../attribute/NDArrayToIntAttributeValue.kt | 0 .../rule/attribute/NumberToBoolean.kt | 0 .../SizeThresholdIntArrayIntIndexRule.kt | 0 .../attribute/StringAttributeToNDArray.kt | 0 .../attribute/StringContainsAdapterRule.kt | 0 .../rule/attribute/StringEqualsAdapterRule.kt | 0 .../attribute/StringNotEqualsAdapterRule.kt | 0 .../rule/attribute/StringToInt.kt | 0 .../rule/attribute/ValueMapping.kt | 0 .../rule/tensor/BaseNDArrayMappingRule.kt | 0 .../rule/tensor/MultiInputIndexMappingRule.kt | 0 .../tensor/PassThroughMultiTensorMapping.kt | 0 .../rule/tensor/TensorMappingRule.kt | 0 .../runner/DefaultImportRunner.kt | 0 .../frameworkimport/runner/IRGraphRunner.kt | 0 .../frameworkimport/runner/ImportRunner.kt | 0 .../src/main/resources/nd4j-op-def.pbtxt | 0 .../samediff-import-onnx/onnx-processes.pbtxt | 0 .../samediff-import-onnx/ops-added-new.txt | 0 .../samediff-import-onnx/ops-imported-new.txt | 0 .../samediff-import-onnx/ops-removed-new.txt | 0 .../samediff-import-onnx/pom.xml | 0 .../samediff/frameworkimport/onnx/OnnxIR.kt | 0 .../frameworkimport/onnx/OnnxImportGraph.kt | 0 .../onnx/OnnxImportGraphHolder.kt | 0 .../onnx/OnnxProtobufExtensions.kt | 0 .../onnx/OnnxRuleDeclarations.kt | 0 .../onnx/context/OnnxMappingContext.kt | 0 .../onnx/definitions/OnnxOpDeclarations.kt | 0 .../onnx/importer/OnnxFrameworkImporter.kt | 0 .../frameworkimport/onnx/ir/OnnxIRArgDef.kt | 0 .../frameworkimport/onnx/ir/OnnxIRAttr.kt | 0 .../frameworkimport/onnx/ir/OnnxIRDataType.kt | 0 .../frameworkimport/onnx/ir/OnnxIRGraph.kt | 0 .../onnx/ir/OnnxIRGraphRunner.kt | 0 .../frameworkimport/onnx/ir/OnnxIRNode.kt | 0 .../frameworkimport/onnx/ir/OnnxIROp.kt | 0 .../frameworkimport/onnx/ir/OnnxIRTensor.kt | 0 .../onnx/opdefs/OnnxOpDescriptorLoader.kt | 0 .../onnx/process/OnnxMappingProcess.kt | 0 .../onnx/process/OnnxMappingProcessLoader.kt | 0 .../attribute/OnnxArgDescriptorConstant.kt | 0 .../OnnxAttributeNDArrayToScalarAttribute.kt | 0 .../OnnxAttributeNumberListNDArray.kt | 0 .../OnnxAttributeScalarNDArrayAttribute.kt | 0 ...xConditionalFieldValueIntIndexArrayRule.kt | 0 ...onditionalFieldValueIntIndexNDArrayRule.kt | 0 .../onnx/rule/attribute/OnnxDataTypeToInt.kt | 0 .../onnx/rule/attribute/OnnxFlattenDims.kt | 0 .../rule/attribute/OnnxInvertBooleanNumber.kt | 0 .../OnnxListAttributeValueLookupToIndex.kt | 0 .../attribute/OnnxListNumberToListNumber.kt | 0 .../rule/attribute/OnnxListNumberToNDArray.kt | 0 .../onnx/rule/attribute/OnnxMapStringToInt.kt | 0 .../OnnxNDArrayAttributeToNDArrayInput.kt | 0 .../OnnxNDArrayExtractScalarValue.kt | 0 .../OnnxNDArrayInputToNumericalAttribute.kt | 0 .../onnx/rule/attribute/OnnxNDArraySizeAt.kt | 0 .../OnnxNDArrayToIntAttributeValue.kt | 0 .../OnnxSizeThresholdIntArrayIntIndexRule.kt | 0 .../attribute/OnnxStringAttributeToNDArray.kt | 0 .../OnnxStringContainsAdapterRule.kt | 0 .../attribute/OnnxStringEqualsAdapterRule.kt | 0 .../OnnxStringNotEqualsAdapterRule.kt | 0 .../onnx/rule/attribute/OnnxStringToIndex.kt | 0 .../onnx/rule/attribute/OnnxValueMapping.kt | 0 .../onnx/rule/tensor/NDArrayMappingRule.kt | 0 .../tensor/OnnxMultiInputIndexMappingRule.kt | 0 .../OnnxPassThroughMultiInputTensorMapping.kt | 0 ...samediff.frameworkimport.ImportGraphHolder | 0 ....frameworkimport.opdefs.OpDescriptorLoader | 0 .../main/resources/onnx-mapping-ruleset.pbtxt | 0 .../src/main/resources/onnx-op-def.pbtxt | 0 .../src/main/resources/onnx-op-defs.pb | Bin .../frameworkimport/onnx/TestOnnxIR.kt | 0 .../importer/TestOnnxFrameworkImporter.kt | 0 .../onnx/loader/TestOnnxProcessLoader.kt | 0 .../onnx/modelzoo/TestPretrainedModels.kt | 0 .../processing/GroupConvPreProcessingRule.kt | 0 .../src/test/resources/lenet.onnx | Bin .../variables-added-new.txt | 0 ...c93c-4ac9-867f-580443a45bb3-container.json | 0 ...47a3-4de9-9fc7-6691ea41aee0-container.json | 0 ...33-461b-4d6f-b6a8-a210cef103ff-result.json | 0 ...2d-1de4-4e85-b844-d90d20eea9fb-result.json | 0 ...c7-64e1-4e2c-a56e-05d390b459d7-result.json | 0 ...4d-ef52-441c-976b-3ef06799a362-result.json | 0 ...56da-4b92-983d-7385c888c80b-container.json | 0 ...053b-4dc9-a686-2040bb4f7fd3-container.json | 0 ...f7-cd22-4de1-80a6-b9890ce473fc-result.json | 0 ...ed9c-40b7-853c-d9462d2a67c0-container.json | 0 ...43-7f5f-4f5b-81db-e942139be1a7-result.json | 0 ...5f0d-4fe6-80ba-0628e9f3057b-container.json | 0 ...a05a-4382-a2fc-0835d2da893a-container.json | 0 ...d2-4a46-47d3-9748-578e7aae7121-result.json | 0 ...1d-1e98-4c83-b1ec-0812d352141d-result.json | 0 ...e3-ed5c-482a-89cf-0b4f46026b31-result.json | 0 ...d9-1d25-419e-a196-4a42f20fd8aa-result.json | 0 ...7c-ec8d-4567-9f3a-3b9bcb3d21f8-result.json | 0 ...2344-4dcc-a3d9-d23b5acbfe81-container.json | 0 ...e3-b7a8-492b-b0ab-ae77f112e105-result.json | 0 ...b0e3-446f-9e05-4fd86f350b83-container.json | 0 ...8b63-499e-9dc4-0206d0c38b29-container.json | 0 ...1a38-4b55-bc80-185feab4c978-container.json | 0 ...05d2-4692-ad17-af4aa388cb31-container.json | 0 ...25-9c6a-40b1-b825-916525e2cb24-result.json | 0 ...a732-411c-a65e-00111c6b550e-container.json | 0 ...1cb5-4edd-873b-c923d04905ec-container.json | 0 ...9a41-44aa-9b00-2bb9633a53be-container.json | 0 ...c322-458b-82e9-efd5494d37fc-container.json | 0 ...37-5ce9-4970-aa3d-7eaec8c8091a-result.json | 0 ...aa78-4897-a810-297802cccdfc-container.json | 0 ...6cd2-4a8d-82c7-9f45d15e8a73-container.json | 0 ...69-8c1d-440c-a135-174d7b873d11-result.json | 0 ...16-0d58-450b-85bb-ec61080f012f-result.json | 0 .../nd4j-op-def.pbtxt | 0 .../ops-added-new.txt | 0 .../ops-added-old.txt | 0 .../ops-imported-new.txt | 0 .../ops-imported-old.txt | 0 .../ops-removed-new.txt | 0 .../ops-removed-old.txt | 0 .../samediff-import-tensorflow/pom.xml | 0 .../tensorflow/TensorflowImportGraph.kt | 0 .../tensorflow/TensorflowImportGraphHolder.kt | 0 .../TensorflowProtobufExtensions.kt | 0 .../tensorflow/TensorflowRuleDeclarations.kt | 0 .../context/TensorflowMappingContext.kt | 0 .../definitions/TensorflowOpDeclarations.kt | 0 .../importer/TensorflowFrameworkImporter.kt | 0 .../tensorflow/ir/TensorflowIR.kt | 0 .../tensorflow/ir/TensorflowIRArgDef.kt | 0 .../tensorflow/ir/TensorflowIRAttr.kt | 0 .../tensorflow/ir/TensorflowIRDataType.kt | 0 .../tensorflow/ir/TensorflowIRGraph.kt | 0 .../tensorflow/ir/TensorflowIRGraphRunner.kt | 0 .../tensorflow/ir/TensorflowIRNode.kt | 0 .../tensorflow/ir/TensorflowIROp.kt | 0 .../tensorflow/ir/TensorflowIRTensor.kt | 0 .../opdefs/TensorflowOpDescriptorLoader.kt | 0 .../process/TensorflowMappingProcess.kt | 0 .../process/TensorflowMappingProcessLoader.kt | 0 .../TensorflowArgDescriptorConstant.kt | 0 ...orflowAttributeNDArrayToScalarAttribute.kt | 0 .../TensorflowAttributeNumberListNDArray.kt | 0 ...nsorflowAttributeScalarNDArrayAttribute.kt | 0 ...wConditionalFieldValueIntIndexArrayRule.kt | 0 ...onditionalFieldValueIntIndexNDArrayRule.kt | 0 .../rule/attribute/TensorflowDataTypeToInt.kt | 0 .../rule/attribute/TensorflowFlattenDims.kt | 0 .../TensorflowInvertBooleanNumber.kt | 0 ...nsorflowListAttributeValueLookupToIndex.kt | 0 .../TensorflowListNumberToListNumber.kt | 0 .../TensorflowListNumberToNDArray.kt | 0 .../attribute/TensorflowMapStringToInt.kt | 0 ...ensorflowNDArrayAttributeToNDArrayInput.kt | 0 .../TensorflowNDArrayExtractScalarValue.kt | 0 ...sorflowNDArrayInputToNumericalAttribute.kt | 0 .../rule/attribute/TensorflowNDArraySizeAt.kt | 0 .../TensorflowNDArrayToIntAttributeValue.kt | 0 .../TensorflowNdArrayToStringIndex.kt | 0 .../TensorflowStringAttributeToNDArray.kt | 0 .../TensorflowStringContainsAdapterRule.kt | 0 .../TensorflowStringEqualsAdapterRule.kt | 0 .../TensorflowStringNotEqualsAdapterRule.kt | 0 .../attribute/TensorflowValueMappingRule.kt | 0 .../rule/tensor/NDArrayMappingRule.kt | 0 .../TensorflowMultiInputIndexMappingRule.kt | 0 ...TensorflowPassThroughMultiTensorMapping.kt | 0 ...samediff.frameworkimport.ImportGraphHolder | 0 ....frameworkimport.opdefs.OpDescriptorLoader | 0 .../tensorflow-mapping-ruleset.pbtxt | 0 .../main/resources/tensorflow-op-def.pbtxt | 0 .../tensorflow/TestTensorflowIR.kt | 0 .../tensorflow/TestTensorflowUtils.kt | 0 .../importer/TestTensorflowImporter.kt | 0 .../loader/TestTensorflowProcessLoader.kt | 0 .../src/test/resources/lenet_frozen.pb | Bin .../src/test/resources/logback.xml | 0 .../tensorflow-processes.pbtxt | 0 .../samediff-import-tensorflow/test.pbtxt | 0 .../variables-added-new.txt | 0 .../variables-added-old.txt | 0 perform-release.sh => .old/perform-release.sh | 0 .../pydatavec}/.eggs/README.txt | 0 .../EGG-INFO/LICENSE | 0 .../EGG-INFO/PKG-INFO | 0 .../EGG-INFO/RECORD | 0 .../EGG-INFO/WHEEL | 0 .../EGG-INFO/entry_points.txt | 0 .../EGG-INFO/requires.txt | 0 .../EGG-INFO/top_level.txt | 0 .../.eggs/pytest_runner-5.2-py3.8.egg/ptr.py | 0 .../pydatavec}/pydatavec.egg-info/PKG-INFO | 0 .../pydatavec}/pydatavec.egg-info/SOURCES.txt | 0 .../pydatavec.egg-info/dependency_links.txt | 0 .../pydatavec.egg-info/requires.txt | 0 .../pydatavec.egg-info/top_level.txt | 0 {rl4j => .old/rl4j}/README.md | 0 {rl4j => .old/rl4j}/docs/images/cartpole.gif | Bin {rl4j => .old/rl4j}/docs/images/doom.gif | Bin {rl4j => .old/rl4j}/docs/images/malmo.gif | Bin {rl4j => .old/rl4j}/pom.xml | 0 {rl4j => .old/rl4j}/rl4j-ale/pom.xml | 0 .../deeplearning4j/rl4j/mdp/ale/ALEMDP.java | 0 {rl4j => .old/rl4j}/rl4j-api/pom.xml | 0 .../org/deeplearning4j/gym/StepReply.java | 0 .../java/org/deeplearning4j/rl4j/mdp/MDP.java | 0 .../rl4j/space/ActionSpace.java | 0 .../rl4j/space/ArrayObservationSpace.java | 0 .../org/deeplearning4j/rl4j/space/Box.java | 0 .../rl4j/space/DiscreteSpace.java | 0 .../deeplearning4j/rl4j/space/Encodable.java | 0 .../rl4j/space/HighLowDiscrete.java | 0 .../rl4j/space/ObservationSpace.java | 0 .../rl4j}/rl4j-core/nd4j-native.properties | 0 {rl4j => .old/rl4j}/rl4j-core/pom.xml | 0 .../org/deeplearning4j/rl4j/agent/Agent.java | 0 .../rl4j/agent/AgentLearner.java | 0 .../org/deeplearning4j/rl4j/agent/IAgent.java | 0 .../rl4j/agent/IAgentLearner.java | 0 .../learning/algorithm/IUpdateAlgorithm.java | 0 .../actorcritic/ActorCriticHelper.java | 0 .../actorcritic/AdvantageActorCritic.java | 0 .../NonRecurrentActorCriticHelper.java | 0 .../RecurrentActorCriticHelper.java | 0 .../algorithm/dqn/BaseDQNAlgorithm.java | 0 .../dqn/BaseTransitionTDAlgorithm.java | 0 .../learning/algorithm/dqn/DoubleDQN.java | 0 .../learning/algorithm/dqn/StandardDQN.java | 0 .../nstepqlearning/NStepQLearning.java | 0 .../nstepqlearning/NStepQLearningHelper.java | 0 .../NonRecurrentNStepQLearningHelper.java | 0 .../RecurrentNStepQLearningHelper.java | 0 .../learning/behavior/ILearningBehavior.java | 0 .../learning/behavior/LearningBehavior.java | 0 .../rl4j/agent/learning/update/Features.java | 0 .../learning/update/FeaturesBuilder.java | 0 .../agent/learning/update/FeaturesLabels.java | 0 .../rl4j/agent/learning/update/Gradients.java | 0 .../agent/learning/update/IUpdateRule.java | 0 .../agent/learning/update/UpdateRule.java | 0 .../update/updater/INeuralNetUpdater.java | 0 .../NeuralNetUpdaterConfiguration.java | 0 .../async/AsyncGradientsNeuralNetUpdater.java | 0 .../async/AsyncLabelsNeuralNetUpdater.java | 0 .../AsyncSharedNetworksUpdateHandler.java | 0 .../async/BaseAsyncNeuralNetUpdater.java | 0 .../sync/BaseSyncNeuralNetUpdater.java | 0 .../sync/SyncGradientsNeuralNetUpdater.java | 0 .../sync/SyncLabelsNeuralNetUpdater.java | 0 .../rl4j/agent/listener/AgentListener.java | 0 .../agent/listener/AgentListenerList.java | 0 .../builder/AdvantageActorCriticBuilder.java | 0 .../rl4j/builder/AsyncNetworkHandler.java | 0 .../rl4j/builder/BaseAgentLearnerBuilder.java | 0 .../builder/BaseAsyncAgentLearnerBuilder.java | 0 .../builder/BaseDQNAgentLearnerBuilder.java | 0 .../rl4j/builder/DoubleDQNBuilder.java | 0 .../rl4j/builder/INetworksHandler.java | 0 .../rl4j/builder/NStepQLearningBuilder.java | 0 .../rl4j/builder/StandardDQNBuilder.java | 0 .../rl4j/builder/SyncNetworkHandler.java | 0 .../rl4j/environment/Environment.java | 0 .../rl4j/environment/IActionSchema.java | 0 .../rl4j/environment/IntegerActionSchema.java | 0 .../rl4j/environment/Schema.java | 0 .../rl4j/environment/StepResult.java | 0 .../rl4j/experience/ExperienceHandler.java | 0 .../ReplayMemoryExperienceHandler.java | 0 .../StateActionExperienceHandler.java | 0 .../rl4j/experience/StateActionReward.java | 0 .../experience/StateActionRewardState.java | 0 .../rl4j/helper/INDArrayHelper.java | 0 .../rl4j/learning/HistoryProcessor.java | 0 .../rl4j/learning/IEpochTrainer.java | 0 .../rl4j/learning/IHistoryProcessor.java | 0 .../rl4j/learning/ILearning.java | 0 .../rl4j/learning/Learning.java | 0 .../rl4j/learning/NeuralNetFetchable.java | 0 .../rl4j/learning/async/AsyncGlobal.java | 0 .../rl4j/learning/async/AsyncLearning.java | 0 .../rl4j/learning/async/AsyncThread.java | 0 .../learning/async/AsyncThreadDiscrete.java | 0 .../rl4j/learning/async/IAsyncGlobal.java | 0 .../rl4j/learning/async/IAsyncLearning.java | 0 .../rl4j/learning/async/UpdateAlgorithm.java | 0 .../async/a3c/discrete/A3CDiscrete.java | 0 .../async/a3c/discrete/A3CDiscreteConv.java | 0 .../async/a3c/discrete/A3CDiscreteDense.java | 0 .../async/a3c/discrete/A3CThreadDiscrete.java | 0 .../AdvantageActorCriticUpdateAlgorithm.java | 0 .../discrete/AsyncNStepQLearningDiscrete.java | 0 .../AsyncNStepQLearningDiscreteConv.java | 0 .../AsyncNStepQLearningDiscreteDense.java | 0 .../AsyncNStepQLearningThreadDiscrete.java | 0 .../discrete/QLearningUpdateAlgorithm.java | 0 .../A3CLearningConfiguration.java | 0 .../AsyncQLearningConfiguration.java | 0 .../IAsyncLearningConfiguration.java | 0 .../configuration/ILearningConfiguration.java | 0 .../configuration/LearningConfiguration.java | 0 .../configuration/QLearningConfiguration.java | 0 .../learning/listener/TrainingListener.java | 0 .../listener/TrainingListenerList.java | 0 .../rl4j/learning/sync/ExpReplay.java | 0 .../rl4j/learning/sync/IExpReplay.java | 0 .../rl4j/learning/sync/SyncLearning.java | 0 .../learning/sync/qlearning/QLearning.java | 0 .../qlearning/discrete/QLearningDiscrete.java | 0 .../discrete/QLearningDiscreteConv.java | 0 .../discrete/QLearningDiscreteDense.java | 0 .../rl4j/mdp/CartpoleEnvironment.java | 0 .../rl4j/mdp/CartpoleNative.java | 0 .../rl4j/mdp/DoAsISayOrDont.java | 0 .../rl4j/mdp/TMazeEnvironment.java | 0 .../rl4j/mdp/robotlake/RobotLake.java | 0 .../rl4j/mdp/robotlake/RobotLakeHelper.java | 0 .../rl4j/mdp/robotlake/RobotLakeMap.java | 0 .../rl4j/mdp/robotlake/RobotLakeState.java | 0 .../rl4j/mdp/toy/HardDeteministicToy.java | 0 .../rl4j/mdp/toy/HardToyState.java | 0 .../rl4j/mdp/toy/SimpleToy.java | 0 .../rl4j/mdp/toy/SimpleToyState.java | 0 .../rl4j/network/ActorCriticNetwork.java | 0 .../rl4j/network/BaseNetwork.java | 0 .../network/ChannelToNetworkInputMapper.java | 0 .../rl4j/network/CommonGradientNames.java | 0 .../rl4j/network/CommonLabelNames.java | 0 .../rl4j/network/CommonOutputNames.java | 0 .../rl4j/network/CompoundNetworkHandler.java | 0 .../rl4j/network/ComputationGraphHandler.java | 0 .../rl4j/network/INetworkHandler.java | 0 .../rl4j/network/IOutputNeuralNet.java | 0 .../rl4j/network/ITrainableNeuralNet.java | 0 .../network/MultiLayerNetworkHandler.java | 0 .../rl4j/network/NetworkHelper.java | 0 .../rl4j/network/NeuralNet.java | 0 .../rl4j/network/NeuralNetOutput.java | 0 .../deeplearning4j/rl4j/network/QNetwork.java | 0 .../rl4j/network/ac/ActorCriticCompGraph.java | 0 .../ac/ActorCriticFactoryCompGraph.java | 0 .../ActorCriticFactoryCompGraphStdConv.java | 0 .../ActorCriticFactoryCompGraphStdDense.java | 0 .../ac/ActorCriticFactorySeparate.java | 0 .../ActorCriticFactorySeparateStdDense.java | 0 .../rl4j/network/ac/ActorCriticLoss.java | 0 .../rl4j/network/ac/ActorCriticSeparate.java | 0 .../rl4j/network/ac/IActorCritic.java | 0 .../ActorCriticDenseNetworkConfiguration.java | 0 .../ActorCriticNetworkConfiguration.java | 0 .../DQNDenseNetworkConfiguration.java | 0 .../configuration/NetworkConfiguration.java | 0 .../deeplearning4j/rl4j/network/dqn/DQN.java | 0 .../rl4j/network/dqn/DQNFactory.java | 0 .../rl4j/network/dqn/DQNFactoryStdConv.java | 0 .../rl4j/network/dqn/DQNFactoryStdDense.java | 0 .../deeplearning4j/rl4j/network/dqn/IDQN.java | 0 .../rl4j/observation/IObservationSource.java | 0 .../rl4j/observation/Observation.java | 0 .../EncodableToINDArrayTransform.java | 0 .../transform/FilterOperation.java | 0 .../transform/ResettableOperation.java | 0 .../transform/TransformProcess.java | 0 .../filter/UniformSkippingFilter.java | 0 .../EncodableToImageWritableTransform.java | 0 .../ImageWritableToINDArrayTransform.java | 0 .../operation/ArrayToINDArrayTransform.java | 0 .../operation/HistoryMergeTransform.java | 0 .../SimpleNormalizationTransform.java | 0 .../historymerge/CircularFifoStore.java | 0 .../historymerge/HistoryMergeAssembler.java | 0 .../HistoryMergeElementStore.java | 0 .../historymerge/HistoryStackAssembler.java | 0 .../deeplearning4j/rl4j/policy/ACPolicy.java | 0 .../rl4j/policy/BoltzmannQ.java | 0 .../deeplearning4j/rl4j/policy/DQNPolicy.java | 0 .../deeplearning4j/rl4j/policy/EpsGreedy.java | 0 .../rl4j/policy/INeuralNetPolicy.java | 0 .../deeplearning4j/rl4j/policy/IPolicy.java | 0 .../deeplearning4j/rl4j/policy/Policy.java | 0 .../rl4j/trainer/AsyncTrainer.java | 0 .../deeplearning4j/rl4j/trainer/ITrainer.java | 0 .../rl4j/trainer/SyncTrainer.java | 0 .../deeplearning4j/rl4j/util/Constants.java | 0 .../deeplearning4j/rl4j/util/DataManager.java | 0 .../util/DataManagerTrainingListener.java | 0 .../rl4j/util/IDataManager.java | 0 .../rl4j/util/LegacyMDPWrapper.java | 0 .../rl4j/util/VideoRecorder.java | 0 .../rl4j/AgentLearnerCartpole.java | 0 .../org/deeplearning4j/rl4j/NStepRnn.java | 0 .../deeplearning4j/rl4j/RobotLakeExample.java | 0 .../org/deeplearning4j/rl4j/TMazeExample.java | 0 .../rl4j/agent/AgentLearnerTest.java | 0 .../deeplearning4j/rl4j/agent/AgentTest.java | 0 .../NonRecurrentActorCriticHelperTest.java | 0 .../NonRecurrentAdvantageActorCriticTest.java | 0 .../RecurrentActorCriticHelperTest.java | 0 .../RecurrentAdvantageActorCriticTest.java | 0 .../learning/algorithm/dqn/DoubleDQNTest.java | 0 .../algorithm/dqn/StandardDQNTest.java | 0 .../NonRecurrentNStepQLearningHelperTest.java | 0 .../NonRecurrentNStepQLearningTest.java | 0 .../RecurrentNStepQLearningHelperTest.java | 0 .../RecurrentNStepQLearningTest.java | 0 .../behavior/LearningBehaviorTest.java | 0 .../learning/update/FeaturesBuilderTest.java | 0 .../learning/update/FeaturesLabelsTest.java | 0 .../agent/learning/update/FeaturesTest.java | 0 .../agent/learning/update/GradientsTest.java | 0 .../agent/learning/update/UpdateRuleTest.java | 0 .../AsyncGradientsNeuralNetUpdaterTest.java | 0 .../AsyncLabelsNeuralNetUpdaterTest.java | 0 .../AsyncSharedNetworksUpdateHandlerTest.java | 0 .../SyncGradientsNeuralNetUpdaterTest.java | 0 .../sync/SyncLabelsNeuralNetUpdaterTest.java | 0 .../builder/BaseAgentLearnerBuilderTest.java | 0 .../ReplayMemoryExperienceHandlerTest.java | 0 .../StateActionExperienceHandlerTest.java | 0 .../rl4j/helper/INDArrayHelperTest.java | 0 .../rl4j/learning/HistoryProcessorTest.java | 0 .../learning/async/AsyncLearningTest.java | 0 .../async/AsyncThreadDiscreteTest.java | 0 .../rl4j/learning/async/AsyncThreadTest.java | 0 ...vantageActorCriticUpdateAlgorithmTest.java | 0 .../AsyncTrainingListenerListTest.java | 0 .../QLearningUpdateAlgorithmTest.java | 0 .../listener/TrainingListenerListTest.java | 0 .../rl4j/learning/sync/ExpReplayTest.java | 0 .../sync/StateActionRewardStateTest.java | 0 .../rl4j/learning/sync/SyncLearningTest.java | 0 .../qlearning/QLearningConfigurationTest.java | 0 .../discrete/QLearningDiscreteTest.java | 0 .../rl4j/learning/sync/support/MockDQN.java | 0 .../learning/sync/support/MockStatEntry.java | 0 .../rl4j/network/ActorCriticNetworkTest.java | 0 .../rl4j/network/BaseNetworkTest.java | 0 .../ChannelToNetworkInputMapperTest.java | 0 .../network/CompoundNetworkHandlerTest.java | 0 .../network/ComputationGraphHandlerTest.java | 0 .../network/MultiLayerNetworkHandlerTest.java | 0 .../rl4j/network/NetworkHelperTest.java | 0 .../rl4j/network/QNetworkTest.java | 0 .../rl4j/network/ac/ActorCriticTest.java | 0 .../rl4j/network/dqn/DQNTest.java | 0 .../transform/TransformProcessTest.java | 0 .../filter/UniformSkippingFilterTest.java | 0 .../ArrayToINDArrayTransformTest.java | 0 .../operation/HistoryMergeTransformTest.java | 0 .../SimpleNormalizationTransformTest.java | 0 .../historymerge/CircularFifoStoreTest.java | 0 .../HistoryStackAssemblerTest.java | 0 .../rl4j/policy/PolicyTest.java | 0 .../deeplearning4j/rl4j/support/MockDQN.java | 0 .../rl4j/support/MockDataManager.java | 0 .../rl4j/support/MockHistoryProcessor.java | 0 .../deeplearning4j/rl4j/support/MockMDP.java | 0 .../rl4j/support/MockNeuralNet.java | 0 .../rl4j/support/MockObservation.java | 0 .../rl4j/support/MockObservationSpace.java | 0 .../rl4j/support/MockPolicy.java | 0 .../rl4j/support/MockRandom.java | 0 .../rl4j/trainer/AsyncTrainerTest.java | 0 .../rl4j/trainer/SyncTrainerTest.java | 0 .../util/DataManagerTrainingListenerTest.java | 0 {rl4j => .old/rl4j}/rl4j-doom/pom.xml | 0 .../rl4j/mdp/vizdoom/Basic.java | 0 .../rl4j/mdp/vizdoom/DeadlyCorridor.java | 0 .../rl4j/mdp/vizdoom/PredictPosition.java | 0 .../rl4j/mdp/vizdoom/TakeCover.java | 0 .../rl4j/mdp/vizdoom/VizDoom.java | 0 .../src/main/java/vizdoom/AutomapMode.java | 0 .../src/main/java/vizdoom/Button.java | 0 .../src/main/java/vizdoom/DoomGame.java | 0 .../vizdoom/FileDoesNotExistException.java | 0 .../src/main/java/vizdoom/GameState.java | 0 .../src/main/java/vizdoom/GameVariable.java | 0 .../src/main/java/vizdoom/Label.java | 0 .../java/vizdoom/MessageQueueException.java | 0 .../rl4j-doom/src/main/java/vizdoom/Mode.java | 0 .../src/main/java/vizdoom/ScreenFormat.java | 0 .../main/java/vizdoom/ScreenResolution.java | 0 .../java/vizdoom/SharedMemoryException.java | 0 .../main/java/vizdoom/SignalException.java | 0 .../java/vizdoom/ViZDoomErrorException.java | 0 .../vizdoom/ViZDoomIsNotRunningException.java | 0 .../ViZDoomUnexpectedExitException.java | 0 {rl4j => .old/rl4j}/rl4j-gym/pom.xml | 0 .../rl4j/mdp/gym/ActionTransformer.java | 0 .../deeplearning4j/rl4j/mdp/gym/GymEnv.java | 0 .../rl4j/mdp/gym/GymEnvTest.java | 0 {rl4j => .old/rl4j}/rl4j-malmo/pom.xml | 0 .../malmo/MalmoActionSpace.java | 0 .../malmo/MalmoActionSpaceDiscrete.java | 0 .../org/deeplearning4j/malmo/MalmoBox.java | 0 .../malmo/MalmoConnectionError.java | 0 .../malmo/MalmoDescretePositionPolicy.java | 0 .../org/deeplearning4j/malmo/MalmoEnv.java | 0 .../malmo/MalmoObservationPolicy.java | 0 .../malmo/MalmoObservationSpace.java | 0 .../malmo/MalmoObservationSpaceGrid.java | 0 .../malmo/MalmoObservationSpacePixels.java | 0 .../malmo/MalmoObservationSpacePosition.java | 0 .../malmo/MalmoResetHandler.java | 0 .../tensorflow-processes.pbtxt | 0 arbiter/.travis.yml | 24 - arbiter/README.md | 45 - arbiter/arbiter-core/pom.xml | 97 -- arbiter/arbiter-core/src/assembly/bin.xml | 91 -- .../optimize/api/AbstractParameterSpace.java | 74 - .../arbiter/optimize/api/Candidate.java | 57 - .../optimize/api/CandidateGenerator.java | 68 - .../optimize/api/OptimizationResult.java | 60 - .../arbiter/optimize/api/ParameterSpace.java | 81 - .../arbiter/optimize/api/TaskCreator.java | 62 - .../optimize/api/TaskCreatorProvider.java | 43 - .../api/adapter/ParameterSpaceAdapter.java | 82 - .../optimize/api/data/DataProvider.java | 54 - .../data/DataSetIteratorFactoryProvider.java | 89 -- .../arbiter/optimize/api/data/DataSource.java | 57 - .../api/evaluation/ModelEvaluator.java | 40 - .../api/saving/InMemoryResultSaver.java | 63 - .../optimize/api/saving/ResultReference.java | 37 - .../optimize/api/saving/ResultSaver.java | 57 - .../optimize/api/score/ScoreFunction.java | 75 - .../termination/MaxCandidatesCondition.java | 50 - .../api/termination/MaxTimeCondition.java | 81 - .../api/termination/TerminationCondition.java | 45 - .../config/OptimizationConfiguration.java | 226 --- .../DegenerateIntegerDistribution.java | 96 -- .../distribution/DistributionUtils.java | 149 -- .../distribution/LogUniformDistribution.java | 155 -- .../generator/BaseCandidateGenerator.java | 91 -- .../GeneticSearchCandidateGenerator.java | 187 --- .../GridSearchCandidateGenerator.java | 232 --- .../generator/RandomSearchGenerator.java | 93 -- .../generator/genetic/Chromosome.java | 42 - .../generator/genetic/ChromosomeFactory.java | 51 - .../crossover/ArithmeticCrossover.java | 120 -- .../genetic/crossover/CrossoverOperator.java | 45 - .../genetic/crossover/CrossoverResult.java | 43 - .../genetic/crossover/KPointCrossover.java | 178 --- .../crossover/SinglePointCrossover.java | 123 -- .../TwoParentsCrossoverOperator.java | 46 - .../genetic/crossover/UniformCrossover.java | 136 -- .../parentselection/ParentSelection.java | 44 - .../RandomTwoParentSelection.java | 65 - .../parentselection/TwoParentSelection.java | 25 - .../utils/CrossoverPointsGenerator.java | 68 - .../genetic/culling/CullOperator.java | 41 - .../genetic/culling/LeastFitCullOperator.java | 50 - .../genetic/culling/RatioCullOperator.java | 70 - .../GeneticGenerationException.java | 23 - .../genetic/mutation/MutationOperator.java | 33 - .../mutation/RandomMutationOperator.java | 93 -- .../EmptyPopulationInitializer.java | 41 - .../population/PopulationInitializer.java | 36 - .../population/PopulationListener.java | 35 - .../genetic/population/PopulationModel.java | 182 --- .../selection/GeneticSelectionOperator.java | 197 --- .../genetic/selection/SelectionOperator.java | 44 - .../generator/util/SerializedSupplier.java | 46 - .../optimize/parameter/BooleanSpace.java | 76 - .../optimize/parameter/FixedValue.java | 90 -- .../continuous/ContinuousParameterSpace.java | 137 -- .../discrete/DiscreteParameterSpace.java | 113 -- .../integer/IntegerParameterSpace.java | 151 -- .../optimize/parameter/math/MathOp.java | 69 - .../arbiter/optimize/parameter/math/Op.java | 76 - .../optimize/parameter/math/PairMathOp.java | 79 - .../runner/BaseOptimizationRunner.java | 383 ----- .../optimize/runner/CandidateInfo.java | 41 - .../optimize/runner/CandidateStatus.java | 24 - .../optimize/runner/IOptimizationRunner.java | 67 - .../runner/LocalOptimizationRunner.java | 150 -- .../runner/listener/BaseStatusListener.java | 54 - .../runner/listener/StatusChangeType.java | 26 - .../runner/listener/StatusListener.java | 60 - .../listener/impl/LoggingStatusListener.java | 57 - .../serde/jackson/FixedValueDeserializer.java | 52 - .../serde/jackson/FixedValueSerializer.java | 52 - .../IntegerDistributionDeserializer.java | 59 - .../IntegerDistributionSerializer.java | 74 - .../optimize/serde/jackson/JsonMapper.java | 77 - .../jackson/RealDistributionDeserializer.java | 78 - .../jackson/RealDistributionSerializer.java | 107 -- .../optimize/serde/jackson/YamlMapper.java | 52 - .../arbiter/util/ClassPathResource.java | 233 --- .../arbiter/util/CollectionUtils.java | 49 - .../arbiter/util/LeafUtils.java | 73 - .../arbiter/util/ObjectUtils.java | 61 - .../optimize/AssertTestsExtendBaseClass.java | 49 - .../arbiter/optimize/BraninFunction.java | 156 -- .../arbiter/optimize/TestGeneticSearch.java | 118 -- .../arbiter/optimize/TestGridSearch.java | 104 -- .../arbiter/optimize/TestJson.java | 122 -- .../arbiter/optimize/TestRandomSearch.java | 61 - .../optimize/distribution/TestLogUniform.java | 70 - .../genetic/TestCrossoverOperator.java | 40 - .../genetic/TestMutationOperator.java | 34 - .../optimize/genetic/TestParentSelection.java | 52 - .../genetic/TestPopulationInitializer.java | 30 - .../optimize/genetic/TestRandomGenerator.java | 88 -- .../crossover/ArithmeticCrossoverTests.java | 68 - .../crossover/CrossoverOperatorTests.java | 43 - .../CrossoverPointsGeneratorTests.java | 45 - .../crossover/KPointCrossoverTests.java | 67 - .../crossover/ParentSelectionTests.java | 39 - .../RandomTwoParentSelectionTests.java | 47 - .../crossover/SinglePointCrossoverTests.java | 68 - .../TwoParentsCrossoverOperatorTests.java | 71 - .../crossover/UniformCrossoverTests.java | 68 - .../culling/LeastFitCullOperatorTests.java | 62 - .../culling/RatioCullOperatorTests.java | 78 - .../mutation/RandomMutationOperatorTests.java | 73 - .../population/PopulationModelTests.java | 195 --- .../GeneticSelectionOperatorTests.java | 255 ---- .../selection/SelectionOperatorTests.java | 60 - .../parameter/TestParameterSpaces.java | 103 -- .../src/test/resources/logback.xml | 51 - arbiter/arbiter-deeplearning4j/pom.xml | 78 - .../arbiter/BaseNetworkSpace.java | 615 -------- .../arbiter/ComputationGraphSpace.java | 316 ---- .../arbiter/DL4JConfiguration.java | 73 - .../arbiter/GraphConfiguration.java | 67 - .../arbiter/MultiLayerSpace.java | 320 ---- .../ActivationParameterSpaceAdapter.java | 58 - .../LossFunctionParameterSpaceAdapter.java | 60 - .../arbiter/conf/dropout/DropoutSpace.java | 63 - .../arbiter/conf/updater/AdaGradSpace.java | 66 - .../arbiter/conf/updater/AdaMaxSpace.java | 83 -- .../arbiter/conf/updater/AdamSpace.java | 83 -- .../conf/updater/BaseUpdaterSpace.java | 70 - .../arbiter/conf/updater/NadamSpace.java | 83 -- .../arbiter/conf/updater/NesterovsSpace.java | 100 -- .../arbiter/conf/updater/RmsPropSpace.java | 54 - .../arbiter/conf/updater/SgdSpace.java | 54 - .../schedule/ExponentialScheduleSpace.java | 92 -- .../schedule/InverseScheduleSpace.java | 106 -- .../updater/schedule/PolyScheduleSpace.java | 106 -- .../schedule/SigmoidScheduleSpace.java | 106 -- .../updater/schedule/StepScheduleSpace.java | 106 -- .../data/DataSetIteratorFactoryProvider.java | 85 -- .../arbiter/data/MnistDataProvider.java | 80 - .../arbiter/dropout/AlphaDropoutSpace.java | 67 - .../arbiter/dropout/DropoutSpace.java | 67 - .../arbiter/dropout/GaussianDropoutSpace.java | 68 - .../arbiter/dropout/GaussianNoiseSpace.java | 67 - .../multilayer/ClassificationEvaluator.java | 68 - .../multilayer/RegressionDataEvaluator.java | 62 - .../layers/AbstractLSTMLayerSpace.java | 108 -- .../arbiter/layers/ActivationLayerSpace.java | 94 -- .../arbiter/layers/AutoEncoderLayerSpace.java | 107 -- .../layers/BaseConvolutionLayerSpace.java | 162 -- .../arbiter/layers/BaseLayerSpace.java | 292 ---- .../arbiter/layers/BaseOutputLayerSpace.java | 87 -- .../layers/BasePretrainNetworkLayerSpace.java | 57 - .../layers/BatchNormalizationSpace.java | 214 --- .../arbiter/layers/Bidirectional.java | 67 - .../layers/CenterLossOutputLayerSpace.java | 87 -- .../arbiter/layers/ConvolutionLayerSpace.java | 172 --- .../layers/Deconvolution2DLayerSpace.java | 52 - .../arbiter/layers/DenseLayerSpace.java | 90 -- .../arbiter/layers/DropoutLayerSpace.java | 89 -- .../arbiter/layers/EmbeddingLayerSpace.java | 88 -- .../arbiter/layers/FeedForwardLayerSpace.java | 154 -- .../layers/GlobalPoolingLayerSpace.java | 135 -- .../GravesBidirectionalLSTMLayerSpace.java | 97 -- .../arbiter/layers/GravesLSTMLayerSpace.java | 76 - .../arbiter/layers/LSTMLayerSpace.java | 77 - .../arbiter/layers/LayerSpace.java | 138 -- .../LocalResponseNormalizationLayerSpace.java | 119 -- .../arbiter/layers/LossLayerSpace.java | 105 -- .../arbiter/layers/OCNNLayerSpace.java | 153 -- .../arbiter/layers/OutputLayerSpace.java | 71 - .../arbiter/layers/RnnOutputLayerSpace.java | 71 - .../SeparableConvolution2DLayerSpace.java | 101 -- .../arbiter/layers/SubsamplingLayerSpace.java | 208 --- .../VariationalAutoencoderLayerSpace.java | 182 --- .../arbiter/layers/fixed/FixedLayerSpace.java | 71 - .../DL4JArbiterStatusReportingListener.java | 49 - .../arbiter/saver/local/FileModelSaver.java | 147 -- .../local/LocalFileNetResultReference.java | 103 -- .../arbiter/scoring/RegressionValue.java | 32 - .../arbiter/scoring/ScoreFunctions.java | 66 - .../scoring/impl/BaseNetScoreFunction.java | 103 -- .../scoring/impl/EvaluationScoreFunction.java | 86 -- .../scoring/impl/ROCScoreFunction.java | 122 -- .../scoring/impl/RegressionScoreFunction.java | 92 -- .../impl/TestSetAccuracyScoreFunction.java | 72 - .../scoring/impl/TestSetF1ScoreFunction.java | 72 - .../impl/TestSetLossScoreFunction.java | 78 - .../impl/TestSetRegressionScoreFunction.java | 85 -- .../arbiter/scoring/util/ScoreUtil.java | 328 ---- .../task/ComputationGraphTaskCreator.java | 267 ---- .../task/MultiLayerNetworkTaskCreator.java | 265 ---- .../arbiter/task/TaskListener.java | 49 - .../arbiter/AssertTestsExtendBaseClass.java | 50 - .../org/deeplearning4j/arbiter/TestUtils.java | 243 --- .../TestComputationGraphSpace.java | 168 --- .../TestGraphLocalExecution.java | 373 ----- .../TestGraphLocalExecutionGenetic.java | 212 --- .../deeplearning4j/arbiter/json/TestJson.java | 268 ---- .../MNISTOptimizationTest.java | 166 --- .../MnistDataSetIteratorFactory.java | 42 - .../TestDL4JLocalExecution.java | 381 ----- .../arbiter/multilayernetwork/TestErrors.java | 158 -- .../multilayernetwork/TestLayerSpace.java | 314 ---- .../TestMultiLayerSpace.java | 819 ---------- .../multilayernetwork/TestScoreFunctions.java | 220 --- .../util/TestDataFactoryProviderMnist.java | 46 - .../arbiter/util/TestDataProviderMnist.java | 61 - .../src/test/resources/logback.xml | 51 - arbiter/arbiter-server/pom.xml | 63 - .../arbiter/server/ArbiterCliGenerator.java | 286 ---- .../arbiter/server/ArbiterCliRunner.java | 152 -- .../server/cli/NeuralNetTypeValidator.java | 41 - .../server/cli/ProblemTypeValidator.java | 41 - .../arbiter/server/ArbiterCLIRunnerTest.java | 121 -- .../server/AssertTestsExtendBaseClass.java | 50 - .../server/MnistDataSetIteratorFactory.java | 43 - .../server/TestDataFactoryProviderMnist.java | 44 - arbiter/arbiter-ui/pom.xml | 73 - .../arbiter/ui/UpdateStatus.java | 33 - .../arbiter/ui/data/BaseJavaPersistable.java | 159 -- .../ui/data/GlobalConfigPersistable.java | 119 -- .../arbiter/ui/data/ModelInfoPersistable.java | 163 -- .../ui/listener/ArbiterStatusListener.java | 238 --- .../arbiter/ui/misc/JsonMapper.java | 76 - .../arbiter/ui/misc/UIUtils.java | 112 -- .../arbiter/ui/module/ArbiterModule.java | 943 ------------ .../org.deeplearning4j.ui.api.UIModule | 17 - .../deeplearning4jUiAssets/dl4j-ui.js | 1319 ----------------- .../deeplearning4jUiAssets/dl4j-ui.js.map | 1 - .../main/resources/templates/ArbiterUI.html | 638 -------- .../optimize/AssertTestsExtendBaseClass.java | 50 - .../arbiter/optimize/TestBasic.java | 791 ---------- .../arbiter-ui/src/test/resources/logback.xml | 51 - arbiter/buildmultiplescalaversions.sh | 53 - arbiter/contrib/formatter.xml | 353 ----- arbiter/pom.xml | 182 --- cavis-datavec/cavis-datavec-api/build.gradle | 74 +- .../java/org/datavec/api/package-info.java | 6 + .../api/records/reader/RecordReader.java | 213 +-- .../records/reader/impl/FileRecordReader.java | 370 ++--- .../datavec/api/writable/WritableFactory.java | 162 +- .../datavec/api/writable/WritableType.java | 119 +- cavis-full/build.gradle | 11 + cavis-native/cavis-native-lib/build.gradle | 25 +- chooseBackend.gradle | 11 +- createTestBackends.gradle | 10 +- 1374 files changed, 501 insertions(+), 28638 deletions(-) rename {.github => .old/.github}/ISSUE_TEMPLATE.md (100%) rename {.github => .old/.github}/PULL_REQUEST_TEMPLATE.md (100%) rename {.github => .old/.github}/actions/download-dl4j-test-resources-linux/action.yml (100%) rename {.github => .old/.github}/actions/download-dl4j-test-resources-windows/action.yml (100%) rename {.github => .old/.github}/actions/install-arm-cross-compile/action.yml (100%) rename {.github => .old/.github}/actions/install-cmake-linux/action.yml (100%) rename {.github => .old/.github}/actions/install-protobuf-linux/action.yml (100%) rename {.github => .old/.github}/actions/msys2-base-setup/action.yml (100%) rename {.github => .old/.github}/actions/publish-gh-packages/action.yml (100%) rename {.github => .old/.github}/actions/update-deps-linux/action.yml (100%) rename {.github => .old/.github}/workflows/build-android-x86_64.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-android-arm32.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-android-arm64.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-linux-arm32.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-linux-arm64.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-linux-cuda-11.0.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-linux-cuda-11.2.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-linux-x86_64.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-mac.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-windows-cuda-11.0.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-windows-cuda-11.2.yml (100%) rename {.github => .old/.github}/workflows/build-deploy-windows.yml (100%) rename {.github => .old/.github}/workflows/cpu-integration-tests.yaml (100%) rename {.github => .old/.github}/workflows/cpu-sanity-check-tests.yaml (100%) rename {.github => .old/.github}/workflows/run-cpu-tests-sanity-checks.yml (100%) rename {.github => .old/.github}/workflows/run-gpu-tests-sanity-checks.yml (100%) rename {.github => .old/.github}/workflows/test_multiple_arch.yaml (100%) rename {ADRs => .old/ADRs}/0001-SameDiff_File_Format.md (100%) rename {ADRs => .old/ADRs}/0002-ONNX_Runtime.md (100%) rename {ADRs => .old/ADRs}/0003-Import_IR.md (100%) rename {ADRs => .old/ADRs}/0003-NdArray_Strides_ArmCompute.md (100%) rename {ADRs => .old/ADRs}/0004-Mapping_IR.md (100%) rename {ADRs => .old/ADRs}/0005-Interpreter.md (100%) rename change-cuda-versions.sh => .old/change-cuda-versions.sh (100%) rename change-scala-versions.sh => .old/change-scala-versions.sh (100%) rename {contrib => .old/contrib}/README.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0001-kotlin_dsl_as_source_of_truth.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0002-separate_object_graph_for_serialization.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0003-dealing_with_inconsistencies_in_java_naming.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0004-auto_initialization_for_inplace_operations.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0005-optional_parameters_and_signatures.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0006-op_specific_enums.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0007-configuration_objects.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0008-inheritance.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0009-aliasing.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/adr/0010-ir-codegen.md (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/Namespace.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/cli/CLI.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/cpp/CppGenerator.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/DocsGenerator.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/JavaPoetGenerator.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/Nd4jNamespaceGenerator.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/python/PythonGenerator.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/GenUtil.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/JsonMapper.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/CodeComponent.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/DataType.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Language.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/LossReduce.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Namespace.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/NamespaceOps.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Op.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Registry.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Variables.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocScope.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocSection.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocTokens.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/ConstraintCodeGenerator.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/Generator.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/GeneratorConfig.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/java/JavaConstraintCodeGenerator.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/python/KotlinExamplePythonGenerator.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/GenerateOps.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/ExtractFromExisting.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/FindUsedParameterTypes.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/mixins/Mixins.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Bitwise.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/CNN.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Image.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Linalg.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Math.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/RNN.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Random.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDLoss.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/resources/logback.xml (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/resources/namespaces/math.json (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/resources/nd4j-op-defs-2.proto (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/resources/onnx-op-defs.pb (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/resources/onnx.pbtxt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/main/resources/tensorflowOpMappings.csv (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/TestGeneration.java (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConfigTest.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConstraintTest.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/NamespaceInvariantTest.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpBuilderTest.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpInvariantTest.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/ops/ConstructionTest.kt (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/resources/lenet.onnx (100%) rename {contrib => .old/contrib}/codegen-tools/codegen/src/test/resources/lenet_frozen.pb (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/README.md (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/op-ir.proto (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/pom.xml (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/OpDeclarationDescriptor.java (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/ParseOpFile.java (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorProposal.java (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorSource.java (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/ArgDescriptorParserUtils.java (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java (100%) rename {contrib => .old/contrib}/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java (100%) rename {contrib => .old/contrib}/codegen-tools/onnx-def-gen/README.md (100%) rename {contrib => .old/contrib}/codegen-tools/onnx-def-gen/lenet.onnx (100%) rename {contrib => .old/contrib}/codegen-tools/onnx-def-gen/onnx-op-defs.pb (100%) rename {contrib => .old/contrib}/codegen-tools/onnx-def-gen/onnx.pbtxt (100%) rename {contrib => .old/contrib}/codegen-tools/onnx-def-gen/onnx_def_gen.py (100%) rename {contrib => .old/contrib}/codegen-tools/onnx-def-gen/save_test.py (100%) rename {contrib => .old/contrib}/codegen-tools/onnx-def-gen/test_onnx_lenet.py (100%) rename {contrib => .old/contrib}/codegen-tools/onnx-def-gen/test_op_def_gen.py (100%) rename {contrib => .old/contrib}/formatter.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/.codeclimate.yml (100%) rename {deeplearning4j => .old/deeplearning4j}/CONTRIBUTORS.md (100%) rename {deeplearning4j => .old/deeplearning4j}/GITTER_GUIDELINES.md (100%) rename {deeplearning4j => .old/deeplearning4j}/README.md (100%) rename {deeplearning4j => .old/deeplearning4j}/buildmultiplescalaversions.sh (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-dataimport-solrj/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-dataimport-solrj/src/main/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-dataimport-solrj/src/test/resources/solr/collection1/README (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/schema.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/solrconfig.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/BaseGraph.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Edge.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IGraph.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IVertexSequence.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/NoEdgeHandling.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Vertex.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/EdgeLineProcessor.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/GraphLoader.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/VertexLoader.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedEdgeLineProcessor.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedVertexLoader.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/WeightedEdgeLineProcessor.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/NoEdgesException.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/ParseException.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/Graph.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/VertexSequence.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/GraphWalkIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/RandomWalkIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/WeightedRandomWalkIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/GraphWalkIteratorProvider.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/RandomWalkGraphIteratorProvider.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/WeightedRandomWalkGraphIteratorProvider.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/BinaryTree.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/GraphVectors.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/DeepWalk.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/GraphHuffman.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorLookupTable.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorsImpl.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/InMemoryGraphLookupTable.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/loader/GraphVectorSerializer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/IntegerVertexFactory.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/StringVertexFactory.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VertexFactory.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VoidVertexFactory.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/TsneTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-manifold/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStream.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModel.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/src/test/resources/solr/collection1/README (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/schema.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/solrconfig.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/ImageConversionUtils.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-remote/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/SparkParagraphVectors.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/DocumentSequenceConvertFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/KeySequenceConvertFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectors.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/SparkModelExporter.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/HdfsModelExporter.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/VocabCacheExporter.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/DistributedFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ElementsFrequenciesAccumulator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExportFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraCountFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraElementsFrequenciesAccumulator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ListSequenceConvertFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/PartitionTrainingFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TokenizerFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkElementsLearningAlgorithm.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkSequenceLearningAlgorithm.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/BaseSparkLearningAlgorithm.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkCBOW.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkSkipGram.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/BaseSparkSequenceLearningAlgorithm.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDBOW.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDM.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/primitives/ExtraCounter.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/word2vec/SparkWord2Vec.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/.gitignore (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/MapToPairFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecFuncCall.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/MaxPerPartitionAccumulator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/WordFreqAccumulator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/GetSentenceCountFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/MapPerPartitionVoidFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/ReduceSentenceCount.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/WordsListToVocabWordsFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerSubscriber.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationTuple.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/DataSetDeserializationCallback.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/MultiDataSetDeserializationCallback.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamCallback.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamMDSCallback.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/conf/SharedTrainingConfiguration.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/MultiPdsIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/PdsIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/ElephasModelImport.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/WiredEncodingHandler.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryConfirmation.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryMessage.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentUpdatesMessage.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/ModelParamsConsumer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdaterParamsConsumer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/WiredEncodingHandler.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/ArrayDescriptor.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/DataSetDescriptor.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/Utils.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingResult.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/BlockingObserver.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/CountingIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/nd4j-native.properties (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RDDTrainingApproach.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartition.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RepartitionStrategy.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartitioner.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingMaster.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingResult.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingWorker.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/WorkerConfiguration.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/CommonSparkTrainingStats.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/SparkTrainingStats.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetExportFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetProvider.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetExportFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetProvider.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToDataSetFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToMultiDataSetFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/SplitDataSetsFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSource.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSourceFactory.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/Add.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/CountPartitionsFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/LoadDataSetFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction2.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/IntDoubleReduceFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/LongDoubleReduceFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitioner.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/EqualPartitioner.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/DataSetToMultiDataSetFn.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/PairDataSetToMultiDataSetFn.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ArrayPairToPair.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/PairToArrayPair.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouter.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouterProvider.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateAggregateFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluationReduceFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/SingleToPairFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingMaster.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingResult.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingWorker.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingResult.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingAggregationTuple.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementAddFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementCombineFunction.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingMasterStats.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupport.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/DefaultRepartitioner.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/EqualRepartitioner.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/NoOpRepartitioner.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/BaseDataSetIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkMultiDataSetIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/ordering/DataSetOrdering.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/BaseEventStats.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/EventStats.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/PartitionCountEventStats.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/SystemClockTimeSource.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSource.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkDataUtils.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/SparkDataValidation.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/ValidationResult.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateDataSetFn.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateMultiDataSetFn.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidationResultReduceFn.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/main/scala/org/apache/spark/TaskContextHelper.scala (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/log4j.properties (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/deeplearning4j-scaleout/spark/pom.xml (100%) rename {deeplearning4j => .old/deeplearning4j}/pom.xml (100%) rename {nd4j => .old/nd4j}/README.md (100%) rename {nd4j => .old/nd4j}/RaspberryPi.md (100%) rename {nd4j => .old/nd4j}/buildmultiplescalaversions.sh (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-api/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/api/JDBCNDArrayIO.java (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/impl/BaseLoader.java (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-hsql/src/main/java/org/nd4j/jdbc/hsql/HsqlLoader.java (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-hsql/src/main/resources/nd4j.jdbc.properties (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-mysql/src/main/java/org/nd4j/jdbc/mysql/MysqlLoader.java (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-mysql/src/main/resources/nd4j.jdbc.properties (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java (100%) rename {nd4j => .old/nd4j}/nd4j-jdbc/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-onnxruntime/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunner.java (100%) rename {nd4j => .old/nd4j}/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/util/ONNXUtils.java (100%) rename {nd4j => .old/nd4j}/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java (100%) rename {nd4j => .old/nd4j}/nd4j-onnxruntime/src/test/resources/add.onnx (100%) rename {nd4j => .old/nd4j}/nd4j-remote/README.md (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-grpc-client/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/GraphInferenceGrpcClient.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/grpc/GraphInferenceServerGrpc.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinaryDeserializer.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinarySerializer.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/README.md (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java (100%) rename {nd4j => .old/nd4j}/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml (100%) rename {nd4j => .old/nd4j}/nd4j-remote/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-arrow/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/ArrowSerde.java (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/DataBufferStruct.java (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-kryo/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jRegistrator.java (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jSerializer.java (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/primitives/AtomicDoubleSerializer.java (100%) rename {nd4j => .old/nd4j}/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java (100%) rename {nd4j => .old/nd4j}/nd4j-serde/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-tvm/pom.xml (100%) rename {nd4j => .old/nd4j}/nd4j-tvm/src/main/java/org/nd4j/tvm/runner/TvmRunner.java (100%) rename {nd4j => .old/nd4j}/nd4j-tvm/src/main/java/org/nd4j/tvm/util/TVMUtils.java (100%) rename {nd4j => .old/nd4j}/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java (100%) rename {nd4j => .old/nd4j}/pom.xml (100%) rename {nd4j => .old/nd4j}/samediff-import/pom.xml (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/pom.xml (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/FrameworkImporter.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphFactory.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphHolder.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/AbstractMappingContext.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/MappingContext.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PostImportHook.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PreImportHook.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/HookResult.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PostHookRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PreHookRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRArgDef.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataType.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataTypeValue.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRFunctions.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRGraph.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRNode.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IROpDef.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRTensor.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/mapper/MapperExtensions.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoaderHolder.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcess.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcessLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcess.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcessLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/reflect/ImportReflectionCache.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/ObjectRegistryHolder.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/OpMappingRegistry.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/MappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ArgDescriptorConstant.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNDArrayToScalarAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNumberListNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeScalarNDArrayAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeValueType.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/BaseAttributeExtractionRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexArrayRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexNDArrayRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/DataTypeToInt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/FlattenDims.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/IRMappingFunctions.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/InvertBooleanNumber.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListAttributeValueLookupToIndex.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToListNumber.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/MapStringToInt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayAttributeToNDArrayInput.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayExtractScalarValue.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayInputToNumericalAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArraySizeAtRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayToIntAttributeValue.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NumberToBoolean.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/SizeThresholdIntArrayIntIndexRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringAttributeToNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringContainsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringEqualsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringNotEqualsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringToInt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ValueMapping.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/BaseNDArrayMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/MultiInputIndexMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/PassThroughMultiTensorMapping.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/TensorMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/DefaultImportRunner.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/IRGraphRunner.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/ImportRunner.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-api/src/main/resources/nd4j-op-def.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/onnx-processes.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/ops-added-new.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/ops-imported-new.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/ops-removed-new.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/pom.xml (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxIR.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraph.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraphHolder.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxProtobufExtensions.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxRuleDeclarations.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/context/OnnxMappingContext.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/definitions/OnnxOpDeclarations.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/OnnxFrameworkImporter.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRArgDef.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRAttr.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRDataType.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraph.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraphRunner.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRNode.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIROp.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRTensor.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/opdefs/OnnxOpDescriptorLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcess.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcessLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxArgDescriptorConstant.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNDArrayToScalarAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNumberListNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeScalarNDArrayAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexArrayRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexNDArrayRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxDataTypeToInt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxFlattenDims.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxInvertBooleanNumber.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListAttributeValueLookupToIndex.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToListNumber.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxMapStringToInt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayAttributeToNDArrayInput.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayExtractScalarValue.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayInputToNumericalAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArraySizeAt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayToIntAttributeValue.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxSizeThresholdIntArrayIntIndexRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringAttributeToNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringContainsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringEqualsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringNotEqualsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringToIndex.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxValueMapping.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/NDArrayMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxMultiInputIndexMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxPassThroughMultiInputTensorMapping.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/resources/onnx-mapping-ruleset.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-def.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-defs.pb (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/loader/TestOnnxProcessLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/processing/GroupConvPreProcessingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/src/test/resources/lenet.onnx (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-onnx/variables-added-new.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/00c6b5c8-c93c-4ac9-867f-580443a45bb3-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/03195bad-47a3-4de9-9fc7-6691ea41aee0-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/04c57933-461b-4d6f-b6a8-a210cef103ff-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/0541722d-1de4-4e85-b844-d90d20eea9fb-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/154f67c7-64e1-4e2c-a56e-05d390b459d7-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/2026534d-ef52-441c-976b-3ef06799a362-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/22943ac9-56da-4b92-983d-7385c888c80b-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/281f7eda-053b-4dc9-a686-2040bb4f7fd3-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/2cfad6f7-cd22-4de1-80a6-b9890ce473fc-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/36bc053b-ed9c-40b7-853c-d9462d2a67c0-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/445afe43-7f5f-4f5b-81db-e942139be1a7-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/505e442f-5f0d-4fe6-80ba-0628e9f3057b-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/52bec90b-a05a-4382-a2fc-0835d2da893a-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/58aad3d2-4a46-47d3-9748-578e7aae7121-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/5b55351d-1e98-4c83-b1ec-0812d352141d-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/629659e3-ed5c-482a-89cf-0b4f46026b31-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/6b35dfd9-1d25-419e-a196-4a42f20fd8aa-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/6f446d7c-ec8d-4567-9f3a-3b9bcb3d21f8-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/75db525b-2344-4dcc-a3d9-d23b5acbfe81-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/767f4fe3-b7a8-492b-b0ab-ae77f112e105-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/7b58c709-b0e3-446f-9e05-4fd86f350b83-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/7ce20a5d-8b63-499e-9dc4-0206d0c38b29-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/8b28bcc6-1a38-4b55-bc80-185feab4c978-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/9d55689b-05d2-4692-ad17-af4aa388cb31-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/9f6c1a25-9c6a-40b1-b825-916525e2cb24-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/a2400879-a732-411c-a65e-00111c6b550e-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/a24baaa5-1cb5-4edd-873b-c923d04905ec-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/b3b58b3b-9a41-44aa-9b00-2bb9633a53be-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/c0f861a5-c322-458b-82e9-efd5494d37fc-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/ca8a7a37-5ce9-4970-aa3d-7eaec8c8091a-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/e25409f2-aa78-4897-a810-297802cccdfc-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/e82fe95c-6cd2-4a8d-82c7-9f45d15e8a73-container.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/eb70b069-8c1d-440c-a135-174d7b873d11-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/allure-results/ef515f16-0d58-450b-85bb-ec61080f012f-result.json (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/nd4j-op-def.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/ops-added-new.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/ops-added-old.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/ops-imported-new.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/ops-imported-old.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/ops-removed-new.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/ops-removed-old.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/pom.xml (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraph.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraphHolder.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowProtobufExtensions.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowRuleDeclarations.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/context/TensorflowMappingContext.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TensorflowFrameworkImporter.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIR.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRArgDef.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRAttr.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRDataType.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraph.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraphRunner.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRNode.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIROp.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRTensor.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcess.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcessLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowArgDescriptorConstant.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNDArrayToScalarAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNumberListNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeScalarNDArrayAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexArrayRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexNDArrayRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowDataTypeToInt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowFlattenDims.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowInvertBooleanNumber.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListAttributeValueLookupToIndex.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToListNumber.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowMapStringToInt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayAttributeToNDArrayInput.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayExtractScalarValue.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayInputToNumericalAttribute.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArraySizeAt.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayToIntAttributeValue.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNdArrayToStringIndex.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringAttributeToNDArray.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringContainsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringEqualsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringNotEqualsAdapterRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowValueMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/NDArrayMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowMultiInputIndexMappingRule.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowPassThroughMultiTensorMapping.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-op-def.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowUtils.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/loader/TestTensorflowProcessLoader.kt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/test/resources/lenet_frozen.pb (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/src/test/resources/logback.xml (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/test.pbtxt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/variables-added-new.txt (100%) rename {nd4j => .old/nd4j}/samediff-import/samediff-import-tensorflow/variables-added-old.txt (100%) rename perform-release.sh => .old/perform-release.sh (100%) rename {pydatavec => .old/pydatavec}/.eggs/README.txt (100%) rename {pydatavec => .old/pydatavec}/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/LICENSE (100%) rename {pydatavec => .old/pydatavec}/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/PKG-INFO (100%) rename {pydatavec => .old/pydatavec}/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/RECORD (100%) rename {pydatavec => .old/pydatavec}/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/WHEEL (100%) rename {pydatavec => .old/pydatavec}/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/entry_points.txt (100%) rename {pydatavec => .old/pydatavec}/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/requires.txt (100%) rename {pydatavec => .old/pydatavec}/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/top_level.txt (100%) rename {pydatavec => .old/pydatavec}/.eggs/pytest_runner-5.2-py3.8.egg/ptr.py (100%) rename {pydatavec => .old/pydatavec}/pydatavec.egg-info/PKG-INFO (100%) rename {pydatavec => .old/pydatavec}/pydatavec.egg-info/SOURCES.txt (100%) rename {pydatavec => .old/pydatavec}/pydatavec.egg-info/dependency_links.txt (100%) rename {pydatavec => .old/pydatavec}/pydatavec.egg-info/requires.txt (100%) rename {pydatavec => .old/pydatavec}/pydatavec.egg-info/top_level.txt (100%) rename {rl4j => .old/rl4j}/README.md (100%) rename {rl4j => .old/rl4j}/docs/images/cartpole.gif (100%) rename {rl4j => .old/rl4j}/docs/images/doom.gif (100%) rename {rl4j => .old/rl4j}/docs/images/malmo.gif (100%) rename {rl4j => .old/rl4j}/pom.xml (100%) rename {rl4j => .old/rl4j}/rl4j-ale/pom.xml (100%) rename {rl4j => .old/rl4j}/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/pom.xml (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java (100%) rename {rl4j => .old/rl4j}/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/nd4j-native.properties (100%) rename {rl4j => .old/rl4j}/rl4j-core/pom.xml (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/IUpdateAlgorithm.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/ActorCriticHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseDQNAlgorithm.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQN.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQN.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearning.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearningHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/ILearningBehavior.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehavior.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Features.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabels.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Gradients.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/IUpdateRule.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRule.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/INeuralNetUpdater.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/NeuralNetUpdaterConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdater.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdater.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/BaseAsyncNeuralNetUpdater.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/BaseSyncNeuralNetUpdater.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdater.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdater.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AsyncNetworkHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAsyncAgentLearnerBuilder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/DoubleDQNBuilder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/INetworksHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/NStepQLearningBuilder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/StandardDQNBuilder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/SyncNetworkHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionReward.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionRewardState.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/NeuralNetFetchable.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/DoAsISayOrDont.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/TMazeEnvironment.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLake.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeMap.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeState.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ActorCriticNetwork.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/BaseNetwork.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonGradientNames.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonLabelNames.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonOutputNames.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ComputationGraphHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/INetworkHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NetworkHelper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNetOutput.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/QNetwork.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraph.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparate.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactory.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/IObservationSource.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransform.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/AsyncTrainer.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/ITrainer.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/SyncTrainer.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/Constants.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/AgentLearnerCartpole.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockStatEntry.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/pom.xml (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/Basic.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/DeadlyCorridor.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/PredictPosition.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/TakeCover.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/AutomapMode.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/Button.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/DoomGame.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/FileDoesNotExistException.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/GameState.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/GameVariable.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/Label.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/MessageQueueException.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/Mode.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/ScreenFormat.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/ScreenResolution.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/SharedMemoryException.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/SignalException.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/ViZDoomErrorException.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/ViZDoomIsNotRunningException.java (100%) rename {rl4j => .old/rl4j}/rl4j-doom/src/main/java/vizdoom/ViZDoomUnexpectedExitException.java (100%) rename {rl4j => .old/rl4j}/rl4j-gym/pom.xml (100%) rename {rl4j => .old/rl4j}/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/ActionTransformer.java (100%) rename {rl4j => .old/rl4j}/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java (100%) rename {rl4j => .old/rl4j}/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/pom.xml (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpaceDiscrete.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoConnectionError.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationPolicy.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java (100%) rename {rl4j => .old/rl4j}/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoResetHandler.java (100%) rename tensorflow-processes.pbtxt => .old/tensorflow-processes.pbtxt (100%) delete mode 100644 arbiter/.travis.yml delete mode 100644 arbiter/README.md delete mode 100644 arbiter/arbiter-core/pom.xml delete mode 100644 arbiter/arbiter-core/src/assembly/bin.xml delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/AbstractParameterSpace.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/Candidate.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/CandidateGenerator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/OptimizationResult.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/ParameterSpace.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/TaskCreator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/TaskCreatorProvider.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/adapter/ParameterSpaceAdapter.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataProvider.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataSetIteratorFactoryProvider.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataSource.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/evaluation/ModelEvaluator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/InMemoryResultSaver.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/ResultReference.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/ResultSaver.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/score/ScoreFunction.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/MaxCandidatesCondition.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/MaxTimeCondition.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/TerminationCondition.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/config/OptimizationConfiguration.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/DegenerateIntegerDistribution.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/DistributionUtils.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/BaseCandidateGenerator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/GeneticSearchCandidateGenerator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/GridSearchCandidateGenerator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/RandomSearchGenerator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/Chromosome.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/ChromosomeFactory.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/ArithmeticCrossover.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/CrossoverOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/CrossoverResult.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/KPointCrossover.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/SinglePointCrossover.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/TwoParentsCrossoverOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/UniformCrossover.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/ParentSelection.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/RandomTwoParentSelection.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/TwoParentSelection.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/utils/CrossoverPointsGenerator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/CullOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/LeastFitCullOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/RatioCullOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/exceptions/GeneticGenerationException.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/mutation/MutationOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/mutation/RandomMutationOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/EmptyPopulationInitializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationInitializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationListener.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationModel.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/selection/GeneticSelectionOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/selection/SelectionOperator.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/util/SerializedSupplier.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/BooleanSpace.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/discrete/DiscreteParameterSpace.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/MathOp.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/Op.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/PairMathOp.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/CandidateInfo.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/CandidateStatus.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/IOptimizationRunner.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/BaseStatusListener.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/StatusChangeType.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/StatusListener.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/impl/LoggingStatusListener.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/IntegerDistributionDeserializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/IntegerDistributionSerializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/RealDistributionDeserializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/RealDistributionSerializer.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ClassPathResource.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/CollectionUtils.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/LeafUtils.java delete mode 100644 arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ObjectUtils.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/BraninFunction.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestCrossoverOperator.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestMutationOperator.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestParentSelection.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestPopulationInitializer.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestRandomGenerator.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java delete mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java delete mode 100644 arbiter/arbiter-core/src/test/resources/logback.xml delete mode 100644 arbiter/arbiter-deeplearning4j/pom.xml delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/BaseNetworkSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/ComputationGraphSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/DL4JConfiguration.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/GraphConfiguration.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/MultiLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/adapter/ActivationParameterSpaceAdapter.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/adapter/LossFunctionParameterSpaceAdapter.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/dropout/DropoutSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdaGradSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdaMaxSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdamSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/BaseUpdaterSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/NadamSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/NesterovsSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/RmsPropSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/SgdSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/ExponentialScheduleSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/InverseScheduleSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/PolyScheduleSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/SigmoidScheduleSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/StepScheduleSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/data/DataSetIteratorFactoryProvider.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/data/MnistDataProvider.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/AlphaDropoutSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/DropoutSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianDropoutSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianNoiseSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/evaluator/multilayer/ClassificationEvaluator.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/evaluator/multilayer/RegressionDataEvaluator.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AbstractLSTMLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ActivationLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AutoEncoderLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseConvolutionLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BasePretrainNetworkLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BatchNormalizationSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Bidirectional.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/CenterLossOutputLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ConvolutionLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Deconvolution2DLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DenseLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DropoutLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/EmbeddingLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/FeedForwardLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GlobalPoolingLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesBidirectionalLSTMLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesLSTMLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LSTMLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LocalResponseNormalizationLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LossLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OCNNLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OutputLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/RnnOutputLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SeparableConvolution2DLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SubsamplingLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/VariationalAutoencoderLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/fixed/FixedLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/listener/DL4JArbiterStatusReportingListener.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/saver/local/FileModelSaver.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/saver/local/LocalFileNetResultReference.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/RegressionValue.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/ScoreFunctions.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/BaseNetScoreFunction.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/EvaluationScoreFunction.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetAccuracyScoreFunction.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetF1ScoreFunction.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetLossScoreFunction.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetRegressionScoreFunction.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/util/ScoreUtil.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/ComputationGraphTaskCreator.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/MultiLayerNetworkTaskCreator.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/TaskListener.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/TestUtils.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MnistDataSetIteratorFactory.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataFactoryProviderMnist.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/util/TestDataProviderMnist.java delete mode 100644 arbiter/arbiter-deeplearning4j/src/test/resources/logback.xml delete mode 100644 arbiter/arbiter-server/pom.xml delete mode 100644 arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliGenerator.java delete mode 100644 arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliRunner.java delete mode 100644 arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/cli/NeuralNetTypeValidator.java delete mode 100644 arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/cli/ProblemTypeValidator.java delete mode 100644 arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java delete mode 100644 arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java delete mode 100644 arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java delete mode 100644 arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java delete mode 100644 arbiter/arbiter-ui/pom.xml delete mode 100644 arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/UpdateStatus.java delete mode 100644 arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/BaseJavaPersistable.java delete mode 100644 arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java delete mode 100644 arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/ModelInfoPersistable.java delete mode 100644 arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java delete mode 100644 arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/JsonMapper.java delete mode 100644 arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/UIUtils.java delete mode 100644 arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java delete mode 100644 arbiter/arbiter-ui/src/main/resources/META-INF/services/org.deeplearning4j.ui.api.UIModule delete mode 100644 arbiter/arbiter-ui/src/main/resources/deeplearning4jUiAssets/dl4j-ui.js delete mode 100644 arbiter/arbiter-ui/src/main/resources/deeplearning4jUiAssets/dl4j-ui.js.map delete mode 100644 arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html delete mode 100644 arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java delete mode 100644 arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java delete mode 100644 arbiter/arbiter-ui/src/test/resources/logback.xml delete mode 100644 arbiter/buildmultiplescalaversions.sh delete mode 100644 arbiter/contrib/formatter.xml delete mode 100644 arbiter/pom.xml create mode 100644 cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/package-info.java diff --git a/.gitignore b/.gitignore index fbe938d6a..6c39e54b9 100644 --- a/.gitignore +++ b/.gitignore @@ -50,12 +50,12 @@ release.properties *.dylib .vs/ .vscode/ -nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/resources/bin -nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/writeNumpy.csv -nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/data-all* -nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/checkpoint -nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/onnx/ -nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/tensorflow/ +.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/resources/bin +.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/test/resources/writeNumpy.csv +.old/nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/data-all* +.old/nd4j/nd4j-backends/nd4j-tests/src/test/resources/tf_graphs/examples/**/checkpoint +.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/onnx/ +.old/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/tensorflow/ doc_sources/ doc_sources_* @@ -67,8 +67,8 @@ venv/ venv2/ # Ignore the nd4j files that are created by javacpp at build to stop merge conflicts -nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java -nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +.old/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java # Ignore meld temp files *.orig diff --git a/.github/ISSUE_TEMPLATE.md b/.old/.github/ISSUE_TEMPLATE.md similarity index 100% rename from .github/ISSUE_TEMPLATE.md rename to .old/.github/ISSUE_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.old/.github/PULL_REQUEST_TEMPLATE.md similarity index 100% rename from .github/PULL_REQUEST_TEMPLATE.md rename to .old/.github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/actions/download-dl4j-test-resources-linux/action.yml b/.old/.github/actions/download-dl4j-test-resources-linux/action.yml similarity index 100% rename from .github/actions/download-dl4j-test-resources-linux/action.yml rename to .old/.github/actions/download-dl4j-test-resources-linux/action.yml diff --git a/.github/actions/download-dl4j-test-resources-windows/action.yml b/.old/.github/actions/download-dl4j-test-resources-windows/action.yml similarity index 100% rename from .github/actions/download-dl4j-test-resources-windows/action.yml rename to .old/.github/actions/download-dl4j-test-resources-windows/action.yml diff --git a/.github/actions/install-arm-cross-compile/action.yml b/.old/.github/actions/install-arm-cross-compile/action.yml similarity index 100% rename from .github/actions/install-arm-cross-compile/action.yml rename to .old/.github/actions/install-arm-cross-compile/action.yml diff --git a/.github/actions/install-cmake-linux/action.yml b/.old/.github/actions/install-cmake-linux/action.yml similarity index 100% rename from .github/actions/install-cmake-linux/action.yml rename to .old/.github/actions/install-cmake-linux/action.yml diff --git a/.github/actions/install-protobuf-linux/action.yml b/.old/.github/actions/install-protobuf-linux/action.yml similarity index 100% rename from .github/actions/install-protobuf-linux/action.yml rename to .old/.github/actions/install-protobuf-linux/action.yml diff --git a/.github/actions/msys2-base-setup/action.yml b/.old/.github/actions/msys2-base-setup/action.yml similarity index 100% rename from .github/actions/msys2-base-setup/action.yml rename to .old/.github/actions/msys2-base-setup/action.yml diff --git a/.github/actions/publish-gh-packages/action.yml b/.old/.github/actions/publish-gh-packages/action.yml similarity index 100% rename from .github/actions/publish-gh-packages/action.yml rename to .old/.github/actions/publish-gh-packages/action.yml diff --git a/.github/actions/update-deps-linux/action.yml b/.old/.github/actions/update-deps-linux/action.yml similarity index 100% rename from .github/actions/update-deps-linux/action.yml rename to .old/.github/actions/update-deps-linux/action.yml diff --git a/.github/workflows/build-android-x86_64.yml b/.old/.github/workflows/build-android-x86_64.yml similarity index 100% rename from .github/workflows/build-android-x86_64.yml rename to .old/.github/workflows/build-android-x86_64.yml diff --git a/.github/workflows/build-deploy-android-arm32.yml b/.old/.github/workflows/build-deploy-android-arm32.yml similarity index 100% rename from .github/workflows/build-deploy-android-arm32.yml rename to .old/.github/workflows/build-deploy-android-arm32.yml diff --git a/.github/workflows/build-deploy-android-arm64.yml b/.old/.github/workflows/build-deploy-android-arm64.yml similarity index 100% rename from .github/workflows/build-deploy-android-arm64.yml rename to .old/.github/workflows/build-deploy-android-arm64.yml diff --git a/.github/workflows/build-deploy-linux-arm32.yml b/.old/.github/workflows/build-deploy-linux-arm32.yml similarity index 100% rename from .github/workflows/build-deploy-linux-arm32.yml rename to .old/.github/workflows/build-deploy-linux-arm32.yml diff --git a/.github/workflows/build-deploy-linux-arm64.yml b/.old/.github/workflows/build-deploy-linux-arm64.yml similarity index 100% rename from .github/workflows/build-deploy-linux-arm64.yml rename to .old/.github/workflows/build-deploy-linux-arm64.yml diff --git a/.github/workflows/build-deploy-linux-cuda-11.0.yml b/.old/.github/workflows/build-deploy-linux-cuda-11.0.yml similarity index 100% rename from .github/workflows/build-deploy-linux-cuda-11.0.yml rename to .old/.github/workflows/build-deploy-linux-cuda-11.0.yml diff --git a/.github/workflows/build-deploy-linux-cuda-11.2.yml b/.old/.github/workflows/build-deploy-linux-cuda-11.2.yml similarity index 100% rename from .github/workflows/build-deploy-linux-cuda-11.2.yml rename to .old/.github/workflows/build-deploy-linux-cuda-11.2.yml diff --git a/.github/workflows/build-deploy-linux-x86_64.yml b/.old/.github/workflows/build-deploy-linux-x86_64.yml similarity index 100% rename from .github/workflows/build-deploy-linux-x86_64.yml rename to .old/.github/workflows/build-deploy-linux-x86_64.yml diff --git a/.github/workflows/build-deploy-mac.yml b/.old/.github/workflows/build-deploy-mac.yml similarity index 100% rename from .github/workflows/build-deploy-mac.yml rename to .old/.github/workflows/build-deploy-mac.yml diff --git a/.github/workflows/build-deploy-windows-cuda-11.0.yml b/.old/.github/workflows/build-deploy-windows-cuda-11.0.yml similarity index 100% rename from .github/workflows/build-deploy-windows-cuda-11.0.yml rename to .old/.github/workflows/build-deploy-windows-cuda-11.0.yml diff --git a/.github/workflows/build-deploy-windows-cuda-11.2.yml b/.old/.github/workflows/build-deploy-windows-cuda-11.2.yml similarity index 100% rename from .github/workflows/build-deploy-windows-cuda-11.2.yml rename to .old/.github/workflows/build-deploy-windows-cuda-11.2.yml diff --git a/.github/workflows/build-deploy-windows.yml b/.old/.github/workflows/build-deploy-windows.yml similarity index 100% rename from .github/workflows/build-deploy-windows.yml rename to .old/.github/workflows/build-deploy-windows.yml diff --git a/.github/workflows/cpu-integration-tests.yaml b/.old/.github/workflows/cpu-integration-tests.yaml similarity index 100% rename from .github/workflows/cpu-integration-tests.yaml rename to .old/.github/workflows/cpu-integration-tests.yaml diff --git a/.github/workflows/cpu-sanity-check-tests.yaml b/.old/.github/workflows/cpu-sanity-check-tests.yaml similarity index 100% rename from .github/workflows/cpu-sanity-check-tests.yaml rename to .old/.github/workflows/cpu-sanity-check-tests.yaml diff --git a/.github/workflows/run-cpu-tests-sanity-checks.yml b/.old/.github/workflows/run-cpu-tests-sanity-checks.yml similarity index 100% rename from .github/workflows/run-cpu-tests-sanity-checks.yml rename to .old/.github/workflows/run-cpu-tests-sanity-checks.yml diff --git a/.github/workflows/run-gpu-tests-sanity-checks.yml b/.old/.github/workflows/run-gpu-tests-sanity-checks.yml similarity index 100% rename from .github/workflows/run-gpu-tests-sanity-checks.yml rename to .old/.github/workflows/run-gpu-tests-sanity-checks.yml diff --git a/.github/workflows/test_multiple_arch.yaml b/.old/.github/workflows/test_multiple_arch.yaml similarity index 100% rename from .github/workflows/test_multiple_arch.yaml rename to .old/.github/workflows/test_multiple_arch.yaml diff --git a/ADRs/0001-SameDiff_File_Format.md b/.old/ADRs/0001-SameDiff_File_Format.md similarity index 100% rename from ADRs/0001-SameDiff_File_Format.md rename to .old/ADRs/0001-SameDiff_File_Format.md diff --git a/ADRs/0002-ONNX_Runtime.md b/.old/ADRs/0002-ONNX_Runtime.md similarity index 100% rename from ADRs/0002-ONNX_Runtime.md rename to .old/ADRs/0002-ONNX_Runtime.md diff --git a/ADRs/0003-Import_IR.md b/.old/ADRs/0003-Import_IR.md similarity index 100% rename from ADRs/0003-Import_IR.md rename to .old/ADRs/0003-Import_IR.md diff --git a/ADRs/0003-NdArray_Strides_ArmCompute.md b/.old/ADRs/0003-NdArray_Strides_ArmCompute.md similarity index 100% rename from ADRs/0003-NdArray_Strides_ArmCompute.md rename to .old/ADRs/0003-NdArray_Strides_ArmCompute.md diff --git a/ADRs/0004-Mapping_IR.md b/.old/ADRs/0004-Mapping_IR.md similarity index 100% rename from ADRs/0004-Mapping_IR.md rename to .old/ADRs/0004-Mapping_IR.md diff --git a/ADRs/0005-Interpreter.md b/.old/ADRs/0005-Interpreter.md similarity index 100% rename from ADRs/0005-Interpreter.md rename to .old/ADRs/0005-Interpreter.md diff --git a/change-cuda-versions.sh b/.old/change-cuda-versions.sh similarity index 100% rename from change-cuda-versions.sh rename to .old/change-cuda-versions.sh diff --git a/change-scala-versions.sh b/.old/change-scala-versions.sh similarity index 100% rename from change-scala-versions.sh rename to .old/change-scala-versions.sh diff --git a/contrib/README.md b/.old/contrib/README.md similarity index 100% rename from contrib/README.md rename to .old/contrib/README.md diff --git a/contrib/codegen-tools/codegen/adr/0001-kotlin_dsl_as_source_of_truth.md b/.old/contrib/codegen-tools/codegen/adr/0001-kotlin_dsl_as_source_of_truth.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0001-kotlin_dsl_as_source_of_truth.md rename to .old/contrib/codegen-tools/codegen/adr/0001-kotlin_dsl_as_source_of_truth.md diff --git a/contrib/codegen-tools/codegen/adr/0002-separate_object_graph_for_serialization.md b/.old/contrib/codegen-tools/codegen/adr/0002-separate_object_graph_for_serialization.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0002-separate_object_graph_for_serialization.md rename to .old/contrib/codegen-tools/codegen/adr/0002-separate_object_graph_for_serialization.md diff --git a/contrib/codegen-tools/codegen/adr/0003-dealing_with_inconsistencies_in_java_naming.md b/.old/contrib/codegen-tools/codegen/adr/0003-dealing_with_inconsistencies_in_java_naming.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0003-dealing_with_inconsistencies_in_java_naming.md rename to .old/contrib/codegen-tools/codegen/adr/0003-dealing_with_inconsistencies_in_java_naming.md diff --git a/contrib/codegen-tools/codegen/adr/0004-auto_initialization_for_inplace_operations.md b/.old/contrib/codegen-tools/codegen/adr/0004-auto_initialization_for_inplace_operations.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0004-auto_initialization_for_inplace_operations.md rename to .old/contrib/codegen-tools/codegen/adr/0004-auto_initialization_for_inplace_operations.md diff --git a/contrib/codegen-tools/codegen/adr/0005-optional_parameters_and_signatures.md b/.old/contrib/codegen-tools/codegen/adr/0005-optional_parameters_and_signatures.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0005-optional_parameters_and_signatures.md rename to .old/contrib/codegen-tools/codegen/adr/0005-optional_parameters_and_signatures.md diff --git a/contrib/codegen-tools/codegen/adr/0006-op_specific_enums.md b/.old/contrib/codegen-tools/codegen/adr/0006-op_specific_enums.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0006-op_specific_enums.md rename to .old/contrib/codegen-tools/codegen/adr/0006-op_specific_enums.md diff --git a/contrib/codegen-tools/codegen/adr/0007-configuration_objects.md b/.old/contrib/codegen-tools/codegen/adr/0007-configuration_objects.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0007-configuration_objects.md rename to .old/contrib/codegen-tools/codegen/adr/0007-configuration_objects.md diff --git a/contrib/codegen-tools/codegen/adr/0008-inheritance.md b/.old/contrib/codegen-tools/codegen/adr/0008-inheritance.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0008-inheritance.md rename to .old/contrib/codegen-tools/codegen/adr/0008-inheritance.md diff --git a/contrib/codegen-tools/codegen/adr/0009-aliasing.md b/.old/contrib/codegen-tools/codegen/adr/0009-aliasing.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0009-aliasing.md rename to .old/contrib/codegen-tools/codegen/adr/0009-aliasing.md diff --git a/contrib/codegen-tools/codegen/adr/0010-ir-codegen.md b/.old/contrib/codegen-tools/codegen/adr/0010-ir-codegen.md similarity index 100% rename from contrib/codegen-tools/codegen/adr/0010-ir-codegen.md rename to .old/contrib/codegen-tools/codegen/adr/0010-ir-codegen.md diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/Namespace.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/Namespace.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/Namespace.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/Namespace.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/cli/CLI.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/cli/CLI.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/cli/CLI.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/cli/CLI.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/cpp/CppGenerator.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/cpp/CppGenerator.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/cpp/CppGenerator.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/cpp/CppGenerator.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/DocsGenerator.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/DocsGenerator.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/DocsGenerator.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/DocsGenerator.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/JavaPoetGenerator.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/JavaPoetGenerator.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/JavaPoetGenerator.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/JavaPoetGenerator.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/Nd4jNamespaceGenerator.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/Nd4jNamespaceGenerator.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/Nd4jNamespaceGenerator.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/java/Nd4jNamespaceGenerator.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/python/PythonGenerator.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/python/PythonGenerator.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/python/PythonGenerator.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/impl/python/PythonGenerator.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/ir/SerializationTest.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/GenUtil.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/GenUtil.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/GenUtil.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/GenUtil.java diff --git a/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/JsonMapper.java b/.old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/JsonMapper.java similarity index 100% rename from contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/JsonMapper.java rename to .old/contrib/codegen-tools/codegen/src/main/java/org/nd4j/codegen/util/JsonMapper.java diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/CodeComponent.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/CodeComponent.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/CodeComponent.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/CodeComponent.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/DataType.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/DataType.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/DataType.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/DataType.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Language.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Language.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Language.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Language.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/LossReduce.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/LossReduce.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/LossReduce.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/LossReduce.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Namespace.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Namespace.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Namespace.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Namespace.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/NamespaceOps.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/NamespaceOps.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/NamespaceOps.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/NamespaceOps.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Op.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Op.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Op.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Op.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Registry.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Registry.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Registry.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Registry.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Variables.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Variables.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Variables.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/Variables.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocScope.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocScope.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocScope.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocScope.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocSection.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocSection.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocSection.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocSection.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocTokens.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocTokens.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocTokens.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/doc/DocTokens.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/ConstraintCodeGenerator.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/ConstraintCodeGenerator.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/ConstraintCodeGenerator.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/ConstraintCodeGenerator.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/Generator.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/Generator.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/Generator.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/Generator.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/GeneratorConfig.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/GeneratorConfig.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/GeneratorConfig.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/api/generator/GeneratorConfig.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/dsl/OpBuilder.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/java/JavaConstraintCodeGenerator.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/java/JavaConstraintCodeGenerator.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/java/JavaConstraintCodeGenerator.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/java/JavaConstraintCodeGenerator.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/python/KotlinExamplePythonGenerator.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/python/KotlinExamplePythonGenerator.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/python/KotlinExamplePythonGenerator.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/impl/python/KotlinExamplePythonGenerator.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/GenerateOps.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/GenerateOps.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/GenerateOps.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/GenerateOps.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/ExtractFromExisting.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/ExtractFromExisting.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/ExtractFromExisting.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/ExtractFromExisting.kt diff --git a/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/FindUsedParameterTypes.kt b/.old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/FindUsedParameterTypes.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/FindUsedParameterTypes.kt rename to .old/contrib/codegen-tools/codegen/src/main/kotlin/org/nd4j/codegen/util/extract/FindUsedParameterTypes.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/mixins/Mixins.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/mixins/Mixins.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/mixins/Mixins.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/mixins/Mixins.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Bitwise.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Bitwise.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Bitwise.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Bitwise.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/CNN.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/CNN.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/CNN.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/CNN.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Image.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Image.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Image.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Image.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Linalg.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Linalg.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Linalg.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Linalg.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Math.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Math.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Math.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Math.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/NeuralNetwork.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/RNN.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/RNN.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/RNN.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/RNN.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Random.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Random.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Random.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/Random.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDBaseOps.kt diff --git a/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDLoss.kt b/.old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDLoss.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDLoss.kt rename to .old/contrib/codegen-tools/codegen/src/main/ops/org/nd4j/codegen/ops/SDLoss.kt diff --git a/contrib/codegen-tools/codegen/src/main/resources/logback.xml b/.old/contrib/codegen-tools/codegen/src/main/resources/logback.xml similarity index 100% rename from contrib/codegen-tools/codegen/src/main/resources/logback.xml rename to .old/contrib/codegen-tools/codegen/src/main/resources/logback.xml diff --git a/contrib/codegen-tools/codegen/src/main/resources/namespaces/math.json b/.old/contrib/codegen-tools/codegen/src/main/resources/namespaces/math.json similarity index 100% rename from contrib/codegen-tools/codegen/src/main/resources/namespaces/math.json rename to .old/contrib/codegen-tools/codegen/src/main/resources/namespaces/math.json diff --git a/contrib/codegen-tools/codegen/src/main/resources/nd4j-op-defs-2.proto b/.old/contrib/codegen-tools/codegen/src/main/resources/nd4j-op-defs-2.proto similarity index 100% rename from contrib/codegen-tools/codegen/src/main/resources/nd4j-op-defs-2.proto rename to .old/contrib/codegen-tools/codegen/src/main/resources/nd4j-op-defs-2.proto diff --git a/contrib/codegen-tools/codegen/src/main/resources/onnx-op-defs.pb b/.old/contrib/codegen-tools/codegen/src/main/resources/onnx-op-defs.pb similarity index 100% rename from contrib/codegen-tools/codegen/src/main/resources/onnx-op-defs.pb rename to .old/contrib/codegen-tools/codegen/src/main/resources/onnx-op-defs.pb diff --git a/contrib/codegen-tools/codegen/src/main/resources/onnx.pbtxt b/.old/contrib/codegen-tools/codegen/src/main/resources/onnx.pbtxt similarity index 100% rename from contrib/codegen-tools/codegen/src/main/resources/onnx.pbtxt rename to .old/contrib/codegen-tools/codegen/src/main/resources/onnx.pbtxt diff --git a/contrib/codegen-tools/codegen/src/main/resources/tensorflowOpMappings.csv b/.old/contrib/codegen-tools/codegen/src/main/resources/tensorflowOpMappings.csv similarity index 100% rename from contrib/codegen-tools/codegen/src/main/resources/tensorflowOpMappings.csv rename to .old/contrib/codegen-tools/codegen/src/main/resources/tensorflowOpMappings.csv diff --git a/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java b/.old/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java similarity index 100% rename from contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java rename to .old/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/DocsGeneratorTest.java diff --git a/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/TestGeneration.java b/.old/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/TestGeneration.java similarity index 100% rename from contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/TestGeneration.java rename to .old/contrib/codegen-tools/codegen/src/test/java/org/nd4j/codegen/dsl/TestGeneration.java diff --git a/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConfigTest.kt b/.old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConfigTest.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConfigTest.kt rename to .old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConfigTest.kt diff --git a/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConstraintTest.kt b/.old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConstraintTest.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConstraintTest.kt rename to .old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/ConstraintTest.kt diff --git a/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/NamespaceInvariantTest.kt b/.old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/NamespaceInvariantTest.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/NamespaceInvariantTest.kt rename to .old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/NamespaceInvariantTest.kt diff --git a/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpBuilderTest.kt b/.old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpBuilderTest.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpBuilderTest.kt rename to .old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpBuilderTest.kt diff --git a/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpInvariantTest.kt b/.old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpInvariantTest.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpInvariantTest.kt rename to .old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/dsl/OpInvariantTest.kt diff --git a/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/ops/ConstructionTest.kt b/.old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/ops/ConstructionTest.kt similarity index 100% rename from contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/ops/ConstructionTest.kt rename to .old/contrib/codegen-tools/codegen/src/test/kotlin/org/nd4j/codegen/ops/ConstructionTest.kt diff --git a/contrib/codegen-tools/codegen/src/test/resources/lenet.onnx b/.old/contrib/codegen-tools/codegen/src/test/resources/lenet.onnx similarity index 100% rename from contrib/codegen-tools/codegen/src/test/resources/lenet.onnx rename to .old/contrib/codegen-tools/codegen/src/test/resources/lenet.onnx diff --git a/contrib/codegen-tools/codegen/src/test/resources/lenet_frozen.pb b/.old/contrib/codegen-tools/codegen/src/test/resources/lenet_frozen.pb similarity index 100% rename from contrib/codegen-tools/codegen/src/test/resources/lenet_frozen.pb rename to .old/contrib/codegen-tools/codegen/src/test/resources/lenet_frozen.pb diff --git a/contrib/codegen-tools/libnd4j-gen/README.md b/.old/contrib/codegen-tools/libnd4j-gen/README.md similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/README.md rename to .old/contrib/codegen-tools/libnd4j-gen/README.md diff --git a/contrib/codegen-tools/libnd4j-gen/op-ir.proto b/.old/contrib/codegen-tools/libnd4j-gen/op-ir.proto similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/op-ir.proto rename to .old/contrib/codegen-tools/libnd4j-gen/op-ir.proto diff --git a/contrib/codegen-tools/libnd4j-gen/pom.xml b/.old/contrib/codegen-tools/libnd4j-gen/pom.xml similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/pom.xml rename to .old/contrib/codegen-tools/libnd4j-gen/pom.xml diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/OpDeclarationDescriptor.java b/.old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/OpDeclarationDescriptor.java similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/OpDeclarationDescriptor.java rename to .old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/OpDeclarationDescriptor.java diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/ParseOpFile.java b/.old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/ParseOpFile.java similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/ParseOpFile.java rename to .old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/ParseOpFile.java diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorProposal.java b/.old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorProposal.java similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorProposal.java rename to .old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorProposal.java diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorSource.java b/.old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorSource.java similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorSource.java rename to .old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/ArgDescriptorSource.java diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/ArgDescriptorParserUtils.java b/.old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/ArgDescriptorParserUtils.java similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/ArgDescriptorParserUtils.java rename to .old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/ArgDescriptorParserUtils.java diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java b/.old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java rename to .old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java b/.old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java similarity index 100% rename from contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java rename to .old/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java diff --git a/contrib/codegen-tools/onnx-def-gen/README.md b/.old/contrib/codegen-tools/onnx-def-gen/README.md similarity index 100% rename from contrib/codegen-tools/onnx-def-gen/README.md rename to .old/contrib/codegen-tools/onnx-def-gen/README.md diff --git a/contrib/codegen-tools/onnx-def-gen/lenet.onnx b/.old/contrib/codegen-tools/onnx-def-gen/lenet.onnx similarity index 100% rename from contrib/codegen-tools/onnx-def-gen/lenet.onnx rename to .old/contrib/codegen-tools/onnx-def-gen/lenet.onnx diff --git a/contrib/codegen-tools/onnx-def-gen/onnx-op-defs.pb b/.old/contrib/codegen-tools/onnx-def-gen/onnx-op-defs.pb similarity index 100% rename from contrib/codegen-tools/onnx-def-gen/onnx-op-defs.pb rename to .old/contrib/codegen-tools/onnx-def-gen/onnx-op-defs.pb diff --git a/contrib/codegen-tools/onnx-def-gen/onnx.pbtxt b/.old/contrib/codegen-tools/onnx-def-gen/onnx.pbtxt similarity index 100% rename from contrib/codegen-tools/onnx-def-gen/onnx.pbtxt rename to .old/contrib/codegen-tools/onnx-def-gen/onnx.pbtxt diff --git a/contrib/codegen-tools/onnx-def-gen/onnx_def_gen.py b/.old/contrib/codegen-tools/onnx-def-gen/onnx_def_gen.py similarity index 100% rename from contrib/codegen-tools/onnx-def-gen/onnx_def_gen.py rename to .old/contrib/codegen-tools/onnx-def-gen/onnx_def_gen.py diff --git a/contrib/codegen-tools/onnx-def-gen/save_test.py b/.old/contrib/codegen-tools/onnx-def-gen/save_test.py similarity index 100% rename from contrib/codegen-tools/onnx-def-gen/save_test.py rename to .old/contrib/codegen-tools/onnx-def-gen/save_test.py diff --git a/contrib/codegen-tools/onnx-def-gen/test_onnx_lenet.py b/.old/contrib/codegen-tools/onnx-def-gen/test_onnx_lenet.py similarity index 100% rename from contrib/codegen-tools/onnx-def-gen/test_onnx_lenet.py rename to .old/contrib/codegen-tools/onnx-def-gen/test_onnx_lenet.py diff --git a/contrib/codegen-tools/onnx-def-gen/test_op_def_gen.py b/.old/contrib/codegen-tools/onnx-def-gen/test_op_def_gen.py similarity index 100% rename from contrib/codegen-tools/onnx-def-gen/test_op_def_gen.py rename to .old/contrib/codegen-tools/onnx-def-gen/test_op_def_gen.py diff --git a/contrib/formatter.xml b/.old/contrib/formatter.xml similarity index 100% rename from contrib/formatter.xml rename to .old/contrib/formatter.xml diff --git a/deeplearning4j/.codeclimate.yml b/.old/deeplearning4j/.codeclimate.yml similarity index 100% rename from deeplearning4j/.codeclimate.yml rename to .old/deeplearning4j/.codeclimate.yml diff --git a/deeplearning4j/CONTRIBUTORS.md b/.old/deeplearning4j/CONTRIBUTORS.md similarity index 100% rename from deeplearning4j/CONTRIBUTORS.md rename to .old/deeplearning4j/CONTRIBUTORS.md diff --git a/deeplearning4j/GITTER_GUIDELINES.md b/.old/deeplearning4j/GITTER_GUIDELINES.md similarity index 100% rename from deeplearning4j/GITTER_GUIDELINES.md rename to .old/deeplearning4j/GITTER_GUIDELINES.md diff --git a/deeplearning4j/README.md b/.old/deeplearning4j/README.md similarity index 100% rename from deeplearning4j/README.md rename to .old/deeplearning4j/README.md diff --git a/deeplearning4j/buildmultiplescalaversions.sh b/.old/deeplearning4j/buildmultiplescalaversions.sh similarity index 100% rename from deeplearning4j/buildmultiplescalaversions.sh rename to .old/deeplearning4j/buildmultiplescalaversions.sh diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml b/.old/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml rename to .old/deeplearning4j/deeplearning4j-dataimport-solrj/pom.xml diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/src/main/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIterator.java b/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/main/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-dataimport-solrj/src/main/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIterator.java rename to .old/deeplearning4j/deeplearning4j-dataimport-solrj/src/main/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIterator.java diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java b/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java rename to .old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/collection1/README b/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/collection1/README similarity index 100% rename from deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/collection1/README rename to .old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/collection1/README diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/schema.xml b/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/schema.xml similarity index 100% rename from deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/schema.xml rename to .old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/schema.xml diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/solrconfig.xml b/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/solrconfig.xml similarity index 100% rename from deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/solrconfig.xml rename to .old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/resources/solr/configsets/mini/conf/solrconfig.xml diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/.old/deeplearning4j/deeplearning4j-graph/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-graph/pom.xml rename to .old/deeplearning4j/deeplearning4j-graph/pom.xml diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/BaseGraph.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/BaseGraph.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/BaseGraph.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/BaseGraph.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Edge.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Edge.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Edge.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Edge.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IGraph.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IGraph.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IGraph.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IGraph.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IVertexSequence.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IVertexSequence.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IVertexSequence.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/IVertexSequence.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/NoEdgeHandling.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/NoEdgeHandling.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/NoEdgeHandling.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/NoEdgeHandling.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Vertex.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Vertex.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Vertex.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/api/Vertex.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/EdgeLineProcessor.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/EdgeLineProcessor.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/EdgeLineProcessor.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/EdgeLineProcessor.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/GraphLoader.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/GraphLoader.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/GraphLoader.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/GraphLoader.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/VertexLoader.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/VertexLoader.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/VertexLoader.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/VertexLoader.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedEdgeLineProcessor.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedEdgeLineProcessor.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedEdgeLineProcessor.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedEdgeLineProcessor.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedVertexLoader.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedVertexLoader.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedVertexLoader.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/DelimitedVertexLoader.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/WeightedEdgeLineProcessor.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/WeightedEdgeLineProcessor.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/WeightedEdgeLineProcessor.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/data/impl/WeightedEdgeLineProcessor.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/NoEdgesException.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/NoEdgesException.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/NoEdgesException.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/NoEdgesException.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/ParseException.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/ParseException.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/ParseException.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/exception/ParseException.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/Graph.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/Graph.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/Graph.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/Graph.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/VertexSequence.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/VertexSequence.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/VertexSequence.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/graph/VertexSequence.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/GraphWalkIterator.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/GraphWalkIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/GraphWalkIterator.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/GraphWalkIterator.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/RandomWalkIterator.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/RandomWalkIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/RandomWalkIterator.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/RandomWalkIterator.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/WeightedRandomWalkIterator.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/WeightedRandomWalkIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/WeightedRandomWalkIterator.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/WeightedRandomWalkIterator.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/GraphWalkIteratorProvider.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/GraphWalkIteratorProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/GraphWalkIteratorProvider.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/GraphWalkIteratorProvider.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/RandomWalkGraphIteratorProvider.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/RandomWalkGraphIteratorProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/RandomWalkGraphIteratorProvider.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/RandomWalkGraphIteratorProvider.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/WeightedRandomWalkGraphIteratorProvider.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/WeightedRandomWalkGraphIteratorProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/WeightedRandomWalkGraphIteratorProvider.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/iterator/parallel/WeightedRandomWalkGraphIteratorProvider.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/BinaryTree.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/BinaryTree.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/BinaryTree.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/BinaryTree.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/GraphVectors.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/GraphVectors.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/GraphVectors.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/GraphVectors.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/DeepWalk.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/DeepWalk.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/DeepWalk.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/DeepWalk.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/GraphHuffman.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/GraphHuffman.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/GraphHuffman.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/deepwalk/GraphHuffman.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorLookupTable.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorLookupTable.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorLookupTable.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorLookupTable.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorsImpl.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorsImpl.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorsImpl.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/GraphVectorsImpl.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/InMemoryGraphLookupTable.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/InMemoryGraphLookupTable.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/InMemoryGraphLookupTable.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/embeddings/InMemoryGraphLookupTable.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/loader/GraphVectorSerializer.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/loader/GraphVectorSerializer.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/loader/GraphVectorSerializer.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/models/loader/GraphVectorSerializer.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/IntegerVertexFactory.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/IntegerVertexFactory.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/IntegerVertexFactory.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/IntegerVertexFactory.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/StringVertexFactory.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/StringVertexFactory.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/StringVertexFactory.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/StringVertexFactory.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VertexFactory.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VertexFactory.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VertexFactory.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VertexFactory.java diff --git a/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VoidVertexFactory.java b/.old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VoidVertexFactory.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VoidVertexFactory.java rename to .old/deeplearning4j/deeplearning4j-graph/src/main/java/org/deeplearning4j/graph/vertexfactory/VoidVertexFactory.java diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java b/.old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java rename to .old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java b/.old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java rename to .old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoading.java diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java b/.old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java rename to .old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/data/TestGraphLoadingWeighted.java diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java b/.old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java rename to .old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/graph/TestGraph.java diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java b/.old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java rename to .old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java b/.old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java rename to .old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestDeepWalk.java diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java b/.old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java similarity index 100% rename from deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java rename to .old/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/TestGraphHuffman.java diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml b/.old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml rename to .old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/pom.xml diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java b/.old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java similarity index 100% rename from deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java rename to .old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java b/.old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java similarity index 100% rename from deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java rename to .old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/Tsne.java diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java b/.old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java similarity index 100% rename from deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java rename to .old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/Test6058.java diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/TsneTest.java b/.old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/TsneTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/TsneTest.java rename to .old/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/test/java/org/deeplearning4j/plot/TsneTest.java diff --git a/deeplearning4j/deeplearning4j-manifold/pom.xml b/.old/deeplearning4j/deeplearning4j-manifold/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-manifold/pom.xml rename to .old/deeplearning4j/deeplearning4j-manifold/pom.xml diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml b/.old/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/pom.xml rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStream.java b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStream.java similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStream.java rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStream.java diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModel.java b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModel.java similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModel.java rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/src/main/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModel.java diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/collection1/README b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/collection1/README similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/collection1/README rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/collection1/README diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/schema.xml b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/schema.xml similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/schema.xml rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/schema.xml diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/solrconfig.xml b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/solrconfig.xml similarity index 100% rename from deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/solrconfig.xml rename to .old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/resources/solr/configsets/mini-expressible/conf/solrconfig.xml diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/pom.xml diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/DL4jServlet.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/main/java/org/deeplearning4j/remote/JsonModelServer.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/BinaryModelServerTest.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/ServletTest.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/House.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/HouseToPredictedPriceAdapter.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/ImageConversionUtils.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/ImageConversionUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/ImageConversionUtils.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/ImageConversionUtils.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/helpers/PredictedPrice.java diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml similarity index 100% rename from deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml rename to .old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/resources/logback.xml diff --git a/deeplearning4j/deeplearning4j-remote/pom.xml b/.old/deeplearning4j/deeplearning4j-remote/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-remote/pom.xml rename to .old/deeplearning4j/deeplearning4j-remote/pom.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/pom.xml b/.old/deeplearning4j/deeplearning4j-scaleout/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/pom.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/pom.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/pom.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/SparkParagraphVectors.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/SparkParagraphVectors.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/SparkParagraphVectors.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/SparkParagraphVectors.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/DocumentSequenceConvertFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/DocumentSequenceConvertFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/DocumentSequenceConvertFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/DocumentSequenceConvertFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/KeySequenceConvertFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/KeySequenceConvertFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/KeySequenceConvertFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/paragraphvectors/functions/KeySequenceConvertFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectors.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectors.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectors.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectors.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/SparkModelExporter.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/SparkModelExporter.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/SparkModelExporter.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/SparkModelExporter.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/HdfsModelExporter.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/HdfsModelExporter.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/HdfsModelExporter.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/HdfsModelExporter.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/VocabCacheExporter.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/VocabCacheExporter.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/VocabCacheExporter.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/export/impl/VocabCacheExporter.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/BaseTokenizerFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/CountFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/DistributedFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/DistributedFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/DistributedFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/DistributedFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ElementsFrequenciesAccumulator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ElementsFrequenciesAccumulator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ElementsFrequenciesAccumulator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ElementsFrequenciesAccumulator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExportFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExportFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExportFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExportFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraCountFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraCountFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraCountFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraCountFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraElementsFrequenciesAccumulator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraElementsFrequenciesAccumulator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraElementsFrequenciesAccumulator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ExtraElementsFrequenciesAccumulator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ListSequenceConvertFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ListSequenceConvertFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ListSequenceConvertFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/ListSequenceConvertFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/PartitionTrainingFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/PartitionTrainingFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/PartitionTrainingFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/PartitionTrainingFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TokenizerFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TokenizerFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TokenizerFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TokenizerFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/TrainingFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/functions/VocabRddFunctionFlat.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkElementsLearningAlgorithm.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkElementsLearningAlgorithm.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkElementsLearningAlgorithm.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkElementsLearningAlgorithm.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkSequenceLearningAlgorithm.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkSequenceLearningAlgorithm.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkSequenceLearningAlgorithm.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/SparkSequenceLearningAlgorithm.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/BaseSparkLearningAlgorithm.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/BaseSparkLearningAlgorithm.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/BaseSparkLearningAlgorithm.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/BaseSparkLearningAlgorithm.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkCBOW.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkCBOW.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkCBOW.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkCBOW.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkSkipGram.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkSkipGram.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkSkipGram.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/elements/SparkSkipGram.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/BaseSparkSequenceLearningAlgorithm.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/BaseSparkSequenceLearningAlgorithm.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/BaseSparkSequenceLearningAlgorithm.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/BaseSparkSequenceLearningAlgorithm.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDBOW.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDBOW.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDBOW.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDBOW.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDM.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDM.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDM.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/learning/sequence/SparkDM.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/primitives/ExtraCounter.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/primitives/ExtraCounter.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/primitives/ExtraCounter.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/sequencevectors/primitives/ExtraCounter.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/word2vec/SparkWord2Vec.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/word2vec/SparkWord2Vec.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/word2vec/SparkWord2Vec.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/main/java/org/deeplearning4j/spark/models/word2vec/SparkWord2Vec.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/SparkSequenceVectorsTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/sequencevectors/export/ExportContainerTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/log4j.properties diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/resources/logback.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/.gitignore b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/.gitignore similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/.gitignore rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/.gitignore diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/pom.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/FirstIterationFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/MapToPairFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/MapToPairFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/MapToPairFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/MapToPairFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/NegativeHolder.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SecondIterationFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/SentenceBatch.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/VocabHolder.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2Vec.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecChange.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecFuncCall.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecFuncCall.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecFuncCall.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecFuncCall.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecParam.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecSetup.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecVariables.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/MaxPerPartitionAccumulator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/MaxPerPartitionAccumulator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/MaxPerPartitionAccumulator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/MaxPerPartitionAccumulator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/WordFreqAccumulator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/WordFreqAccumulator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/WordFreqAccumulator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/accumulators/WordFreqAccumulator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/CountCumSum.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldBetweenPartitionFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/FoldWithinPartitionFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/GetSentenceCountFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/GetSentenceCountFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/GetSentenceCountFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/GetSentenceCountFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/MapPerPartitionVoidFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/MapPerPartitionVoidFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/MapPerPartitionVoidFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/MapPerPartitionVoidFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/ReduceSentenceCount.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/ReduceSentenceCount.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/ReduceSentenceCount.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/ReduceSentenceCount.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TextPipeline.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/TokenizerFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/UpdateWordFreqAccumulatorFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/WordsListToVocabWordsFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/WordsListToVocabWordsFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/WordsListToVocabWordsFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/text/functions/WordsListToVocabWordsFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TestFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/log4j.properties diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/resources/logback.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/pom.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerSubscriber.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerSubscriber.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerSubscriber.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerSubscriber.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationTuple.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationTuple.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationTuple.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationTuple.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/DataSetDeserializationCallback.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/DataSetDeserializationCallback.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/DataSetDeserializationCallback.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/DataSetDeserializationCallback.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/MultiDataSetDeserializationCallback.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/MultiDataSetDeserializationCallback.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/MultiDataSetDeserializationCallback.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/MultiDataSetDeserializationCallback.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamCallback.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamCallback.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamCallback.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamCallback.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamMDSCallback.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamMDSCallback.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamMDSCallback.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/callbacks/PortableDataStreamMDSCallback.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/conf/SharedTrainingConfiguration.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/conf/SharedTrainingConfiguration.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/conf/SharedTrainingConfiguration.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/conf/SharedTrainingConfiguration.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapDataSet.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapMultiDataSet.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPaths.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/functions/SharedFlatMapPathsMDS.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/MultiPdsIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/MultiPdsIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/MultiPdsIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/MultiPdsIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/PdsIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/PdsIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/PdsIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/PdsIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualMultiDataSetIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/ElephasModelImport.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/ElephasModelImport.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/ElephasModelImport.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/ElephasModelImport.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/SilentTrainingDriver.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/WiredEncodingHandler.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/WiredEncodingHandler.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/WiredEncodingHandler.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/WiredEncodingHandler.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryConfirmation.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryConfirmation.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryConfirmation.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryConfirmation.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryMessage.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryMessage.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryMessage.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentIntroductoryMessage.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentUpdatesMessage.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentUpdatesMessage.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentUpdatesMessage.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v1/messages/SilentUpdatesMessage.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/ModelParamsConsumer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/ModelParamsConsumer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/ModelParamsConsumer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/ModelParamsConsumer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdaterParamsConsumer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdaterParamsConsumer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdaterParamsConsumer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdaterParamsConsumer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/UpdatesConsumer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/WiredEncodingHandler.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/WiredEncodingHandler.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/WiredEncodingHandler.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/networking/v2/WiredEncodingHandler.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/ArrayDescriptor.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/ArrayDescriptor.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/ArrayDescriptor.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/ArrayDescriptor.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/DataSetDescriptor.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/DataSetDescriptor.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/DataSetDescriptor.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/DataSetDescriptor.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/Utils.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/Utils.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/Utils.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/python/Utils.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingResult.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingResult.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingResult.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingResult.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/BlockingObserver.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/BlockingObserver.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/BlockingObserver.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/BlockingObserver.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/CountingIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/CountingIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/CountingIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/util/CountingIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAccumulationFunctionTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/accumulation/SharedTrainingAggregateFunctionTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualDataSetIteratorTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/iterators/VirtualIteratorTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/modelimport/elephas/TestElephasImport.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/log4j.properties diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/resources/logback.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/nd4j-native.properties b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/nd4j-native.properties similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/nd4j-native.properties rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/nd4j-native.properties diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RDDTrainingApproach.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RDDTrainingApproach.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RDDTrainingApproach.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RDDTrainingApproach.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartition.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartition.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartition.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartition.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RepartitionStrategy.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RepartitionStrategy.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RepartitionStrategy.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/RepartitionStrategy.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartitioner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartitioner.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartitioner.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/Repartitioner.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingMaster.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingMaster.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingMaster.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingMaster.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingResult.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingResult.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingResult.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingResult.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingWorker.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingWorker.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingWorker.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/TrainingWorker.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/WorkerConfiguration.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/WorkerConfiguration.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/WorkerConfiguration.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/WorkerConfiguration.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/CommonSparkTrainingStats.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/CommonSparkTrainingStats.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/CommonSparkTrainingStats.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/CommonSparkTrainingStats.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/SparkTrainingStats.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/SparkTrainingStats.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/SparkTrainingStats.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/SparkTrainingStats.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/stats/StatsCalculationHelper.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerFlatMap.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerMultiDataSetFlatMap.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSFlatMap.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPDSMDSFlatMap.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathFlatMap.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/ExecuteWorkerPathMDSFlatMap.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportDataSetsFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchAndExportMultiDataSetsFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/BatchDataSetsFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetExportFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetExportFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetExportFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetExportFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetProvider.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetProvider.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/DataSetProvider.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetExportFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetExportFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetExportFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetExportFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetProvider.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetProvider.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/MultiDataSetProvider.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToDataSetFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToDataSetFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToDataSetFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToDataSetFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToMultiDataSetFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToMultiDataSetFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToMultiDataSetFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/PathToMultiDataSetFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/SplitDataSetsFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/SplitDataSetsFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/SplitDataSetsFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/SplitDataSetsFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSource.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSource.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSource.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSource.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSourceFactory.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSourceFactory.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSourceFactory.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/loader/RemoteFileSourceFactory.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/data/shuffle/SplitDataSetExamplesPairFlatMapFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecByteDataSetFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecDataSetFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequenceDataSetFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/DataVecSequencePairDataSetFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RDDMiniBatches.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/RecordReaderFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/export/StringToDataSetExportFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecord.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/DataVecRecords.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/IteratorUtils.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/RRMDSIFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummyReader.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/datavec/iterator/SparkSourceDummySeqReader.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkDataSetLossCalculator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingGraphTrainer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkEarlyStoppingTrainer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/earlystopping/SparkLossCalculatorComputationGraph.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/SparkListenable.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/Add.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/Add.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/Add.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/Add.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/CountPartitionsFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/CountPartitionsFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/CountPartitionsFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/CountPartitionsFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/LoadDataSetFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/LoadDataSetFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/LoadDataSetFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/LoadDataSetFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction2.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction2.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction2.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/SplitPartitionsFunction2.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/IntDoubleReduceFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/IntDoubleReduceFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/IntDoubleReduceFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/IntDoubleReduceFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/LongDoubleReduceFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/LongDoubleReduceFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/LongDoubleReduceFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/reduce/LongDoubleReduceFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitioner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitioner.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitioner.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitioner.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/EqualPartitioner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/EqualPartitioner.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/EqualPartitioner.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/EqualPartitioner.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitioner.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/repartition/MapTupleToPairFlatMap.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/DataSetToMultiDataSetFn.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/DataSetToMultiDataSetFn.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/DataSetToMultiDataSetFn.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/DataSetToMultiDataSetFn.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/PairDataSetToMultiDataSetFn.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/PairDataSetToMultiDataSetFn.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/PairDataSetToMultiDataSetFn.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/dataset/PairDataSetToMultiDataSetFn.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSFlatMapFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/evaluation/IEvaluateMDSPathsFlatMapFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ArrayPairToPair.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ArrayPairToPair.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ArrayPairToPair.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ArrayPairToPair.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/GraphFeedForwardWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/PairToArrayPair.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/PairToArrayPair.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/PairToArrayPair.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/PairToArrayPair.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreExamplesWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGDataSet.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/ScoreFlatMapFunctionCGMultiDataSet.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouter.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouter.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouter.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouter.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouterProvider.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouterProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouterProvider.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/listeners/VanillaStatsStorageRouterProvider.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateAggregateFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateAggregateFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateAggregateFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateAggregateFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluateFlatMapFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluationReduceFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluationReduceFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluationReduceFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/evaluation/IEvaluationReduceFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreFlatMapFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/SingleToPairFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/SingleToPairFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/SingleToPairFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/SingleToPairFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingMaster.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingMaster.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingMaster.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingMaster.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingResult.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingResult.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingResult.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingResult.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingWorker.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingWorker.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingWorker.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/BaseTrainingWorker.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingResult.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingResult.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingResult.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingResult.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingAggregationTuple.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingAggregationTuple.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingAggregationTuple.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingAggregationTuple.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementAddFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementAddFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementAddFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementAddFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementCombineFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementCombineFunction.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementCombineFunction.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/aggregator/ParameterAveragingElementCombineFunction.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingMasterStats.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingMasterStats.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingMasterStats.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingMasterStats.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/stats/ParameterAveragingTrainingWorkerStats.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupport.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupport.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupport.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupport.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/DefaultRepartitioner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/DefaultRepartitioner.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/DefaultRepartitioner.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/DefaultRepartitioner.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/EqualRepartitioner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/EqualRepartitioner.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/EqualRepartitioner.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/EqualRepartitioner.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/NoOpRepartitioner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/NoOpRepartitioner.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/NoOpRepartitioner.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/repartitioner/NoOpRepartitioner.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/BaseDataSetIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/BaseDataSetIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/BaseDataSetIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/BaseDataSetIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkDataSetIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkMultiDataSetIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkMultiDataSetIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkMultiDataSetIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PathSparkMultiDataSetIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamDataSetIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/PortableDataStreamMultiDataSetIterator.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkADSI.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/ordering/DataSetOrdering.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/ordering/DataSetOrdering.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/ordering/DataSetOrdering.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/ordering/DataSetOrdering.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/BaseEventStats.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/BaseEventStats.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/BaseEventStats.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/BaseEventStats.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/EventStats.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/EventStats.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/EventStats.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/EventStats.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/ExampleCountEventStats.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/PartitionCountEventStats.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/PartitionCountEventStats.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/PartitionCountEventStats.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/PartitionCountEventStats.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/stats/StatsUtils.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/NTPTimeSource.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/SystemClockTimeSource.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/SystemClockTimeSource.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/SystemClockTimeSource.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/SystemClockTimeSource.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSource.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSource.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSource.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSource.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/time/TimeSourceProvider.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/MLLibUtil.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkDataUtils.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkDataUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkDataUtils.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkDataUtils.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/SparkUtils.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/SparkDataValidation.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/SparkDataValidation.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/SparkDataValidation.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/SparkDataValidation.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/ValidationResult.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/ValidationResult.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/ValidationResult.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/ValidationResult.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateDataSetFn.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateDataSetFn.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateDataSetFn.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateDataSetFn.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateMultiDataSetFn.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateMultiDataSetFn.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateMultiDataSetFn.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidateMultiDataSetFn.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidationResultReduceFn.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidationResultReduceFn.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidationResultReduceFn.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/data/validation/ValidationResultReduceFn.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelDeserializer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/util/serde/StorageLevelSerializer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/scala/org/apache/spark/TaskContextHelper.scala b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/scala/org/apache/spark/TaskContextHelper.scala similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/scala/org/apache/spark/TaskContextHelper.scala rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/scala/org/apache/spark/TaskContextHelper.scala diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/common/AddTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestShuffleExamples.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/data/TestSparkDataUtils.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/MiniBatchTests.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestDataVecDataSetFunctions.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestExport.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/iterator/TestIteratorUtils.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/BalancedPartitionerTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/common/repartition/HashingBalancedPartitionerTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestJsonYaml.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/util/ExportSupportTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/time/TestTimeSource.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/MLLIbUtilTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestRepartitioning.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/util/TestValidation.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/log4j.properties b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/log4j.properties similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/log4j.properties rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/log4j.properties diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/resources/logback.xml diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml b/.old/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml similarity index 100% rename from deeplearning4j/deeplearning4j-scaleout/spark/pom.xml rename to .old/deeplearning4j/deeplearning4j-scaleout/spark/pom.xml diff --git a/deeplearning4j/pom.xml b/.old/deeplearning4j/pom.xml similarity index 100% rename from deeplearning4j/pom.xml rename to .old/deeplearning4j/pom.xml diff --git a/nd4j/README.md b/.old/nd4j/README.md similarity index 100% rename from nd4j/README.md rename to .old/nd4j/README.md diff --git a/nd4j/RaspberryPi.md b/.old/nd4j/RaspberryPi.md similarity index 100% rename from nd4j/RaspberryPi.md rename to .old/nd4j/RaspberryPi.md diff --git a/nd4j/buildmultiplescalaversions.sh b/.old/nd4j/buildmultiplescalaversions.sh similarity index 100% rename from nd4j/buildmultiplescalaversions.sh rename to .old/nd4j/buildmultiplescalaversions.sh diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-api/pom.xml diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/driverfinder/DriverFinder.java diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/api/JDBCNDArrayIO.java b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/api/JDBCNDArrayIO.java similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/api/JDBCNDArrayIO.java rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/api/JDBCNDArrayIO.java diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/impl/BaseLoader.java b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/impl/BaseLoader.java similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/impl/BaseLoader.java rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-api/src/main/java/org/nd4j/jdbc/loader/impl/BaseLoader.java diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/pom.xml diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/main/java/org/nd4j/jdbc/hsql/HsqlLoader.java b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/main/java/org/nd4j/jdbc/hsql/HsqlLoader.java similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/main/java/org/nd4j/jdbc/hsql/HsqlLoader.java rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/main/java/org/nd4j/jdbc/hsql/HsqlLoader.java diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/main/resources/nd4j.jdbc.properties b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/main/resources/nd4j.jdbc.properties similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/main/resources/nd4j.jdbc.properties rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/main/resources/nd4j.jdbc.properties diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-hsql/src/test/java/org/nd4j/jdbc/hsql/HSqlLoaderTest.java diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/main/java/org/nd4j/jdbc/mysql/MysqlLoader.java b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/main/java/org/nd4j/jdbc/mysql/MysqlLoader.java similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/main/java/org/nd4j/jdbc/mysql/MysqlLoader.java rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/main/java/org/nd4j/jdbc/mysql/MysqlLoader.java diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/main/resources/nd4j.jdbc.properties b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/main/resources/nd4j.jdbc.properties similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/main/resources/nd4j.jdbc.properties rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/main/resources/nd4j.jdbc.properties diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java b/.old/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java similarity index 100% rename from nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java rename to .old/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/src/test/java/org/nd4j/jdbc/mysql/MysqlLoaderTest.java diff --git a/nd4j/nd4j-jdbc/pom.xml b/.old/nd4j/nd4j-jdbc/pom.xml similarity index 100% rename from nd4j/nd4j-jdbc/pom.xml rename to .old/nd4j/nd4j-jdbc/pom.xml diff --git a/nd4j/nd4j-onnxruntime/pom.xml b/.old/nd4j/nd4j-onnxruntime/pom.xml similarity index 100% rename from nd4j/nd4j-onnxruntime/pom.xml rename to .old/nd4j/nd4j-onnxruntime/pom.xml diff --git a/nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunner.java b/.old/nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunner.java similarity index 100% rename from nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunner.java rename to .old/nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunner.java diff --git a/nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/util/ONNXUtils.java b/.old/nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/util/ONNXUtils.java similarity index 100% rename from nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/util/ONNXUtils.java rename to .old/nd4j/nd4j-onnxruntime/src/main/java/org/nd4j/onnxruntime/util/ONNXUtils.java diff --git a/nd4j/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java b/.old/nd4j/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java similarity index 100% rename from nd4j/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java rename to .old/nd4j/nd4j-onnxruntime/src/test/java/org/nd4j/onnxruntime/runner/OnnxRuntimeRunnerTests.java diff --git a/nd4j/nd4j-onnxruntime/src/test/resources/add.onnx b/.old/nd4j/nd4j-onnxruntime/src/test/resources/add.onnx similarity index 100% rename from nd4j/nd4j-onnxruntime/src/test/resources/add.onnx rename to .old/nd4j/nd4j-onnxruntime/src/test/resources/add.onnx diff --git a/nd4j/nd4j-remote/README.md b/.old/nd4j/nd4j-remote/README.md similarity index 100% rename from nd4j/nd4j-remote/README.md rename to .old/nd4j/nd4j-remote/README.md diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml b/.old/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml similarity index 100% rename from nd4j/nd4j-remote/nd4j-grpc-client/pom.xml rename to .old/nd4j/nd4j-remote/nd4j-grpc-client/pom.xml diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/GraphInferenceGrpcClient.java b/.old/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/GraphInferenceGrpcClient.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/GraphInferenceGrpcClient.java rename to .old/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/GraphInferenceGrpcClient.java diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/grpc/GraphInferenceServerGrpc.java b/.old/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/grpc/GraphInferenceServerGrpc.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/grpc/GraphInferenceServerGrpc.java rename to .old/nd4j/nd4j-remote/nd4j-grpc-client/src/main/java/org/nd4j/remote/grpc/grpc/GraphInferenceServerGrpc.java diff --git a/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java b/.old/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java rename to .old/nd4j/nd4j-remote/nd4j-grpc-client/src/test/java/org/nd4j/graph/GraphInferenceGrpcClientTest.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/pom.xml b/.old/nd4j/nd4j-remote/nd4j-json-client/pom.xml similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/pom.xml rename to .old/nd4j/nd4j-remote/nd4j-json-client/pom.xml diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/JsonRemoteInference.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinaryDeserializer.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinaryDeserializer.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinaryDeserializer.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinaryDeserializer.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinarySerializer.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinarySerializer.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinarySerializer.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/BinarySerializer.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonDeserializer.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/JsonSerializer.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/AbstractSerDe.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/BooleanSerde.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleArraySerde.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/DoubleSerde.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatArraySerde.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/FloatSerde.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/IntegerSerde.java diff --git a/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java b/.old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java rename to .old/nd4j/nd4j-remote/nd4j-json-client/src/main/java/org/nd4j/remote/clients/serde/impl/StringSerde.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/README.md b/.old/nd4j/nd4j-remote/nd4j-json-server/README.md similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/README.md rename to .old/nd4j/nd4j-remote/nd4j-json-server/README.md diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/.old/nd4j/nd4j-remote/nd4j-json-server/pom.xml similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/pom.xml rename to .old/nd4j/nd4j-remote/nd4j-json-server/pom.xml diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/SameDiffJsonModelServer.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ModelServingServlet.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/SameDiffServlet.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/main/java/org/nd4j/remote/serving/ServingProcessor.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffJsonModelServerTest.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/SameDiffServletTest.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/House.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/HouseToPredictedPriceAdapter.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/helpers/PredictedPrice.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java b/.old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/test/java/org/nd4j/remote/serde/BasicSerdeTests.java diff --git a/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml b/.old/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml similarity index 100% rename from nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml rename to .old/nd4j/nd4j-remote/nd4j-json-server/src/test/resources/logback.xml diff --git a/nd4j/nd4j-remote/pom.xml b/.old/nd4j/nd4j-remote/pom.xml similarity index 100% rename from nd4j/nd4j-remote/pom.xml rename to .old/nd4j/nd4j-remote/pom.xml diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/.old/nd4j/nd4j-serde/nd4j-arrow/pom.xml similarity index 100% rename from nd4j/nd4j-serde/nd4j-arrow/pom.xml rename to .old/nd4j/nd4j-serde/nd4j-arrow/pom.xml diff --git a/nd4j/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/ArrowSerde.java b/.old/nd4j/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/ArrowSerde.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/ArrowSerde.java rename to .old/nd4j/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/ArrowSerde.java diff --git a/nd4j/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/DataBufferStruct.java b/.old/nd4j/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/DataBufferStruct.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/DataBufferStruct.java rename to .old/nd4j/nd4j-serde/nd4j-arrow/src/main/java/org/nd4j/arrow/DataBufferStruct.java diff --git a/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java b/.old/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java rename to .old/nd4j/nd4j-serde/nd4j-arrow/src/test/java/org/nd4j/arrow/ArrowSerdeTest.java diff --git a/nd4j/nd4j-serde/nd4j-kryo/pom.xml b/.old/nd4j/nd4j-serde/nd4j-kryo/pom.xml similarity index 100% rename from nd4j/nd4j-serde/nd4j-kryo/pom.xml rename to .old/nd4j/nd4j-serde/nd4j-kryo/pom.xml diff --git a/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jRegistrator.java b/.old/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jRegistrator.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jRegistrator.java rename to .old/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jRegistrator.java diff --git a/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jSerializer.java b/.old/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jSerializer.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jSerializer.java rename to .old/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/Nd4jSerializer.java diff --git a/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/primitives/AtomicDoubleSerializer.java b/.old/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/primitives/AtomicDoubleSerializer.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/primitives/AtomicDoubleSerializer.java rename to .old/nd4j/nd4j-serde/nd4j-kryo/src/main/java/org/nd4j/kryo/primitives/AtomicDoubleSerializer.java diff --git a/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java b/.old/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java similarity index 100% rename from nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java rename to .old/nd4j/nd4j-serde/nd4j-kryo/src/test/java/org/nd4j/TestNd4jKryoSerialization.java diff --git a/nd4j/nd4j-serde/pom.xml b/.old/nd4j/nd4j-serde/pom.xml similarity index 100% rename from nd4j/nd4j-serde/pom.xml rename to .old/nd4j/nd4j-serde/pom.xml diff --git a/nd4j/nd4j-tvm/pom.xml b/.old/nd4j/nd4j-tvm/pom.xml similarity index 100% rename from nd4j/nd4j-tvm/pom.xml rename to .old/nd4j/nd4j-tvm/pom.xml diff --git a/nd4j/nd4j-tvm/src/main/java/org/nd4j/tvm/runner/TvmRunner.java b/.old/nd4j/nd4j-tvm/src/main/java/org/nd4j/tvm/runner/TvmRunner.java similarity index 100% rename from nd4j/nd4j-tvm/src/main/java/org/nd4j/tvm/runner/TvmRunner.java rename to .old/nd4j/nd4j-tvm/src/main/java/org/nd4j/tvm/runner/TvmRunner.java diff --git a/nd4j/nd4j-tvm/src/main/java/org/nd4j/tvm/util/TVMUtils.java b/.old/nd4j/nd4j-tvm/src/main/java/org/nd4j/tvm/util/TVMUtils.java similarity index 100% rename from nd4j/nd4j-tvm/src/main/java/org/nd4j/tvm/util/TVMUtils.java rename to .old/nd4j/nd4j-tvm/src/main/java/org/nd4j/tvm/util/TVMUtils.java diff --git a/nd4j/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java b/.old/nd4j/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java similarity index 100% rename from nd4j/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java rename to .old/nd4j/nd4j-tvm/src/test/java/org/nd4j/tvm/runner/TvmRunnerTests.java diff --git a/nd4j/pom.xml b/.old/nd4j/pom.xml similarity index 100% rename from nd4j/pom.xml rename to .old/nd4j/pom.xml diff --git a/nd4j/samediff-import/pom.xml b/.old/nd4j/samediff-import/pom.xml similarity index 100% rename from nd4j/samediff-import/pom.xml rename to .old/nd4j/samediff-import/pom.xml diff --git a/nd4j/samediff-import/samediff-import-api/pom.xml b/.old/nd4j/samediff-import/samediff-import-api/pom.xml similarity index 100% rename from nd4j/samediff-import/samediff-import-api/pom.xml rename to .old/nd4j/samediff-import/samediff-import-api/pom.xml diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/FrameworkImporter.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/FrameworkImporter.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/FrameworkImporter.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/FrameworkImporter.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphFactory.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphFactory.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphFactory.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphFactory.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphHolder.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphHolder.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphHolder.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraphHolder.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/AbstractMappingContext.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/AbstractMappingContext.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/AbstractMappingContext.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/AbstractMappingContext.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/MappingContext.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/MappingContext.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/MappingContext.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/context/MappingContext.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PostImportHook.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PostImportHook.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PostImportHook.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PostImportHook.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PreImportHook.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PreImportHook.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PreImportHook.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/PreImportHook.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/HookResult.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/HookResult.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/HookResult.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/HookResult.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PostHookRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PostHookRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PostHookRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PostHookRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PreHookRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PreHookRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PreHookRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/hooks/annotations/PreHookRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRArgDef.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRArgDef.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRArgDef.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRArgDef.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRAttribute.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataType.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataType.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataType.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataType.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataTypeValue.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataTypeValue.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataTypeValue.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRDataTypeValue.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRFunctions.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRFunctions.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRFunctions.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRFunctions.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRGraph.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRGraph.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRGraph.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRGraph.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRNode.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRNode.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRNode.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRNode.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IROpDef.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IROpDef.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IROpDef.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IROpDef.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRTensor.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRTensor.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRTensor.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ir/IRTensor.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/mapper/MapperExtensions.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/mapper/MapperExtensions.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/mapper/MapperExtensions.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/mapper/MapperExtensions.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoader.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoader.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoader.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoaderHolder.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoaderHolder.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoaderHolder.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/opdefs/OpDescriptorLoaderHolder.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcess.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcess.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcess.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcess.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcessLoader.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcessLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcessLoader.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/AbstractMappingProcessLoader.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcess.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcess.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcess.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcess.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcessLoader.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcessLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcessLoader.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/process/MappingProcessLoader.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/reflect/ImportReflectionCache.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/reflect/ImportReflectionCache.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/reflect/ImportReflectionCache.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/reflect/ImportReflectionCache.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/ObjectRegistryHolder.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/ObjectRegistryHolder.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/ObjectRegistryHolder.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/ObjectRegistryHolder.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/OpMappingRegistry.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/OpMappingRegistry.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/OpMappingRegistry.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/registry/OpMappingRegistry.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/MappingRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/MappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/MappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/MappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ArgDescriptorConstant.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ArgDescriptorConstant.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ArgDescriptorConstant.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ArgDescriptorConstant.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNDArrayToScalarAttribute.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNDArrayToScalarAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNDArrayToScalarAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNDArrayToScalarAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNumberListNDArray.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNumberListNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNumberListNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeNumberListNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeScalarNDArrayAttribute.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeScalarNDArrayAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeScalarNDArrayAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeScalarNDArrayAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeValueType.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeValueType.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeValueType.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/AttributeValueType.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/BaseAttributeExtractionRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/BaseAttributeExtractionRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/BaseAttributeExtractionRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/BaseAttributeExtractionRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexArrayRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexArrayRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexArrayRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexArrayRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexNDArrayRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexNDArrayRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexNDArrayRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ConditionalFieldValueIntIndexNDArrayRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/DataTypeToInt.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/DataTypeToInt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/DataTypeToInt.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/DataTypeToInt.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/FlattenDims.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/FlattenDims.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/FlattenDims.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/FlattenDims.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/IRMappingFunctions.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/IRMappingFunctions.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/IRMappingFunctions.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/IRMappingFunctions.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/InvertBooleanNumber.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/InvertBooleanNumber.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/InvertBooleanNumber.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/InvertBooleanNumber.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListAttributeValueLookupToIndex.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListAttributeValueLookupToIndex.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListAttributeValueLookupToIndex.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListAttributeValueLookupToIndex.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToListNumber.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToListNumber.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToListNumber.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToListNumber.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToNDArray.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ListNumberToNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/MapStringToInt.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/MapStringToInt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/MapStringToInt.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/MapStringToInt.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayAttributeToNDArrayInput.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayAttributeToNDArrayInput.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayAttributeToNDArrayInput.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayAttributeToNDArrayInput.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayExtractScalarValue.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayExtractScalarValue.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayExtractScalarValue.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayExtractScalarValue.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayInputToNumericalAttribute.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayInputToNumericalAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayInputToNumericalAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayInputToNumericalAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArraySizeAtRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArraySizeAtRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArraySizeAtRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArraySizeAtRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayToIntAttributeValue.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayToIntAttributeValue.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayToIntAttributeValue.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NDArrayToIntAttributeValue.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NumberToBoolean.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NumberToBoolean.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NumberToBoolean.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/NumberToBoolean.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/SizeThresholdIntArrayIntIndexRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/SizeThresholdIntArrayIntIndexRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/SizeThresholdIntArrayIntIndexRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/SizeThresholdIntArrayIntIndexRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringAttributeToNDArray.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringAttributeToNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringAttributeToNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringAttributeToNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringContainsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringContainsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringContainsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringContainsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringEqualsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringEqualsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringEqualsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringEqualsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringNotEqualsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringNotEqualsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringNotEqualsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringNotEqualsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringToInt.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringToInt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringToInt.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/StringToInt.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ValueMapping.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ValueMapping.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ValueMapping.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/attribute/ValueMapping.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/BaseNDArrayMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/BaseNDArrayMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/BaseNDArrayMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/BaseNDArrayMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/MultiInputIndexMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/MultiInputIndexMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/MultiInputIndexMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/MultiInputIndexMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/PassThroughMultiTensorMapping.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/PassThroughMultiTensorMapping.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/PassThroughMultiTensorMapping.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/PassThroughMultiTensorMapping.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/TensorMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/TensorMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/TensorMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/rule/tensor/TensorMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/DefaultImportRunner.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/DefaultImportRunner.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/DefaultImportRunner.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/DefaultImportRunner.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/IRGraphRunner.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/IRGraphRunner.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/IRGraphRunner.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/IRGraphRunner.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/ImportRunner.kt b/.old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/ImportRunner.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/ImportRunner.kt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/runner/ImportRunner.kt diff --git a/nd4j/samediff-import/samediff-import-api/src/main/resources/nd4j-op-def.pbtxt b/.old/nd4j/samediff-import/samediff-import-api/src/main/resources/nd4j-op-def.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-api/src/main/resources/nd4j-op-def.pbtxt rename to .old/nd4j/samediff-import/samediff-import-api/src/main/resources/nd4j-op-def.pbtxt diff --git a/nd4j/samediff-import/samediff-import-onnx/onnx-processes.pbtxt b/.old/nd4j/samediff-import/samediff-import-onnx/onnx-processes.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/onnx-processes.pbtxt rename to .old/nd4j/samediff-import/samediff-import-onnx/onnx-processes.pbtxt diff --git a/nd4j/samediff-import/samediff-import-onnx/ops-added-new.txt b/.old/nd4j/samediff-import/samediff-import-onnx/ops-added-new.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/ops-added-new.txt rename to .old/nd4j/samediff-import/samediff-import-onnx/ops-added-new.txt diff --git a/nd4j/samediff-import/samediff-import-onnx/ops-imported-new.txt b/.old/nd4j/samediff-import/samediff-import-onnx/ops-imported-new.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/ops-imported-new.txt rename to .old/nd4j/samediff-import/samediff-import-onnx/ops-imported-new.txt diff --git a/nd4j/samediff-import/samediff-import-onnx/ops-removed-new.txt b/.old/nd4j/samediff-import/samediff-import-onnx/ops-removed-new.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/ops-removed-new.txt rename to .old/nd4j/samediff-import/samediff-import-onnx/ops-removed-new.txt diff --git a/nd4j/samediff-import/samediff-import-onnx/pom.xml b/.old/nd4j/samediff-import/samediff-import-onnx/pom.xml similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/pom.xml rename to .old/nd4j/samediff-import/samediff-import-onnx/pom.xml diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxIR.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxIR.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxIR.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxIR.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraph.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraph.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraph.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraph.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraphHolder.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraphHolder.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraphHolder.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxImportGraphHolder.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxProtobufExtensions.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxProtobufExtensions.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxProtobufExtensions.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxProtobufExtensions.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxRuleDeclarations.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxRuleDeclarations.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxRuleDeclarations.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/OnnxRuleDeclarations.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/context/OnnxMappingContext.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/context/OnnxMappingContext.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/context/OnnxMappingContext.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/context/OnnxMappingContext.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/definitions/OnnxOpDeclarations.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/definitions/OnnxOpDeclarations.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/definitions/OnnxOpDeclarations.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/definitions/OnnxOpDeclarations.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/OnnxFrameworkImporter.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/OnnxFrameworkImporter.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/OnnxFrameworkImporter.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/OnnxFrameworkImporter.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRArgDef.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRArgDef.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRArgDef.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRArgDef.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRAttr.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRAttr.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRAttr.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRAttr.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRDataType.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRDataType.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRDataType.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRDataType.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraph.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraph.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraph.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraph.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraphRunner.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraphRunner.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraphRunner.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRGraphRunner.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRNode.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRNode.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRNode.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRNode.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIROp.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIROp.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIROp.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIROp.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRTensor.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRTensor.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRTensor.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/ir/OnnxIRTensor.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/opdefs/OnnxOpDescriptorLoader.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/opdefs/OnnxOpDescriptorLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/opdefs/OnnxOpDescriptorLoader.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/opdefs/OnnxOpDescriptorLoader.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcess.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcess.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcess.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcess.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcessLoader.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcessLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcessLoader.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/process/OnnxMappingProcessLoader.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxArgDescriptorConstant.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxArgDescriptorConstant.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxArgDescriptorConstant.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxArgDescriptorConstant.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNDArrayToScalarAttribute.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNDArrayToScalarAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNDArrayToScalarAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNDArrayToScalarAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNumberListNDArray.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNumberListNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNumberListNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeNumberListNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeScalarNDArrayAttribute.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeScalarNDArrayAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeScalarNDArrayAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxAttributeScalarNDArrayAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexArrayRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexArrayRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexArrayRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexArrayRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexNDArrayRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexNDArrayRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexNDArrayRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxConditionalFieldValueIntIndexNDArrayRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxDataTypeToInt.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxDataTypeToInt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxDataTypeToInt.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxDataTypeToInt.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxFlattenDims.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxFlattenDims.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxFlattenDims.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxFlattenDims.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxInvertBooleanNumber.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxInvertBooleanNumber.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxInvertBooleanNumber.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxInvertBooleanNumber.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListAttributeValueLookupToIndex.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListAttributeValueLookupToIndex.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListAttributeValueLookupToIndex.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListAttributeValueLookupToIndex.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToListNumber.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToListNumber.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToListNumber.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToListNumber.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToNDArray.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxListNumberToNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxMapStringToInt.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxMapStringToInt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxMapStringToInt.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxMapStringToInt.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayAttributeToNDArrayInput.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayAttributeToNDArrayInput.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayAttributeToNDArrayInput.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayAttributeToNDArrayInput.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayExtractScalarValue.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayExtractScalarValue.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayExtractScalarValue.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayExtractScalarValue.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayInputToNumericalAttribute.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayInputToNumericalAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayInputToNumericalAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayInputToNumericalAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArraySizeAt.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArraySizeAt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArraySizeAt.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArraySizeAt.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayToIntAttributeValue.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayToIntAttributeValue.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayToIntAttributeValue.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxNDArrayToIntAttributeValue.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxSizeThresholdIntArrayIntIndexRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxSizeThresholdIntArrayIntIndexRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxSizeThresholdIntArrayIntIndexRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxSizeThresholdIntArrayIntIndexRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringAttributeToNDArray.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringAttributeToNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringAttributeToNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringAttributeToNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringContainsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringContainsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringContainsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringContainsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringEqualsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringEqualsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringEqualsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringEqualsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringNotEqualsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringNotEqualsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringNotEqualsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringNotEqualsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringToIndex.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringToIndex.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringToIndex.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxStringToIndex.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxValueMapping.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxValueMapping.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxValueMapping.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/attribute/OnnxValueMapping.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/NDArrayMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/NDArrayMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/NDArrayMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/NDArrayMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxMultiInputIndexMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxMultiInputIndexMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxMultiInputIndexMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxMultiInputIndexMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxPassThroughMultiInputTensorMapping.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxPassThroughMultiInputTensorMapping.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxPassThroughMultiInputTensorMapping.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/kotlin/org/nd4j/samediff/frameworkimport/onnx/rule/tensor/OnnxPassThroughMultiInputTensorMapping.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-mapping-ruleset.pbtxt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-mapping-ruleset.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-mapping-ruleset.pbtxt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-mapping-ruleset.pbtxt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-def.pbtxt b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-def.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-def.pbtxt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-def.pbtxt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-defs.pb b/.old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-defs.pb similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-defs.pb rename to .old/nd4j/samediff-import/samediff-import-onnx/src/main/resources/onnx-op-defs.pb diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/TestOnnxIR.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/importer/TestOnnxFrameworkImporter.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/loader/TestOnnxProcessLoader.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/loader/TestOnnxProcessLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/loader/TestOnnxProcessLoader.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/loader/TestOnnxProcessLoader.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/modelzoo/TestPretrainedModels.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/processing/GroupConvPreProcessingRule.kt b/.old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/processing/GroupConvPreProcessingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/processing/GroupConvPreProcessingRule.kt rename to .old/nd4j/samediff-import/samediff-import-onnx/src/test/kotlin/org/nd4j/samediff/frameworkimport/onnx/processing/GroupConvPreProcessingRule.kt diff --git a/nd4j/samediff-import/samediff-import-onnx/src/test/resources/lenet.onnx b/.old/nd4j/samediff-import/samediff-import-onnx/src/test/resources/lenet.onnx similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/src/test/resources/lenet.onnx rename to .old/nd4j/samediff-import/samediff-import-onnx/src/test/resources/lenet.onnx diff --git a/nd4j/samediff-import/samediff-import-onnx/variables-added-new.txt b/.old/nd4j/samediff-import/samediff-import-onnx/variables-added-new.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-onnx/variables-added-new.txt rename to .old/nd4j/samediff-import/samediff-import-onnx/variables-added-new.txt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/00c6b5c8-c93c-4ac9-867f-580443a45bb3-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/00c6b5c8-c93c-4ac9-867f-580443a45bb3-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/00c6b5c8-c93c-4ac9-867f-580443a45bb3-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/00c6b5c8-c93c-4ac9-867f-580443a45bb3-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/03195bad-47a3-4de9-9fc7-6691ea41aee0-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/03195bad-47a3-4de9-9fc7-6691ea41aee0-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/03195bad-47a3-4de9-9fc7-6691ea41aee0-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/03195bad-47a3-4de9-9fc7-6691ea41aee0-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/04c57933-461b-4d6f-b6a8-a210cef103ff-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/04c57933-461b-4d6f-b6a8-a210cef103ff-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/04c57933-461b-4d6f-b6a8-a210cef103ff-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/04c57933-461b-4d6f-b6a8-a210cef103ff-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/0541722d-1de4-4e85-b844-d90d20eea9fb-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/0541722d-1de4-4e85-b844-d90d20eea9fb-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/0541722d-1de4-4e85-b844-d90d20eea9fb-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/0541722d-1de4-4e85-b844-d90d20eea9fb-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/154f67c7-64e1-4e2c-a56e-05d390b459d7-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/154f67c7-64e1-4e2c-a56e-05d390b459d7-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/154f67c7-64e1-4e2c-a56e-05d390b459d7-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/154f67c7-64e1-4e2c-a56e-05d390b459d7-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/2026534d-ef52-441c-976b-3ef06799a362-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/2026534d-ef52-441c-976b-3ef06799a362-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/2026534d-ef52-441c-976b-3ef06799a362-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/2026534d-ef52-441c-976b-3ef06799a362-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/22943ac9-56da-4b92-983d-7385c888c80b-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/22943ac9-56da-4b92-983d-7385c888c80b-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/22943ac9-56da-4b92-983d-7385c888c80b-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/22943ac9-56da-4b92-983d-7385c888c80b-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/281f7eda-053b-4dc9-a686-2040bb4f7fd3-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/281f7eda-053b-4dc9-a686-2040bb4f7fd3-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/281f7eda-053b-4dc9-a686-2040bb4f7fd3-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/281f7eda-053b-4dc9-a686-2040bb4f7fd3-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/2cfad6f7-cd22-4de1-80a6-b9890ce473fc-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/2cfad6f7-cd22-4de1-80a6-b9890ce473fc-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/2cfad6f7-cd22-4de1-80a6-b9890ce473fc-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/2cfad6f7-cd22-4de1-80a6-b9890ce473fc-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/36bc053b-ed9c-40b7-853c-d9462d2a67c0-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/36bc053b-ed9c-40b7-853c-d9462d2a67c0-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/36bc053b-ed9c-40b7-853c-d9462d2a67c0-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/36bc053b-ed9c-40b7-853c-d9462d2a67c0-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/445afe43-7f5f-4f5b-81db-e942139be1a7-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/445afe43-7f5f-4f5b-81db-e942139be1a7-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/445afe43-7f5f-4f5b-81db-e942139be1a7-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/445afe43-7f5f-4f5b-81db-e942139be1a7-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/505e442f-5f0d-4fe6-80ba-0628e9f3057b-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/505e442f-5f0d-4fe6-80ba-0628e9f3057b-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/505e442f-5f0d-4fe6-80ba-0628e9f3057b-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/505e442f-5f0d-4fe6-80ba-0628e9f3057b-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/52bec90b-a05a-4382-a2fc-0835d2da893a-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/52bec90b-a05a-4382-a2fc-0835d2da893a-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/52bec90b-a05a-4382-a2fc-0835d2da893a-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/52bec90b-a05a-4382-a2fc-0835d2da893a-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/58aad3d2-4a46-47d3-9748-578e7aae7121-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/58aad3d2-4a46-47d3-9748-578e7aae7121-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/58aad3d2-4a46-47d3-9748-578e7aae7121-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/58aad3d2-4a46-47d3-9748-578e7aae7121-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/5b55351d-1e98-4c83-b1ec-0812d352141d-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/5b55351d-1e98-4c83-b1ec-0812d352141d-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/5b55351d-1e98-4c83-b1ec-0812d352141d-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/5b55351d-1e98-4c83-b1ec-0812d352141d-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/629659e3-ed5c-482a-89cf-0b4f46026b31-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/629659e3-ed5c-482a-89cf-0b4f46026b31-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/629659e3-ed5c-482a-89cf-0b4f46026b31-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/629659e3-ed5c-482a-89cf-0b4f46026b31-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/6b35dfd9-1d25-419e-a196-4a42f20fd8aa-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/6b35dfd9-1d25-419e-a196-4a42f20fd8aa-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/6b35dfd9-1d25-419e-a196-4a42f20fd8aa-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/6b35dfd9-1d25-419e-a196-4a42f20fd8aa-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/6f446d7c-ec8d-4567-9f3a-3b9bcb3d21f8-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/6f446d7c-ec8d-4567-9f3a-3b9bcb3d21f8-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/6f446d7c-ec8d-4567-9f3a-3b9bcb3d21f8-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/6f446d7c-ec8d-4567-9f3a-3b9bcb3d21f8-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/75db525b-2344-4dcc-a3d9-d23b5acbfe81-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/75db525b-2344-4dcc-a3d9-d23b5acbfe81-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/75db525b-2344-4dcc-a3d9-d23b5acbfe81-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/75db525b-2344-4dcc-a3d9-d23b5acbfe81-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/767f4fe3-b7a8-492b-b0ab-ae77f112e105-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/767f4fe3-b7a8-492b-b0ab-ae77f112e105-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/767f4fe3-b7a8-492b-b0ab-ae77f112e105-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/767f4fe3-b7a8-492b-b0ab-ae77f112e105-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/7b58c709-b0e3-446f-9e05-4fd86f350b83-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/7b58c709-b0e3-446f-9e05-4fd86f350b83-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/7b58c709-b0e3-446f-9e05-4fd86f350b83-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/7b58c709-b0e3-446f-9e05-4fd86f350b83-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/7ce20a5d-8b63-499e-9dc4-0206d0c38b29-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/7ce20a5d-8b63-499e-9dc4-0206d0c38b29-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/7ce20a5d-8b63-499e-9dc4-0206d0c38b29-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/7ce20a5d-8b63-499e-9dc4-0206d0c38b29-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/8b28bcc6-1a38-4b55-bc80-185feab4c978-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/8b28bcc6-1a38-4b55-bc80-185feab4c978-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/8b28bcc6-1a38-4b55-bc80-185feab4c978-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/8b28bcc6-1a38-4b55-bc80-185feab4c978-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/9d55689b-05d2-4692-ad17-af4aa388cb31-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/9d55689b-05d2-4692-ad17-af4aa388cb31-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/9d55689b-05d2-4692-ad17-af4aa388cb31-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/9d55689b-05d2-4692-ad17-af4aa388cb31-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/9f6c1a25-9c6a-40b1-b825-916525e2cb24-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/9f6c1a25-9c6a-40b1-b825-916525e2cb24-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/9f6c1a25-9c6a-40b1-b825-916525e2cb24-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/9f6c1a25-9c6a-40b1-b825-916525e2cb24-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/a2400879-a732-411c-a65e-00111c6b550e-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/a2400879-a732-411c-a65e-00111c6b550e-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/a2400879-a732-411c-a65e-00111c6b550e-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/a2400879-a732-411c-a65e-00111c6b550e-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/a24baaa5-1cb5-4edd-873b-c923d04905ec-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/a24baaa5-1cb5-4edd-873b-c923d04905ec-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/a24baaa5-1cb5-4edd-873b-c923d04905ec-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/a24baaa5-1cb5-4edd-873b-c923d04905ec-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/b3b58b3b-9a41-44aa-9b00-2bb9633a53be-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/b3b58b3b-9a41-44aa-9b00-2bb9633a53be-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/b3b58b3b-9a41-44aa-9b00-2bb9633a53be-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/b3b58b3b-9a41-44aa-9b00-2bb9633a53be-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/c0f861a5-c322-458b-82e9-efd5494d37fc-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/c0f861a5-c322-458b-82e9-efd5494d37fc-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/c0f861a5-c322-458b-82e9-efd5494d37fc-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/c0f861a5-c322-458b-82e9-efd5494d37fc-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/ca8a7a37-5ce9-4970-aa3d-7eaec8c8091a-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/ca8a7a37-5ce9-4970-aa3d-7eaec8c8091a-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/ca8a7a37-5ce9-4970-aa3d-7eaec8c8091a-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/ca8a7a37-5ce9-4970-aa3d-7eaec8c8091a-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/e25409f2-aa78-4897-a810-297802cccdfc-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/e25409f2-aa78-4897-a810-297802cccdfc-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/e25409f2-aa78-4897-a810-297802cccdfc-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/e25409f2-aa78-4897-a810-297802cccdfc-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/e82fe95c-6cd2-4a8d-82c7-9f45d15e8a73-container.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/e82fe95c-6cd2-4a8d-82c7-9f45d15e8a73-container.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/e82fe95c-6cd2-4a8d-82c7-9f45d15e8a73-container.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/e82fe95c-6cd2-4a8d-82c7-9f45d15e8a73-container.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/eb70b069-8c1d-440c-a135-174d7b873d11-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/eb70b069-8c1d-440c-a135-174d7b873d11-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/eb70b069-8c1d-440c-a135-174d7b873d11-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/eb70b069-8c1d-440c-a135-174d7b873d11-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/allure-results/ef515f16-0d58-450b-85bb-ec61080f012f-result.json b/.old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/ef515f16-0d58-450b-85bb-ec61080f012f-result.json similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/allure-results/ef515f16-0d58-450b-85bb-ec61080f012f-result.json rename to .old/nd4j/samediff-import/samediff-import-tensorflow/allure-results/ef515f16-0d58-450b-85bb-ec61080f012f-result.json diff --git a/nd4j/samediff-import/samediff-import-tensorflow/nd4j-op-def.pbtxt b/.old/nd4j/samediff-import/samediff-import-tensorflow/nd4j-op-def.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/nd4j-op-def.pbtxt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/nd4j-op-def.pbtxt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/ops-added-new.txt b/.old/nd4j/samediff-import/samediff-import-tensorflow/ops-added-new.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/ops-added-new.txt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/ops-added-new.txt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/ops-added-old.txt b/.old/nd4j/samediff-import/samediff-import-tensorflow/ops-added-old.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/ops-added-old.txt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/ops-added-old.txt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/ops-imported-new.txt b/.old/nd4j/samediff-import/samediff-import-tensorflow/ops-imported-new.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/ops-imported-new.txt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/ops-imported-new.txt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/ops-imported-old.txt b/.old/nd4j/samediff-import/samediff-import-tensorflow/ops-imported-old.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/ops-imported-old.txt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/ops-imported-old.txt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/ops-removed-new.txt b/.old/nd4j/samediff-import/samediff-import-tensorflow/ops-removed-new.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/ops-removed-new.txt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/ops-removed-new.txt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/ops-removed-old.txt b/.old/nd4j/samediff-import/samediff-import-tensorflow/ops-removed-old.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/ops-removed-old.txt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/ops-removed-old.txt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/pom.xml b/.old/nd4j/samediff-import/samediff-import-tensorflow/pom.xml similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/pom.xml rename to .old/nd4j/samediff-import/samediff-import-tensorflow/pom.xml diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraph.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraph.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraph.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraph.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraphHolder.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraphHolder.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraphHolder.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowImportGraphHolder.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowProtobufExtensions.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowProtobufExtensions.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowProtobufExtensions.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowProtobufExtensions.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowRuleDeclarations.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowRuleDeclarations.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowRuleDeclarations.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TensorflowRuleDeclarations.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/context/TensorflowMappingContext.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/context/TensorflowMappingContext.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/context/TensorflowMappingContext.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/context/TensorflowMappingContext.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/definitions/TensorflowOpDeclarations.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TensorflowFrameworkImporter.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TensorflowFrameworkImporter.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TensorflowFrameworkImporter.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TensorflowFrameworkImporter.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIR.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIR.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIR.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIR.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRArgDef.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRArgDef.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRArgDef.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRArgDef.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRAttr.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRAttr.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRAttr.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRAttr.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRDataType.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRDataType.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRDataType.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRDataType.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraph.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraph.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraph.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraph.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraphRunner.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraphRunner.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraphRunner.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRGraphRunner.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRNode.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRNode.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRNode.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRNode.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIROp.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIROp.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIROp.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIROp.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRTensor.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRTensor.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRTensor.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/ir/TensorflowIRTensor.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/opdefs/TensorflowOpDescriptorLoader.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcess.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcess.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcess.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcess.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcessLoader.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcessLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcessLoader.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/process/TensorflowMappingProcessLoader.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowArgDescriptorConstant.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowArgDescriptorConstant.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowArgDescriptorConstant.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowArgDescriptorConstant.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNDArrayToScalarAttribute.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNDArrayToScalarAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNDArrayToScalarAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNDArrayToScalarAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNumberListNDArray.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNumberListNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNumberListNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeNumberListNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeScalarNDArrayAttribute.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeScalarNDArrayAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeScalarNDArrayAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowAttributeScalarNDArrayAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexArrayRule.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexArrayRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexArrayRule.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexArrayRule.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexNDArrayRule.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexNDArrayRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexNDArrayRule.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowConditionalFieldValueIntIndexNDArrayRule.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowDataTypeToInt.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowDataTypeToInt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowDataTypeToInt.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowDataTypeToInt.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowFlattenDims.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowFlattenDims.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowFlattenDims.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowFlattenDims.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowInvertBooleanNumber.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowInvertBooleanNumber.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowInvertBooleanNumber.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowInvertBooleanNumber.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListAttributeValueLookupToIndex.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListAttributeValueLookupToIndex.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListAttributeValueLookupToIndex.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListAttributeValueLookupToIndex.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToListNumber.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToListNumber.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToListNumber.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToListNumber.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToNDArray.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowListNumberToNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowMapStringToInt.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowMapStringToInt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowMapStringToInt.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowMapStringToInt.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayAttributeToNDArrayInput.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayAttributeToNDArrayInput.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayAttributeToNDArrayInput.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayAttributeToNDArrayInput.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayExtractScalarValue.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayExtractScalarValue.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayExtractScalarValue.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayExtractScalarValue.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayInputToNumericalAttribute.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayInputToNumericalAttribute.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayInputToNumericalAttribute.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayInputToNumericalAttribute.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArraySizeAt.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArraySizeAt.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArraySizeAt.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArraySizeAt.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayToIntAttributeValue.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayToIntAttributeValue.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayToIntAttributeValue.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNDArrayToIntAttributeValue.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNdArrayToStringIndex.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNdArrayToStringIndex.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNdArrayToStringIndex.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowNdArrayToStringIndex.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringAttributeToNDArray.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringAttributeToNDArray.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringAttributeToNDArray.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringAttributeToNDArray.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringContainsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringContainsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringContainsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringContainsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringEqualsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringEqualsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringEqualsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringEqualsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringNotEqualsAdapterRule.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringNotEqualsAdapterRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringNotEqualsAdapterRule.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowStringNotEqualsAdapterRule.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowValueMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowValueMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowValueMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/attribute/TensorflowValueMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/NDArrayMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/NDArrayMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/NDArrayMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/NDArrayMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowMultiInputIndexMappingRule.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowMultiInputIndexMappingRule.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowMultiInputIndexMappingRule.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowMultiInputIndexMappingRule.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowPassThroughMultiTensorMapping.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowPassThroughMultiTensorMapping.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowPassThroughMultiTensorMapping.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/rule/tensor/TensorflowPassThroughMultiTensorMapping.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.ImportGraphHolder diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/META-INF/services/org.nd4j.samediff.frameworkimport.opdefs.OpDescriptorLoader diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-mapping-ruleset.pbtxt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-op-def.pbtxt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-op-def.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-op-def.pbtxt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/main/resources/tensorflow-op-def.pbtxt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowIR.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowUtils.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowUtils.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowUtils.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/TestTensorflowUtils.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/importer/TestTensorflowImporter.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/loader/TestTensorflowProcessLoader.kt b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/loader/TestTensorflowProcessLoader.kt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/loader/TestTensorflowProcessLoader.kt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/test/kotlin/org/nd4j/samediff/frameworkimport/tensorflow/loader/TestTensorflowProcessLoader.kt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/resources/lenet_frozen.pb b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/test/resources/lenet_frozen.pb similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/test/resources/lenet_frozen.pb rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/test/resources/lenet_frozen.pb diff --git a/nd4j/samediff-import/samediff-import-tensorflow/src/test/resources/logback.xml b/.old/nd4j/samediff-import/samediff-import-tensorflow/src/test/resources/logback.xml similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/src/test/resources/logback.xml rename to .old/nd4j/samediff-import/samediff-import-tensorflow/src/test/resources/logback.xml diff --git a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt b/.old/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/test.pbtxt b/.old/nd4j/samediff-import/samediff-import-tensorflow/test.pbtxt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/test.pbtxt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/test.pbtxt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/variables-added-new.txt b/.old/nd4j/samediff-import/samediff-import-tensorflow/variables-added-new.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/variables-added-new.txt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/variables-added-new.txt diff --git a/nd4j/samediff-import/samediff-import-tensorflow/variables-added-old.txt b/.old/nd4j/samediff-import/samediff-import-tensorflow/variables-added-old.txt similarity index 100% rename from nd4j/samediff-import/samediff-import-tensorflow/variables-added-old.txt rename to .old/nd4j/samediff-import/samediff-import-tensorflow/variables-added-old.txt diff --git a/perform-release.sh b/.old/perform-release.sh similarity index 100% rename from perform-release.sh rename to .old/perform-release.sh diff --git a/pydatavec/.eggs/README.txt b/.old/pydatavec/.eggs/README.txt similarity index 100% rename from pydatavec/.eggs/README.txt rename to .old/pydatavec/.eggs/README.txt diff --git a/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/LICENSE b/.old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/LICENSE similarity index 100% rename from pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/LICENSE rename to .old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/LICENSE diff --git a/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/PKG-INFO b/.old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/PKG-INFO similarity index 100% rename from pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/PKG-INFO rename to .old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/PKG-INFO diff --git a/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/RECORD b/.old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/RECORD similarity index 100% rename from pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/RECORD rename to .old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/RECORD diff --git a/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/WHEEL b/.old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/WHEEL similarity index 100% rename from pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/WHEEL rename to .old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/WHEEL diff --git a/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/entry_points.txt b/.old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/entry_points.txt similarity index 100% rename from pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/entry_points.txt rename to .old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/entry_points.txt diff --git a/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/requires.txt b/.old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/requires.txt similarity index 100% rename from pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/requires.txt rename to .old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/requires.txt diff --git a/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/top_level.txt b/.old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/top_level.txt similarity index 100% rename from pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/top_level.txt rename to .old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/EGG-INFO/top_level.txt diff --git a/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/ptr.py b/.old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/ptr.py similarity index 100% rename from pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/ptr.py rename to .old/pydatavec/.eggs/pytest_runner-5.2-py3.8.egg/ptr.py diff --git a/pydatavec/pydatavec.egg-info/PKG-INFO b/.old/pydatavec/pydatavec.egg-info/PKG-INFO similarity index 100% rename from pydatavec/pydatavec.egg-info/PKG-INFO rename to .old/pydatavec/pydatavec.egg-info/PKG-INFO diff --git a/pydatavec/pydatavec.egg-info/SOURCES.txt b/.old/pydatavec/pydatavec.egg-info/SOURCES.txt similarity index 100% rename from pydatavec/pydatavec.egg-info/SOURCES.txt rename to .old/pydatavec/pydatavec.egg-info/SOURCES.txt diff --git a/pydatavec/pydatavec.egg-info/dependency_links.txt b/.old/pydatavec/pydatavec.egg-info/dependency_links.txt similarity index 100% rename from pydatavec/pydatavec.egg-info/dependency_links.txt rename to .old/pydatavec/pydatavec.egg-info/dependency_links.txt diff --git a/pydatavec/pydatavec.egg-info/requires.txt b/.old/pydatavec/pydatavec.egg-info/requires.txt similarity index 100% rename from pydatavec/pydatavec.egg-info/requires.txt rename to .old/pydatavec/pydatavec.egg-info/requires.txt diff --git a/pydatavec/pydatavec.egg-info/top_level.txt b/.old/pydatavec/pydatavec.egg-info/top_level.txt similarity index 100% rename from pydatavec/pydatavec.egg-info/top_level.txt rename to .old/pydatavec/pydatavec.egg-info/top_level.txt diff --git a/rl4j/README.md b/.old/rl4j/README.md similarity index 100% rename from rl4j/README.md rename to .old/rl4j/README.md diff --git a/rl4j/docs/images/cartpole.gif b/.old/rl4j/docs/images/cartpole.gif similarity index 100% rename from rl4j/docs/images/cartpole.gif rename to .old/rl4j/docs/images/cartpole.gif diff --git a/rl4j/docs/images/doom.gif b/.old/rl4j/docs/images/doom.gif similarity index 100% rename from rl4j/docs/images/doom.gif rename to .old/rl4j/docs/images/doom.gif diff --git a/rl4j/docs/images/malmo.gif b/.old/rl4j/docs/images/malmo.gif similarity index 100% rename from rl4j/docs/images/malmo.gif rename to .old/rl4j/docs/images/malmo.gif diff --git a/rl4j/pom.xml b/.old/rl4j/pom.xml similarity index 100% rename from rl4j/pom.xml rename to .old/rl4j/pom.xml diff --git a/rl4j/rl4j-ale/pom.xml b/.old/rl4j/rl4j-ale/pom.xml similarity index 100% rename from rl4j/rl4j-ale/pom.xml rename to .old/rl4j/rl4j-ale/pom.xml diff --git a/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java b/.old/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java similarity index 100% rename from rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java rename to .old/rl4j/rl4j-ale/src/main/java/org/deeplearning4j/rl4j/mdp/ale/ALEMDP.java diff --git a/rl4j/rl4j-api/pom.xml b/.old/rl4j/rl4j-api/pom.xml similarity index 100% rename from rl4j/rl4j-api/pom.xml rename to .old/rl4j/rl4j-api/pom.xml diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/gym/StepReply.java diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/mdp/MDP.java diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ActionSpace.java diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ArrayObservationSpace.java diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Box.java diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/DiscreteSpace.java diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/Encodable.java diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/HighLowDiscrete.java diff --git a/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java b/.old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java similarity index 100% rename from rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java rename to .old/rl4j/rl4j-api/src/main/java/org/deeplearning4j/rl4j/space/ObservationSpace.java diff --git a/rl4j/rl4j-core/nd4j-native.properties b/.old/rl4j/rl4j-core/nd4j-native.properties similarity index 100% rename from rl4j/rl4j-core/nd4j-native.properties rename to .old/rl4j/rl4j-core/nd4j-native.properties diff --git a/rl4j/rl4j-core/pom.xml b/.old/rl4j/rl4j-core/pom.xml similarity index 100% rename from rl4j/rl4j-core/pom.xml rename to .old/rl4j/rl4j-core/pom.xml diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/Agent.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/AgentLearner.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgent.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/IAgentLearner.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/IUpdateAlgorithm.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/IUpdateAlgorithm.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/IUpdateAlgorithm.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/IUpdateAlgorithm.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/ActorCriticHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/ActorCriticHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/ActorCriticHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/ActorCriticHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/AdvantageActorCritic.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseDQNAlgorithm.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseDQNAlgorithm.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseDQNAlgorithm.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseDQNAlgorithm.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/BaseTransitionTDAlgorithm.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQN.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQN.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQN.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQN.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQN.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQN.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQN.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQN.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearning.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearning.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearning.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearning.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearningHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearningHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearningHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NStepQLearningHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/ILearningBehavior.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/ILearningBehavior.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/ILearningBehavior.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/ILearningBehavior.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehavior.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehavior.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehavior.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehavior.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Features.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Features.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Features.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Features.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilder.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabels.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabels.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabels.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabels.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Gradients.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Gradients.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Gradients.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/Gradients.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/IUpdateRule.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/IUpdateRule.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/IUpdateRule.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/IUpdateRule.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRule.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRule.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRule.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRule.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/INeuralNetUpdater.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/INeuralNetUpdater.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/INeuralNetUpdater.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/INeuralNetUpdater.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/NeuralNetUpdaterConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/NeuralNetUpdaterConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/NeuralNetUpdaterConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/NeuralNetUpdaterConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdater.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdater.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdater.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdater.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdater.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdater.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdater.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdater.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/BaseAsyncNeuralNetUpdater.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/BaseAsyncNeuralNetUpdater.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/BaseAsyncNeuralNetUpdater.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/BaseAsyncNeuralNetUpdater.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/BaseSyncNeuralNetUpdater.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/BaseSyncNeuralNetUpdater.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/BaseSyncNeuralNetUpdater.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/BaseSyncNeuralNetUpdater.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdater.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdater.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdater.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdater.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdater.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdater.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdater.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdater.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListener.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/listener/AgentListenerList.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AdvantageActorCriticBuilder.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AsyncNetworkHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AsyncNetworkHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AsyncNetworkHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/AsyncNetworkHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilder.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAsyncAgentLearnerBuilder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAsyncAgentLearnerBuilder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAsyncAgentLearnerBuilder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseAsyncAgentLearnerBuilder.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/BaseDQNAgentLearnerBuilder.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/DoubleDQNBuilder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/DoubleDQNBuilder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/DoubleDQNBuilder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/DoubleDQNBuilder.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/INetworksHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/INetworksHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/INetworksHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/INetworksHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/NStepQLearningBuilder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/NStepQLearningBuilder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/NStepQLearningBuilder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/NStepQLearningBuilder.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/StandardDQNBuilder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/StandardDQNBuilder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/StandardDQNBuilder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/StandardDQNBuilder.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/SyncNetworkHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/SyncNetworkHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/SyncNetworkHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/builder/SyncNetworkHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Environment.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IActionSchema.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/IntegerActionSchema.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/Schema.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/environment/StepResult.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ExperienceHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionReward.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionReward.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionReward.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionReward.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionRewardState.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionRewardState.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionRewardState.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/experience/StateActionRewardState.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/helper/INDArrayHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/HistoryProcessor.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IEpochTrainer.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/IHistoryProcessor.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/ILearning.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/Learning.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/NeuralNetFetchable.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/NeuralNetFetchable.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/NeuralNetFetchable.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/NeuralNetFetchable.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncGlobal.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThread.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscrete.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncGlobal.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/IAsyncLearning.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/UpdateAlgorithm.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteConv.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscreteDense.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithm.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteDense.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithm.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/A3CLearningConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/AsyncQLearningConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/IAsyncLearningConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/ILearningConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/LearningConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/configuration/QLearningConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListener.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerList.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/IExpReplay.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/SyncLearning.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscrete.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteConv.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteDense.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleEnvironment.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/CartpoleNative.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/DoAsISayOrDont.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/DoAsISayOrDont.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/DoAsISayOrDont.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/DoAsISayOrDont.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/TMazeEnvironment.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/TMazeEnvironment.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/TMazeEnvironment.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/TMazeEnvironment.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLake.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLake.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLake.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLake.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeMap.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeMap.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeMap.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeMap.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeState.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeState.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeState.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/robotlake/RobotLakeState.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardDeteministicToy.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/HardToyState.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToy.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/mdp/toy/SimpleToyState.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ActorCriticNetwork.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ActorCriticNetwork.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ActorCriticNetwork.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ActorCriticNetwork.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/BaseNetwork.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/BaseNetwork.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/BaseNetwork.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/BaseNetwork.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonGradientNames.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonGradientNames.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonGradientNames.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonGradientNames.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonLabelNames.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonLabelNames.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonLabelNames.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonLabelNames.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonOutputNames.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonOutputNames.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonOutputNames.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CommonOutputNames.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ComputationGraphHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ComputationGraphHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ComputationGraphHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ComputationGraphHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/INetworkHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/INetworkHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/INetworkHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/INetworkHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ITrainableNeuralNet.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NetworkHelper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NetworkHelper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NetworkHelper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NetworkHelper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNet.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNetOutput.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNetOutput.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNetOutput.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/NeuralNetOutput.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/QNetwork.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/QNetwork.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/QNetwork.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/QNetwork.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticCompGraph.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraph.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraph.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraph.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraph.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparate.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparate.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparate.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparate.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticLoss.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/IActorCritic.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticDenseNetworkConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/ActorCriticNetworkConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/DQNDenseNetworkConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/configuration/NetworkConfiguration.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactory.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactory.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactory.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactory.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/IObservationSource.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/IObservationSource.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/IObservationSource.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/IObservationSource.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/Observation.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/EncodableToINDArrayTransform.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/FilterOperation.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/ResettableOperation.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/TransformProcess.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilter.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/EncodableToImageWritableTransform.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/legacy/ImageWritableToINDArrayTransform.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransform.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransform.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransform.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransform.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransform.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransform.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStore.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeAssembler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryMergeElementStore.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssembler.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/ACPolicy.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/BoltzmannQ.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/DQNPolicy.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/EpsGreedy.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/INeuralNetPolicy.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/IPolicy.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/policy/Policy.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/AsyncTrainer.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/AsyncTrainer.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/AsyncTrainer.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/AsyncTrainer.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/ITrainer.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/ITrainer.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/ITrainer.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/ITrainer.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/SyncTrainer.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/SyncTrainer.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/SyncTrainer.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/trainer/SyncTrainer.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/Constants.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/Constants.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/Constants.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/Constants.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManager.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListener.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/IDataManager.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/LegacyMDPWrapper.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java similarity index 100% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java rename to .old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/util/VideoRecorder.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/AgentLearnerCartpole.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/AgentLearnerCartpole.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/AgentLearnerCartpole.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/AgentLearnerCartpole.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentLearnerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/AgentTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentActorCriticHelperTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/NonRecurrentAdvantageActorCriticTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentActorCriticHelperTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/actorcritic/RecurrentAdvantageActorCriticTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/DoubleDQNTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/dqn/StandardDQNTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningHelperTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/NonRecurrentNStepQLearningTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningHelperTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/algorithm/nstepqlearning/RecurrentNStepQLearningTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/behavior/LearningBehaviorTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesBuilderTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesLabelsTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/FeaturesTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/GradientsTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/UpdateRuleTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncGradientsNeuralNetUpdaterTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncLabelsNeuralNetUpdaterTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/async/AsyncSharedNetworksUpdateHandlerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncGradientsNeuralNetUpdaterTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/agent/learning/update/updater/sync/SyncLabelsNeuralNetUpdaterTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/builder/BaseAgentLearnerBuilderTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/ReplayMemoryExperienceHandlerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/experience/StateActionExperienceHandlerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/helper/INDArrayHelperTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/HistoryProcessorTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncLearningTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadDiscreteTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/AdvantageActorCriticUpdateAlgorithmTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/listener/AsyncTrainingListenerListTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/QLearningUpdateAlgorithmTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/listener/TrainingListenerListTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/ExpReplayTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/StateActionRewardStateTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/SyncLearningTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearningConfigurationTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockDQN.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockStatEntry.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockStatEntry.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockStatEntry.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockStatEntry.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ActorCriticNetworkTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/BaseNetworkTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ChannelToNetworkInputMapperTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/CompoundNetworkHandlerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ComputationGraphHandlerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/NetworkHelperTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/QNetworkTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/ac/ActorCriticTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/dqn/DQNTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/TransformProcessTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/filter/UniformSkippingFilterTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/ArrayToINDArrayTransformTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/HistoryMergeTransformTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/SimpleNormalizationTransformTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/CircularFifoStoreTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/observation/transform/operation/historymerge/HistoryStackAssemblerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDQN.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockDataManager.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockHistoryProcessor.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockMDP.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockNeuralNet.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservation.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockObservationSpace.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockPolicy.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/support/MockRandom.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/AsyncTrainerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/trainer/SyncTrainerTest.java diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java similarity index 100% rename from rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java rename to .old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/util/DataManagerTrainingListenerTest.java diff --git a/rl4j/rl4j-doom/pom.xml b/.old/rl4j/rl4j-doom/pom.xml similarity index 100% rename from rl4j/rl4j-doom/pom.xml rename to .old/rl4j/rl4j-doom/pom.xml diff --git a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/Basic.java b/.old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/Basic.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/Basic.java rename to .old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/Basic.java diff --git a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/DeadlyCorridor.java b/.old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/DeadlyCorridor.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/DeadlyCorridor.java rename to .old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/DeadlyCorridor.java diff --git a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/PredictPosition.java b/.old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/PredictPosition.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/PredictPosition.java rename to .old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/PredictPosition.java diff --git a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/TakeCover.java b/.old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/TakeCover.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/TakeCover.java rename to .old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/TakeCover.java diff --git a/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java b/.old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java rename to .old/rl4j/rl4j-doom/src/main/java/org/deeplearning4j/rl4j/mdp/vizdoom/VizDoom.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/AutomapMode.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/AutomapMode.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/AutomapMode.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/AutomapMode.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/Button.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/Button.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/Button.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/Button.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/DoomGame.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/DoomGame.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/DoomGame.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/DoomGame.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/FileDoesNotExistException.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/FileDoesNotExistException.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/FileDoesNotExistException.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/FileDoesNotExistException.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/GameState.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/GameState.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/GameState.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/GameState.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/GameVariable.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/GameVariable.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/GameVariable.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/GameVariable.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/Label.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/Label.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/Label.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/Label.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/MessageQueueException.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/MessageQueueException.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/MessageQueueException.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/MessageQueueException.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/Mode.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/Mode.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/Mode.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/Mode.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/ScreenFormat.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/ScreenFormat.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/ScreenFormat.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/ScreenFormat.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/ScreenResolution.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/ScreenResolution.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/ScreenResolution.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/ScreenResolution.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/SharedMemoryException.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/SharedMemoryException.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/SharedMemoryException.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/SharedMemoryException.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/SignalException.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/SignalException.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/SignalException.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/SignalException.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomErrorException.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomErrorException.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomErrorException.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomErrorException.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomIsNotRunningException.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomIsNotRunningException.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomIsNotRunningException.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomIsNotRunningException.java diff --git a/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomUnexpectedExitException.java b/.old/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomUnexpectedExitException.java similarity index 100% rename from rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomUnexpectedExitException.java rename to .old/rl4j/rl4j-doom/src/main/java/vizdoom/ViZDoomUnexpectedExitException.java diff --git a/rl4j/rl4j-gym/pom.xml b/.old/rl4j/rl4j-gym/pom.xml similarity index 100% rename from rl4j/rl4j-gym/pom.xml rename to .old/rl4j/rl4j-gym/pom.xml diff --git a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/ActionTransformer.java b/.old/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/ActionTransformer.java similarity index 100% rename from rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/ActionTransformer.java rename to .old/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/ActionTransformer.java diff --git a/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java b/.old/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java similarity index 100% rename from rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java rename to .old/rl4j/rl4j-gym/src/main/java/org/deeplearning4j/rl4j/mdp/gym/GymEnv.java diff --git a/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java b/.old/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java similarity index 100% rename from rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java rename to .old/rl4j/rl4j-gym/src/test/java/org/deeplearning4j/rl4j/mdp/gym/GymEnvTest.java diff --git a/rl4j/rl4j-malmo/pom.xml b/.old/rl4j/rl4j-malmo/pom.xml similarity index 100% rename from rl4j/rl4j-malmo/pom.xml rename to .old/rl4j/rl4j-malmo/pom.xml diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpace.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpaceDiscrete.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpaceDiscrete.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpaceDiscrete.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoActionSpaceDiscrete.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoBox.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoConnectionError.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoConnectionError.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoConnectionError.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoConnectionError.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoDescretePositionPolicy.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoEnv.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationPolicy.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationPolicy.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationPolicy.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationPolicy.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpace.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpaceGrid.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePixels.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoObservationSpacePosition.java diff --git a/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoResetHandler.java b/.old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoResetHandler.java similarity index 100% rename from rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoResetHandler.java rename to .old/rl4j/rl4j-malmo/src/main/java/org/deeplearning4j/malmo/MalmoResetHandler.java diff --git a/tensorflow-processes.pbtxt b/.old/tensorflow-processes.pbtxt similarity index 100% rename from tensorflow-processes.pbtxt rename to .old/tensorflow-processes.pbtxt diff --git a/arbiter/.travis.yml b/arbiter/.travis.yml deleted file mode 100644 index 30638a6a9..000000000 --- a/arbiter/.travis.yml +++ /dev/null @@ -1,24 +0,0 @@ -branches: - only: - - master -notifications: - email: false -dist: trusty -sudo: false -cache: - directories: - - $HOME/.m2 -language: java -jdk: - - openjdk8 -matrix: - include: - - os: linux - env: OS=linux-x86_64 SCALA=2.10 - install: true - script: bash ./ci/build-linux-x86_64.sh - - os: linux - env: OS=linux-x86_64 SCALA=2.11 - install: true - script: bash ./ci/build-linux-x86_64.sh - diff --git a/arbiter/README.md b/arbiter/README.md deleted file mode 100644 index 67124f30a..000000000 --- a/arbiter/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# Arbiter - -A tool dedicated to tuning (hyperparameter optimization) of machine learning models. Part of the DL4J Suite of Machine Learning / Deep Learning tools for the enterprise. - - -## Modules -Arbiter contains the following modules: - -- arbiter-core: Defines the API and core functionality, and also contains functionality for the Arbiter UI -- arbiter-deeplearning4j: For hyperparameter optimization of DL4J models (MultiLayerNetwork and ComputationGraph networks) - - -## Hyperparameter Optimization Functionality - -The open-source version of Arbiter currently defines two methods of hyperparameter optimization: - -- Grid search -- Random search - -For optimization of complex models such as neural networks (those with more than a few hyperparameters), random search is superior to grid search, though Bayesian hyperparameter optimization schemes -For a comparison of random and grid search methods, see [Random Search for Hyper-parameter Optimization (Bergstra and Bengio, 2012)](http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf). - -### Core Concepts and Classes in Arbiter for Hyperparameter Optimization - -In order to conduct hyperparameter optimization in Arbiter, it is necessary for the user to understand and define the following: - -- **Parameter Space**: A ```ParameterSpace

      ``` specifies the type and allowable values of hyperparameters for a model configuration of type ```P```. For example, ```P``` could be a MultiLayerConfiguration for DL4J -- **Candidate Generator**: A ```CandidateGenerator``` is used to generate candidate models configurations of some type ```C```. The following implementations are defined in arbiter-core: - - ```RandomSearchCandidateGenerator``` - - ```GridSearchCandidateGenerator``` -- **Score Function**: A ```ScoreFunction``` is used to score a model of type ```M``` given data of type ```D```. For example, in DL4J a score function might be used to calculate the classification accuracy from a DataSetIterator - - A key concept here is that they score is a single numerical (double precision) value that we either want to minimize or maximize - this is the goal of hyperparameter optimization -- **Termination Conditions**: One or more ```TerminationCondition``` instances must be provided to the ```OptimizationConfiguration```. ```TerminationCondition``` instances are used to control when hyperparameter optimization should be stopped. Some built-in termination conditions: - - ```MaxCandidatesCondition```: Terminate if more than the specified number of candidate hyperparameter configurations have been executed - - ```MaxTimeCondition```: Terminate after a specified amount of time has elapsed since starting the optimization -- **Result Saver**: The ```ResultSaver``` interface is used to specify how the results of each hyperparameter optimization run should be saved. For example, whether saving should be done to local disk, to a database, to HDFS, or simply stored in memory. - - Note that ```ResultSaver.saveModel``` method returns a ```ResultReference``` object, which provides a mechanism for re-loading both the model and score from wherever it may be saved. -- **Optimization Configuration**: An ```OptimizationConfiguration``` ties together the above configuration options in a fluent (builder) pattern. -- **Candidate Executor**: The ```CandidateExecutor``` interface provides a layer of abstraction between the configuration and execution of each instance of learning. Currently, the only option is the ```LocalCandidateExecutor```, which is used to execute learning on a single machine (in the current JVM). In principle, other execution methods (for example, on Spark or cloud computing machines) could be implemented. -- **Optimization Runner**: The ```OptimizationRunner``` uses an ```OptimizationConfiguration``` and a ```CandidateExecutor``` to actually run the optimization, and save the results. - - -### Optimization of DeepLearning4J Models - -(This section: forthcoming) diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml deleted file mode 100644 index ab5ded1b8..000000000 --- a/arbiter/arbiter-core/pom.xml +++ /dev/null @@ -1,97 +0,0 @@ - - - - - arbiter - net.brutex.ai - 1.0.0-SNAPSHOT - - 4.0.0 - - arbiter-core - jar - - arbiter-core - - - - net.brutex.ai - nd4j-api - ${project.version} - - - com.google.code.findbugs - * - - - - - com.google.guava - guava - ${guava.jre.version} - - - org.apache.commons - commons-lang3 - ${commons.lang.version} - - - - org.apache.commons - commons-math3 - ${commons.math.version} - - - - org.slf4j - slf4j-api - ${slf4j.version} - - - - joda-time - joda-time - ${jodatime.version} - - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - - net.brutex.ai - deeplearning4j-common-tests - ${project.version} - test - - - com.fasterxml.jackson.datatype - jackson-datatype-joda - ${jackson.version} - - - net.brutex.ai - nd4j-native - ${project.version} - test - windows-x86_64 - - - diff --git a/arbiter/arbiter-core/src/assembly/bin.xml b/arbiter/arbiter-core/src/assembly/bin.xml deleted file mode 100644 index c99d6b144..000000000 --- a/arbiter/arbiter-core/src/assembly/bin.xml +++ /dev/null @@ -1,91 +0,0 @@ - - - - bin - - - tar.gz - - - - - - - lib - - *:jar:* - - - *:sources - - - - - - - - - readme.txt - - - - - src/main/resources/bin/ - bin - - arbiter - - unix - 0755 - - - - examples - examples - - - - - - - - target - ./ - - *.jar - - - - - - \ No newline at end of file diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/AbstractParameterSpace.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/AbstractParameterSpace.java deleted file mode 100644 index 4ff9dd964..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/AbstractParameterSpace.java +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api; - -import java.lang.reflect.Field; -import java.util.ArrayList; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -/** - * Created by Alex on 23/07/2017. - */ -public abstract class AbstractParameterSpace implements ParameterSpace { - - @Override - public Map getNestedSpaces() { - Map m = new LinkedHashMap<>(); - - //Need to manually build and walk the class heirarchy... - Class currClass = this.getClass(); - List> classHeirarchy = new ArrayList<>(); - while (currClass != Object.class) { - classHeirarchy.add(currClass); - currClass = currClass.getSuperclass(); - } - - for (int i = classHeirarchy.size() - 1; i >= 0; i--) { - //Use reflection here to avoid a mass of boilerplate code... - Field[] allFields = classHeirarchy.get(i).getDeclaredFields(); - - for (Field f : allFields) { - - String name = f.getName(); - Class fieldClass = f.getType(); - boolean isParamSpacefield = ParameterSpace.class.isAssignableFrom(fieldClass); - - if (!isParamSpacefield) { - continue; - } - - f.setAccessible(true); - - ParameterSpace p; - try { - p = (ParameterSpace) f.get(this); - } catch (IllegalAccessException e) { - throw new RuntimeException(e); - } - - if (p != null) { - m.put(name, p); - } - } - } - - return m; - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/Candidate.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/Candidate.java deleted file mode 100644 index 4f00d92e7..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/Candidate.java +++ /dev/null @@ -1,57 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api; - -import lombok.AllArgsConstructor; -import lombok.Data; -import org.deeplearning4j.arbiter.optimize.generator.util.SerializedSupplier; -import org.nd4j.common.function.Supplier; - -import java.io.Serializable; -import java.util.Map; - -/** - * Candidate: a proposed hyperparameter configuration. - * Also includes a map for data parameters, to configure things like data preprocessing, etc. - */ -@Data -@AllArgsConstructor -public class Candidate implements Serializable { - - private Supplier supplier; - private int index; - private double[] flatParameters; - private Map dataParameters; - private Exception exception; - - public Candidate(C value, int index, double[] flatParameters, Map dataParameters, Exception e) { - this(new SerializedSupplier(value), index, flatParameters, dataParameters, e); - } - - public Candidate(C value, int index, double[] flatParameters) { - this(new SerializedSupplier(value), index, flatParameters); - } - - public Candidate(Supplier value, int index, double[] flatParameters) { - this(value, index, flatParameters, null, null); - } - - public C getValue(){ - return supplier.get(); - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/CandidateGenerator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/CandidateGenerator.java deleted file mode 100644 index 3b070fd37..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/CandidateGenerator.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api; - -import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonSubTypes; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -/** - * A CandidateGenerator proposes candidates (i.e., hyperparameter configurations) for evaluation. - * This abstraction allows for different ways of generating the next configuration to test; for example, - * random search, grid search, Bayesian optimization methods, etc. - * - * @author Alex Black - */ -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface CandidateGenerator { - - /** - * Is this candidate generator able to generate more candidates? This will always return true in some - * cases, but some search strategies have a limit (grid search, for example) - */ - boolean hasMoreCandidates(); - - /** - * Generate a candidate hyperparameter configuration - */ - Candidate getCandidate(); - - /** - * Report results for the candidate generator. - * - * @param result The results to report - */ - void reportResults(OptimizationResult result); - - /** - * @return Get the parameter space for this candidate generator - */ - ParameterSpace getParameterSpace(); - - /** - * @param rngSeed Set the random number generator seed for the candidate generator - */ - void setRngSeed(long rngSeed); - - /** - * @return The type (class) of the generated candidates - */ - Class getCandidateType(); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/OptimizationResult.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/OptimizationResult.java deleted file mode 100644 index 8868b73ba..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/OptimizationResult.java +++ /dev/null @@ -1,60 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api; - -import lombok.Data; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import java.io.Serializable; - -/** - * An optimization result represents the results of an optimization run, including the canditate configuration, the - * trained model, the score for that model, and index of the model - * - * @author Alex Black - */ -@Data -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -@JsonIgnoreProperties({"resultReference"}) -public class OptimizationResult implements Serializable { - @JsonProperty - private Candidate candidate; - @JsonProperty - private Double score; - @JsonProperty - private int index; - @JsonProperty - private Object modelSpecificResults; - @JsonProperty - private CandidateInfo candidateInfo; - private ResultReference resultReference; - - - public OptimizationResult(Candidate candidate, Double score, int index, Object modelSpecificResults, - CandidateInfo candidateInfo, ResultReference resultReference) { - this.candidate = candidate; - this.score = score; - this.index = index; - this.modelSpecificResults = modelSpecificResults; - this.candidateInfo = candidateInfo; - this.resultReference = resultReference; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/ParameterSpace.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/ParameterSpace.java deleted file mode 100644 index 7a2dff8e7..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/ParameterSpace.java +++ /dev/null @@ -1,81 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api; - -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import java.util.List; -import java.util.Map; - -/** - * ParameterSpace: defines the acceptable ranges of values a given parameter may take. - * Note that parameter spaces can be simple (like {@code ParameterSpace}) or complicated, including - * multiple nested ParameterSpaces - * - * @author Alex Black - */ -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface ParameterSpace

      { - - /** - * Generate a candidate given a set of values. These values are then mapped to a specific candidate, using some - * mapping function (such as the prior probability distribution) - * - * @param parameterValues A set of values, each in the range [0,1], of length {@link #numParameters()} - */ - P getValue(double[] parameterValues); - - /** - * Get the total number of parameters (hyperparameters) to be optimized. This includes optional parameters from - * different parameter subpaces. (Thus, not every parameter may be used in every candidate) - * - * @return Number of hyperparameters to be optimized - */ - int numParameters(); - - /** - * Collect a list of parameters, recursively. Note that leaf parameters are parameters that do not have any - * nested parameter spaces - */ - List collectLeaves(); - - /** - * Get a list of nested parameter spaces by name. Note that the returned parameter spaces may in turn have further - * nested parameter spaces. The map should be empty for leaf parameter spaces - * - * @return A map of nested parameter spaces - */ - Map getNestedSpaces(); - - /** - * Is this ParameterSpace a leaf? (i.e., does it contain other ParameterSpaces internally?) - */ - @JsonIgnore - boolean isLeaf(); - - /** - * For leaf ParameterSpaces: set the indices of the leaf ParameterSpace. - * Expects input of length {@link #numParameters()}. Throws exception if {@link #isLeaf()} is false. - * - * @param indices Indices to set. Length should equal {@link #numParameters()} - */ - void setIndices(int... indices); - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/TaskCreator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/TaskCreator.java deleted file mode 100644 index c6e58905d..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/TaskCreator.java +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api; - -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; - -import java.util.List; -import java.util.Properties; -import java.util.concurrent.Callable; - -/** - * The TaskCreator is used to take a candidate configuration, data provider and score function, and create something - * that can be executed as a Callable - * - * @author Alex Black - */ -public interface TaskCreator { - - /** - * Generate a callable that can be executed to conduct the training of this model (given the model configuration) - * - * @param candidate Candidate (model) configuration to be trained - * @param dataProvider DataProvider, for the data - * @param scoreFunction Score function to be used to evaluate the model - * @param statusListeners Status listeners, that can be used for callbacks (to UI, for example) - * @return A callable that returns an OptimizationResult, once optimization is complete - */ - @Deprecated - Callable create(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction, - List statusListeners, IOptimizationRunner runner); - - /** - * Generate a callable that can be executed to conduct the training of this model (given the model configuration) - * - * @param candidate Candidate (model) configuration to be trained - * @param dataSource Data source - * @param dataSourceProperties Properties (may be null) for the data source - * @param scoreFunction Score function to be used to evaluate the model - * @param statusListeners Status listeners, that can be used for callbacks (to UI, for example) - * @return A callable that returns an OptimizationResult, once optimization is complete - */ - Callable create(Candidate candidate, Class dataSource, Properties dataSourceProperties, - ScoreFunction scoreFunction, List statusListeners, IOptimizationRunner runner); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/TaskCreatorProvider.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/TaskCreatorProvider.java deleted file mode 100644 index ea0a4f283..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/TaskCreatorProvider.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api; - -import java.util.HashMap; -import java.util.Map; - -public class TaskCreatorProvider { - - private static Map, Class> map = new HashMap<>(); - - public synchronized static TaskCreator defaultTaskCreatorFor(Class paramSpaceClass){ - Class c = map.get(paramSpaceClass); - try { - if(c == null){ - return null; - } - return c.newInstance(); - } catch (Exception e){ - throw new RuntimeException("Could not create new instance of task creator class: " + c + " - missing no-arg constructor?", e); - } - } - - public synchronized static void registerDefaultTaskCreatorClass(Class spaceClass, - Class creatorClass){ - map.put(spaceClass, creatorClass); - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/adapter/ParameterSpaceAdapter.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/adapter/ParameterSpaceAdapter.java deleted file mode 100644 index 56bd51d69..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/adapter/ParameterSpaceAdapter.java +++ /dev/null @@ -1,82 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.adapter; - -import lombok.AllArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * An abstract class used for adapting one type into another. Subclasses of this need to merely implement 2 simple methods - * - * @param Type to convert from - * @param Type to convert to - * @author Alex Black - */ -@AllArgsConstructor -public abstract class ParameterSpaceAdapter implements ParameterSpace { - - - protected abstract T convertValue(F from); - - protected abstract ParameterSpace underlying(); - - protected abstract String underlyingName(); - - - @Override - public T getValue(double[] parameterValues) { - return convertValue(underlying().getValue(parameterValues)); - } - - @Override - public int numParameters() { - return underlying().numParameters(); - } - - @Override - public List collectLeaves() { - ParameterSpace p = underlying(); - if(p.isLeaf()){ - return Collections.singletonList(p); - } - return underlying().collectLeaves(); - } - - @Override - public Map getNestedSpaces() { - return Collections.singletonMap(underlyingName(), (ParameterSpace)underlying()); - } - - @Override - public boolean isLeaf() { - return false; //Underlying may be a leaf, however - } - - @Override - public void setIndices(int... indices) { - underlying().setIndices(indices); - } - - @Override - public String toString() { - return underlying().toString(); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataProvider.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataProvider.java deleted file mode 100644 index 23918373f..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataProvider.java +++ /dev/null @@ -1,54 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.data; - -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import java.io.Serializable; -import java.util.Map; - -/** - * DataProvider interface abstracts out the providing of data - * @deprecated Use {@link DataSource} - */ -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -@Deprecated -public interface DataProvider extends Serializable { - - /** - * Get training data given some parameters for the data. - * Data parameters map is used to specify things like batch - * size data preprocessing - * - * @param dataParameters Parameters for data. May be null or empty for default data - * @return training data - */ - Object trainData(Map dataParameters); - - /** - * Get training data given some parameters for the data. Data parameters map is used to specify things like batch - * size data preprocessing - * - * @param dataParameters Parameters for data. May be null or empty for default data - * @return training data - */ - Object testData(Map dataParameters); - - Class getDataType(); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataSetIteratorFactoryProvider.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataSetIteratorFactoryProvider.java deleted file mode 100644 index 3766338a9..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataSetIteratorFactoryProvider.java +++ /dev/null @@ -1,89 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.data; - -import lombok.Data; -import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; - -import java.util.Map; - -/** - * This is a {@link DataProvider} for - * an {@link DataSetIteratorFactory} which - * based on a key of {@link DataSetIteratorFactoryProvider#FACTORY_KEY} - * will create {@link org.nd4j.linalg.dataset.api.iterator.DataSetIterator} - * for use with arbiter. - * - * This {@link DataProvider} is mainly meant for use for command line driven - * applications. - * - * @author Adam Gibson - */ -@Data -public class DataSetIteratorFactoryProvider implements DataProvider { - - public final static String FACTORY_KEY = "org.deeplearning4j.arbiter.data.data.factory"; - - /** - * Get training data given some parameters for the data. - * Data parameters map is used to specify things like batch - * size data preprocessing - * - * @param dataParameters Parameters for data. May be null or empty for default data - * @return training data - */ - @Override - public DataSetIteratorFactory trainData(Map dataParameters) { - return create(dataParameters); - } - - /** - * Get training data given some parameters for the data. Data parameters map - * is used to specify things like batch - * size data preprocessing - * - * @param dataParameters Parameters for data. May be null or empty for default data - * @return training data - */ - @Override - public DataSetIteratorFactory testData(Map dataParameters) { - return create(dataParameters); - } - - @Override - public Class getDataType() { - return DataSetIteratorFactory.class; - } - - private DataSetIteratorFactory create(Map dataParameters) { - if (dataParameters == null) - throw new IllegalArgumentException( - "Data parameters is null. Please specify a class name to create a dataset iterator."); - if (!dataParameters.containsKey(FACTORY_KEY)) - throw new IllegalArgumentException( - "No data set iterator factory class found. Please specify a class name with key " - + FACTORY_KEY); - String value = dataParameters.get(FACTORY_KEY).toString(); - try { - Class clazz = - (Class) Class.forName(value); - return clazz.newInstance(); - } catch (Exception e) { - throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataSource.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataSource.java deleted file mode 100644 index 0afe7bb70..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/data/DataSource.java +++ /dev/null @@ -1,57 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.data; - -import java.io.Serializable; -import java.util.Properties; - -/** - * DataSource: defines where the data should come from for training and testing. - * Note that implementations must have a no-argument contsructor - * - * @author Alex Black - */ -public interface DataSource extends Serializable { - - /** - * Configure the current data source with the specified properties - * Note: These properties are fixed for the training instance, and are optionally provided by the user - * at the configuration stage. - * The properties could be anything - and are usually specific to each DataSource implementation. - * For example, values such as batch size could be set using these properties - * @param properties Properties to apply to the data source instance - */ - void configure(Properties properties); - - /** - * Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator - */ - Object trainData(); - - /** - * Get test data to be used for the optimization. Usually a DataSetIterator or MultiDataSetIterator - */ - Object testData(); - - /** - * The type of data returned by {@link #trainData()} and {@link #testData()}. - * Usually DataSetIterator or MultiDataSetIterator - * @return Class of the objects returned by trainData and testData - */ - Class getDataType(); - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/evaluation/ModelEvaluator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/evaluation/ModelEvaluator.java deleted file mode 100644 index e5dd31d6e..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/evaluation/ModelEvaluator.java +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.evaluation; - -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; - -import java.io.Serializable; -import java.util.List; - -/** - * ModelEvaluator: Used to conduct additional evaluation. - * For example, this may be classification performance on a test set or similar - */ -public interface ModelEvaluator extends Serializable { - Object evaluateModel(Object model, DataProvider dataProvider); - - /** - * @return The model types supported by this class - */ - List> getSupportedModelTypes(); - - /** - * @return The datatypes supported by this class - */ - List> getSupportedDataTypes(); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/InMemoryResultSaver.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/InMemoryResultSaver.java deleted file mode 100644 index 43b914cb3..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/InMemoryResultSaver.java +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.saving; - -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; - -import java.io.IOException; -import java.util.Collections; -import java.util.List; - -/** - * A simple class to store optimization results in-memory. - * Not recommended for large (or a large number of) models. - */ -@NoArgsConstructor -public class InMemoryResultSaver implements ResultSaver { - @Override - public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException { - return new InMemoryResult(result, modelResult); - } - - @Override - public List> getSupportedCandidateTypes() { - return Collections.>singletonList(Object.class); - } - - @Override - public List> getSupportedModelTypes() { - return Collections.>singletonList(Object.class); - } - - @AllArgsConstructor - private static class InMemoryResult implements ResultReference { - private OptimizationResult result; - private Object modelResult; - - @Override - public OptimizationResult getResult() throws IOException { - return result; - } - - @Override - public Object getResultModel() throws IOException { - return modelResult; - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/ResultReference.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/ResultReference.java deleted file mode 100644 index 02e4ec453..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/ResultReference.java +++ /dev/null @@ -1,37 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.saving; - -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import java.io.IOException; - -/** - * Idea: We can't store all results in memory in general (might have thousands of candidates with millions of - * parameters each) - * So instead: return a reference to the saved result. Idea is that the result may be saved to disk or a database, - * and we can easily load it back into memory (if/when required) using the getResult() method - */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface ResultReference { - - OptimizationResult getResult() throws IOException; - - Object getResultModel() throws IOException; - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/ResultSaver.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/ResultSaver.java deleted file mode 100644 index 3506d536b..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/saving/ResultSaver.java +++ /dev/null @@ -1,57 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.saving; - -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import java.io.IOException; -import java.util.List; - -/** - * The ResultSaver interface provides a means of saving models in such a way that they can be loaded back into memory later, - * regardless of where/how they are saved. - * - * @author Alex Black - */ -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface ResultSaver { - - /** - * Save the model (including configuration and any additional evaluation/results) - * - * @param result Optimization result for the model to save - * @param modelResult Model result to save - * @return ResultReference, such that the result can be loaded back into memory - * @throws IOException If IO error occurs during model saving - */ - ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException; - - /** - * @return The candidate types supported by this class - */ - List> getSupportedCandidateTypes(); - - /** - * @return The model types supported by this class - */ - List> getSupportedModelTypes(); - - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/score/ScoreFunction.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/score/ScoreFunction.java deleted file mode 100644 index c6ad6ed29..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/score/ScoreFunction.java +++ /dev/null @@ -1,75 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.score; - -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import java.io.Serializable; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -/** - * ScoreFunction defines the objective of hyperparameter optimization. - * Specifically, it is used to calculate a score for a given model, relative to the data set provided - * in the configuration. - * - */ -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface ScoreFunction extends Serializable { - - /** - * Calculate and return the score, for the given model and data provider - * - * @param model Model to score - * @param dataProvider Data provider - data to use - * @param dataParameters Parameters for data - * @return Calculated score - */ - double score(Object model, DataProvider dataProvider, Map dataParameters); - - /** - * Calculate and return the score, for the given model and data provider - * - * @param model Model to score - * @param dataSource Data source - * @param dataSourceProperties data source properties - * @return Calculated score - */ - double score(Object model, Class dataSource, Properties dataSourceProperties); - - /** - * Should this score function be minimized or maximized? - * - * @return true if score should be minimized, false if score should be maximized - */ - boolean minimize(); - - /** - * @return The model types supported by this class - */ - List> getSupportedModelTypes(); - - /** - * @return The data types supported by this class - */ - List> getSupportedDataTypes(); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/MaxCandidatesCondition.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/MaxCandidatesCondition.java deleted file mode 100644 index 61b76dc90..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/MaxCandidatesCondition.java +++ /dev/null @@ -1,50 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.termination; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import com.fasterxml.jackson.annotation.JsonProperty; - -/** - * Terminate hyperparameter search when the number of candidates exceeds a specified value. - * Note that this is counted as number of completed candidates, plus number of failed candidates. - */ -@AllArgsConstructor -@NoArgsConstructor -@Data -public class MaxCandidatesCondition implements TerminationCondition { - @JsonProperty - private int maxCandidates; - - @Override - public void initialize(IOptimizationRunner optimizationRunner) { - //No op - } - - @Override - public boolean terminate(IOptimizationRunner optimizationRunner) { - return optimizationRunner.numCandidatesCompleted() + optimizationRunner.numCandidatesFailed() >= maxCandidates; - } - - @Override - public String toString() { - return "MaxCandidatesCondition(" + maxCandidates + ")"; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/MaxTimeCondition.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/MaxTimeCondition.java deleted file mode 100644 index c346c0ea5..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/MaxTimeCondition.java +++ /dev/null @@ -1,81 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.termination; - -import lombok.Data; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.joda.time.format.DateTimeFormat; -import org.joda.time.format.DateTimeFormatter; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.concurrent.TimeUnit; - -/** - * Terminate hyperparameter optimization after - * a fixed amount of time has passed - * @author Alex Black - */ -@NoArgsConstructor -@Data -public class MaxTimeCondition implements TerminationCondition { - private static final DateTimeFormatter formatter = DateTimeFormat.forPattern("dd-MMM HH:mm ZZ"); - - private long duration; - private TimeUnit timeUnit; - private long startTime; - private long endTime; - - - private MaxTimeCondition(@JsonProperty("duration") long duration, @JsonProperty("timeUnit") TimeUnit timeUnit, - @JsonProperty("startTime") long startTime, @JsonProperty("endTime") long endTime) { - this.duration = duration; - this.timeUnit = timeUnit; - this.startTime = startTime; - this.endTime = endTime; - } - - /** - * @param duration Duration of time - * @param timeUnit Unit that the duration is specified in - */ - public MaxTimeCondition(long duration, TimeUnit timeUnit) { - this.duration = duration; - this.timeUnit = timeUnit; - } - - @Override - public void initialize(IOptimizationRunner optimizationRunner) { - startTime = System.currentTimeMillis(); - this.endTime = startTime + timeUnit.toMillis(duration); - } - - @Override - public boolean terminate(IOptimizationRunner optimizationRunner) { - return System.currentTimeMillis() >= endTime; - } - - @Override - public String toString() { - if (startTime > 0) { - return "MaxTimeCondition(" + duration + "," + timeUnit + ",start=\"" + formatter.print(startTime) - + "\",end=\"" + formatter.print(endTime) + "\")"; - } else { - return "MaxTimeCondition(" + duration + "," + timeUnit + "\")"; - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/TerminationCondition.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/TerminationCondition.java deleted file mode 100644 index ec5e1982f..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/api/termination/TerminationCondition.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.api.termination; - - -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -/** - * Global termination condition for conducting hyperparameter optimization. - * Termination conditions are used to determine if/when the optimization should stop. - */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -@JsonInclude(JsonInclude.Include.NON_NULL) -public interface TerminationCondition { - - /** - * Initialize the termination condition (such as starting timers, etc). - */ - void initialize(IOptimizationRunner optimizationRunner); - - /** - * Determine whether optimization should be terminated - * - * @param optimizationRunner Optimization runner - * @return true if learning should be terminated, false otherwise - */ - boolean terminate(IOptimizationRunner optimizationRunner); - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/config/OptimizationConfiguration.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/config/OptimizationConfiguration.java deleted file mode 100644 index 59b3e9a6a..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/config/OptimizationConfiguration.java +++ /dev/null @@ -1,226 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.config; - -import lombok.*; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; - -import java.io.IOException; -import java.lang.reflect.Constructor; -import java.util.Arrays; -import java.util.List; -import java.util.Properties; - -/** - * OptimizationConfiguration ties together all of the various - * components (such as data, score functions, result saving etc) - * required to execute hyperparameter optimization. - * - * @author Alex Black - */ -@Data -@NoArgsConstructor -@EqualsAndHashCode(exclude = {"dataProvider", "terminationConditions", "candidateGenerator", "resultSaver"}) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public class OptimizationConfiguration { - @JsonSerialize - private DataProvider dataProvider; - @JsonSerialize - private Class dataSource; - @JsonSerialize - private Properties dataSourceProperties; - @JsonSerialize - private CandidateGenerator candidateGenerator; - @JsonSerialize - private ResultSaver resultSaver; - @JsonSerialize - private ScoreFunction scoreFunction; - @JsonSerialize - private List terminationConditions; - @JsonSerialize - private Long rngSeed; - - @Getter - @Setter - private long executionStartTime; - - - private OptimizationConfiguration(Builder builder) { - this.dataProvider = builder.dataProvider; - this.dataSource = builder.dataSource; - this.dataSourceProperties = builder.dataSourceProperties; - this.candidateGenerator = builder.candidateGenerator; - this.resultSaver = builder.resultSaver; - this.scoreFunction = builder.scoreFunction; - this.terminationConditions = builder.terminationConditions; - this.rngSeed = builder.rngSeed; - - if (rngSeed != null) - candidateGenerator.setRngSeed(rngSeed); - - //Validate the configuration: data types, score types, etc - //TODO - - //Validate that the dataSource has a no-arg constructor - if (dataSource != null) { - try { - dataSource.getConstructor(); - } catch (NoSuchMethodException e) { - throw new IllegalStateException("Data source class " + dataSource.getName() + " does not have a public no-argument constructor"); - } - } - } - - public static class Builder { - - private DataProvider dataProvider; - private Class dataSource; - private Properties dataSourceProperties; - private CandidateGenerator candidateGenerator; - private ResultSaver resultSaver; - private ScoreFunction scoreFunction; - private List terminationConditions; - private Long rngSeed; - - /** - * @deprecated Use {@link #dataSource(Class, Properties)} - */ - @Deprecated - public Builder dataProvider(DataProvider dataProvider) { - this.dataProvider = dataProvider; - return this; - } - - /** - * DataSource: defines where the data should come from for training and testing. - * Note that implementations must have a no-argument contsructor - * - * @param dataSource Class for the data source - * @param dataSourceProperties May be null. Properties for configuring the data source - */ - public Builder dataSource(Class dataSource, Properties dataSourceProperties) { - this.dataSource = dataSource; - this.dataSourceProperties = dataSourceProperties; - return this; - } - - public Builder candidateGenerator(CandidateGenerator candidateGenerator) { - this.candidateGenerator = candidateGenerator; - return this; - } - - public Builder modelSaver(ResultSaver resultSaver) { - this.resultSaver = resultSaver; - return this; - } - - public Builder scoreFunction(ScoreFunction scoreFunction) { - this.scoreFunction = scoreFunction; - return this; - } - - /** - * Termination conditions to use - * - * @param conditions - * @return - */ - public Builder terminationConditions(TerminationCondition... conditions) { - terminationConditions = Arrays.asList(conditions); - return this; - } - - public Builder terminationConditions(List terminationConditions) { - this.terminationConditions = terminationConditions; - return this; - } - - public Builder rngSeed(long rngSeed) { - this.rngSeed = rngSeed; - return this; - } - - public OptimizationConfiguration build() { - return new OptimizationConfiguration(this); - } - } - - - /** - * Create an optimization configuration from the json - * - * @param json the json to create the config from - * For type definitions - * @see OptimizationConfiguration - */ - public static OptimizationConfiguration fromYaml(String json) { - try { - return JsonMapper.getYamlMapper().readValue(json, OptimizationConfiguration.class); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - /** - * Create an optimization configuration from the json - * - * @param json the json to create the config from - * @see OptimizationConfiguration - */ - public static OptimizationConfiguration fromJson(String json) { - try { - return JsonMapper.getMapper().readValue(json, OptimizationConfiguration.class); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - /** - * Return a json configuration of this optimization configuration - * - * @return - */ - public String toJson() { - try { - return JsonMapper.getMapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Return a yaml configuration of this optimization configuration - * - * @return - */ - public String toYaml() { - try { - return JsonMapper.getYamlMapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/DegenerateIntegerDistribution.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/DegenerateIntegerDistribution.java deleted file mode 100644 index c613d08b6..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/DegenerateIntegerDistribution.java +++ /dev/null @@ -1,96 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.distribution; - -import org.apache.commons.math3.distribution.IntegerDistribution; -import org.apache.commons.math3.exception.NumberIsTooLargeException; -import org.apache.commons.math3.exception.OutOfRangeException; - -/** - * Degenerate distribution: i.e., integer "distribution" that is just a fixed value - */ -public class DegenerateIntegerDistribution implements IntegerDistribution { - private int value; - - public DegenerateIntegerDistribution(int value) { - this.value = value; - } - - - @Override - public double probability(int x) { - return (x == value ? 1.0 : 0.0); - } - - @Override - public double cumulativeProbability(int x) { - return (x >= value ? 1.0 : 0.0); - } - - @Override - public double cumulativeProbability(int x0, int x1) throws NumberIsTooLargeException { - return (value >= x0 && value <= x1 ? 1.0 : 0.0); - } - - @Override - public int inverseCumulativeProbability(double p) throws OutOfRangeException { - throw new UnsupportedOperationException(); - } - - @Override - public double getNumericalMean() { - return value; - } - - @Override - public double getNumericalVariance() { - return 0; - } - - @Override - public int getSupportLowerBound() { - return value; - } - - @Override - public int getSupportUpperBound() { - return value; - } - - @Override - public boolean isSupportConnected() { - return true; - } - - @Override - public void reseedRandomGenerator(long seed) { - //no op - } - - @Override - public int sample() { - return value; - } - - @Override - public int[] sample(int sampleSize) { - int[] out = new int[sampleSize]; - for (int i = 0; i < out.length; i++) - out[i] = value; - return out; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/DistributionUtils.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/DistributionUtils.java deleted file mode 100644 index 24dafc726..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/DistributionUtils.java +++ /dev/null @@ -1,149 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.distribution; - -import org.apache.commons.math3.distribution.*; - -/** - * Distribution utils for Apache Commons math distributions - which don't provide equals, hashcode, toString methods, - * don't implement serializable etc. - * Which makes unit testing etc quite difficult. - * - * @author Alex Black - */ -public class DistributionUtils { - - private DistributionUtils() {} - - - public static boolean distributionsEqual(RealDistribution a, RealDistribution b) { - if (a.getClass() != b.getClass()) - return false; - Class c = a.getClass(); - if (c == BetaDistribution.class) { - BetaDistribution ba = (BetaDistribution) a; - BetaDistribution bb = (BetaDistribution) b; - - return ba.getAlpha() == bb.getAlpha() && ba.getBeta() == bb.getBeta(); - } else if (c == CauchyDistribution.class) { - CauchyDistribution ca = (CauchyDistribution) a; - CauchyDistribution cb = (CauchyDistribution) b; - return ca.getMedian() == cb.getMedian() && ca.getScale() == cb.getScale(); - } else if (c == ChiSquaredDistribution.class) { - ChiSquaredDistribution ca = (ChiSquaredDistribution) a; - ChiSquaredDistribution cb = (ChiSquaredDistribution) b; - return ca.getDegreesOfFreedom() == cb.getDegreesOfFreedom(); - } else if (c == ExponentialDistribution.class) { - ExponentialDistribution ea = (ExponentialDistribution) a; - ExponentialDistribution eb = (ExponentialDistribution) b; - return ea.getMean() == eb.getMean(); - } else if (c == FDistribution.class) { - FDistribution fa = (FDistribution) a; - FDistribution fb = (FDistribution) b; - return fa.getNumeratorDegreesOfFreedom() == fb.getNumeratorDegreesOfFreedom() - && fa.getDenominatorDegreesOfFreedom() == fb.getDenominatorDegreesOfFreedom(); - } else if (c == GammaDistribution.class) { - GammaDistribution ga = (GammaDistribution) a; - GammaDistribution gb = (GammaDistribution) b; - return ga.getShape() == gb.getShape() && ga.getScale() == gb.getScale(); - } else if (c == LevyDistribution.class) { - LevyDistribution la = (LevyDistribution) a; - LevyDistribution lb = (LevyDistribution) b; - return la.getLocation() == lb.getLocation() && la.getScale() == lb.getScale(); - } else if (c == LogNormalDistribution.class) { - LogNormalDistribution la = (LogNormalDistribution) a; - LogNormalDistribution lb = (LogNormalDistribution) b; - return la.getScale() == lb.getScale() && la.getShape() == lb.getShape(); - } else if (c == NormalDistribution.class) { - NormalDistribution na = (NormalDistribution) a; - NormalDistribution nb = (NormalDistribution) b; - return na.getMean() == nb.getMean() && na.getStandardDeviation() == nb.getStandardDeviation(); - } else if (c == ParetoDistribution.class) { - ParetoDistribution pa = (ParetoDistribution) a; - ParetoDistribution pb = (ParetoDistribution) b; - return pa.getScale() == pb.getScale() && pa.getShape() == pb.getShape(); - } else if (c == TDistribution.class) { - TDistribution ta = (TDistribution) a; - TDistribution tb = (TDistribution) b; - return ta.getDegreesOfFreedom() == tb.getDegreesOfFreedom(); - } else if (c == TriangularDistribution.class) { - TriangularDistribution ta = (TriangularDistribution) a; - TriangularDistribution tb = (TriangularDistribution) b; - return ta.getSupportLowerBound() == tb.getSupportLowerBound() - && ta.getSupportUpperBound() == tb.getSupportUpperBound() && ta.getMode() == tb.getMode(); - } else if (c == UniformRealDistribution.class) { - UniformRealDistribution ua = (UniformRealDistribution) a; - UniformRealDistribution ub = (UniformRealDistribution) b; - return ua.getSupportLowerBound() == ub.getSupportLowerBound() - && ua.getSupportUpperBound() == ub.getSupportUpperBound(); - } else if (c == WeibullDistribution.class) { - WeibullDistribution wa = (WeibullDistribution) a; - WeibullDistribution wb = (WeibullDistribution) b; - return wa.getShape() == wb.getShape() && wa.getScale() == wb.getScale(); - } else if (c == LogUniformDistribution.class ){ - LogUniformDistribution lu_a = (LogUniformDistribution)a; - LogUniformDistribution lu_b = (LogUniformDistribution)b; - return lu_a.getMin() == lu_b.getMin() && lu_a.getMax() == lu_b.getMax(); - } else { - throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + c); - } - } - - public static boolean distributionEquals(IntegerDistribution a, IntegerDistribution b) { - if (a.getClass() != b.getClass()) - return false; - Class c = a.getClass(); - - if (c == BinomialDistribution.class) { - BinomialDistribution ba = (BinomialDistribution) a; - BinomialDistribution bb = (BinomialDistribution) b; - return ba.getNumberOfTrials() == bb.getNumberOfTrials() - && ba.getProbabilityOfSuccess() == bb.getProbabilityOfSuccess(); - } else if (c == GeometricDistribution.class) { - GeometricDistribution ga = (GeometricDistribution) a; - GeometricDistribution gb = (GeometricDistribution) b; - return ga.getProbabilityOfSuccess() == gb.getProbabilityOfSuccess(); - } else if (c == HypergeometricDistribution.class) { - HypergeometricDistribution ha = (HypergeometricDistribution) a; - HypergeometricDistribution hb = (HypergeometricDistribution) b; - return ha.getPopulationSize() == hb.getPopulationSize() - && ha.getNumberOfSuccesses() == hb.getNumberOfSuccesses() - && ha.getSampleSize() == hb.getSampleSize(); - } else if (c == PascalDistribution.class) { - PascalDistribution pa = (PascalDistribution) a; - PascalDistribution pb = (PascalDistribution) b; - return pa.getNumberOfSuccesses() == pb.getNumberOfSuccesses() - && pa.getProbabilityOfSuccess() == pb.getProbabilityOfSuccess(); - } else if (c == PoissonDistribution.class) { - PoissonDistribution pa = (PoissonDistribution) a; - PoissonDistribution pb = (PoissonDistribution) b; - return pa.getMean() == pb.getMean(); - } else if (c == UniformIntegerDistribution.class) { - UniformIntegerDistribution ua = (UniformIntegerDistribution) a; - UniformIntegerDistribution ub = (UniformIntegerDistribution) b; - return ua.getSupportUpperBound() == ub.getSupportUpperBound() - && ua.getSupportUpperBound() == ub.getSupportUpperBound(); - } else if (c == ZipfDistribution.class) { - ZipfDistribution za = (ZipfDistribution) a; - ZipfDistribution zb = (ZipfDistribution) b; - return za.getNumberOfElements() == zb.getNumberOfElements() && za.getExponent() == zb.getNumberOfElements(); - } else { - throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c); - } - - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java deleted file mode 100644 index da790c422..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/distribution/LogUniformDistribution.java +++ /dev/null @@ -1,155 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.distribution; - -import com.google.common.base.Preconditions; -import lombok.Getter; -import org.apache.commons.math3.distribution.RealDistribution; -import org.apache.commons.math3.exception.NumberIsTooLargeException; -import org.apache.commons.math3.exception.OutOfRangeException; - -import java.util.Random; - -/** - * Log uniform distribution, with support in range [min, max] for min > 0 - * - * Reference: https://www.vosesoftware.com/riskwiki/LogUniformdistribution.php - * - * @author Alex Black - */ -public class LogUniformDistribution implements RealDistribution { - - @Getter private final double min; - @Getter private final double max; - - private final double logMin; - private final double logMax; - - private transient Random rng = new Random(); - - /** - * - * @param min Minimum value - * @param max Maximum value - */ - public LogUniformDistribution(double min, double max) { - Preconditions.checkArgument(min > 0, "Minimum must be > 0. Got: " + min); - Preconditions.checkArgument(max > min, "Maximum must be > min. Got: (min, max)=(" - + min + "," + max + ")"); - this.min = min; - this.max = max; - - this.logMin = Math.log(min); - this.logMax = Math.log(max); - } - - @Override - public double probability(double x) { - if(x < min || x > max){ - return 0; - } - - return 1.0 / (x * (logMax - logMin)); - } - - @Override - public double density(double x) { - return probability(x); - } - - @Override - public double cumulativeProbability(double x) { - if(x <= min){ - return 0.0; - } else if(x >= max){ - return 1.0; - } - - return (Math.log(x)-logMin)/(logMax-logMin); - } - - @Override - public double cumulativeProbability(double x0, double x1) throws NumberIsTooLargeException { - return cumulativeProbability(x1) - cumulativeProbability(x0); - } - - @Override - public double inverseCumulativeProbability(double p) throws OutOfRangeException { - Preconditions.checkArgument(p >= 0 && p <= 1, "Invalid input: " + p); - return Math.exp(p * (logMax-logMin) + logMin); - } - - @Override - public double getNumericalMean() { - return (max-min)/(logMax-logMin); - } - - @Override - public double getNumericalVariance() { - double d1 = (logMax-logMin)*(max*max - min*min) - 2*(max-min)*(max-min); - return d1 / (2*Math.pow(logMax-logMin, 2.0)); - } - - @Override - public double getSupportLowerBound() { - return min; - } - - @Override - public double getSupportUpperBound() { - return max; - } - - @Override - public boolean isSupportLowerBoundInclusive() { - return true; - } - - @Override - public boolean isSupportUpperBoundInclusive() { - return true; - } - - @Override - public boolean isSupportConnected() { - return true; - } - - @Override - public void reseedRandomGenerator(long seed) { - rng.setSeed(seed); - } - - @Override - public double sample() { - return inverseCumulativeProbability(rng.nextDouble()); - } - - @Override - public double[] sample(int sampleSize) { - double[] d = new double[sampleSize]; - for( int i=0; i Type of candidates to generate - */ -@Data -@EqualsAndHashCode(exclude = {"rng", "candidateCounter"}) -public abstract class BaseCandidateGenerator implements CandidateGenerator { - protected ParameterSpace parameterSpace; - protected AtomicInteger candidateCounter = new AtomicInteger(0); - protected SynchronizedRandomGenerator rng = new SynchronizedRandomGenerator(new JDKRandomGenerator()); - protected Map dataParameters; - protected boolean initDone = false; - - public BaseCandidateGenerator(ParameterSpace parameterSpace, Map dataParameters, - boolean initDone) { - this.parameterSpace = parameterSpace; - this.dataParameters = dataParameters; - this.initDone = initDone; - } - - protected void initialize() { - if(!initDone) { - //First: collect leaf parameter spaces objects and remove duplicates - List noDuplicatesList = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves()); - - //Second: assign each a number - int i = 0; - for (ParameterSpace ps : noDuplicatesList) { - int np = ps.numParameters(); - if (np == 1) { - ps.setIndices(i++); - } else { - int[] values = new int[np]; - for (int j = 0; j < np; j++) - values[j] = i++; - ps.setIndices(values); - } - } - initDone = true; - } - } - - @Override - public ParameterSpace getParameterSpace() { - return parameterSpace; - } - - @Override - public void reportResults(OptimizationResult result) { - //No op - } - - @Override - public void setRngSeed(long rngSeed) { - rng.setSeed(rngSeed); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/GeneticSearchCandidateGenerator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/GeneticSearchCandidateGenerator.java deleted file mode 100644 index 564c194ba..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/GeneticSearchCandidateGenerator.java +++ /dev/null @@ -1,187 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator; - -import lombok.Getter; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; -import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.EmptyPopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator; - -import java.util.Map; - -/** - * Uses a genetic algorithm to generate candidates. - * - * @author Alexandre Boulanger - */ -@Slf4j -public class GeneticSearchCandidateGenerator extends BaseCandidateGenerator { - - @Getter - protected final PopulationModel populationModel; - - protected final ChromosomeFactory chromosomeFactory; - protected final SelectionOperator selectionOperator; - - protected boolean hasMoreCandidates = true; - - public static class Builder { - protected final ParameterSpace parameterSpace; - - protected Map dataParameters; - protected boolean initDone; - protected boolean minimizeScore; - protected PopulationModel populationModel; - protected ChromosomeFactory chromosomeFactory; - protected SelectionOperator selectionOperator; - - /** - * @param parameterSpace ParameterSpace from which to generate candidates - * @param scoreFunction The score function that will be used in the OptimizationConfiguration - */ - public Builder(ParameterSpace parameterSpace, ScoreFunction scoreFunction) { - this.parameterSpace = parameterSpace; - this.minimizeScore = scoreFunction.minimize(); - } - - /** - * @param populationModel The PopulationModel instance to use. - */ - public Builder populationModel(PopulationModel populationModel) { - this.populationModel = populationModel; - return this; - } - - /** - * @param selectionOperator The SelectionOperator to use. Default is GeneticSelectionOperator - */ - public Builder selectionOperator(SelectionOperator selectionOperator) { - this.selectionOperator = selectionOperator; - return this; - } - - public Builder dataParameters(Map dataParameters) { - - this.dataParameters = dataParameters; - return this; - } - - public GeneticSearchCandidateGenerator.Builder initDone(boolean initDone) { - this.initDone = initDone; - return this; - } - - /** - * @param chromosomeFactory The ChromosomeFactory to use - */ - public Builder chromosomeFactory(ChromosomeFactory chromosomeFactory) { - this.chromosomeFactory = chromosomeFactory; - return this; - } - - public GeneticSearchCandidateGenerator build() { - if (populationModel == null) { - PopulationInitializer defaultPopulationInitializer = new EmptyPopulationInitializer(); - populationModel = new PopulationModel.Builder().populationInitializer(defaultPopulationInitializer) - .build(); - } - - if (chromosomeFactory == null) { - chromosomeFactory = new ChromosomeFactory(); - } - - if (selectionOperator == null) { - selectionOperator = new GeneticSelectionOperator.Builder().build(); - } - - return new GeneticSearchCandidateGenerator(this); - } - } - - private GeneticSearchCandidateGenerator(Builder builder) { - super(builder.parameterSpace, builder.dataParameters, builder.initDone); - - initialize(); - - chromosomeFactory = builder.chromosomeFactory; - populationModel = builder.populationModel; - selectionOperator = builder.selectionOperator; - - chromosomeFactory.initializeInstance(builder.parameterSpace.numParameters()); - populationModel.initializeInstance(builder.minimizeScore); - selectionOperator.initializeInstance(populationModel, chromosomeFactory); - - } - - @Override - public boolean hasMoreCandidates() { - return hasMoreCandidates; - } - - @Override - public Candidate getCandidate() { - - double[] values = null; - Object value = null; - Exception e = null; - - try { - values = selectionOperator.buildNextGenes(); - value = parameterSpace.getValue(values); - } catch (GeneticGenerationException e2) { - log.warn("Error generating candidate", e2); - e = e2; - hasMoreCandidates = false; - } catch (Exception e2) { - log.warn("Error getting configuration for candidate", e2); - e = e2; - } - - return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e); - } - - @Override - public Class getCandidateType() { - return null; - } - - @Override - public String toString() { - return "GeneticSearchCandidateGenerator"; - } - - @Override - public void reportResults(OptimizationResult result) { - if (result.getScore() == null) { - return; - } - - Chromosome newChromosome = chromosomeFactory.createChromosome(result.getCandidate().getFlatParameters(), - result.getScore()); - populationModel.add(newChromosome); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/GridSearchCandidateGenerator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/GridSearchCandidateGenerator.java deleted file mode 100644 index 4d056087f..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/GridSearchCandidateGenerator.java +++ /dev/null @@ -1,232 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator; - -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.math3.random.RandomAdaptor; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.util.LeafUtils; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.*; -import java.util.concurrent.ConcurrentLinkedQueue; - - -/** - * GridSearchCandidateGenerator: generates candidates in an exhaustive grid search manner.
      - * Note that:
      - * - For discrete parameters: the grid size (# values to check per hyperparameter) is equal to the number of values for - * that hyperparameter
      - * - For integer parameters: the grid size is equal to {@code min(discretizationCount,max-min+1)}. Some integer ranges can - * be large, and we don't necessarily want to exhaustively search them. {@code discretizationCount} is a constructor argument
      - * - For continuous parameters: the grid size is equal to {@code discretizationCount}.
      - * In all cases, the minimum, maximum and gridSize-2 values between the min/max will be generated.
      - * Also note that: if a probability distribution is provided for continuous hyperparameters, this will be taken into account - * when generating candidates. This allows the grid for a hyperparameter to be non-linear: i.e., for example, linear in log space - * - * @author Alex Black - */ -@Slf4j -@EqualsAndHashCode(exclude = {"order"}, callSuper = true) -@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"}) -public class GridSearchCandidateGenerator extends BaseCandidateGenerator { - - /** - * In what order should candidates be generated?
      - * Sequential: generate candidates in order. The first hyperparameter will be changed most rapidly, and the last - * will be changed least rapidly.
      - * RandomOrder: generate candidates in a random order
      - * In both cases, the same candidates will be generated; only the order of generation is different - */ - public enum Mode { - Sequential, RandomOrder - } - - private final int discretizationCount; - private final Mode mode; - - private int[] numValuesPerParam; - @Getter - private int totalNumCandidates; - private Queue order; - - /** - * @param parameterSpace ParameterSpace from which to generate candidates - * @param discretizationCount For continuous parameters: into how many values should we discretize them into? - * For example, suppose continuous parameter is in range [0,1] with 3 bins: - * do [0.0, 0.5, 1.0]. Note that if all values - * @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order - * in which candidates should be generated. - */ - public GridSearchCandidateGenerator(@JsonProperty("parameterSpace") ParameterSpace parameterSpace, - @JsonProperty("discretizationCount") int discretizationCount, @JsonProperty("mode") Mode mode, - @JsonProperty("dataParameters") Map dataParameters, - @JsonProperty("initDone") boolean initDone) { - super(parameterSpace, dataParameters, initDone); - this.discretizationCount = discretizationCount; - this.mode = mode; - initialize(); - } - - /** - * @param parameterSpace ParameterSpace from which to generate candidates - * @param discretizationCount For continuous parameters: into how many values should we discretize them into? - * For example, suppose continuous parameter is in range [0,1] with 3 bins: - * do [0.0, 0.5, 1.0]. Note that if all values - * @param mode {@link GridSearchCandidateGenerator.Mode} specifies the order - * in which candidates should be generated. - */ - public GridSearchCandidateGenerator(ParameterSpace parameterSpace, int discretizationCount, Mode mode, - Map dataParameters){ - this(parameterSpace, discretizationCount, mode, dataParameters, false); - } - - @Override - protected void initialize() { - super.initialize(); - - List leaves = LeafUtils.getUniqueObjects(parameterSpace.collectLeaves()); - int nParams = leaves.size(); - - //Work out for each parameter: is it continuous or discrete? - // for grid search: discrete values are grid-searchable as-is - // continuous values: discretize using 'discretizationCount' bins - // integer values: use min(max-min+1, discretizationCount) values. i.e., discretize if necessary - numValuesPerParam = new int[nParams]; - long searchSize = 1; - for (int i = 0; i < nParams; i++) { - ParameterSpace ps = leaves.get(i); - if (ps instanceof DiscreteParameterSpace) { - DiscreteParameterSpace dps = (DiscreteParameterSpace) ps; - numValuesPerParam[i] = dps.numValues(); - } else if (ps instanceof IntegerParameterSpace) { - IntegerParameterSpace ips = (IntegerParameterSpace) ps; - int min = ips.getMin(); - int max = ips.getMax(); - //Discretize, as some integer ranges are much too large to search (i.e., num. neural network units, between 100 and 1000) - numValuesPerParam[i] = Math.min(max - min + 1, discretizationCount); - } else if (ps instanceof FixedValue){ - numValuesPerParam[i] = 1; - } else { - numValuesPerParam[i] = discretizationCount; - } - searchSize *= numValuesPerParam[i]; - } - - if (searchSize >= Integer.MAX_VALUE) - throw new IllegalStateException("Invalid search: cannot process search with " + searchSize - + " candidates > Integer.MAX_VALUE"); //TODO find a more reasonable upper bound? - - order = new ConcurrentLinkedQueue<>(); - - totalNumCandidates = (int) searchSize; - switch (mode) { - case Sequential: - for (int i = 0; i < totalNumCandidates; i++) { - order.add(i); - } - break; - case RandomOrder: - List tempList = new ArrayList<>(totalNumCandidates); - for (int i = 0; i < totalNumCandidates; i++) { - tempList.add(i); - } - - Collections.shuffle(tempList, new RandomAdaptor(rng)); - order.addAll(tempList); - break; - default: - throw new RuntimeException(); - } - - } - - @Override - public boolean hasMoreCandidates() { - return !order.isEmpty(); - } - - @Override - public Candidate getCandidate() { - int next = order.remove(); - - //Next: max integer (candidate number) to values - double[] values = indexToValues(numValuesPerParam, next, totalNumCandidates); - - Object value = null; - Exception e = null; - try { - value = parameterSpace.getValue(values); - } catch (Exception e2) { - log.warn("Error getting configuration for candidate", e2); - e = e2; - } - - return new Candidate(value, candidateCounter.getAndIncrement(), values, dataParameters, e); - } - - @Override - public Class getCandidateType() { - return null; - } - - public static double[] indexToValues(int[] numValuesPerParam, int candidateIdx, int product) { - //How? first map to index of num possible values. Then: to double values in range 0 to 1 - // 0-> [0,0,0], 1-> [1,0,0], 2-> [2,0,0], 3-> [0,1,0] etc - //Based on: Nd4j Shape.ind2sub - - int countNon1 = 0; - for( int i : numValuesPerParam) - if(i > 1) - countNon1++; - - int denom = product; - int num = candidateIdx; - int[] index = new int[numValuesPerParam.length]; - - for (int i = index.length - 1; i >= 0; i--) { - denom /= numValuesPerParam[i]; - index[i] = num / denom; - num %= denom; - } - - //Now: convert indexes to values in range [0,1] - //min value -> 0 - //max value -> 1 - double[] out = new double[countNon1]; - int outIdx = 0; - for (int i = 0; i < numValuesPerParam.length; i++) { - if (numValuesPerParam[i] > 1){ - out[outIdx++] = index[i] / ((double) (numValuesPerParam[i] - 1)); - } - } - - return out; - } - - @Override - public String toString() { - return "GridSearchCandidateGenerator(mode=" + mode + ")"; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/RandomSearchGenerator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/RandomSearchGenerator.java deleted file mode 100644 index 04b5c8da8..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/RandomSearchGenerator.java +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator; - -import lombok.EqualsAndHashCode; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.Map; - -/** - * RandomSearchGenerator: generates candidates at random.
      - * Note: if a probability distribution is provided for continuous hyperparameters, - * this will be taken into account - * when generating candidates. This allows the search to be weighted more towards - * certain values according to a probability - * density. For example: generate samples for learning rate according to log uniform distribution - * - * @author Alex Black - */ -@Slf4j -@EqualsAndHashCode(callSuper = true) -@JsonIgnoreProperties({"numValuesPerParam", "totalNumCandidates", "order", "candidateCounter", "rng", "candidate"}) -public class RandomSearchGenerator extends BaseCandidateGenerator { - - @JsonCreator - public RandomSearchGenerator(@JsonProperty("parameterSpace") ParameterSpace parameterSpace, - @JsonProperty("dataParameters") Map dataParameters, - @JsonProperty("initDone") boolean initDone) { - super(parameterSpace, dataParameters, initDone); - initialize(); - } - - public RandomSearchGenerator(ParameterSpace parameterSpace, Map dataParameters){ - this(parameterSpace, dataParameters, false); - } - - public RandomSearchGenerator(ParameterSpace parameterSpace){ - this(parameterSpace, null, false); - } - - - @Override - public boolean hasMoreCandidates() { - return true; - } - - @Override - public Candidate getCandidate() { - double[] randomValues = new double[parameterSpace.numParameters()]; - for (int i = 0; i < randomValues.length; i++) - randomValues[i] = rng.nextDouble(); - - Object value = null; - Exception e = null; - try { - value = parameterSpace.getValue(randomValues); - } catch (Exception e2) { - log.warn("Error getting configuration for candidate", e2); - e = e2; - } - - return new Candidate(value, candidateCounter.getAndIncrement(), randomValues, dataParameters, e); - } - - @Override - public Class getCandidateType() { - return null; - } - - @Override - public String toString() { - return "RandomSearchGenerator"; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/Chromosome.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/Chromosome.java deleted file mode 100644 index 5d8d00f0f..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/Chromosome.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic; - -import lombok.Data; - -/** - * Candidates are stored as Chromosome in the population model - * - * @author Alexandre Boulanger - */ -@Data -public class Chromosome { - /** - * The fitness score of the genes. - */ - protected final double fitness; - - /** - * The genes. - */ - protected final double[] genes; - - public Chromosome(double[] genes, double fitness) { - this.genes = genes; - this.fitness = fitness; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/ChromosomeFactory.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/ChromosomeFactory.java deleted file mode 100644 index ede86406a..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/ChromosomeFactory.java +++ /dev/null @@ -1,51 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic; - -/** - * A factory that builds new chromosomes. Used by the GeneticSearchCandidateGenerator. - * - * @author Alexandre Boulanger - */ -public class ChromosomeFactory { - private int chromosomeLength; - - /** - * Called by the GeneticSearchCandidateGenerator. - */ - public void initializeInstance(int chromosomeLength) { - this.chromosomeLength = chromosomeLength; - } - - /** - * Create a new instance of a Chromosome - * - * @param genes The genes - * @param fitness The fitness score - * @return A new instance of Chromosome - */ - public Chromosome createChromosome(double[] genes, double fitness) { - return new Chromosome(genes, fitness); - } - - /** - * @return The number of genes in a chromosome - */ - public int getChromosomeLength() { - return chromosomeLength; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/ArithmeticCrossover.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/ArithmeticCrossover.java deleted file mode 100644 index 978e7166b..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/ArithmeticCrossover.java +++ /dev/null @@ -1,120 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover; - -import org.apache.commons.math3.random.JDKRandomGenerator; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.random.SynchronizedRandomGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; -import org.nd4j.common.base.Preconditions; - -/** - * A crossover operator that linearly combines the genes of two parents.
      - * When a crossover is generated (with a of probability crossover rate), each genes is a linear combination of the corresponding genes of the parents. - *

      - * t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene. - * - * @author Alexandre Boulanger - */ -public class ArithmeticCrossover extends TwoParentsCrossoverOperator { - private static final double DEFAULT_CROSSOVER_RATE = 0.85; - - private final double crossoverRate; - private final RandomGenerator rng; - - public static class Builder { - private double crossoverRate = DEFAULT_CROSSOVER_RATE; - private RandomGenerator rng; - private TwoParentSelection parentSelection; - - /** - * The probability that the operator generates a crossover (default 0.85). - * - * @param rate A value between 0.0 and 1.0 - */ - public Builder crossoverRate(double rate) { - Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate); - - this.crossoverRate = rate; - return this; - } - - /** - * Use a supplied RandomGenerator - * - * @param rng An instance of RandomGenerator - */ - public Builder randomGenerator(RandomGenerator rng) { - this.rng = rng; - return this; - } - - /** - * The parent selection behavior. Default is random parent selection. - * - * @param parentSelection An instance of TwoParentSelection - */ - public Builder parentSelection(TwoParentSelection parentSelection) { - this.parentSelection = parentSelection; - return this; - } - - public ArithmeticCrossover build() { - if (rng == null) { - rng = new SynchronizedRandomGenerator(new JDKRandomGenerator()); - } - - if (parentSelection == null) { - parentSelection = new RandomTwoParentSelection(); - } - - return new ArithmeticCrossover(this); - } - } - - private ArithmeticCrossover(ArithmeticCrossover.Builder builder) { - super(builder.parentSelection); - - this.crossoverRate = builder.crossoverRate; - this.rng = builder.rng; - } - - /** - * Has a probability crossoverRate of performing the crossover where each gene is a linear combination of:
      - * t*parentA + (1-t)*parentB, where t is [0, 1] and different for each gene.
      - * Otherwise, returns the genes of a random parent. - * - * @return The crossover result. See {@link CrossoverResult}. - */ - @Override - public CrossoverResult crossover() { - double[][] parents = parentSelection.selectParents(); - - double[] offspringValues = new double[parents[0].length]; - - if (rng.nextDouble() < crossoverRate) { - for (int i = 0; i < offspringValues.length; ++i) { - double t = rng.nextDouble(); - offspringValues[i] = t * parents[0][i] + (1.0 - t) * parents[1][i]; - } - return new CrossoverResult(true, offspringValues); - } - - return new CrossoverResult(false, parents[0]); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/CrossoverOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/CrossoverOperator.java deleted file mode 100644 index cfae61e09..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/CrossoverOperator.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; - -/** - * Abstract class for all crossover operators - * - * @author Alexandre Boulanger - */ -public abstract class CrossoverOperator { - protected PopulationModel populationModel; - - /** - * Will be called by the selection operator once the population model is instantiated. - */ - public void initializeInstance(PopulationModel populationModel) { - this.populationModel = populationModel; - } - - /** - * Performs the crossover - * - * @return The crossover result. See {@link CrossoverResult}. - */ - public abstract CrossoverResult crossover(); - - - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/CrossoverResult.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/CrossoverResult.java deleted file mode 100644 index 68b7bdecb..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/CrossoverResult.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover; - -import lombok.Data; - -/** - * Returned by a crossover operator - * - * @author Alexandre Boulanger - */ -@Data -public class CrossoverResult { - /** - * If false, there was no crossover and the operator simply returned the genes of a random parent. - * If true, the genes are the result of a crossover. - */ - private final boolean isModified; - - /** - * The genes returned by the operator. - */ - private final double[] genes; - - public CrossoverResult(boolean isModified, double[] genes) { - this.isModified = isModified; - this.genes = genes; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/KPointCrossover.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/KPointCrossover.java deleted file mode 100644 index 8a7bb3a2a..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/KPointCrossover.java +++ /dev/null @@ -1,178 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover; - -import org.apache.commons.math3.random.JDKRandomGenerator; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.random.SynchronizedRandomGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator; -import org.nd4j.common.base.Preconditions; - -import java.util.Deque; - -/** -* The K-Point crossover will select at random multiple crossover points.
      -* Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched. -*/ -public class KPointCrossover extends TwoParentsCrossoverOperator { - private static final double DEFAULT_CROSSOVER_RATE = 0.85; - private static final int DEFAULT_MIN_CROSSOVER = 1; - private static final int DEFAULT_MAX_CROSSOVER = 4; - - private final double crossoverRate; - private final int minCrossovers; - private final int maxCrossovers; - - private final RandomGenerator rng; - - public static class Builder { - private double crossoverRate = DEFAULT_CROSSOVER_RATE; - private int minCrossovers = DEFAULT_MIN_CROSSOVER; - private int maxCrossovers = DEFAULT_MAX_CROSSOVER; - private RandomGenerator rng; - private TwoParentSelection parentSelection; - - /** - * The probability that the operator generates a crossover (default 0.85). - * - * @param rate A value between 0.0 and 1.0 - */ - public Builder crossoverRate(double rate) { - Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate); - - this.crossoverRate = rate; - return this; - } - - /** - * The number of crossovers points (default is min 1, max 4) - * - * @param min The minimum number - * @param max The maximum number - */ - public Builder numCrossovers(int min, int max) { - Preconditions.checkState(max >= 0 && min >= 0, "Min and max must be positive"); - Preconditions.checkState(max >= min, "Max must be greater or equal to min"); - - this.minCrossovers = min; - this.maxCrossovers = max; - return this; - } - - /** - * Use a fixed number of crossover points - * - * @param num The number of crossovers - */ - public Builder numCrossovers(int num) { - Preconditions.checkState(num >= 0, "Num must be positive"); - - this.minCrossovers = num; - this.maxCrossovers = num; - return this; - } - - /** - * Use a supplied RandomGenerator - * - * @param rng An instance of RandomGenerator - */ - public Builder randomGenerator(RandomGenerator rng) { - this.rng = rng; - return this; - } - - /** - * The parent selection behavior. Default is random parent selection. - * - * @param parentSelection An instance of TwoParentSelection - */ - public Builder parentSelection(TwoParentSelection parentSelection) { - this.parentSelection = parentSelection; - return this; - } - - public KPointCrossover build() { - if (rng == null) { - rng = new SynchronizedRandomGenerator(new JDKRandomGenerator()); - } - - if (parentSelection == null) { - parentSelection = new RandomTwoParentSelection(); - } - - return new KPointCrossover(this); - } - } - - private CrossoverPointsGenerator crossoverPointsGenerator; - - private KPointCrossover(KPointCrossover.Builder builder) { - super(builder.parentSelection); - - this.crossoverRate = builder.crossoverRate; - this.maxCrossovers = builder.maxCrossovers; - this.minCrossovers = builder.minCrossovers; - this.rng = builder.rng; - } - - private CrossoverPointsGenerator getCrossoverPointsGenerator(int chromosomeLength) { - if (crossoverPointsGenerator == null) { - crossoverPointsGenerator = - new CrossoverPointsGenerator(chromosomeLength, minCrossovers, maxCrossovers, rng); - } - - return crossoverPointsGenerator; - } - - /** - * Has a probability crossoverRate of performing the crossover where the operator will select at random multiple crossover points.
      - * Each gene comes from one of the two parents. Each time a crossover point is reached, the parent is switched.
      - * Otherwise, returns the genes of a random parent. - * - * @return The crossover result. See {@link CrossoverResult}. - */ - @Override - public CrossoverResult crossover() { - double[][] parents = parentSelection.selectParents(); - - boolean isModified = false; - double[] resultGenes = parents[0]; - - if (rng.nextDouble() < crossoverRate) { - // Select crossover points - Deque crossoverPoints = getCrossoverPointsGenerator(parents[0].length).getCrossoverPoints(); - - // Crossover - resultGenes = new double[parents[0].length]; - int currentParent = 0; - int nextCrossover = crossoverPoints.pop(); - for (int i = 0; i < resultGenes.length; ++i) { - if (i == nextCrossover) { - currentParent = currentParent == 0 ? 1 : 0; - nextCrossover = crossoverPoints.pop(); - } - resultGenes[i] = parents[currentParent][i]; - } - isModified = true; - } - - return new CrossoverResult(isModified, resultGenes); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/SinglePointCrossover.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/SinglePointCrossover.java deleted file mode 100644 index cbeca1232..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/SinglePointCrossover.java +++ /dev/null @@ -1,123 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover; - -import org.apache.commons.math3.random.JDKRandomGenerator; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.random.SynchronizedRandomGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; -import org.nd4j.common.base.Preconditions; - -/** - * The single point crossover will select a random point where every genes before that point comes from one parent - * and after which every genes comes from the other parent. - * - * @author Alexandre Boulanger - */ -public class SinglePointCrossover extends TwoParentsCrossoverOperator { - private static final double DEFAULT_CROSSOVER_RATE = 0.85; - - private final RandomGenerator rng; - private final double crossoverRate; - - public static class Builder { - private double crossoverRate = DEFAULT_CROSSOVER_RATE; - private RandomGenerator rng; - private TwoParentSelection parentSelection; - - /** - * The probability that the operator generates a crossover (default 0.85). - * - * @param rate A value between 0.0 and 1.0 - */ - public Builder crossoverRate(double rate) { - Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate); - - this.crossoverRate = rate; - return this; - } - - /** - * Use a supplied RandomGenerator - * - * @param rng An instance of RandomGenerator - */ - public Builder randomGenerator(RandomGenerator rng) { - this.rng = rng; - return this; - } - - /** - * The parent selection behavior. Default is random parent selection. - * - * @param parentSelection An instance of TwoParentSelection - */ - public Builder parentSelection(TwoParentSelection parentSelection) { - this.parentSelection = parentSelection; - return this; - } - - public SinglePointCrossover build() { - if (rng == null) { - rng = new SynchronizedRandomGenerator(new JDKRandomGenerator()); - } - - if (parentSelection == null) { - parentSelection = new RandomTwoParentSelection(); - } - - return new SinglePointCrossover(this); - } - } - - private SinglePointCrossover(SinglePointCrossover.Builder builder) { - super(builder.parentSelection); - - this.crossoverRate = builder.crossoverRate; - this.rng = builder.rng; - } - - /** - * Has a probability crossoverRate of performing the crossover where the operator will select a random crossover point.
      - * Each gene before this point comes from one of the two parents and each gene at or after this point comes from the other parent. - * Otherwise, returns the genes of a random parent. - * - * @return The crossover result. See {@link CrossoverResult}. - */ - public CrossoverResult crossover() { - double[][] parents = parentSelection.selectParents(); - - boolean isModified = false; - double[] resultGenes = parents[0]; - - if (rng.nextDouble() < crossoverRate) { - int chromosomeLength = parents[0].length; - - // Crossover - resultGenes = new double[chromosomeLength]; - - int crossoverPoint = rng.nextInt(chromosomeLength); - for (int i = 0; i < resultGenes.length; ++i) { - resultGenes[i] = ((i < crossoverPoint) ? parents[0] : parents[1])[i]; - } - isModified = true; - } - - return new CrossoverResult(isModified, resultGenes); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/TwoParentsCrossoverOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/TwoParentsCrossoverOperator.java deleted file mode 100644 index 69f1fb105..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/TwoParentsCrossoverOperator.java +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; - -/** - * Abstract class for all crossover operators that applies to two parents. - * - * @author Alexandre Boulanger - */ -public abstract class TwoParentsCrossoverOperator extends CrossoverOperator { - - protected final TwoParentSelection parentSelection; - - /** - * @param parentSelection A parent selection that selects two parents. - */ - protected TwoParentsCrossoverOperator(TwoParentSelection parentSelection) { - this.parentSelection = parentSelection; - } - - /** - * Will be called by the selection operator once the population model is instantiated. - */ - @Override - public void initializeInstance(PopulationModel populationModel) { - super.initializeInstance(populationModel); - parentSelection.initializeInstance(populationModel.getPopulation()); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/UniformCrossover.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/UniformCrossover.java deleted file mode 100644 index 8912a1298..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/UniformCrossover.java +++ /dev/null @@ -1,136 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover; - -import org.apache.commons.math3.random.JDKRandomGenerator; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.random.SynchronizedRandomGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; -import org.nd4j.common.base.Preconditions; - -/** - * The uniform crossover will, for each gene, randomly select the parent that donates the gene. - * - * @author Alexandre Boulanger - */ -public class UniformCrossover extends TwoParentsCrossoverOperator { - private static final double DEFAULT_CROSSOVER_RATE = 0.85; - private static final double DEFAULT_PARENT_BIAS_FACTOR = 0.5; - - private final double crossoverRate; - private final double parentBiasFactor; - private final RandomGenerator rng; - - public static class Builder { - private double crossoverRate = DEFAULT_CROSSOVER_RATE; - private double parentBiasFactor = DEFAULT_PARENT_BIAS_FACTOR; - private RandomGenerator rng; - private TwoParentSelection parentSelection; - - /** - * The probability that the operator generates a crossover (default 0.85). - * - * @param rate A value between 0.0 and 1.0 - */ - public Builder crossoverRate(double rate) { - Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate); - - this.crossoverRate = rate; - return this; - } - - /** - * A factor that will introduce a bias in the parent selection.
      - * - * @param factor In the range [0, 1]. 0 will only select the first parent while 1 only select the second one. The default is 0.5; no bias. - */ - public Builder parentBiasFactor(double factor) { - Preconditions.checkState(factor >= 0.0 && factor <= 1.0, "Factor must be between 0.0 and 1.0, got %s", - factor); - - this.parentBiasFactor = factor; - return this; - } - - /** - * Use a supplied RandomGenerator - * - * @param rng An instance of RandomGenerator - */ - public Builder randomGenerator(RandomGenerator rng) { - this.rng = rng; - return this; - } - - /** - * The parent selection behavior. Default is random parent selection. - * - * @param parentSelection An instance of TwoParentSelection - */ - public Builder parentSelection(TwoParentSelection parentSelection) { - this.parentSelection = parentSelection; - return this; - } - - public UniformCrossover build() { - if (rng == null) { - rng = new SynchronizedRandomGenerator(new JDKRandomGenerator()); - } - if (parentSelection == null) { - parentSelection = new RandomTwoParentSelection(); - } - return new UniformCrossover(this); - } - } - - private UniformCrossover(UniformCrossover.Builder builder) { - super(builder.parentSelection); - - this.crossoverRate = builder.crossoverRate; - this.parentBiasFactor = builder.parentBiasFactor; - this.rng = builder.rng; - } - - /** - * Has a probability crossoverRate of performing the crossover where the operator will select randomly which parent donates the gene.
      - * One of the parent may be favored if the bias is different than 0.5 - * Otherwise, returns the genes of a random parent. - * - * @return The crossover result. See {@link CrossoverResult}. - */ - @Override - public CrossoverResult crossover() { - // select the parents - double[][] parents = parentSelection.selectParents(); - - double[] resultGenes = parents[0]; - boolean isModified = false; - - if (rng.nextDouble() < crossoverRate) { - // Crossover - resultGenes = new double[parents[0].length]; - - for (int i = 0; i < resultGenes.length; ++i) { - resultGenes[i] = ((rng.nextDouble() < parentBiasFactor) ? parents[0] : parents[1])[i]; - } - isModified = true; - } - - return new CrossoverResult(isModified, resultGenes); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/ParentSelection.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/ParentSelection.java deleted file mode 100644 index 4fa9ed17c..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/ParentSelection.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; - -import java.util.List; - -/** - * Abstract class for all parent selection behaviors - * - * @author Alexandre Boulanger - */ -public abstract class ParentSelection { - protected List population; - - /** - * Will be called by the crossover operator once the population model is instantiated. - */ - public void initializeInstance(List population) { - this.population = population; - } - - /** - * Performs the parent selection - * - * @return An array of parents genes. The outer array are the parents, and the inner array are the genes. - */ - public abstract double[][] selectParents(); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/RandomTwoParentSelection.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/RandomTwoParentSelection.java deleted file mode 100644 index 81baeb07c..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/RandomTwoParentSelection.java +++ /dev/null @@ -1,65 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection; - -import org.apache.commons.math3.random.JDKRandomGenerator; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.random.SynchronizedRandomGenerator; - -/** - * A parent selection behavior that returns two random parents. - * - * @author Alexandre Boulanger - */ -public class RandomTwoParentSelection extends TwoParentSelection { - - private final RandomGenerator rng; - - public RandomTwoParentSelection() { - this(new SynchronizedRandomGenerator(new JDKRandomGenerator())); - } - - /** - * Use a supplied RandomGenerator - * - * @param rng An instance of RandomGenerator - */ - public RandomTwoParentSelection(RandomGenerator rng) { - this.rng = rng; - } - - /** - * Selects two random parents - * - * @return An array of parents genes. The outer array are the parents, and the inner array are the genes. - */ - @Override - public double[][] selectParents() { - double[][] parents = new double[2][]; - - int parent1Idx = rng.nextInt(population.size()); - int parent2Idx; - do { - parent2Idx = rng.nextInt(population.size()); - } while (parent1Idx == parent2Idx); - - parents[0] = population.get(parent1Idx).getGenes(); - parents[1] = population.get(parent2Idx).getGenes(); - - return parents; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/TwoParentSelection.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/TwoParentSelection.java deleted file mode 100644 index b4b4f4843..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/parentselection/TwoParentSelection.java +++ /dev/null @@ -1,25 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection; - -/** - * Abstract class for all parent selection behaviors that selects two parents. - * - * @author Alexandre Boulanger - */ -public abstract class TwoParentSelection extends ParentSelection { -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/utils/CrossoverPointsGenerator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/utils/CrossoverPointsGenerator.java deleted file mode 100644 index 7e6e799e7..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/crossover/utils/CrossoverPointsGenerator.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils; - -import org.apache.commons.math3.random.RandomGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover; - -import java.util.*; - -/** - * A helper class used by {@link KPointCrossover} to generate the crossover points - * - * @author Alexandre Boulanger - */ -public class CrossoverPointsGenerator { - private final int minCrossovers; - private final int maxCrossovers; - private final RandomGenerator rng; - private List parameterIndexes; - - /** - * Constructor - * - * @param chromosomeLength The number of genes - * @param minCrossovers The minimum number of crossover points to generate - * @param maxCrossovers The maximum number of crossover points to generate - * @param rng A RandomGenerator instance - */ - public CrossoverPointsGenerator(int chromosomeLength, int minCrossovers, int maxCrossovers, RandomGenerator rng) { - this.minCrossovers = minCrossovers; - this.maxCrossovers = maxCrossovers; - this.rng = rng; - parameterIndexes = new ArrayList(); - for (int i = 0; i < chromosomeLength; ++i) { - parameterIndexes.add(i); - } - } - - /** - * Generate a list of crossover points. - * - * @return An ordered list of crossover point indexes and with Integer.MAX_VALUE as the last element - */ - public Deque getCrossoverPoints() { - Collections.shuffle(parameterIndexes); - List crossoverPointLists = - parameterIndexes.subList(0, rng.nextInt(maxCrossovers - minCrossovers) + minCrossovers); - Collections.sort(crossoverPointLists); - Deque crossoverPoints = new ArrayDeque(crossoverPointLists); - crossoverPoints.add(Integer.MAX_VALUE); - - return crossoverPoints; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/CullOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/CullOperator.java deleted file mode 100644 index 95452a7eb..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/CullOperator.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.culling; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; - -/** - * The cull operator will remove from the population the least desirables chromosomes. - * - * @author Alexandre Boulanger - */ -public interface CullOperator { - /** - * Will be called by the population model once created. - */ - void initializeInstance(PopulationModel populationModel); - - /** - * Cull the population to the culled size. - */ - void cullPopulation(); - - /** - * @return The target population size after culling. - */ - int getCulledSize(); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/LeastFitCullOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/LeastFitCullOperator.java deleted file mode 100644 index 6ec5c64df..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/LeastFitCullOperator.java +++ /dev/null @@ -1,50 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.culling; - -/** - * An elitist cull operator that discards the chromosomes with the worst fitness while keeping the best ones. - * - * @author Alexandre Boulanger - */ -public class LeastFitCullOperator extends RatioCullOperator { - - /** - * The default cull ratio is 1/3. - */ - public LeastFitCullOperator() { - super(); - } - - /** - * @param cullRatio The ratio of the maximum population size to be culled.
      - * For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20. - */ - public LeastFitCullOperator(double cullRatio) { - super(cullRatio); - } - - /** - * Will discard the chromosomes with the worst fitness until the population size fall back at the culled size. - */ - @Override - public void cullPopulation() { - while (population.size() > culledSize) { - population.remove(population.size() - 1); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/RatioCullOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/RatioCullOperator.java deleted file mode 100644 index 9c838acc8..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/culling/RatioCullOperator.java +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.culling; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.nd4j.common.base.Preconditions; - -import java.util.List; - -/** - * An abstract base for cull operators that culls back the population to a ratio of its maximum size. - * - * @author Alexandre Boulanger - */ -public abstract class RatioCullOperator implements CullOperator { - private static final double DEFAULT_CULL_RATIO = 1.0 / 3.0; - protected int culledSize; - protected List population; - protected final double cullRatio; - - /** - * @param cullRatio The ratio of the maximum population size to be culled.
      - * For example, a ratio of 1/3 on a population with a maximum size of 30 will cull back a given population to 20. - */ - public RatioCullOperator(double cullRatio) { - Preconditions.checkState(cullRatio >= 0.0 && cullRatio <= 1.0, "Cull ratio must be between 0.0 and 1.0, got %s", - cullRatio); - - this.cullRatio = cullRatio; - } - - /** - * The default cull ratio is 1/3 - */ - public RatioCullOperator() { - this(DEFAULT_CULL_RATIO); - } - - /** - * Will be called by the population model once created. - */ - public void initializeInstance(PopulationModel populationModel) { - this.population = populationModel.getPopulation(); - culledSize = (int) (populationModel.getPopulationSize() * (1.0 - cullRatio) + 0.5); - } - - /** - * @return The target population size after culling. - */ - @Override - public int getCulledSize() { - return culledSize; - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/exceptions/GeneticGenerationException.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/exceptions/GeneticGenerationException.java deleted file mode 100644 index b0a9a42b3..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/exceptions/GeneticGenerationException.java +++ /dev/null @@ -1,23 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions; - -public class GeneticGenerationException extends RuntimeException { - public GeneticGenerationException(String message) { - super(message); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/mutation/MutationOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/mutation/MutationOperator.java deleted file mode 100644 index 56f459a73..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/mutation/MutationOperator.java +++ /dev/null @@ -1,33 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation; - -/** - * The mutation operator will apply a mutation to the given genes. - * - * @author Alexandre Boulanger - */ -public interface MutationOperator { - - /** - * Performs a mutation. - * - * @param genes The genes to be mutated - * @return True if the genes were mutated, otherwise false. - */ - boolean mutate(double[] genes); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/mutation/RandomMutationOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/mutation/RandomMutationOperator.java deleted file mode 100644 index ba10676b6..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/mutation/RandomMutationOperator.java +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.mutation; - -import org.apache.commons.math3.random.JDKRandomGenerator; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.random.SynchronizedRandomGenerator; -import org.nd4j.common.base.Preconditions; - -/** - * A mutation operator where each gene has a chance of being mutated with a mutation rate probability. - * - * @author Alexandre Boulanger - */ -public class RandomMutationOperator implements MutationOperator { - private static final double DEFAULT_MUTATION_RATE = 0.005; - - private final double mutationRate; - private final RandomGenerator rng; - - public static class Builder { - private double mutationRate = DEFAULT_MUTATION_RATE; - private RandomGenerator rng; - - /** - * Each gene will have this probability of being mutated. - * - * @param rate The mutation rate. (default 0.005) - */ - public Builder mutationRate(double rate) { - Preconditions.checkState(rate >= 0.0 && rate <= 1.0, "Rate must be between 0.0 and 1.0, got %s", rate); - - this.mutationRate = rate; - return this; - } - - /** - * Use a supplied RandomGenerator - * - * @param rng An instance of RandomGenerator - */ - public Builder randomGenerator(RandomGenerator rng) { - this.rng = rng; - return this; - } - - public RandomMutationOperator build() { - if (rng == null) { - rng = new SynchronizedRandomGenerator(new JDKRandomGenerator()); - } - return new RandomMutationOperator(this); - } - } - - private RandomMutationOperator(RandomMutationOperator.Builder builder) { - this.mutationRate = builder.mutationRate; - this.rng = builder.rng; - } - - /** - * Performs the mutation. Each gene has a mutation rate probability of being mutated. - * - * @param genes The genes to be mutated - * @return True if the genes were mutated, otherwise false. - */ - @Override - public boolean mutate(double[] genes) { - boolean hasMutation = false; - - for (int i = 0; i < genes.length; ++i) { - if (rng.nextDouble() < mutationRate) { - genes[i] = rng.nextDouble(); - hasMutation = true; - } - } - - return hasMutation; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/EmptyPopulationInitializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/EmptyPopulationInitializer.java deleted file mode 100644 index 20c147385..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/EmptyPopulationInitializer.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.population; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; - -import java.util.ArrayList; -import java.util.List; - -/** - * A population initializer that build an empty population. - * - * @author Alexandre Boulanger - */ -public class EmptyPopulationInitializer implements PopulationInitializer { - - /** - * Initialize an empty population - * - * @param size The maximum size of the population. - * @return The initialized population. - */ - @Override - public List getInitializedPopulation(int size) { - return new ArrayList<>(size); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationInitializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationInitializer.java deleted file mode 100644 index 40dd4f438..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationInitializer.java +++ /dev/null @@ -1,36 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.population; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; - -import java.util.List; - -/** - * An initializer that construct the population used by the population model. - * - * @author Alexandre Boulanger - */ -public interface PopulationInitializer { - /** - * Called by the population model to construct the population - * - * @param size The maximum size of the population - * @return An initialized population - */ - List getInitializedPopulation(int size); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationListener.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationListener.java deleted file mode 100644 index aca266b57..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationListener.java +++ /dev/null @@ -1,35 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.population; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; - -import java.util.List; - -/** - * A listener that is called when the population changes. - * - * @author Alexandre Boulanger - */ -public interface PopulationListener { - /** - * Called after the population has changed. - * - * @param population The population after it has changed. - */ - void onChanged(List population); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationModel.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationModel.java deleted file mode 100644 index 9c5a4c7e1..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/population/PopulationModel.java +++ /dev/null @@ -1,182 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.population; - -import lombok.Getter; -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; - -/** - * The population model handles all aspects of the population (initialization, additions and culling) - * - * @author Alexandre Boulanger - */ -public class PopulationModel { - private static final int DEFAULT_POPULATION_SIZE = 30; - - private final CullOperator cullOperator; - private final List populationListeners = new ArrayList<>(); - private Comparator chromosomeComparator; - - /** - * The maximum population size - */ - @Getter - private final int populationSize; - - /** - * The population - */ - @Getter - public final List population; - - /** - * A comparator used when higher fitness value is better - */ - public static class MaximizeScoreComparator implements Comparator { - @Override - public int compare(Chromosome lhs, Chromosome rhs) { - return -Double.compare(lhs.getFitness(), rhs.getFitness()); - } - } - - /** - * A comparator used when lower fitness value is better - */ - public static class MinimizeScoreComparator implements Comparator { - @Override - public int compare(Chromosome lhs, Chromosome rhs) { - return Double.compare(lhs.getFitness(), rhs.getFitness()); - } - } - - public static class Builder { - private int populationSize = DEFAULT_POPULATION_SIZE; - private PopulationInitializer populationInitializer; - private CullOperator cullOperator; - - /** - * Use an alternate population initialization behavior. Default is empty population. - * - * @param populationInitializer An instance of PopulationInitializer - */ - public Builder populationInitializer(PopulationInitializer populationInitializer) { - this.populationInitializer = populationInitializer; - return this; - } - - /** - * The maximum population size.
      - * If using a ratio based culling, using a population with culled size of around 1.5 to 2 times the number of genes generally gives good results. - * (e.g. For a chromosome having 10 genes, the culled size should be between 15 and 20. And with a cull ratio of 1/3 we should set the population size to 23 to 30. (15 / (1 - 1/3)), rounded up) - * - * @param size The maximum size of the population - */ - public Builder populationSize(int size) { - populationSize = size; - return this; - } - - /** - * Use an alternate cull operator behavior. Default is least fit culling. - * - * @param cullOperator An instance of a CullOperator - */ - public Builder cullOperator(CullOperator cullOperator) { - this.cullOperator = cullOperator; - return this; - } - - public PopulationModel build() { - if (cullOperator == null) { - cullOperator = new LeastFitCullOperator(); - } - - if (populationInitializer == null) { - populationInitializer = new EmptyPopulationInitializer(); - } - - return new PopulationModel(this); - } - - } - - public PopulationModel(PopulationModel.Builder builder) { - populationSize = builder.populationSize; - population = new ArrayList<>(builder.populationSize); - PopulationInitializer populationInitializer = builder.populationInitializer; - - List initializedPopulation = populationInitializer.getInitializedPopulation(populationSize); - population.clear(); - population.addAll(initializedPopulation); - - cullOperator = builder.cullOperator; - cullOperator.initializeInstance(this); - } - - /** - * Called by the GeneticSearchCandidateGenerator - */ - public void initializeInstance(boolean minimizeScore) { - chromosomeComparator = minimizeScore ? new MinimizeScoreComparator() : new MaximizeScoreComparator(); - } - - /** - * Add a PopulationListener to the list of change listeners - * @param listener A PopulationListener instance - */ - public void addListener(PopulationListener listener) { - populationListeners.add(listener); - } - - /** - * Add a Chromosome to the population and call the PopulationListeners. Culling may be triggered. - * - * @param element The chromosome to be added - */ - public void add(Chromosome element) { - if (population.size() == populationSize) { - cullOperator.cullPopulation(); - } - - population.add(element); - - Collections.sort(population, chromosomeComparator); - - triggerPopulationChangedListeners(population); - } - - /** - * @return Return false when the population is below the culled size, otherwise true.
      - * Used by the selection operator to know if the population is still too small and should generate random genes. - */ - public boolean isReadyToBreed() { - return population.size() >= cullOperator.getCulledSize(); - } - - private void triggerPopulationChangedListeners(List population) { - for (PopulationListener listener : populationListeners) { - listener.onChanged(population); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/selection/GeneticSelectionOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/selection/GeneticSelectionOperator.java deleted file mode 100644 index 40b6a49c8..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/selection/GeneticSelectionOperator.java +++ /dev/null @@ -1,197 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.selection; - -import org.apache.commons.math3.random.JDKRandomGenerator; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.random.SynchronizedRandomGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover; -import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException; -import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; - -import java.util.Arrays; - -/** - * A selection operator that will generate random genes initially. Once the population has reached the culled size, - * will start to generate offsprings of parents selected in the population. - * - * @author Alexandre Boulanger - */ -public class GeneticSelectionOperator extends SelectionOperator { - - private final static int PREVIOUS_GENES_TO_KEEP = 100; - private final static int MAX_NUM_GENERATION_ATTEMPTS = 1024; - - private final CrossoverOperator crossoverOperator; - private final MutationOperator mutationOperator; - private final RandomGenerator rng; - private double[][] previousGenes = new double[PREVIOUS_GENES_TO_KEEP][]; - private int previousGenesIdx = 0; - - public static class Builder { - private ChromosomeFactory chromosomeFactory; - private PopulationModel populationModel; - private CrossoverOperator crossoverOperator; - private MutationOperator mutationOperator; - private RandomGenerator rng; - - /** - * Use an alternate crossover behavior. Default is SinglePointCrossover. - * - * @param crossoverOperator An instance of CrossoverOperator - */ - public Builder crossoverOperator(CrossoverOperator crossoverOperator) { - this.crossoverOperator = crossoverOperator; - return this; - } - - /** - * Use an alternate mutation behavior. Default is RandomMutationOperator. - * - * @param mutationOperator An instance of MutationOperator - */ - public Builder mutationOperator(MutationOperator mutationOperator) { - this.mutationOperator = mutationOperator; - return this; - } - - /** - * Use a supplied RandomGenerator - * - * @param rng An instance of RandomGenerator - */ - public Builder randomGenerator(RandomGenerator rng) { - this.rng = rng; - return this; - } - - public GeneticSelectionOperator build() { - if (crossoverOperator == null) { - crossoverOperator = new SinglePointCrossover.Builder().build(); - } - - if (mutationOperator == null) { - mutationOperator = new RandomMutationOperator.Builder().build(); - } - - if (rng == null) { - rng = new SynchronizedRandomGenerator(new JDKRandomGenerator()); - } - - return new GeneticSelectionOperator(crossoverOperator, mutationOperator, rng); - } - } - - private GeneticSelectionOperator(CrossoverOperator crossoverOperator, MutationOperator mutationOperator, - RandomGenerator rng) { - this.crossoverOperator = crossoverOperator; - this.mutationOperator = mutationOperator; - this.rng = rng; - } - - /** - * Called by GeneticSearchCandidateGenerator - */ - @Override - public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) { - super.initializeInstance(populationModel, chromosomeFactory); - crossoverOperator.initializeInstance(populationModel); - } - - /** - * Build a new set of genes. Has two distinct modes of operation - *

        - *
      • Before the population has reached the culled size: will return a random set of genes.
      • - *
      • After: Parents will be selected among the population, a crossover will be applied followed by a mutation.
      • - *
      - * @return Returns the generated set of genes - * @throws GeneticGenerationException If buildNextGenes() can't generate a set that has not already been tried, - * or if the crossover and the mutation operators can't generate a set, - * this exception is thrown. - */ - @Override - public double[] buildNextGenes() { - double[] result; - - boolean hasAlreadyBeenTried; - int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS; - do { - if (populationModel.isReadyToBreed()) { - result = buildOffspring(); - } else { - result = buildRandomGenes(); - } - - hasAlreadyBeenTried = hasAlreadyBeenTried(result); - if (hasAlreadyBeenTried && --attemptsRemaining == 0) { - throw new GeneticGenerationException("Failed to generate a set of genes not already tried."); - } - } while (hasAlreadyBeenTried); - - previousGenes[previousGenesIdx] = result; - previousGenesIdx = ++previousGenesIdx % previousGenes.length; - - return result; - } - - private boolean hasAlreadyBeenTried(double[] genes) { - for (int i = 0; i < previousGenes.length; ++i) { - double[] current = previousGenes[i]; - if (current != null && Arrays.equals(current, genes)) { - return true; - } - } - - return false; - } - - private double[] buildOffspring() { - double[] offspringValues; - - boolean isModified; - int attemptsRemaining = MAX_NUM_GENERATION_ATTEMPTS; - do { - CrossoverResult crossoverResult = crossoverOperator.crossover(); - offspringValues = crossoverResult.getGenes(); - isModified = crossoverResult.isModified(); - isModified |= mutationOperator.mutate(offspringValues); - - if (!isModified && --attemptsRemaining == 0) { - throw new GeneticGenerationException( - String.format("Crossover and mutation operators failed to generate a new set of genes after %s attempts.", - MAX_NUM_GENERATION_ATTEMPTS)); - } - } while (!isModified); - - return offspringValues; - } - - private double[] buildRandomGenes() { - double[] randomValues = new double[chromosomeFactory.getChromosomeLength()]; - for (int i = 0; i < randomValues.length; ++i) { - randomValues[i] = rng.nextDouble(); - } - - return randomValues; - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/selection/SelectionOperator.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/selection/SelectionOperator.java deleted file mode 100644 index 7be470ea6..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/genetic/selection/SelectionOperator.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.genetic.selection; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; - -/** - * An abstract class for all selection operators. Used by the GeneticSearchCandidateGenerator to generate new candidates. - * - * @author Alexandre Boulanger - */ -public abstract class SelectionOperator { - protected PopulationModel populationModel; - protected ChromosomeFactory chromosomeFactory; - - /** - * Called by GeneticSearchCandidateGenerator - */ - public void initializeInstance(PopulationModel populationModel, ChromosomeFactory chromosomeFactory) { - - this.populationModel = populationModel; - this.chromosomeFactory = chromosomeFactory; - } - - /** - * Generate a new set of genes. - */ - public abstract double[] buildNextGenes(); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/util/SerializedSupplier.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/util/SerializedSupplier.java deleted file mode 100644 index 81109816d..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/generator/util/SerializedSupplier.java +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.generator.util; - -import org.nd4j.common.function.Supplier; - -import java.io.*; - -public class SerializedSupplier implements Serializable, Supplier { - - private byte[] asBytes; - - public SerializedSupplier(T obj){ - try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ - oos.writeObject(obj); - oos.flush(); - oos.close(); - asBytes = baos.toByteArray(); - } catch (Exception e){ - throw new RuntimeException("Error serializing object - must be serializable",e); - } - } - - @Override - public T get() { - try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(asBytes))){ - return (T)ois.readObject(); - } catch (Exception e){ - throw new RuntimeException("Error deserializing object",e); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/BooleanSpace.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/BooleanSpace.java deleted file mode 100644 index fd20afb47..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/BooleanSpace.java +++ /dev/null @@ -1,76 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter; - -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * BooleanParameterSpace is a {@code ParameterSpace}; Defines {True, False} as a parameter space - * If argument to setValue is less than or equal to 0.5 it will return True else False - * - * @author susaneraly - */ -@EqualsAndHashCode -public class BooleanSpace implements ParameterSpace { - private int index = -1; - - @Override - public Boolean getValue(double[] input) { - if (index == -1) { - throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set"); - } - if (input[index] <= 0.5) return Boolean.TRUE; - else return Boolean.FALSE; - } - - @Override - public int numParameters() { - return 1; - } - - @Override - public List collectLeaves() { - return Collections.singletonList((ParameterSpace) this); - } - - @Override - public Map getNestedSpaces() { - return Collections.emptyMap(); - } - - @Override - public boolean isLeaf() { - return true; - } - - @Override - public void setIndices(int... indices) { - if (indices == null || indices.length != 1) - throw new IllegalArgumentException("Invalid index"); - this.index = indices[0]; - } - - @Override - public String toString() { - return "BooleanSpace()"; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java deleted file mode 100644 index b22f77a52..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/FixedValue.java +++ /dev/null @@ -1,90 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter; - -import lombok.EqualsAndHashCode; -import lombok.Getter; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueDeserializer; -import org.deeplearning4j.arbiter.optimize.serde.jackson.FixedValueSerializer; -import org.deeplearning4j.arbiter.util.ObjectUtils; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * FixedValue is a ParameterSpace that defines only a single fixed value - * - * @param Type of (fixed) value - */ -@EqualsAndHashCode -@JsonSerialize(using = FixedValueSerializer.class) -@JsonDeserialize(using = FixedValueDeserializer.class) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public class FixedValue implements ParameterSpace { - @Getter - private Object value; - private int index; - - @JsonCreator - public FixedValue(@JsonProperty("value") T value) { - this.value = value; - } - - @Override - public String toString() { - return "FixedValue(" + ObjectUtils.valueToString(value) + ")"; - } - - @Override - public T getValue(double[] input) { - return (T) value; - } - - @Override - public int numParameters() { - return 0; - } - - @Override - public List collectLeaves() { - return Collections.emptyList(); - } - - @Override - public Map getNestedSpaces() { - return Collections.emptyMap(); - } - - @Override - public boolean isLeaf() { - return true; - } - - @Override - public void setIndices(int... indices) { - if (indices != null && indices.length != 0) - throw new IllegalArgumentException( - "Invalid call: FixedValue ParameterSpace " + "should not be given an index (0 params)"); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java deleted file mode 100644 index c8f139ebb..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/continuous/ContinuousParameterSpace.java +++ /dev/null @@ -1,137 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter.continuous; - -import org.apache.commons.math3.distribution.RealDistribution; -import org.apache.commons.math3.distribution.UniformRealDistribution; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils; -import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionDeserializer; -import org.deeplearning4j.arbiter.optimize.serde.jackson.RealDistributionSerializer; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * ContinuousParametSpace is a {@code ParameterSpace} that (optionally) takes an Apache Commons - * {@link RealDistribution} when used for random sampling (such as in a RandomSearchCandidateGenerator) - * - * @author Alex Black - */ -public class ContinuousParameterSpace implements ParameterSpace { - - //Need to use custom serializers/deserializers for commons RealDistribution instances - @JsonSerialize(using = RealDistributionSerializer.class) - @JsonDeserialize(using = RealDistributionDeserializer.class) - private RealDistribution distribution; - private int index = -1; - - /** - * ContinuousParameterSpace with uniform distribution between the minimum and maximum values - * - * @param min Minimum value that can be generated - * @param max Maximum value that can be generated - */ - public ContinuousParameterSpace(double min, double max) { - this(new UniformRealDistribution(min, max)); - } - - /** - * ConditiousParameterSpcae wiht a specified probability distribution. The provided distribution defines the min/max - * values, and (for random search, etc) will be used when generating random values - * - * @param distribution Distribution to sample from - */ - public ContinuousParameterSpace(@JsonProperty("distribution") RealDistribution distribution) { - this.distribution = distribution; - } - - - @Override - public Double getValue(double[] input) { - if (index == -1) { - throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set"); - } - return distribution.inverseCumulativeProbability(input[index]); - } - - @Override - public int numParameters() { - return 1; - } - - @Override - public List collectLeaves() { - return Collections.singletonList((ParameterSpace) this); - } - - @Override - public Map getNestedSpaces() { - return Collections.emptyMap(); - } - - @Override - public boolean isLeaf() { - return true; - } - - @Override - public void setIndices(int... indices) { - if (indices == null || indices.length != 1) { - throw new IllegalArgumentException("Invalid index"); - } - this.index = indices[0]; - } - - - @Override - public String toString() { - if (distribution instanceof UniformRealDistribution) { - return "ContinuousParameterSpace(min=" + distribution.getSupportLowerBound() + ",max=" - + distribution.getSupportUpperBound() + ")"; - } else { - return "ContinuousParameterSpace(" + distribution + ")"; - } - } - - public boolean equals(Object o) { - if (o == this) - return true; - if (!(o instanceof ContinuousParameterSpace)) - return false; - final ContinuousParameterSpace other = (ContinuousParameterSpace) o; - if (distribution == null ? other.distribution != null - : !DistributionUtils.distributionsEqual(distribution, other.distribution)) - return false; - if (this.index != other.index) - return false; - return true; - } - - public int hashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + (distribution == null ? 43 : distribution.getClass().hashCode()); - result = result * PRIME + this.index; - return result; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/discrete/DiscreteParameterSpace.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/discrete/DiscreteParameterSpace.java deleted file mode 100644 index 3c70aaa03..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/discrete/DiscreteParameterSpace.java +++ /dev/null @@ -1,113 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter.discrete; - -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.util.ObjectUtils; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; - -import java.util.*; - -/** - * A DiscreteParameterSpace is used for a set of un-ordered values - * - * @param

      Parameter type - * @author Alex Black - */ -@EqualsAndHashCode -public class DiscreteParameterSpace

      implements ParameterSpace

      { - - @JsonSerialize - private List

      values; - private int index = -1; - - public DiscreteParameterSpace(@JsonProperty("values") P... values) { - if (values != null) - this.values = Arrays.asList(values); - } - - public DiscreteParameterSpace(Collection

      values) { - this.values = new ArrayList<>(values); - } - - public int numValues() { - return values.size(); - } - - @Override - public P getValue(double[] input) { - if (index == -1) { - throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set"); - } - if (values == null) - throw new IllegalStateException("Values are null."); - //Map a value in range [0,1] to one of the list of values - //First value: [0,width], second: (width,2*width], third: (3*width,4*width] etc - int size = values.size(); - if (size == 1) - return values.get(0); - double width = 1.0 / size; - int val = (int) (input[index] / width); - return values.get(Math.min(val, size - 1)); - } - - @Override - public int numParameters() { - return 1; - } - - @Override - public List collectLeaves() { - return Collections.singletonList((ParameterSpace) this); - } - - @Override - public Map getNestedSpaces() { - return Collections.emptyMap(); - } - - @Override - public boolean isLeaf() { - return true; - } - - @Override - public void setIndices(int... indices) { - if (indices == null || indices.length != 1) { - throw new IllegalArgumentException("Invalid index"); - } - this.index = indices[0]; - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - sb.append("DiscreteParameterSpace("); - int n = values.size(); - for (int i = 0; i < n; i++) { - P value = values.get(i); - sb.append(ObjectUtils.valueToString(value)); - sb.append((i == n - 1 ? ")" : ",")); - } - return sb.toString(); - } - - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java deleted file mode 100644 index d76381244..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/integer/IntegerParameterSpace.java +++ /dev/null @@ -1,151 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter.integer; - -import lombok.NoArgsConstructor; -import org.apache.commons.math3.distribution.IntegerDistribution; -import org.apache.commons.math3.distribution.UniformIntegerDistribution; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.distribution.DistributionUtils; -import org.deeplearning4j.arbiter.optimize.serde.jackson.IntegerDistributionDeserializer; -import org.deeplearning4j.arbiter.optimize.serde.jackson.IntegerDistributionSerializer; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * IntegerParameterSpace is a {@code ParameterSpace}; i.e., defines an ordered space of integers between - * some minimum and maximum value - * - * @author Alex Black - */ -@JsonIgnoreProperties({"min", "max"}) -@NoArgsConstructor -public class IntegerParameterSpace implements ParameterSpace { - - @JsonSerialize(using = IntegerDistributionSerializer.class) - @JsonDeserialize(using = IntegerDistributionDeserializer.class) - private IntegerDistribution distribution; - private int index = -1; - - /** - * Create an IntegerParameterSpace with a uniform distribution between the specified min/max (inclusive) - * - * @param min Min value, inclusive - * @param max Max value, inclusive - */ - public IntegerParameterSpace(int min, int max) { - this(new UniformIntegerDistribution(min, max)); - } - - /** - * Crate an IntegerParametSpace from the given IntegerDistribution - * - * @param distribution Distribution to use - */ - @JsonCreator - public IntegerParameterSpace(@JsonProperty("distribution") IntegerDistribution distribution) { - this.distribution = distribution; - } - - public int getMin() { - return distribution.getSupportLowerBound(); - } - - public int getMax() { - return distribution.getSupportUpperBound(); - } - - @Override - public Integer getValue(double[] input) { - if (index == -1) { - throw new IllegalStateException("Cannot get value: ParameterSpace index has not been set"); - } - return distribution.inverseCumulativeProbability(input[index]); - } - - @Override - public int numParameters() { - return 1; - } - - @Override - public List collectLeaves() { - return Collections.singletonList((ParameterSpace) this); - } - - @Override - public Map getNestedSpaces() { - return Collections.emptyMap(); - } - - @Override - public boolean isLeaf() { - return true; - } - - @Override - public void setIndices(int... indices) { - if (indices == null || indices.length != 1) - throw new IllegalArgumentException("Invalid index"); - this.index = indices[0]; - } - - @Override - public String toString() { - if (distribution instanceof UniformIntegerDistribution) { - return "IntegerParameterSpace(min=" + distribution.getSupportLowerBound() + ",max=" - + distribution.getSupportUpperBound() + ")"; - } else { - return "IntegerParameterSpace(" + distribution + ")"; - } - } - - public boolean equals(Object o) { - if (o == this) - return true; - if (!(o instanceof IntegerParameterSpace)) - return false; - final IntegerParameterSpace other = (IntegerParameterSpace) o; - if (!other.canEqual(this)) - return false; - if (distribution == null ? other.distribution != null - : !DistributionUtils.distributionEquals(distribution, other.distribution)) - return false; - if (this.index != other.index) - return false; - return true; - } - - public int hashCode() { - final int PRIME = 59; - int result = 1; - result = result * PRIME + (distribution == null ? 43 : distribution.getClass().hashCode()); - result = result * PRIME + this.index; - return result; - } - - protected boolean canEqual(Object other) { - return other instanceof IntegerParameterSpace; - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/MathOp.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/MathOp.java deleted file mode 100644 index 2d567536f..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/MathOp.java +++ /dev/null @@ -1,69 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter.math; - -import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; - -import java.util.List; - -/** - * A simple parameter space that implements scalar mathematical operations on another parameter space. This allows you - * to do things like Y = X * 2, where X is a parameter space. For example, a layer size hyperparameter could be set - * using this to 2x the size of the previous layer - * - * @param Type of the parameter space - * @author Alex Black - */ -public class MathOp extends AbstractParameterSpace { - - private ParameterSpace parameterSpace; - private Op op; - private T scalar; - - public MathOp(ParameterSpace parameterSpace, Op op, T scalar){ - this.parameterSpace = parameterSpace; - this.op = op; - this.scalar = scalar; - } - - @Override - public T getValue(double[] parameterValues) { - T u = parameterSpace.getValue(parameterValues); - return op.doOp(u, scalar); - } - - @Override - public int numParameters() { - return parameterSpace.numParameters(); - } - - @Override - public List collectLeaves() { - return parameterSpace.collectLeaves(); - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - parameterSpace.setIndices(indices); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/Op.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/Op.java deleted file mode 100644 index 2102804ce..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/Op.java +++ /dev/null @@ -1,76 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter.math; - -public enum Op { - ADD, SUB, MUL, DIV; - - - //Package private - T doOp(T first, T second){ - if(first instanceof Integer || first instanceof Long){ - long result; - switch (this){ - case ADD: - result = Long.valueOf(first.longValue() + second.longValue()); - break; - case SUB: - result = Long.valueOf(first.longValue() - second.longValue()); - break; - case MUL: - result = Long.valueOf(first.longValue() * second.longValue()); - break; - case DIV: - result = Long.valueOf(first.longValue() / second.longValue()); - break; - default: - throw new UnsupportedOperationException("Unknown op: " + this); - } - if(first instanceof Long){ - return (T)Long.valueOf(result); - } else { - return (T)Integer.valueOf((int)result); - } - } else if(first instanceof Double || first instanceof Float){ - double result; - switch (this){ - case ADD: - result = Double.valueOf(first.doubleValue() + second.doubleValue()); - break; - case SUB: - result = Double.valueOf(first.doubleValue() - second.doubleValue()); - break; - case MUL: - result = Double.valueOf(first.doubleValue() * second.doubleValue()); - break; - case DIV: - result = Double.valueOf(first.doubleValue() / second.doubleValue()); - break; - default: - throw new UnsupportedOperationException("Unknown op: " + this); - } - if(first instanceof Double){ - return (T)Double.valueOf(result); - } else { - return (T)Float.valueOf((float)result); - } - } else { - throw new UnsupportedOperationException("Not supported type: only Integer, Long, Double, Float supported" + - " here. Got type: " + first.getClass()); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/PairMathOp.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/PairMathOp.java deleted file mode 100644 index db0a9c98b..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/parameter/math/PairMathOp.java +++ /dev/null @@ -1,79 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter.math; - -import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** - * A simple parameter space that implements pairwise mathematical operations on another parameter space. This allows you - * to do things like Z = X + Y, where X and Y are parameter spaces. - * - * @param Type of the parameter space - * @author Alex Black - */ -public class PairMathOp extends AbstractParameterSpace { - - private ParameterSpace first; - private ParameterSpace second; - private Op op; - - public PairMathOp(ParameterSpace first, ParameterSpace second, Op op){ - this.first = first; - this.second = second; - this.op = op; - } - - @Override - public T getValue(double[] parameterValues) { - T f = first.getValue(parameterValues); - T s = second.getValue(parameterValues); - return op.doOp(f, s); - } - - @Override - public int numParameters() { - return first.numParameters() + second.numParameters(); - } - - @Override - public List collectLeaves() { - List l = new ArrayList<>(); - l.addAll(first.collectLeaves()); - l.addAll(second.collectLeaves()); - return l; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - int n1 = first.numParameters(); - int n2 = second.numParameters(); - int[] s1 = Arrays.copyOfRange(indices, 0, n1); - int[] s2 = Arrays.copyOfRange(indices, n1, n1+n2); - first.setIndices(s1); - second.setIndices(s2); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java deleted file mode 100644 index fa503ef6d..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/BaseOptimizationRunner.java +++ /dev/null @@ -1,383 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner; - -import com.google.common.util.concurrent.ListenableFuture; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.exception.ExceptionUtils; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; - -import java.util.*; -import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; - -/** - * BaseOptimization runner: responsible for scheduling tasks, saving results using the result saver, etc. - * - * @author Alex Black - */ -@Slf4j -public abstract class BaseOptimizationRunner implements IOptimizationRunner { - private static final int POLLING_FREQUENCY = 1; - private static final TimeUnit POLLING_FREQUENCY_UNIT = TimeUnit.SECONDS; - - protected OptimizationConfiguration config; - protected Queue> queuedFutures = new ConcurrentLinkedQueue<>(); - protected BlockingQueue> completedFutures = new LinkedBlockingQueue<>(); - protected AtomicInteger totalCandidateCount = new AtomicInteger(); - protected AtomicInteger numCandidatesCompleted = new AtomicInteger(); - protected AtomicInteger numCandidatesFailed = new AtomicInteger(); - protected Double bestScore = null; - protected Long bestScoreTime = null; - protected AtomicInteger bestScoreCandidateIndex = new AtomicInteger(-1); - protected List allResults = new ArrayList<>(); - - protected Map currentStatus = new ConcurrentHashMap<>(); //TODO: better design possible? - - protected ExecutorService futureListenerExecutor; - - protected List statusListeners = new ArrayList<>(); - - - protected BaseOptimizationRunner(OptimizationConfiguration config) { - this.config = config; - - if (config.getTerminationConditions() == null || config.getTerminationConditions().size() == 0) { - throw new IllegalArgumentException("Cannot create BaseOptimizationRunner without TerminationConditions (" - + "termination conditions are null or empty)"); - } - - } - - protected void init() { - futureListenerExecutor = Executors.newFixedThreadPool(maxConcurrentTasks(), new ThreadFactory() { - private AtomicLong counter = new AtomicLong(0); - - @Override - public Thread newThread(Runnable r) { - Thread t = Executors.defaultThreadFactory().newThread(r); - t.setDaemon(true); - t.setName("ArbiterOptimizationRunner-" + counter.getAndIncrement()); - return t; - } - }); - } - - /** - * - */ - @Override - public void execute() { - log.info("{}: execution started", this.getClass().getSimpleName()); - config.setExecutionStartTime(System.currentTimeMillis()); - for (StatusListener listener : statusListeners) { - listener.onInitialization(this); - } - - //Initialize termination conditions (start timers, etc) - for (TerminationCondition c : config.getTerminationConditions()) { - c.initialize(this); - } - - //Queue initial tasks: - List> tempList = new ArrayList<>(100); - while (true) { - //Otherwise: add tasks if required - Future future = null; - try { - future = completedFutures.poll(POLLING_FREQUENCY, POLLING_FREQUENCY_UNIT); - } catch (InterruptedException e) { - //No op? - } - if (future != null) { - tempList.add(future); - } - completedFutures.drainTo(tempList); - - //Process results (if any) - for (Future f : tempList) { - queuedFutures.remove(f); - processReturnedTask(f); - } - - if (tempList.size() > 0) { - for (StatusListener sl : statusListeners) { - sl.onRunnerStatusChange(this); - } - } - tempList.clear(); - - //Check termination conditions: - if (terminate()) { - shutdown(true); - break; - } - - //Add additional tasks - while (config.getCandidateGenerator().hasMoreCandidates() && queuedFutures.size() < maxConcurrentTasks()) { - Candidate candidate = config.getCandidateGenerator().getCandidate(); - CandidateInfo status; - if (candidate.getException() != null) { - //Failed on generation... - status = processFailedCandidates(candidate); - } else { - long created = System.currentTimeMillis(); - ListenableFuture f; - if(config.getDataSource() != null){ - f = execute(candidate, config.getDataSource(), config.getDataSourceProperties(), config.getScoreFunction()); - } else { - f = execute(candidate, config.getDataProvider(), config.getScoreFunction()); - } - f.addListener(new OnCompletionListener(f), futureListenerExecutor); - queuedFutures.add(f); - totalCandidateCount.getAndIncrement(); - - status = new CandidateInfo(candidate.getIndex(), CandidateStatus.Created, null, - created, null, null, candidate.getFlatParameters(), null); - currentStatus.put(candidate.getIndex(), status); - } - - for (StatusListener listener : statusListeners) { - listener.onCandidateStatusChange(status, this, null); - } - } - } - - //Process any final (completed) tasks: - completedFutures.drainTo(tempList); - for (Future f : tempList) { - queuedFutures.remove(f); - processReturnedTask(f); - } - tempList.clear(); - - log.info("Optimization runner: execution complete"); - for (StatusListener listener : statusListeners) { - listener.onShutdown(this); - } - } - - - private CandidateInfo processFailedCandidates(Candidate candidate) { - //In case the candidate fails during the creation of the candidate - - long time = System.currentTimeMillis(); - String stackTrace = ExceptionUtils.getStackTrace(candidate.getException()); - CandidateInfo newStatus = new CandidateInfo(candidate.getIndex(), CandidateStatus.Failed, null, time, time, - time, candidate.getFlatParameters(), stackTrace); - currentStatus.put(candidate.getIndex(), newStatus); - - return newStatus; - } - - /** - * Process returned task (either completed or failed - */ - private void processReturnedTask(Future future) { - long currentTime = System.currentTimeMillis(); - OptimizationResult result; - try { - result = future.get(100, TimeUnit.MILLISECONDS); - } catch (InterruptedException e) { - throw new RuntimeException("Unexpected InterruptedException thrown for task", e); - } catch (ExecutionException e) { - //Note that most of the time, an OptimizationResult is returned even for an exception - //This is just to handle any that are missed there (or, by implementations that don't properly do this) - log.warn("Task failed", e); - - numCandidatesFailed.getAndIncrement(); - return; - } catch (TimeoutException e) { - throw new RuntimeException(e); //TODO - } - - //Update internal status: - CandidateInfo status = currentStatus.get(result.getIndex()); - CandidateInfo newStatus = new CandidateInfo(result.getIndex(), result.getCandidateInfo().getCandidateStatus(), - result.getScore(), status.getCreatedTime(), result.getCandidateInfo().getStartTime(), - currentTime, status.getFlatParams(), result.getCandidateInfo().getExceptionStackTrace()); - currentStatus.put(result.getIndex(), newStatus); - - //Listeners (on complete, etc) should be executed in underlying task - - - if (result.getCandidateInfo().getCandidateStatus() == CandidateStatus.Failed) { - log.info("Task {} failed during execution: {}", result.getIndex(), result.getCandidateInfo().getExceptionStackTrace()); - numCandidatesFailed.getAndIncrement(); - } else { - - //Report completion to candidate generator - config.getCandidateGenerator().reportResults(result); - - Double score = result.getScore(); - log.info("Completed task {}, score = {}", result.getIndex(), result.getScore()); - - boolean minimize = config.getScoreFunction().minimize(); - if (score != null && (bestScore == null - || ((minimize && score < bestScore) || (!minimize && score > bestScore)))) { - if (bestScore == null) { - log.info("New best score: {} (first completed model)", score); - } else { - int idx = result.getIndex(); - int lastBestIdx = bestScoreCandidateIndex.get(); - log.info("New best score: {}, model {} (prev={}, model {})", score, idx, bestScore, lastBestIdx); - } - bestScore = score; - bestScoreTime = System.currentTimeMillis(); - bestScoreCandidateIndex.set(result.getIndex()); - } - numCandidatesCompleted.getAndIncrement(); - - //Model saving is done in the optimization tasks, to avoid CUDA threading issues - ResultReference resultReference = result.getResultReference(); - - if (resultReference != null) - allResults.add(resultReference); - } - } - - @Override - public int numCandidatesTotal() { - return totalCandidateCount.get(); - } - - @Override - public int numCandidatesCompleted() { - return numCandidatesCompleted.get(); - } - - @Override - public int numCandidatesFailed() { - return numCandidatesFailed.get(); - } - - @Override - public int numCandidatesQueued() { - return queuedFutures.size(); - } - - @Override - public Double bestScore() { - return bestScore; - } - - @Override - public Long bestScoreTime() { - return bestScoreTime; - } - - @Override - public int bestScoreCandidateIndex() { - return bestScoreCandidateIndex.get(); - } - - @Override - public List getResults() { - return new ArrayList<>(allResults); - } - - @Override - public OptimizationConfiguration getConfiguration() { - return config; - } - - - @Override - public void addListeners(StatusListener... listeners) { - for (StatusListener l : listeners) { - if (!statusListeners.contains(l)) { - statusListeners.add(l); - } - } - } - - @Override - public void removeListeners(StatusListener... listeners) { - for (StatusListener l : listeners) { - if (statusListeners.contains(l)) { - statusListeners.remove(l); - } - } - } - - @Override - public void removeAllListeners() { - statusListeners.clear(); - } - - @Override - public List getCandidateStatus() { - List list = new ArrayList<>(); - list.addAll(currentStatus.values()); - return list; - } - - private boolean terminate() { - for (TerminationCondition c : config.getTerminationConditions()) { - if (c.terminate(this)) { - log.info("BaseOptimizationRunner global termination condition hit: {}", c); - return true; - } - } - return false; - } - - @AllArgsConstructor - @Data - private class FutureDetails { - private final Future future; - private final long startTime; - private final int index; - } - - @AllArgsConstructor - private class OnCompletionListener implements Runnable { - private Future future; - - @Override - public void run() { - completedFutures.add(future); - } - } - - - protected abstract int maxConcurrentTasks(); - - @Deprecated - protected abstract ListenableFuture execute(Candidate candidate, DataProvider dataProvider, - ScoreFunction scoreFunction); - @Deprecated - protected abstract List> execute(List candidates, - DataProvider dataProvider, ScoreFunction scoreFunction); - - protected abstract ListenableFuture execute(Candidate candidate, Class dataSource, - Properties dataSourceProperties, ScoreFunction scoreFunction); - - protected abstract List> execute(List candidates, Class dataSource, - Properties dataSourceProperties, ScoreFunction scoreFunction); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/CandidateInfo.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/CandidateInfo.java deleted file mode 100644 index e8c7ccf25..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/CandidateInfo.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner; - -import lombok.AllArgsConstructor; -import lombok.Data; - -/** - * Simple helper class to store status of a candidate that is/has been/will be executed - */ -@AllArgsConstructor -@Data -public class CandidateInfo { - - public CandidateInfo() { - //No arg constructor for Jackson - } - - private int index; - private CandidateStatus candidateStatus; - private Double score; - private long createdTime; - private Long startTime; - private Long endTime; - private double[] flatParams; //Same as parameters in Candidate class - private String exceptionStackTrace; -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/CandidateStatus.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/CandidateStatus.java deleted file mode 100644 index a19f89a52..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/CandidateStatus.java +++ /dev/null @@ -1,24 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner; - -/** - * Status for candidates - */ -public enum CandidateStatus { - Created, Running, Complete, Failed, Cancelled -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/IOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/IOptimizationRunner.java deleted file mode 100644 index 50e6dc4b0..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/IOptimizationRunner.java +++ /dev/null @@ -1,67 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner; - -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import java.util.List; - -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface IOptimizationRunner { - - void execute(); - - /** Total number of candidates: created (scheduled), completed and failed */ - int numCandidatesTotal(); - - int numCandidatesCompleted(); - - int numCandidatesFailed(); - - /** Number of candidates running or queued */ - int numCandidatesQueued(); - - /** Best score found so far */ - Double bestScore(); - - /** Time that the best score was found at, or 0 if no jobs have completed successfully */ - Long bestScoreTime(); - - /** Index of the best scoring candidate, or -1 if no candidate has scored yet*/ - int bestScoreCandidateIndex(); - - List getResults(); - - OptimizationConfiguration getConfiguration(); - - void addListeners(StatusListener... listeners); - - void removeListeners(StatusListener... listeners); - - void removeAllListeners(); - - List getCandidateStatus(); - - /** - * @param awaitCompletion If true: await completion of currently scheduled tasks. If false: shutdown immediately, - * cancelling any currently executing tasks - */ - void shutdown(boolean awaitCompletion); -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java deleted file mode 100644 index a3992b09a..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/LocalOptimizationRunner.java +++ /dev/null @@ -1,150 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner; - -import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.ListeningExecutorService; -import com.google.common.util.concurrent.MoreExecutors; -import lombok.Setter; -import org.deeplearning4j.arbiter.optimize.api.*; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicLong; - -/** - * LocalOptimizationRunner: execute hyperparameter optimization - * locally (on current machine, in current JVM). - * - * @author Alex Black - */ -public class LocalOptimizationRunner extends BaseOptimizationRunner { - - public static final int DEFAULT_MAX_CONCURRENT_TASKS = 1; - - private final int maxConcurrentTasks; - - private TaskCreator taskCreator; - private ListeningExecutorService executor; - @Setter - private long shutdownMaxWaitMS = 2L * 24 * 60 * 60 * 1000; - - public LocalOptimizationRunner(OptimizationConfiguration config){ - this(config, null); - } - - public LocalOptimizationRunner(OptimizationConfiguration config, TaskCreator taskCreator) { - this(DEFAULT_MAX_CONCURRENT_TASKS, config, taskCreator); - } - - public LocalOptimizationRunner(int maxConcurrentTasks, OptimizationConfiguration config){ - this(maxConcurrentTasks, config, null); - } - - public LocalOptimizationRunner(int maxConcurrentTasks, OptimizationConfiguration config, TaskCreator taskCreator) { - super(config); - if (maxConcurrentTasks <= 0) - throw new IllegalArgumentException("maxConcurrentTasks must be > 0 (got: " + maxConcurrentTasks + ")"); - this.maxConcurrentTasks = maxConcurrentTasks; - - if(taskCreator == null){ - Class psClass = config.getCandidateGenerator().getParameterSpace().getClass(); - taskCreator = TaskCreatorProvider.defaultTaskCreatorFor(psClass); - if(taskCreator == null){ - throw new IllegalStateException("No TaskCreator was provided and a default TaskCreator cannot be " + - "inferred for ParameterSpace class " + psClass.getName() + ". Please provide a TaskCreator " + - "via the LocalOptimizationRunner constructor"); - } - } - - this.taskCreator = taskCreator; - - ExecutorService exec = Executors.newFixedThreadPool(maxConcurrentTasks, new ThreadFactory() { - private AtomicLong counter = new AtomicLong(0); - - @Override - public Thread newThread(Runnable r) { - Thread t = Executors.defaultThreadFactory().newThread(r); - t.setDaemon(true); - t.setName("LocalCandidateExecutor-" + counter.getAndIncrement()); - return t; - } - }); - executor = MoreExecutors.listeningDecorator(exec); - - init(); - } - - @Override - protected int maxConcurrentTasks() { - return maxConcurrentTasks; - } - - @Override - protected ListenableFuture execute(Candidate candidate, DataProvider dataProvider, - ScoreFunction scoreFunction) { - return execute(Collections.singletonList(candidate), dataProvider, scoreFunction).get(0); - } - - @Override - protected List> execute(List candidates, DataProvider dataProvider, - ScoreFunction scoreFunction) { - List> list = new ArrayList<>(candidates.size()); - for (Candidate candidate : candidates) { - Callable task = - taskCreator.create(candidate, dataProvider, scoreFunction, statusListeners, this); - list.add(executor.submit(task)); - } - return list; - } - - @Override - protected ListenableFuture execute(Candidate candidate, Class dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) { - return execute(Collections.singletonList(candidate), dataSource, dataSourceProperties, scoreFunction).get(0); - } - - @Override - protected List> execute(List candidates, Class dataSource, Properties dataSourceProperties, ScoreFunction scoreFunction) { - List> list = new ArrayList<>(candidates.size()); - for (Candidate candidate : candidates) { - Callable task = taskCreator.create(candidate, dataSource, dataSourceProperties, scoreFunction, statusListeners, this); - list.add(executor.submit(task)); - } - return list; - } - - @Override - public void shutdown(boolean awaitTermination) { - if(awaitTermination){ - try { - executor.shutdown(); - executor.awaitTermination(shutdownMaxWaitMS, TimeUnit.MILLISECONDS); - } catch (InterruptedException e){ - throw new RuntimeException(e); - } - } else { - executor.shutdownNow(); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/BaseStatusListener.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/BaseStatusListener.java deleted file mode 100644 index aca25d95d..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/BaseStatusListener.java +++ /dev/null @@ -1,54 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner.listener; - -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; - -/** - * BaseStatusListener: implements all methods of {@link StatusListener} as no-op. - * Users can extend this and override only the methods actually required - * - * @author Alex Black - */ -public abstract class BaseStatusListener implements StatusListener{ - @Override - public void onInitialization(IOptimizationRunner runner) { - //No op - } - - @Override - public void onShutdown(IOptimizationRunner runner) { - //No op - } - - @Override - public void onRunnerStatusChange(IOptimizationRunner runner) { - //No op - } - - @Override - public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result) { - //No op - } - - @Override - public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) { - //No op - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/StatusChangeType.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/StatusChangeType.java deleted file mode 100644 index d8e2f429b..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/StatusChangeType.java +++ /dev/null @@ -1,26 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner.listener; - -/** - * Created by Alex on 20/07/2017. - */ -public enum StatusChangeType { - - CandidateCompleted, CandidateFailed, CandidateNewScheduled, CandidateNewBestScore - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/StatusListener.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/StatusListener.java deleted file mode 100644 index fa5ba25a2..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/StatusListener.java +++ /dev/null @@ -1,60 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner.listener; - -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; - -/** - * The status Listener interface is used to inspect/track the status of execution, both for individual candidates, - * and for the optimisation runner overall. - * - * @author Alex Black - */ -public interface StatusListener { - - /** Called when optimization runner starts execution */ - void onInitialization(IOptimizationRunner runner); - - /** Called when optimization runner terminates */ - void onShutdown(IOptimizationRunner runner); - - /** Called when any of the summary stats change, for the optimization runner: - * number scheduled, number completed, number failed, best score, etc. */ - void onRunnerStatusChange(IOptimizationRunner runner); - - /** - * Called when the status of the candidate is change. For example created, completed, failed. - * - * @param candidateInfo Candidate information - * @param runner Optimisation runner calling this method - * @param result Optimisation result. Maybe null. - */ - void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result); - - /** - * This method may be called by tasks as they are executing. The intent of this method is to report partial results, - * such as different stages of learning, or scores/evaluations so far - * - * @param candidateInfo Candidate information - * @param candidate Current candidate value/configuration - * @param iteration Current iteration number - */ - void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration); - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/impl/LoggingStatusListener.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/impl/LoggingStatusListener.java deleted file mode 100644 index add0d4ff7..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/runner/listener/impl/LoggingStatusListener.java +++ /dev/null @@ -1,57 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.runner.listener.impl; - -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; - -/** - * Created by Alex on 20/07/2017. - */ -@Slf4j -public class LoggingStatusListener implements StatusListener { - - - @Override - public void onInitialization(IOptimizationRunner runner) { - log.info("Optimization runner: initialized"); - } - - @Override - public void onShutdown(IOptimizationRunner runner) { - log.info("Optimization runner: shut down"); - } - - @Override - public void onRunnerStatusChange(IOptimizationRunner runner) { - log.info("Optimization runner: status change"); - } - - @Override - public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, - OptimizationResult result) { - log.info("Candidate status change: {}", candidateInfo); - } - - @Override - public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) { - log.info("Candidate iteration #{} - {}", iteration, candidate); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java deleted file mode 100644 index 7ca349878..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueDeserializer.java +++ /dev/null @@ -1,52 +0,0 @@ -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.apache.commons.codec.binary.Base64; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; - -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.ObjectInputStream; - -/** - * A custom deserializer to be used in conjunction with {@link FixedValueSerializer} - * @author Alex Black - */ -public class FixedValueDeserializer extends JsonDeserializer { - @Override - public FixedValue deserialize(JsonParser p, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { - JsonNode node = p.getCodec().readTree(p); - String className = node.get("@valueclass").asText(); - Class c; - try { - c = Class.forName(className); - } catch (Exception e) { - throw new RuntimeException(e); - } - - if(node.has("value")){ - //Number, String, Enum - JsonNode valueNode = node.get("value"); - Object o = new ObjectMapper().treeToValue(valueNode, c); - return new FixedValue<>(o); - } else { - //Everything else - JsonNode valueNode = node.get("data"); - String data = valueNode.asText(); - - byte[] b = new Base64().decode(data); - ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(b)); - try { - Object o = ois.readObject(); - return new FixedValue<>(o); - } catch (Throwable t) { - throw new RuntimeException(t); - } - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java deleted file mode 100644 index 80ff7d61d..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/FixedValueSerializer.java +++ /dev/null @@ -1,52 +0,0 @@ -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.apache.commons.net.util.Base64; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.core.type.WritableTypeId; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.SerializerProvider; -import com.fasterxml.jackson.databind.jsontype.TypeSerializer; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.io.ObjectOutputStream; - -import static com.fasterxml.jackson.core.JsonToken.START_OBJECT; - - -/** - * A custom serializer to handle arbitrary object types - * Uses standard JSON where safe (number, string, enumerations) or Java object serialization (bytes -> base64) - * The latter is not an ideal approach, but Jackson doesn't support serialization/deserialization of arbitrary - * objects very well - * - * @author Alex Black - */ -public class FixedValueSerializer extends JsonSerializer { - @Override - public void serialize(FixedValue fixedValue, JsonGenerator j, SerializerProvider serializerProvider) throws IOException { - Object o = fixedValue.getValue(); - - j.writeStringField("@valueclass", o.getClass().getName()); - if(o instanceof Number || o instanceof String || o instanceof Enum){ - j.writeObjectField("value", o); - } else { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(baos); - oos.writeObject(o); - baos.close(); - byte[] b = baos.toByteArray(); - String base64 = new Base64().encodeToString(b); - j.writeStringField("data", base64); - } - } - - @Override - public void serializeWithType(FixedValue value, JsonGenerator gen, SerializerProvider serializers, TypeSerializer typeSer) throws IOException { - WritableTypeId typeId = typeSer.typeId(value, START_OBJECT); - typeSer.writeTypePrefix(gen, typeId); - serialize(value, gen, serializers); - typeSer.writeTypeSuffix(gen, typeId); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/IntegerDistributionDeserializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/IntegerDistributionDeserializer.java deleted file mode 100644 index 6700e9753..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/IntegerDistributionDeserializer.java +++ /dev/null @@ -1,59 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.apache.commons.math3.distribution.*; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.JsonNode; - -import java.io.IOException; - -/** - * Custom Jackson deserializer for integer distributions - * - * @author Alex Black - */ -public class IntegerDistributionDeserializer extends JsonDeserializer { - - @Override - public IntegerDistribution deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { - JsonNode node = p.getCodec().readTree(p); - String simpleName = node.get("distribution").asText(); - - switch (simpleName) { - case "BinomialDistribution": - return new BinomialDistribution(node.get("trials").asInt(), node.get("p").asDouble()); - case "GeometricDistribution": - return new GeometricDistribution(node.get("p").asDouble()); - case "HypergeometricDistribution": - return new HypergeometricDistribution(node.get("populationSize").asInt(), - node.get("numberOfSuccesses").asInt(), node.get("sampleSize").asInt()); - case "PascalDistribution": - return new PascalDistribution(node.get("r").asInt(), node.get("p").asDouble()); - case "PoissonDistribution": - return new PoissonDistribution(node.get("p").asDouble()); - case "UniformIntegerDistribution": - return new UniformIntegerDistribution(node.get("lower").asInt(), node.get("upper").asInt()); - case "ZipfDistribution": - return new ZipfDistribution(node.get("numElements").asInt(), node.get("exponent").asDouble()); - default: - throw new RuntimeException("Unknown or not supported distribution: " + simpleName); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/IntegerDistributionSerializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/IntegerDistributionSerializer.java deleted file mode 100644 index 4157df2f7..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/IntegerDistributionSerializer.java +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.apache.commons.math3.distribution.*; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.SerializerProvider; - -import java.io.IOException; - -/** - * Custom Jackson serializer for integer distributions - * - * @author Alex Black - */ -public class IntegerDistributionSerializer extends JsonSerializer { - @Override - public void serialize(IntegerDistribution d, JsonGenerator j, SerializerProvider serializerProvider) - throws IOException { - Class c = d.getClass(); - String s = c.getSimpleName(); - - j.writeStartObject(); - j.writeStringField("distribution", s); - - if (c == BinomialDistribution.class) { - BinomialDistribution bd = (BinomialDistribution) d; - j.writeNumberField("trials", bd.getNumberOfTrials()); - j.writeNumberField("p", bd.getProbabilityOfSuccess()); - } else if (c == GeometricDistribution.class) { - GeometricDistribution gd = (GeometricDistribution) d; - j.writeNumberField("p", gd.getProbabilityOfSuccess()); - } else if (c == HypergeometricDistribution.class) { - HypergeometricDistribution hd = (HypergeometricDistribution) d; - j.writeNumberField("populationSize", hd.getPopulationSize()); - j.writeNumberField("numberOfSuccesses", hd.getNumberOfSuccesses()); - j.writeNumberField("sampleSize", hd.getSampleSize()); - } else if (c == PascalDistribution.class) { - PascalDistribution pd = (PascalDistribution) d; - j.writeNumberField("r", pd.getNumberOfSuccesses()); - j.writeNumberField("p", pd.getProbabilityOfSuccess()); - } else if (c == PoissonDistribution.class) { - PoissonDistribution pd = (PoissonDistribution) d; - j.writeNumberField("p", pd.getMean()); - } else if (c == UniformIntegerDistribution.class) { - UniformIntegerDistribution ud = (UniformIntegerDistribution) d; - j.writeNumberField("lower", ud.getSupportLowerBound()); - j.writeNumberField("upper", ud.getSupportUpperBound()); - } else if (c == ZipfDistribution.class) { - ZipfDistribution zd = (ZipfDistribution) d; - j.writeNumberField("numElements", zd.getNumberOfElements()); - j.writeNumberField("exponent", zd.getExponent()); - } else { - throw new UnsupportedOperationException("Unknown or not supported IntegerDistribution: " + c); - } - - j.writeEndObject(); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java deleted file mode 100644 index 7ed1bfe45..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/JsonMapper.java +++ /dev/null @@ -1,77 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import com.fasterxml.jackson.annotation.JsonAutoDetect; -import com.fasterxml.jackson.annotation.PropertyAccessor; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializationFeature; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.fasterxml.jackson.datatype.joda.JodaModule; - -/** - * Created by Alex on 16/11/2016. - */ -public class JsonMapper { - - private static ObjectMapper mapper; - private static ObjectMapper yamlMapper; - - static { - mapper = new ObjectMapper(); - mapper.registerModule(new JodaModule()); - mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - mapper.enable(SerializationFeature.INDENT_OUTPUT); - mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); - mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); - mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); - mapper.setVisibility(PropertyAccessor.SETTER, JsonAutoDetect.Visibility.ANY); - yamlMapper = new ObjectMapper(new YAMLFactory()); - yamlMapper.registerModule(new JodaModule()); - yamlMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - yamlMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - yamlMapper.enable(SerializationFeature.INDENT_OUTPUT); - yamlMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); - yamlMapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); - yamlMapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); - } - - private JsonMapper() { - } - - - /** - * Return the yaml mapper - * - * @return - */ - public static ObjectMapper getYamlMapper() { - return yamlMapper; - } - - /** - * Return a json mapper - * - * @return - */ - public static ObjectMapper getMapper() { - return mapper; - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/RealDistributionDeserializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/RealDistributionDeserializer.java deleted file mode 100644 index a30626560..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/RealDistributionDeserializer.java +++ /dev/null @@ -1,78 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.apache.commons.math3.distribution.*; -import org.deeplearning4j.arbiter.optimize.distribution.LogUniformDistribution; -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.JsonNode; - -import java.io.IOException; - -/** - * Created by Alex on 14/02/2017. - */ -public class RealDistributionDeserializer extends JsonDeserializer { - - @Override - public RealDistribution deserialize(JsonParser p, DeserializationContext ctxt) - throws IOException, JsonProcessingException { - JsonNode node = p.getCodec().readTree(p); - String simpleName = node.get("distribution").asText(); - - switch (simpleName) { - case "BetaDistribution": - return new BetaDistribution(node.get("alpha").asDouble(), node.get("beta").asDouble()); - case "CauchyDistribution": - return new CauchyDistribution(node.get("median").asDouble(), node.get("scale").asDouble()); - case "ChiSquaredDistribution": - return new ChiSquaredDistribution(node.get("dof").asDouble()); - case "ExponentialDistribution": - return new ExponentialDistribution(node.get("mean").asDouble()); - case "FDistribution": - return new FDistribution(node.get("numeratorDof").asDouble(), node.get("denominatorDof").asDouble()); - case "GammaDistribution": - return new GammaDistribution(node.get("shape").asDouble(), node.get("scale").asDouble()); - case "LevyDistribution": - return new LevyDistribution(node.get("mu").asDouble(), node.get("c").asDouble()); - case "LogNormalDistribution": - return new LogNormalDistribution(node.get("scale").asDouble(), node.get("shape").asDouble()); - case "NormalDistribution": - return new NormalDistribution(node.get("mean").asDouble(), node.get("stdev").asDouble()); - case "ParetoDistribution": - return new ParetoDistribution(node.get("scale").asDouble(), node.get("shape").asDouble()); - case "TDistribution": - return new TDistribution(node.get("dof").asDouble()); - case "TriangularDistribution": - return new TriangularDistribution(node.get("a").asDouble(), node.get("b").asDouble(), - node.get("c").asDouble()); - case "UniformRealDistribution": - return new UniformRealDistribution(node.get("lower").asDouble(), node.get("upper").asDouble()); - case "WeibullDistribution": - return new WeibullDistribution(node.get("alpha").asDouble(), node.get("beta").asDouble()); - case "LogUniformDistribution": - return new LogUniformDistribution(node.get("min").asDouble(), node.get("max").asDouble()); - default: - throw new RuntimeException("Unknown or not supported distribution: " + simpleName); - } - - - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/RealDistributionSerializer.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/RealDistributionSerializer.java deleted file mode 100644 index b108aad0a..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/RealDistributionSerializer.java +++ /dev/null @@ -1,107 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import org.apache.commons.math3.distribution.*; -import org.deeplearning4j.arbiter.optimize.distribution.LogUniformDistribution; -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.SerializerProvider; - -import java.io.IOException; - -/** - * Custom JSON serializer for Apache commons RealDistribution instances. - * The custom serializer is set up to use the built-in c - */ -public class RealDistributionSerializer extends JsonSerializer { - - @Override - public void serialize(RealDistribution d, JsonGenerator j, SerializerProvider serializerProvider) - throws IOException { - Class c = d.getClass(); - String s = c.getSimpleName(); - - j.writeStartObject(); - j.writeStringField("distribution", s); - - - if (c == BetaDistribution.class) { - BetaDistribution bd = (BetaDistribution) d; - j.writeNumberField("alpha", bd.getAlpha()); - j.writeNumberField("beta", bd.getBeta()); - } else if (c == CauchyDistribution.class) { - CauchyDistribution cd = (CauchyDistribution) d; - j.writeNumberField("median", cd.getMedian()); - j.writeNumberField("scale", cd.getScale()); - } else if (c == ChiSquaredDistribution.class) { - ChiSquaredDistribution cd = (ChiSquaredDistribution) d; - j.writeNumberField("dof", cd.getDegreesOfFreedom()); - } else if (c == ExponentialDistribution.class) { - ExponentialDistribution ed = (ExponentialDistribution) d; - j.writeNumberField("mean", ed.getMean()); - } else if (c == FDistribution.class) { - FDistribution fd = (FDistribution) d; - j.writeNumberField("numeratorDof", fd.getNumeratorDegreesOfFreedom()); - j.writeNumberField("denominatorDof", fd.getDenominatorDegreesOfFreedom()); - } else if (c == GammaDistribution.class) { - GammaDistribution gd = (GammaDistribution) d; - j.writeNumberField("shape", gd.getShape()); - j.writeNumberField("scale", gd.getScale()); - } else if (c == LevyDistribution.class) { - LevyDistribution ld = (LevyDistribution) d; - j.writeNumberField("mu", ld.getLocation()); - j.writeNumberField("c", ld.getScale()); - } else if (c == LogNormalDistribution.class) { - LogNormalDistribution ln = (LogNormalDistribution) d; - j.writeNumberField("scale", ln.getScale()); - j.writeNumberField("shape", ln.getShape()); - } else if (c == NormalDistribution.class) { - NormalDistribution nd = (NormalDistribution) d; - j.writeNumberField("mean", nd.getMean()); - j.writeNumberField("stdev", nd.getStandardDeviation()); - } else if (c == ParetoDistribution.class) { - ParetoDistribution pd = (ParetoDistribution) d; - j.writeNumberField("scale", pd.getScale()); - j.writeNumberField("shape", pd.getShape()); - } else if (c == TDistribution.class) { - TDistribution td = (TDistribution) d; - j.writeNumberField("dof", td.getDegreesOfFreedom()); - } else if (c == TriangularDistribution.class) { - TriangularDistribution td = (TriangularDistribution) d; - j.writeNumberField("a", td.getSupportLowerBound()); - j.writeNumberField("b", td.getMode()); - j.writeNumberField("c", td.getSupportUpperBound()); - } else if (c == UniformRealDistribution.class) { - UniformRealDistribution u = (UniformRealDistribution) d; - j.writeNumberField("lower", u.getSupportLowerBound()); - j.writeNumberField("upper", u.getSupportUpperBound()); - } else if (c == WeibullDistribution.class) { - WeibullDistribution wb = (WeibullDistribution) d; - j.writeNumberField("alpha", wb.getShape()); - j.writeNumberField("beta", wb.getScale()); - } else if (c == LogUniformDistribution.class){ - LogUniformDistribution lud = (LogUniformDistribution) d; - j.writeNumberField("min", lud.getMin()); - j.writeNumberField("max", lud.getMax()); - } else { - throw new UnsupportedOperationException("Unknown or not supported RealDistribution: " + d.getClass()); - } - - j.writeEndObject(); - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java deleted file mode 100644 index b1aae22b2..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/optimize/serde/jackson/YamlMapper.java +++ /dev/null @@ -1,52 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.serde.jackson; - -import com.fasterxml.jackson.annotation.JsonAutoDetect; -import com.fasterxml.jackson.annotation.PropertyAccessor; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializationFeature; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.fasterxml.jackson.datatype.joda.JodaModule; - -/** - * Created by Alex on 16/11/2016. - */ -public class YamlMapper { - - private static final ObjectMapper mapper; - - static { - mapper = new ObjectMapper(new YAMLFactory()); - mapper.registerModule(new JodaModule()); - mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - mapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - mapper.enable(SerializationFeature.INDENT_OUTPUT); - mapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); - mapper.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); - mapper.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); - } - - - private YamlMapper() {} - - public static ObjectMapper getMapper() { - return mapper; - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ClassPathResource.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ClassPathResource.java deleted file mode 100644 index d22db15c8..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ClassPathResource.java +++ /dev/null @@ -1,233 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.util; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.*; -import java.net.MalformedURLException; -import java.net.URI; -import java.net.URISyntaxException; -import java.net.URL; -import java.util.zip.ZipEntry; -import java.util.zip.ZipFile; - -/** - * Simple utility class used to get access to files at the classpath, or packed into jar. - * Based on Spring ClassPathResource implementation + jar internals access implemented. - * - * @author raver119@gmail.com - */ -public class ClassPathResource { - - private String resourceName; - - private static Logger log = LoggerFactory.getLogger(ClassPathResource.class); - - /** - * Builds new ClassPathResource object - * - * @param resourceName String name of resource, to be retrieved - */ - public ClassPathResource(String resourceName) { - if (resourceName == null) - throw new IllegalStateException("Resource name can't be null"); - this.resourceName = resourceName; - } - - /** - * Returns URL of the requested resource - * - * @return URL of the resource, if it's available in current Jar - */ - private URL getUrl() { - ClassLoader loader = null; - try { - loader = Thread.currentThread().getContextClassLoader(); - } catch (Exception e) { - // do nothing - } - - if (loader == null) { - loader = ClassPathResource.class.getClassLoader(); - } - - URL url = loader.getResource(this.resourceName); - if (url == null) { - // try to check for mis-used starting slash - // TODO: see TODO below - if (this.resourceName.startsWith("/")) { - url = loader.getResource(this.resourceName.replaceFirst("[\\\\/]", "")); - if (url != null) - return url; - } else { - // try to add slash, to make clear it's not an issue - // TODO: change this mechanic to actual path purifier - url = loader.getResource("/" + this.resourceName); - if (url != null) - return url; - } - throw new IllegalStateException("Resource '" + this.resourceName + "' cannot be found."); - } - return url; - } - - /** - * Returns requested ClassPathResource as File object - *

      - * Please note: if this method called from compiled jar, temporary file will be created to provide File access - * - * @return File requested at constructor call - * @throws FileNotFoundException - */ - public File getFile() throws FileNotFoundException { - URL url = this.getUrl(); - - if (isJarURL(url)) { - /* - This is actually request for file, that's packed into jar. Probably the current one, but that doesn't matters. - */ - try { - url = extractActualUrl(url); - File file = File.createTempFile("canova_temp", "file"); - file.deleteOnExit(); - - ZipFile zipFile = new ZipFile(url.getFile()); - ZipEntry entry = zipFile.getEntry(this.resourceName); - if (entry == null) { - if (this.resourceName.startsWith("/")) { - entry = zipFile.getEntry(this.resourceName.replaceFirst("/", "")); - if (entry == null) { - throw new FileNotFoundException("Resource " + this.resourceName + " not found"); - } - } else - throw new FileNotFoundException("Resource " + this.resourceName + " not found"); - } - - long size = entry.getSize(); - - InputStream stream = zipFile.getInputStream(entry); - FileOutputStream outputStream = new FileOutputStream(file); - byte[] array = new byte[1024]; - int rd = 0; - long bytesRead = 0; - do { - rd = stream.read(array); - outputStream.write(array, 0, rd); - bytesRead += rd; - } while (bytesRead < size); - - outputStream.flush(); - outputStream.close(); - - stream.close(); - zipFile.close(); - - return file; - } catch (Exception e) { - throw new RuntimeException(e); - } - - } else { - /* - It's something in the actual underlying filesystem, so we can just go for it - */ - - try { - URI uri = new URI(url.toString().replaceAll(" ", "%20")); - return new File(uri.getSchemeSpecificPart()); - } catch (URISyntaxException e) { - return new File(url.getFile()); - } - } - } - - /** - * Checks, if proposed URL is packed into archive. - * - * @param url URL to be checked - * @return True, if URL is archive entry, False otherwise - */ - private boolean isJarURL(URL url) { - String protocol = url.getProtocol(); - return "jar".equals(protocol) || "zip".equals(protocol) || "wsjar".equals(protocol) - || "code-source".equals(protocol) && url.getPath().contains("!/"); - } - - /** - * Extracts parent Jar URL from original ClassPath entry URL. - * - * @param jarUrl Original URL of the resource - * @return URL of the Jar file, containing requested resource - * @throws MalformedURLException - */ - private URL extractActualUrl(URL jarUrl) throws MalformedURLException { - String urlFile = jarUrl.getFile(); - int separatorIndex = urlFile.indexOf("!/"); - if (separatorIndex != -1) { - String jarFile = urlFile.substring(0, separatorIndex); - - try { - return new URL(jarFile); - } catch (MalformedURLException var5) { - if (!jarFile.startsWith("/")) { - jarFile = "/" + jarFile; - } - - return new URL("file:" + jarFile); - } - } else { - return jarUrl; - } - } - - /** - * Returns requested ClassPathResource as InputStream object - * - * @return File requested at constructor call - * @throws FileNotFoundException - */ - public InputStream getInputStream() throws FileNotFoundException { - URL url = this.getUrl(); - if (isJarURL(url)) { - try { - url = extractActualUrl(url); - ZipFile zipFile = new ZipFile(url.getFile()); - ZipEntry entry = zipFile.getEntry(this.resourceName); - - if (entry == null) { - if (this.resourceName.startsWith("/")) { - entry = zipFile.getEntry(this.resourceName.replaceFirst("/", "")); - if (entry == null) { - throw new FileNotFoundException("Resource " + this.resourceName + " not found"); - } - } else - throw new FileNotFoundException("Resource " + this.resourceName + " not found"); - } - - InputStream stream = zipFile.getInputStream(entry); - return stream; - } catch (Exception e) { - throw new RuntimeException(e); - } - } else { - File srcFile = this.getFile(); - return new FileInputStream(srcFile); - } - } -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/CollectionUtils.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/CollectionUtils.java deleted file mode 100644 index eb9275d82..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/CollectionUtils.java +++ /dev/null @@ -1,49 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.util; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; - -public class CollectionUtils { - - /** - * Count the number of unique values in a collection - */ - public static int countUnique(Collection collection) { - HashSet set = new HashSet<>(collection); - return set.size(); - } - - /** - * Returns a list containing only unique values in a collection - */ - public static List getUnique(Collection collection) { - HashSet set = new HashSet<>(); - List out = new ArrayList<>(); - for (T t : collection) { - if (!set.contains(t)) { - out.add(t); - set.add(t); - } - } - return out; - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/LeafUtils.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/LeafUtils.java deleted file mode 100644 index 2a86dc48f..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/LeafUtils.java +++ /dev/null @@ -1,73 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.util; - -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; - -import java.util.ArrayList; -import java.util.List; - -/** - * Created by Alex on 29/06/2017. - */ -public class LeafUtils { - - private LeafUtils() {} - - /** - * Returns a list of unique objects, not using the .equals() method, but rather using == - * - * @param allLeaves Leaf values to process - * @return A list of unique parameter space values - */ - public static List getUniqueObjects(List allLeaves) { - List unique = new ArrayList<>(); - for (ParameterSpace p : allLeaves) { - //This isn't especially efficient, but small number of parameters in general means it's fine - boolean found = false; - for (ParameterSpace q : unique) { - if (p == q) { - found = true; - } - } - if (!found) { - unique.add(p); - } - } - - return unique; - } - - /** - * Count the number of unique parameters in the specified leaf nodes - * - * @param allLeaves Leaf values to count the parameters fore - * @return Number of parameters for all unique objects - */ - public static int countUniqueParameters(List allLeaves) { - List unique = getUniqueObjects(allLeaves); - int count = 0; - for (ParameterSpace ps : unique) { - if (!ps.isLeaf()) { - throw new IllegalStateException("Method should only be used with leaf nodes"); - } - count += ps.numParameters(); - } - return count; - } - -} diff --git a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ObjectUtils.java b/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ObjectUtils.java deleted file mode 100644 index 9c3213430..000000000 --- a/arbiter/arbiter-core/src/main/java/org/deeplearning4j/arbiter/util/ObjectUtils.java +++ /dev/null @@ -1,61 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.util; - -import java.util.Arrays; - -/** - * @author Alex Black - */ -public class ObjectUtils { - - private ObjectUtils() {} - - /** - * Get the string representation of the object. Arrays, including primitive arrays, are printed using - * Arrays.toString(...) methods. - * - * @param v Value to convert to a string - * @return String representation - */ - public static String valueToString(Object v) { - if (v.getClass().isArray()) { - if (v.getClass().getComponentType().isPrimitive()) { - Class c = v.getClass().getComponentType(); - if (c == int.class) { - return Arrays.toString((int[]) v); - } else if (c == double.class) { - return Arrays.toString((double[]) v); - } else if (c == float.class) { - return Arrays.toString((float[]) v); - } else if (c == long.class) { - return Arrays.toString((long[]) v); - } else if (c == byte.class) { - return Arrays.toString((byte[]) v); - } else if (c == short.class) { - return Arrays.toString((short[]) v); - } else { - return v.toString(); - } - } else { - return Arrays.toString((Object[]) v); - } - } else { - return v.toString(); - } - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java deleted file mode 100644 index cfb5e2556..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,49 +0,0 @@ -/* ****************************************************************************** - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.arbiter.optimize; - -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.BaseDL4JTest; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import java.util.*; - -/** - * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) - * extends BaseDl4jTest - either directly or indirectly. - * Other than a small set of exceptions, all tests must extend this - * - * @author Alex Black - */ - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.deeplearning4j.arbiter.optimize"; - } - - @Override - protected Class getBaseClass() { - return BaseDL4JTest.class; - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/BraninFunction.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/BraninFunction.java deleted file mode 100644 index 4d507ee7d..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/BraninFunction.java +++ /dev/null @@ -1,156 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize; - -import lombok.AllArgsConstructor; -import lombok.Data; -import org.deeplearning4j.arbiter.optimize.api.*; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; - -import java.io.Serializable; -import java.util.*; -import java.util.concurrent.Callable; - -public class BraninFunction { - public static class BraninSpace extends AbstractParameterSpace { - private int[] indices; - private ParameterSpace first = new ContinuousParameterSpace(-5, 10); - private ParameterSpace second = new ContinuousParameterSpace(0, 15); - - @Override - public BraninConfig getValue(double[] parameterValues) { - double f = first.getValue(parameterValues); - double s = second.getValue(parameterValues); - return new BraninConfig(f, s); //-5 to +10 and 0 to 15 - } - - @Override - public int numParameters() { - return 2; - } - - @Override - public List collectLeaves() { - List list = new ArrayList<>(); - list.addAll(first.collectLeaves()); - list.addAll(second.collectLeaves()); - return list; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - throw new UnsupportedOperationException(); - } - } - - @AllArgsConstructor - @Data - public static class BraninConfig implements Serializable { - private double x1; - private double x2; - } - - public static class BraninScoreFunction implements ScoreFunction { - private static final double a = 1.0; - private static final double b = 5.1 / (4.0 * Math.PI * Math.PI); - private static final double c = 5.0 / Math.PI; - private static final double r = 6.0; - private static final double s = 10.0; - private static final double t = 1.0 / (8.0 * Math.PI); - - @Override - public double score(Object m, DataProvider data, Map dataParameters) { - BraninConfig model = (BraninConfig) m; - double x1 = model.getX1(); - double x2 = model.getX2(); - - return a * Math.pow(x2 - b * x1 * x1 + c * x1 - r, 2.0) + s * (1 - t) * Math.cos(x1) + s; - } - - @Override - public double score(Object model, Class dataSource, Properties dataSourceProperties) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean minimize() { - return true; - } - - @Override - public List> getSupportedModelTypes() { - return Collections.>singletonList(BraninConfig.class); - } - - @Override - public List> getSupportedDataTypes() { - return Collections.>singletonList(Object.class); - } - } - - public static class BraninTaskCreator implements TaskCreator { - @Override - public Callable create(final Candidate c, DataProvider dataProvider, - final ScoreFunction scoreFunction, final List statusListeners, - IOptimizationRunner runner) { - - return new Callable() { - @Override - public OptimizationResult call() throws Exception { - - BraninConfig candidate = (BraninConfig) c.getValue(); - - double score = scoreFunction.score(candidate, null, (Map) null); -// System.out.println(candidate.getX1() + "\t" + candidate.getX2() + "\t" + score); - - Thread.sleep(20); - - if (statusListeners != null) { - for (StatusListener sl : statusListeners) { - sl.onCandidateIteration(null, null, 0); - } - } - - CandidateInfo ci = new CandidateInfo(-1, CandidateStatus.Complete, score, - System.currentTimeMillis(), null, null, null, null); - - return new OptimizationResult(c, score, c.getIndex(), null, ci, null); - } - }; - } - - @Override - public Callable create(Candidate candidate, Class dataSource, - Properties dataSourceProperties, ScoreFunction scoreFunction, - List statusListeners, IOptimizationRunner runner) { - throw new UnsupportedOperationException(); - } - } - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java deleted file mode 100644 index 9410fa602..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGeneticSearch.java +++ /dev/null @@ -1,118 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.generator.GeneticSearchCandidateGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException; -import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusListener; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class TestGeneticSearch extends BaseDL4JTest { - public class TestSelectionOperator extends SelectionOperator { - - @Override - public double[] buildNextGenes() { - throw new GeneticGenerationException("Forced exception to test exception handling."); - } - } - - public class TestTerminationCondition implements TerminationCondition { - - public boolean hasAFailedCandidate = false; - public int evalCount = 0; - - @Override - public void initialize(IOptimizationRunner optimizationRunner) {} - - @Override - public boolean terminate(IOptimizationRunner optimizationRunner) { - if (++evalCount == 50) { - // Generator did not handle GeneticGenerationException - return true; - } - - for (CandidateInfo candidateInfo : optimizationRunner.getCandidateStatus()) { - if (candidateInfo.getCandidateStatus() == CandidateStatus.Failed) { - hasAFailedCandidate = true; - return true; - } - } - - return false; - } - } - - @Test - public void GeneticSearchCandidateGenerator_getCandidate_ShouldGenerateCandidates() throws Exception { - - ScoreFunction scoreFunction = new BraninFunction.BraninScoreFunction(); - - //Define configuration: - CandidateGenerator candidateGenerator = - new GeneticSearchCandidateGenerator.Builder(new BraninFunction.BraninSpace(), scoreFunction) - .build(); - - TestTerminationCondition testTerminationCondition = new TestTerminationCondition(); - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).scoreFunction(scoreFunction) - .terminationConditions(new MaxCandidatesCondition(50), testTerminationCondition).build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator()); - - runner.addListeners(new LoggingStatusListener()); - runner.execute(); - - Assertions.assertFalse(testTerminationCondition.hasAFailedCandidate); - } - - @Test - public void GeneticSearchCandidateGenerator_getCandidate_GeneticExceptionShouldMarkCandidateAsFailed() { - - ScoreFunction scoreFunction = new BraninFunction.BraninScoreFunction(); - - //Define configuration: - CandidateGenerator candidateGenerator = - new GeneticSearchCandidateGenerator.Builder(new BraninFunction.BraninSpace(), scoreFunction) - .selectionOperator(new TestSelectionOperator()).build(); - - TestTerminationCondition testTerminationCondition = new TestTerminationCondition(); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).scoreFunction(scoreFunction) - .terminationConditions(testTerminationCondition).build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator()); - - runner.addListeners(new LoggingStatusListener()); - runner.execute(); - - Assertions.assertTrue(testTerminationCondition.hasAFailedCandidate); - } - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java deleted file mode 100644 index 45a9aadf5..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestGridSearch.java +++ /dev/null @@ -1,104 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator; -import org.junit.jupiter.api.Test; - -import java.util.HashMap; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.*; - -public class TestGridSearch extends BaseDL4JTest { - - @Test - public void testIndexing() { - int[] nValues = {2, 3}; - int prod = 2 * 3; - double[][] expVals = new double[][] {{0.0, 0.0}, {1.0, 0.0}, {0.0, 0.5}, {1.0, 0.5}, {0.0, 1.0}, {1.0, 1.0}}; - for (int i = 0; i < prod; i++) { - double[] out = GridSearchCandidateGenerator.indexToValues(nValues, i, prod); - double[] exp = expVals[i]; - assertArrayEquals(exp, out, 1e-4); - } - } - - @Test - public void testGeneration() throws Exception { - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>()); - - //Define configuration: - CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(new BraninFunction.BraninSpace(), 4, - GridSearchCandidateGenerator.Mode.Sequential, commands); - - //Check sequential: - double[] expValuesFirst = {-5, 0, 5, 10}; //Range: -5 to +10, with 4 values - double[] expValuesSecond = {0, 5, 10, 15}; //Range: 0 to +15, with 4 values - for (int i = 0; i < 4 * 4; i++) { - BraninFunction.BraninConfig conf = (BraninFunction.BraninConfig) candidateGenerator.getCandidate().getValue(); - double expF = expValuesFirst[i % 4]; //Changes most rapidly - double expS = expValuesSecond[i / 4]; - - double actF = conf.getX1(); - double actS = conf.getX2(); - - assertEquals(expF, actF, 1e-4); - assertEquals(expS, actS, 1e-4); - } - - //Check random order. specifically: check that all values are generated, in some order - double[][] orderedOutput = new double[16][2]; - for (int i = 0; i < expValuesFirst.length; i++) { - for (int j = 0; j < expValuesSecond.length; j++) { - orderedOutput[4 * j + i][0] = expValuesFirst[i]; - orderedOutput[4 * j + i][1] = expValuesSecond[j]; - } - } - - - candidateGenerator = new GridSearchCandidateGenerator(new BraninFunction.BraninSpace(), 4, - GridSearchCandidateGenerator.Mode.RandomOrder, commands); - boolean[] seen = new boolean[16]; - int seenCount = 0; - for (int i = 0; i < 4 * 4; i++) { - assertTrue(candidateGenerator.hasMoreCandidates()); - BraninFunction.BraninConfig config = (BraninFunction.BraninConfig) candidateGenerator.getCandidate().getValue(); - double x1 = config.getX1(); - double x2 = config.getX2(); - //Work out which of the values this is... - boolean matched = false; - for (int j = 0; j < 16; j++) { - if (Math.abs(orderedOutput[j][0] - x1) < 1e-5 && Math.abs(orderedOutput[j][1] - x2) < 1e-5) { - matched = true; - if (seen[j]) - fail("Same candidate generated multiple times"); - seen[j] = true; - seenCount++; - break; - } - } - assertTrue(matched, "Candidate " + x1 + ", " + x2 + " not found; invalid?"); - } - assertFalse(candidateGenerator.hasMoreCandidates()); - assertEquals(16, seenCount); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java deleted file mode 100644 index 225894d6f..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestJson.java +++ /dev/null @@ -1,122 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize; - -import org.apache.commons.math3.distribution.LogNormalDistribution; -import org.apache.commons.math3.distribution.NormalDistribution; -import org.apache.commons.math3.distribution.UniformIntegerDistribution; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.parameter.BooleanSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.junit.jupiter.api.Test; -import com.fasterxml.jackson.annotation.JsonAutoDetect; -import com.fasterxml.jackson.annotation.PropertyAccessor; -import com.fasterxml.jackson.core.JsonFactory; -import com.fasterxml.jackson.databind.DeserializationFeature; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.SerializationFeature; -import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; -import com.fasterxml.jackson.datatype.joda.JodaModule; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** - * Created by Alex on 02/02/2017. - */ -public class TestJson extends BaseDL4JTest { - - protected static ObjectMapper getObjectMapper(JsonFactory factory) { - ObjectMapper om = new ObjectMapper(factory); - om.registerModule(new JodaModule()); - om.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false); - om.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false); - om.enable(SerializationFeature.INDENT_OUTPUT); - om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.NONE); - om.setVisibility(PropertyAccessor.FIELD, JsonAutoDetect.Visibility.ANY); - om.setVisibility(PropertyAccessor.CREATOR, JsonAutoDetect.Visibility.ANY); - return om; - } - - private static ObjectMapper jsonMapper = getObjectMapper(new JsonFactory()); - private static ObjectMapper yamlMapper = getObjectMapper(new YAMLFactory()); - - - @Test - public void testParameterSpaceJson() throws Exception { - - List> l = new ArrayList<>(); - l.add(new FixedValue<>(1.0)); - l.add(new FixedValue<>(1)); - l.add(new FixedValue<>("string")); - l.add(new ContinuousParameterSpace(-1, 1)); - l.add(new ContinuousParameterSpace(new LogNormalDistribution(1, 1))); - l.add(new ContinuousParameterSpace(new NormalDistribution(2, 0.01))); - l.add(new DiscreteParameterSpace<>(1, 5, 7)); - l.add(new DiscreteParameterSpace<>("first", "second", "third")); - l.add(new IntegerParameterSpace(0, 10)); - l.add(new IntegerParameterSpace(new UniformIntegerDistribution(0, 50))); - l.add(new BooleanSpace()); - - for (ParameterSpace ps : l) { - String strJson = jsonMapper.writeValueAsString(ps); - String strYaml = yamlMapper.writeValueAsString(ps); - - ParameterSpace fromJson = jsonMapper.readValue(strJson, ParameterSpace.class); - ParameterSpace fromYaml = yamlMapper.readValue(strYaml, ParameterSpace.class); - - assertEquals(ps, fromJson); - assertEquals(ps, fromYaml); - } - } - - @Test - public void testCandidateGeneratorJson() throws Exception { - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>()); - - List l = new ArrayList<>(); - l.add(new GridSearchCandidateGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), 10, - GridSearchCandidateGenerator.Mode.Sequential, commands)); - l.add(new GridSearchCandidateGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), 10, - GridSearchCandidateGenerator.Mode.RandomOrder, commands)); - l.add(new RandomSearchGenerator(new DiscreteParameterSpace<>(0, 1, 2, 3, 4, 5), commands)); - - for (CandidateGenerator cg : l) { - String strJson = jsonMapper.writeValueAsString(cg); - String strYaml = yamlMapper.writeValueAsString(cg); - - CandidateGenerator fromJson = jsonMapper.readValue(strJson, CandidateGenerator.class); - CandidateGenerator fromYaml = yamlMapper.readValue(strYaml, CandidateGenerator.class); - - assertEquals(cg, fromJson); - assertEquals(cg, fromYaml); - } - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java deleted file mode 100644 index db7702b76..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/TestRandomSearch.java +++ /dev/null @@ -1,61 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.impl.LoggingStatusListener; -import org.junit.jupiter.api.Test; - -import java.util.HashMap; -import java.util.Map; - -/** - * - * Test random search on the Branin Function: - * http://www.sfu.ca/~ssurjano/branin.html - */ -public class TestRandomSearch extends BaseDL4JTest { - - @Test - public void test() throws Exception { - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, new HashMap<>()); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(new BraninFunction.BraninSpace(), commands); - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).scoreFunction(new BraninFunction.BraninScoreFunction()) - .terminationConditions(new MaxCandidatesCondition(50)).build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new BraninFunction.BraninTaskCreator()); - - runner.addListeners(new LoggingStatusListener()); - runner.execute(); - - -// System.out.println("----- Complete -----"); - } - - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java deleted file mode 100644 index e2a6044ce..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/distribution/TestLogUniform.java +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.distribution; - -import org.apache.commons.math3.distribution.RealDistribution; -import org.deeplearning4j.BaseDL4JTest; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -public class TestLogUniform extends BaseDL4JTest { - - @Test - public void testSimple(){ - - double min = 0.5; - double max = 3; - - double logMin = Math.log(min); - double logMax = Math.log(max); - - RealDistribution rd = new LogUniformDistribution(min, max); - - for(double d = 0.1; d<= 3.5; d+= 0.1){ - double density = rd.density(d); - double cumulative = rd.cumulativeProbability(d); - double dExp; - double cumExp; - if(d < min){ - dExp = 0; - cumExp = 0; - } else if( d > max){ - dExp = 0; - cumExp = 1; - } else { - dExp = 1.0 / (d * (logMax-logMin)); - cumExp = (Math.log(d) - logMin) / (logMax - logMin); - } - - assertTrue(dExp >= 0); - assertTrue(cumExp >= 0); - assertTrue(cumExp <= 1.0); - assertEquals(dExp, density, 1e-5); - assertEquals(cumExp, cumulative, 1e-5); - } - - rd.reseedRandomGenerator(12345); - for( int i=0; i<100; i++ ){ - double d = rd.sample(); - assertTrue(d >= min); - assertTrue(d <= max); - } - } - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestCrossoverOperator.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestCrossoverOperator.java deleted file mode 100644 index 9297c3df7..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestCrossoverOperator.java +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; - -public class TestCrossoverOperator extends CrossoverOperator { - - private final CrossoverResult[] results; - private int resultIdx = 0; - - public PopulationModel getPopulationModel() { - return populationModel; - } - - public TestCrossoverOperator(CrossoverResult[] results) { - this.results = results; - } - - @Override - public CrossoverResult crossover() { - return results[resultIdx++]; - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestMutationOperator.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestMutationOperator.java deleted file mode 100644 index 4718714d1..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestMutationOperator.java +++ /dev/null @@ -1,34 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator; - -public class TestMutationOperator implements MutationOperator { - - private final boolean[] results; - private int resultIdx = 0; - - public TestMutationOperator(boolean[] results) { - this.results = results; - } - - @Override - public boolean mutate(double[] genes) { - return results[resultIdx++]; - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestParentSelection.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestParentSelection.java deleted file mode 100644 index 7f9c33b14..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestParentSelection.java +++ /dev/null @@ -1,52 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; - -import java.util.List; - -public class TestParentSelection extends TwoParentSelection { - - public boolean hasBeenInitialized; - - private final double[][] parents; - - public TestParentSelection(double[][] parents) { - this.parents = parents; - } - - public TestParentSelection() { - this(null); - } - - @Override - public void initializeInstance(List population) { - super.initializeInstance(population); - hasBeenInitialized = true; - } - - @Override - public double[][] selectParents() { - return parents; - } - - public List getPopulation() { - return population; - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestPopulationInitializer.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestPopulationInitializer.java deleted file mode 100644 index 926555f79..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestPopulationInitializer.java +++ /dev/null @@ -1,30 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic; - -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; - -import java.util.ArrayList; -import java.util.List; - -public class TestPopulationInitializer implements PopulationInitializer { - @Override - public List getInitializedPopulation(int size) { - return new ArrayList<>(); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestRandomGenerator.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestRandomGenerator.java deleted file mode 100644 index abeba96e8..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/TestRandomGenerator.java +++ /dev/null @@ -1,88 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic; - -import org.apache.commons.lang3.NotImplementedException; -import org.apache.commons.math3.random.RandomGenerator; - -public class TestRandomGenerator implements RandomGenerator { - private final int[] intRandomNumbers; - private int currentIntIdx = 0; - private final double[] doubleRandomNumbers; - private int currentDoubleIdx = 0; - - - public TestRandomGenerator(int[] intRandomNumbers, double[] doubleRandomNumbers) { - this.intRandomNumbers = intRandomNumbers; - this.doubleRandomNumbers = doubleRandomNumbers; - } - - @Override - public void setSeed(int i) { - - } - - @Override - public void setSeed(int[] ints) { - - } - - @Override - public void setSeed(long l) { - - } - - @Override - public void nextBytes(byte[] bytes) { - - } - - @Override - public int nextInt() { - return intRandomNumbers[currentIntIdx++]; - } - - @Override - public int nextInt(int i) { - return intRandomNumbers[currentIntIdx++]; - } - - @Override - public long nextLong() { - throw new NotImplementedException("Not implemented"); - } - - @Override - public boolean nextBoolean() { - throw new NotImplementedException("Not implemented"); - } - - @Override - public float nextFloat() { - throw new NotImplementedException("Not implemented"); - } - - @Override - public double nextDouble() { - return doubleRandomNumbers[currentDoubleIdx++]; - } - - @Override - public double nextGaussian() { - throw new NotImplementedException("Not implemented"); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java deleted file mode 100644 index f234465b0..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ArithmeticCrossoverTests.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.apache.commons.math3.random.RandomGenerator; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.ArithmeticCrossover; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; -import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; -import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class ArithmeticCrossoverTests extends BaseDL4JTest { - - @Test - public void ArithmeticCrossover_Crossover_OutsideCrossoverRate_ShouldReturnParent0() { - double[][] parents = new double[2][]; - parents[0] = new double[] {1.0}; - parents[1] = new double[] {2.0}; - - TestParentSelection parentSelection = new TestParentSelection(parents); - - RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0}); - - ArithmeticCrossover sut = - new ArithmeticCrossover.Builder().parentSelection(parentSelection).randomGenerator(rng).build(); - CrossoverResult result = sut.crossover(); - - Assertions.assertFalse(result.isModified()); - Assertions.assertEquals(1, result.getGenes().length); - Assertions.assertEquals(1.0, result.getGenes()[0], 0.001); - } - - @Test - public void ArithmeticCrossover_Crossover_WithinCrossoverRate_ShouldReturnLinearCombination() { - double[][] parents = new double[2][]; - parents[0] = new double[] {1.0}; - parents[1] = new double[] {2.0}; - - TestParentSelection parentSelection = new TestParentSelection(parents); - - RandomGenerator rng = new TestRandomGenerator(null, new double[] {0.1, 0.1}); - - ArithmeticCrossover sut = - new ArithmeticCrossover.Builder().parentSelection(parentSelection).randomGenerator(rng).build(); - CrossoverResult result = sut.crossover(); - - Assertions.assertTrue(result.isModified()); - Assertions.assertEquals(1, result.getGenes().length); - Assertions.assertEquals(1.9, result.getGenes()[0], 0.001); - } - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java deleted file mode 100644 index 2cea0b608..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverOperatorTests.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator; -import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class CrossoverOperatorTests extends BaseDL4JTest { - - @Test - public void CrossoverOperator_initializeInstance_ShouldInitPopulationModel() throws IllegalAccessException { - TestCrossoverOperator sut = new TestCrossoverOperator(null); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel populationModel = - new PopulationModel.Builder().populationInitializer(populationInitializer).build(); - sut.initializeInstance(populationModel); - - Assertions.assertSame(populationModel, sut.getPopulationModel()); - - - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java deleted file mode 100644 index 120fa8a28..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/CrossoverPointsGeneratorTests.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.apache.commons.math3.random.RandomGenerator; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.utils.CrossoverPointsGenerator; -import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.Deque; - -public class CrossoverPointsGeneratorTests extends BaseDL4JTest { - - @Test - public void CrossoverPointsGenerator_FixedNumberCrossovers() { - RandomGenerator rng = new TestRandomGenerator(new int[] {0}, null); - CrossoverPointsGenerator sut = new CrossoverPointsGenerator(10, 2, 2, rng); - - Deque result = sut.getCrossoverPoints(); - - Assertions.assertEquals(3, result.size()); - int a = result.pop(); - int b = result.pop(); - int c = result.pop(); - Assertions.assertTrue(a < b); - Assertions.assertTrue(b < c); - Assertions.assertEquals(Integer.MAX_VALUE, c); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java deleted file mode 100644 index 64d56e5ac..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/KPointCrossoverTests.java +++ /dev/null @@ -1,67 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.apache.commons.math3.random.RandomGenerator; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.KPointCrossover; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; -import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; -import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class KPointCrossoverTests extends BaseDL4JTest { - - @Test - public void KPointCrossover_BelowCrossoverRate_ShouldReturnParent0() { - RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0}); - - double[][] parents = new double[2][]; - parents[0] = new double[] {0.0}; - parents[1] = new double[] {1.0}; - TwoParentSelection parentSelection = new TestParentSelection(parents); - KPointCrossover sut = new KPointCrossover.Builder().randomGenerator(rng).crossoverRate(0.0) - .parentSelection(parentSelection).build(); - - CrossoverResult result = sut.crossover(); - - Assertions.assertFalse(result.isModified()); - Assertions.assertSame(parents[0], result.getGenes()); - } - - @Test - public void KPointCrossover_FixedNumberOfCrossovers() { - RandomGenerator rng = new TestRandomGenerator(new int[] {0, 1}, new double[] {0.0}); - - double[][] parents = new double[3][]; - parents[0] = new double[] {0.0, 0.0, 0.0, 0.0, 0.0}; - parents[1] = new double[] {1.0, 1.0, 1.0, 1.0, 1.0}; - parents[2] = new double[] {2.0, 2.0, 2.0, 2.0, 2.0}; - TwoParentSelection parentSelection = new TestParentSelection(parents); - KPointCrossover sut = new KPointCrossover.Builder().randomGenerator(rng).crossoverRate(1.0) - .parentSelection(parentSelection).numCrossovers(2).build(); - - CrossoverResult result = sut.crossover(); - - Assertions.assertTrue(result.isModified()); - for (double x : result.getGenes()) { - Assertions.assertTrue(x == 0.0 || x == 1.0); - } - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java deleted file mode 100644 index ca65e6ef0..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/ParentSelectionTests.java +++ /dev/null @@ -1,39 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.List; - -public class ParentSelectionTests extends BaseDL4JTest { - - @Test - public void ParentSelection_InitializeInstance_ShouldInitPopulation() { - TestParentSelection sut = new TestParentSelection(); - - List population = new ArrayList<>(); - sut.initializeInstance(population); - - Assertions.assertSame(population, sut.getPopulation()); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java deleted file mode 100644 index 214ee0181..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/RandomTwoParentSelectionTests.java +++ /dev/null @@ -1,47 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.apache.commons.math3.random.RandomGenerator; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.RandomTwoParentSelection; -import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.List; - -public class RandomTwoParentSelectionTests extends BaseDL4JTest { - @Test - public void RandomTwoParentSelection_ShouldReturnTwoDifferentParents() { - RandomGenerator rng = new TestRandomGenerator(new int[] {1, 1, 1, 0}, null); - RandomTwoParentSelection sut = new RandomTwoParentSelection(rng); - - List population = new ArrayList<>(); - population.add(new Chromosome(new double[] {1, 1, 1}, 1.0)); - population.add(new Chromosome(new double[] {2, 2, 2}, 2.0)); - population.add(new Chromosome(new double[] {3, 3, 3}, 3.0)); - sut.initializeInstance(population); - - double[][] result = sut.selectParents(); - - Assertions.assertSame(population.get(1).getGenes(), result[0]); - Assertions.assertSame(population.get(0).getGenes(), result[1]); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java deleted file mode 100644 index 52fba0c59..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/SinglePointCrossoverTests.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.apache.commons.math3.random.RandomGenerator; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.SinglePointCrossover; -import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; -import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class SinglePointCrossoverTests extends BaseDL4JTest { - @Test - public void SinglePointCrossover_BelowCrossoverRate_ShouldReturnParent0() { - RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0}); - - double[][] parents = new double[2][]; - parents[0] = new double[] {1.0, 1.0, 1.0}; - parents[1] = new double[] {2.0, 2.0, 2.0}; - TestParentSelection parentSelection = new TestParentSelection(parents); - - SinglePointCrossover sut = new SinglePointCrossover.Builder().parentSelection(parentSelection) - .randomGenerator(rng).crossoverRate(0.0).build(); - - CrossoverResult result = sut.crossover(); - - Assertions.assertFalse(result.isModified()); - Assertions.assertSame(parents[0], result.getGenes()); - } - - @Test - public void SinglePointCrossover_ShouldReturnSingleSplit() { - RandomGenerator rng = new TestRandomGenerator(new int[] {2}, new double[] {0.1}); - - double[][] parents = new double[2][]; - parents[0] = new double[] {1.0, 1.0, 1.0}; - parents[1] = new double[] {2.0, 2.0, 2.0}; - TestParentSelection parentSelection = new TestParentSelection(parents); - - SinglePointCrossover sut = new SinglePointCrossover.Builder().parentSelection(parentSelection) - .randomGenerator(rng).crossoverRate(0.5).build(); - - CrossoverResult result = sut.crossover(); - - Assertions.assertTrue(result.isModified()); - Assertions.assertEquals(1.0, result.getGenes()[0], 0.0); - Assertions.assertEquals(1.0, result.getGenes()[1], 0.0); - Assertions.assertEquals(2.0, result.getGenes()[2], 0.0); - - } - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java deleted file mode 100644 index 972d528ed..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/TwoParentsCrossoverOperatorTests.java +++ /dev/null @@ -1,71 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.apache.commons.lang3.NotImplementedException; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.TwoParentsCrossoverOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.parentselection.TwoParentSelection; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; -import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class TwoParentsCrossoverOperatorTests extends BaseDL4JTest { - - class TestTwoParentsCrossoverOperator extends TwoParentsCrossoverOperator { - - public TestTwoParentsCrossoverOperator(TwoParentSelection parentSelection) { - super(parentSelection); - } - - public TwoParentSelection getParentSelection() { - return parentSelection; - } - - @Override - public CrossoverResult crossover() { - throw new NotImplementedException("Not implemented"); - } - } - - @Test - public void TwoParentsCrossoverOperator_ctor_ShouldInitParentSelection() { - TestParentSelection parentSelection = new TestParentSelection(); - TestTwoParentsCrossoverOperator sut = new TestTwoParentsCrossoverOperator(parentSelection); - - Assertions.assertSame(parentSelection, sut.getParentSelection()); - } - - @Test - public void TwoParentsCrossoverOperator_initializeInstanceShouldInitializeParentSelection() { - TestParentSelection parentSelection = new TestParentSelection(); - TestTwoParentsCrossoverOperator sut = new TestTwoParentsCrossoverOperator(parentSelection); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - PopulationModel populationModel = - new PopulationModel.Builder().populationInitializer(populationInitializer).build(); - - sut.initializeInstance(populationModel); - - Assertions.assertTrue(parentSelection.hasBeenInitialized); - } - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java deleted file mode 100644 index 5efff80b2..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/crossover/UniformCrossoverTests.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.crossover; - -import org.apache.commons.math3.random.RandomGenerator; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.UniformCrossover; -import org.deeplearning4j.arbiter.optimize.genetic.TestParentSelection; -import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class UniformCrossoverTests extends BaseDL4JTest { - - @Test - public void UniformCrossover_BelowCrossoverRate_ShouldReturnParent0() { - RandomGenerator rng = new TestRandomGenerator(null, new double[] {1.0}); - - double[][] parents = new double[2][]; - parents[0] = new double[] {1.0, 1.0, 1.0}; - parents[1] = new double[] {2.0, 2.0, 2.0}; - TestParentSelection parentSelection = new TestParentSelection(parents); - - UniformCrossover sut = new UniformCrossover.Builder().parentSelection(parentSelection).randomGenerator(rng) - .crossoverRate(0.0).build(); - - CrossoverResult result = sut.crossover(); - - Assertions.assertFalse(result.isModified()); - Assertions.assertSame(parents[0], result.getGenes()); - } - - @Test - public void UniformCrossover_ShouldReturnMixedParents() { - RandomGenerator rng = new TestRandomGenerator(null, new double[] {0.1, 0.1, 0.3, 0.2}); - - double[][] parents = new double[2][]; - parents[0] = new double[] {1.0, 1.0, 1.0}; - parents[1] = new double[] {2.0, 2.0, 2.0}; - TestParentSelection parentSelection = new TestParentSelection(parents); - - UniformCrossover sut = new UniformCrossover.Builder().parentSelection(parentSelection).randomGenerator(rng) - .crossoverRate(0.5).parentBiasFactor(0.3).build(); - - CrossoverResult result = sut.crossover(); - - Assertions.assertTrue(result.isModified()); - Assertions.assertEquals(1.0, result.getGenes()[0], 0.0); - Assertions.assertEquals(2.0, result.getGenes()[1], 0.0); - Assertions.assertEquals(1.0, result.getGenes()[2], 0.0); - } - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java deleted file mode 100644 index c5cde76d6..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/LeastFitCullOperatorTests.java +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.culling; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.LeastFitCullOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.List; - -public class LeastFitCullOperatorTests extends BaseDL4JTest { - - @Test - public void LeastFitCullingOperation_ShouldCullLastElements() { - LeastFitCullOperator sut = new LeastFitCullOperator(0.50); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel populationModel = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(10).build(); - sut.initializeInstance(populationModel); - - List originalChromosomes = new ArrayList<>(); - for (int i = 0; i < 10; ++i) { - originalChromosomes.add(new Chromosome(null, (double) i)); - } - - List chromosomes = populationModel.getPopulation(); - for (int i = 0; i < 10; ++i) { - chromosomes.add(originalChromosomes.get(i)); - } - - sut.cullPopulation(); - - Assertions.assertEquals(5, chromosomes.size()); - for (int i = 0; i < 5; ++i) { - Assertions.assertSame(originalChromosomes.get(i), chromosomes.get(i)); - } - } - - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java deleted file mode 100644 index ae09537f6..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/culling/RatioCullOperatorTests.java +++ /dev/null @@ -1,78 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.culling; - -import org.apache.commons.lang3.NotImplementedException; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.RatioCullOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.List; - -public class RatioCullOperatorTests extends BaseDL4JTest { - - class TestRatioCullOperator extends RatioCullOperator { - - public TestRatioCullOperator() { - super(); - } - - public TestRatioCullOperator(double ratio) { - super(ratio); - } - - public List getPopulation() { - return population; - } - - @Override - public void cullPopulation() { - throw new NotImplementedException("Not implemented"); - } - - public double getCullRatio() { - return cullRatio; - } - } - - @Test - public void RatioCullingOperation_ctorWithCullRatio_ShouldHaveParamRatio() { - TestRatioCullOperator sut = new TestRatioCullOperator(0.123); - - Assertions.assertEquals(0.123, sut.getCullRatio(), 0.0); - } - - @Test - public void RatioCullingOperation_initialize_shouldSetCulledSizeAndPopulation() throws IllegalAccessException { - TestRatioCullOperator sut = new TestRatioCullOperator(0.50); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel populationModel = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(10).build(); - sut.initializeInstance(populationModel); - - Assertions.assertSame(populationModel.getPopulation(), sut.getPopulation()); - Assertions.assertEquals(5, sut.getCulledSize()); - } - -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java deleted file mode 100644 index 8b45ec9ad..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/mutation/RandomMutationOperatorTests.java +++ /dev/null @@ -1,73 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.mutation; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.RandomMutationOperator; -import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.lang.reflect.Field; -import java.util.Arrays; - -public class RandomMutationOperatorTests extends BaseDL4JTest { - @Test - public void RandomMutationOperator_DefaultBuild_ShouldNotBeNull() { - RandomMutationOperator sut = new RandomMutationOperator.Builder().build(); - Assertions.assertNotNull(sut); - } - - @Test - public void RandomMutationOperator_BuildWithMutationRate_ShouldUseSuppliedRate() throws Exception { - RandomMutationOperator sut = new RandomMutationOperator.Builder().mutationRate(0.123).build(); - - Field f = sut.getClass().getDeclaredField("mutationRate"); - f.setAccessible(true); - Double mutationRate = (Double) f.get(sut); - - Assertions.assertEquals(0.123, mutationRate, 0.0); - } - - @Test - public void RandomMutationOperator_BelowMutationRate_ShouldNotMutate() { - double[] randomNumbers = new double[] {0.1, 1.0, 1.0}; - - RandomMutationOperator sut = new RandomMutationOperator.Builder().mutationRate(0.1) - .randomGenerator(new TestRandomGenerator(null, randomNumbers)).build(); - - double[] genes = new double[] {-1.0, -1.0, -1.0}; - boolean hasMutated = sut.mutate(genes); - - Assertions.assertFalse(hasMutated); - Assertions.assertTrue(Arrays.equals(new double[] {-1.0, -1.0, -1.0}, genes)); - } - - @Test - public void RandomMutationOperator_AboveMutationRate_ShouldMutate() { - double[] randomNumbers = new double[] {0.099, 0.123, 1.0, 1.0}; - - RandomMutationOperator sut = new RandomMutationOperator.Builder().mutationRate(0.1) - .randomGenerator(new TestRandomGenerator(null, randomNumbers)).build(); - - double[] genes = new double[] {-1.0, -1.0, -1.0}; - boolean hasMutated = sut.mutate(genes); - - Assertions.assertTrue(hasMutated); - Assertions.assertTrue(Arrays.equals(new double[] {0.123, -1.0, -1.0}, genes)); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java deleted file mode 100644 index 914a4be40..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/population/PopulationModelTests.java +++ /dev/null @@ -1,195 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.population; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.Chromosome; -import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationListener; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.List; - -public class PopulationModelTests extends BaseDL4JTest { - - private class TestCullOperator implements CullOperator { - - private final int culledSize; - public boolean hasCulled = false; - - public TestCullOperator(int culledSize) { - this.culledSize = culledSize; - } - - @Override - public void initializeInstance(PopulationModel populationModel) { - - } - - @Override - public void cullPopulation() { - hasCulled = true; - } - - @Override - public int getCulledSize() { - return culledSize; - } - } - - private class TestPopulationListener implements PopulationListener { - - public List population; - - @Override - public void onChanged(List population) { - this.population = population; - } - } - - @Test - public void PopulationModel_IsReadyToBreed_NotReadyToBreed_ShouldReturnFalse() { - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel sut = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(5).cullOperator(new TestCullOperator(2)).build(); - - boolean result = sut.isReadyToBreed(); - - Assertions.assertFalse(result); - } - - @Test - public void PopulationModel_IsReadyToBreed_ReadyToBreed_ShouldReturnTrue() { - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel sut = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(5).cullOperator(new TestCullOperator(1)).build(); - - sut.getPopulation().add(null); - - boolean result = sut.isReadyToBreed(); - - Assertions.assertTrue(result); - } - - @Test - public void PopulationModel_Add_MaximizeScore_ShouldOrderDescendingPopulation() { - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel sut = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(5).cullOperator(new TestCullOperator(2)).build(); - - sut.initializeInstance(false); - - Chromosome[] chromosomes = new Chromosome[3]; - chromosomes[0] = new Chromosome(new double[0], 1.0); - chromosomes[1] = new Chromosome(new double[0], 100.0); - chromosomes[2] = new Chromosome(new double[0], 10.0); - sut.add(chromosomes[0]); - sut.add(chromosomes[1]); - sut.add(chromosomes[2]); - - Assertions.assertSame(chromosomes[1], sut.getPopulation().get(0)); - Assertions.assertSame(chromosomes[2], sut.getPopulation().get(1)); - Assertions.assertSame(chromosomes[0], sut.getPopulation().get(2)); - } - - @Test - public void PopulationModel_Add_MinimizeScore_ShouldOrderAscendingPopulation() { - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel sut = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(5).cullOperator(new TestCullOperator(2)).build(); - - sut.initializeInstance(true); - - Chromosome[] chromosomes = new Chromosome[3]; - chromosomes[0] = new Chromosome(new double[0], 100.0); - chromosomes[1] = new Chromosome(new double[0], 1.0); - chromosomes[2] = new Chromosome(new double[0], 10.0); - sut.add(chromosomes[0]); - sut.add(chromosomes[1]); - sut.add(chromosomes[2]); - - Assertions.assertSame(chromosomes[1], sut.getPopulation().get(0)); - Assertions.assertSame(chromosomes[2], sut.getPopulation().get(1)); - Assertions.assertSame(chromosomes[0], sut.getPopulation().get(2)); - } - - @Test - public void PopulationModel_Add_ShouldTriggerPopulationListeners() { - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel sut = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(5).cullOperator(new TestCullOperator(2)).build(); - - sut.initializeInstance(true); - - TestPopulationListener populationListener = new TestPopulationListener(); - sut.addListener(populationListener); - - sut.add(new Chromosome(new double[0], 100.0)); - - Assertions.assertSame(sut.getPopulation(), populationListener.population); - } - - @Test - public void PopulationModel_Add_BelowPopulationSize_ShouldNotCull() { - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - TestCullOperator cullOperator = new TestCullOperator(3); - - PopulationModel sut = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(5).cullOperator(cullOperator).build(); - - sut.initializeInstance(true); - - sut.add(new Chromosome(new double[0], 1.0)); - sut.add(new Chromosome(new double[0], 2.0)); - sut.add(new Chromosome(new double[0], 3.0)); - sut.add(new Chromosome(new double[0], 4.0)); - sut.add(new Chromosome(new double[0], 5.0)); - - Assertions.assertFalse(cullOperator.hasCulled); - } - - @Test - public void PopulationModel_Add_AbovePopulationSize_ShouldCull() { - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - TestCullOperator cullOperator = new TestCullOperator(3); - - PopulationModel sut = new PopulationModel.Builder().populationInitializer(populationInitializer) - .populationSize(5).cullOperator(cullOperator).build(); - - sut.initializeInstance(true); - - sut.add(new Chromosome(new double[0], 1.0)); - sut.add(new Chromosome(new double[0], 2.0)); - sut.add(new Chromosome(new double[0], 3.0)); - sut.add(new Chromosome(new double[0], 4.0)); - sut.add(new Chromosome(new double[0], 5.0)); - sut.add(new Chromosome(new double[0], 6.0)); - - Assertions.assertTrue(cullOperator.hasCulled); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java deleted file mode 100644 index 4a0b2a498..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/GeneticSelectionOperatorTests.java +++ /dev/null @@ -1,255 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.selection; - -import org.apache.commons.lang3.NotImplementedException; -import org.apache.commons.math3.random.RandomGenerator; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.crossover.CrossoverResult; -import org.deeplearning4j.arbiter.optimize.generator.genetic.culling.CullOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.exceptions.GeneticGenerationException; -import org.deeplearning4j.arbiter.optimize.generator.genetic.mutation.MutationOperator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.GeneticSelectionOperator; -import org.deeplearning4j.arbiter.optimize.genetic.TestCrossoverOperator; -import org.deeplearning4j.arbiter.optimize.genetic.TestMutationOperator; -import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; -import org.deeplearning4j.arbiter.optimize.genetic.TestRandomGenerator; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; - -public class GeneticSelectionOperatorTests extends BaseDL4JTest { - - private class TestCullOperator implements CullOperator { - - private final int culledSize; - - public TestCullOperator(int culledSize) { - - this.culledSize = culledSize; - } - - @Override - public void initializeInstance(PopulationModel populationModel) { - - } - - @Override - public void cullPopulation() { - throw new NotImplementedException("Not implemented"); - } - - @Override - public int getCulledSize() { - return culledSize; - } - } - - private class GeneticSelectionOperatorTestsMutationOperator implements MutationOperator { - - private boolean mutateResult; - - public GeneticSelectionOperatorTestsMutationOperator(boolean mutateResult) { - - this.mutateResult = mutateResult; - } - - @Override - public boolean mutate(double[] genes) { - return mutateResult; - } - } - - private class GeneticSelectionOperatorTestsCrossoverOperator extends CrossoverOperator { - - private CrossoverResult result; - - public GeneticSelectionOperatorTestsCrossoverOperator(CrossoverResult result) { - - this.result = result; - } - - @Override - public CrossoverResult crossover() { - return result; - } - } - - @Test - public void GeneticSelectionOperator_PopulationNotReadyToBreed_ShouldReturnRandomGenes() { - RandomGenerator rng = new TestRandomGenerator(null, new double[] {123.0}); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - TestCullOperator cullOperator = new TestCullOperator(1000); - PopulationModel populationModel = new PopulationModel.Builder().populationInitializer(populationInitializer) - .cullOperator(cullOperator).build(); - ChromosomeFactory chromosomeFactory = new ChromosomeFactory(); - chromosomeFactory.initializeInstance(1); - GeneticSelectionOperator sut = new GeneticSelectionOperator.Builder().randomGenerator(rng).build(); - sut.initializeInstance(populationModel, chromosomeFactory); - - double[] newGenes = sut.buildNextGenes(); - - Assertions.assertEquals(1, newGenes.length); - Assertions.assertEquals(123.0, newGenes[0], 0.0); - } - - @Test - public void GeneticSelectionOperator_NoModificationOnFirstTry() { - RandomGenerator rng = new TestRandomGenerator(null, new double[] {123.0}); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - TestCullOperator cullOperator = new TestCullOperator(-1); - - PopulationModel populationModel = new PopulationModel.Builder().populationInitializer(populationInitializer) - .cullOperator(cullOperator).build(); - - ChromosomeFactory chromosomeFactory = new ChromosomeFactory(); - chromosomeFactory.initializeInstance(1); - - CrossoverResult[] crossoverResults = new CrossoverResult[2]; - crossoverResults[0] = new CrossoverResult(false, new double[0]); - crossoverResults[1] = new CrossoverResult(true, new double[0]); - TestCrossoverOperator crossoverOperator = new TestCrossoverOperator(crossoverResults); - - boolean[] mutationResults = new boolean[] {false, false}; - TestMutationOperator mutationOperator = new TestMutationOperator(mutationResults); - - GeneticSelectionOperator sut = new GeneticSelectionOperator.Builder().randomGenerator(rng) - .crossoverOperator(crossoverOperator).mutationOperator(mutationOperator).build(); - sut.initializeInstance(populationModel, chromosomeFactory); - - double[] newGenes = sut.buildNextGenes(); - - Assertions.assertSame(crossoverResults[1].getGenes(), newGenes); - } - - @Test - public void GeneticSelectionOperator_MutationNoModificationOnFirstTry() { - RandomGenerator rng = new TestRandomGenerator(null, new double[] {123.0}); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - TestCullOperator cullOperator = new TestCullOperator(-1); - - PopulationModel populationModel = new PopulationModel.Builder().populationInitializer(populationInitializer) - .cullOperator(cullOperator).build(); - - ChromosomeFactory chromosomeFactory = new ChromosomeFactory(); - chromosomeFactory.initializeInstance(1); - - CrossoverResult[] crossoverResults = new CrossoverResult[3]; - crossoverResults[0] = new CrossoverResult(false, new double[0]); - crossoverResults[1] = new CrossoverResult(false, new double[0]); - crossoverResults[2] = new CrossoverResult(true, new double[0]); - TestCrossoverOperator crossoverOperator = new TestCrossoverOperator(crossoverResults); - - boolean[] mutationResults = new boolean[] {false, false, true}; - TestMutationOperator mutationOperator = new TestMutationOperator(mutationResults); - - GeneticSelectionOperator sut = new GeneticSelectionOperator.Builder().randomGenerator(rng) - .crossoverOperator(crossoverOperator).mutationOperator(mutationOperator).build(); - sut.initializeInstance(populationModel, chromosomeFactory); - - double[] newGenes = sut.buildNextGenes(); - - Assertions.assertSame(crossoverResults[2].getGenes(), newGenes); - } - - @Test - public void GeneticSelectionOperator_ShouldNotBuildDuplicates() { - RandomGenerator rng = new TestRandomGenerator(null, new double[] {123.0}); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - TestCullOperator cullOperator = new TestCullOperator(-1); - - PopulationModel populationModel = new PopulationModel.Builder().populationInitializer(populationInitializer) - .cullOperator(cullOperator).build(); - - ChromosomeFactory chromosomeFactory = new ChromosomeFactory(); - chromosomeFactory.initializeInstance(1); - - CrossoverResult[] crossoverResults = new CrossoverResult[3]; - crossoverResults[0] = new CrossoverResult(true, new double[] {1.0}); - crossoverResults[1] = new CrossoverResult(true, new double[] {1.0}); - crossoverResults[2] = new CrossoverResult(true, new double[] {2.0}); - TestCrossoverOperator crossoverOperator = new TestCrossoverOperator(crossoverResults); - - boolean[] mutationResults = new boolean[] {false, false, false}; - TestMutationOperator mutationOperator = new TestMutationOperator(mutationResults); - - GeneticSelectionOperator sut = new GeneticSelectionOperator.Builder().randomGenerator(rng) - .crossoverOperator(crossoverOperator).mutationOperator(mutationOperator).build(); - sut.initializeInstance(populationModel, chromosomeFactory); - - double[] newGenes = sut.buildNextGenes(); - assertArrayEquals(crossoverResults[0].getGenes(), newGenes, 1e-6); - - newGenes = sut.buildNextGenes(); - assertArrayEquals(crossoverResults[2].getGenes(), newGenes, 1e-6); - } - - @Test() - public void GeneticSelectionOperator_CrossoverAndMutationCantGenerateNew_ShouldThrow() { - Assertions.assertThrows(GeneticGenerationException.class, () -> { - TestCullOperator cullOperator = new TestCullOperator(-1); - - - PopulationModel populationModel = new PopulationModel.Builder().cullOperator(cullOperator).build(); - - MutationOperator mutationOperator = new GeneticSelectionOperatorTestsMutationOperator(false); - CrossoverOperator crossoverOperator = - new GeneticSelectionOperatorTestsCrossoverOperator(new CrossoverResult(false, null)); - - GeneticSelectionOperator sut = new GeneticSelectionOperator.Builder().crossoverOperator(crossoverOperator) - .mutationOperator(mutationOperator).build(); - sut.initializeInstance(populationModel, null); - - sut.buildNextGenes(); - }); - } - - @Test - public void GeneticSelectionOperator_CrossoverAndMutationAlwaysGenerateSame_ShouldThrow() { - Assertions.assertThrows(GeneticGenerationException.class, () -> { - TestCullOperator cullOperator = new TestCullOperator(-1); - - PopulationModel populationModel = new PopulationModel.Builder().cullOperator(cullOperator).build(); - - MutationOperator mutationOperator = new GeneticSelectionOperatorTestsMutationOperator(false); - CrossoverOperator crossoverOperator = new GeneticSelectionOperatorTestsCrossoverOperator( - new CrossoverResult(true, new double[]{1.0})); - - GeneticSelectionOperator sut = new GeneticSelectionOperator.Builder().crossoverOperator(crossoverOperator) - .mutationOperator(mutationOperator).build(); - sut.initializeInstance(populationModel, null); - - // This call is used to add the genes to the previousGenes collection - sut.buildNextGenes(); - - sut.buildNextGenes(); - }); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java deleted file mode 100644 index 47bb3e37c..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/genetic/selection/SelectionOperatorTests.java +++ /dev/null @@ -1,60 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.genetic.selection; - -import org.apache.commons.lang3.NotImplementedException; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.generator.genetic.ChromosomeFactory; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationInitializer; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.generator.genetic.selection.SelectionOperator; -import org.deeplearning4j.arbiter.optimize.genetic.TestPopulationInitializer; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class SelectionOperatorTests extends BaseDL4JTest { - private class TestSelectionOperator extends SelectionOperator { - - public PopulationModel getPopulationModel() { - return populationModel; - } - - public ChromosomeFactory getChromosomeFactory() { - return chromosomeFactory; - } - - @Override - public double[] buildNextGenes() { - throw new NotImplementedException("Not implemented"); - } - } - - @Test - public void SelectionOperator_InitializeInstance_ShouldInitializeFields() { - TestSelectionOperator sut = new TestSelectionOperator(); - - PopulationInitializer populationInitializer = new TestPopulationInitializer(); - - PopulationModel populationModel = - new PopulationModel.Builder().populationInitializer(populationInitializer).build(); - ChromosomeFactory chromosomeFactory = new ChromosomeFactory(); - sut.initializeInstance(populationModel, chromosomeFactory); - - Assertions.assertSame(populationModel, sut.getPopulationModel()); - Assertions.assertSame(chromosomeFactory, sut.getChromosomeFactory()); - } -} diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java deleted file mode 100644 index 5f477018c..000000000 --- a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/parameter/TestParameterSpaces.java +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize.parameter; - -import org.apache.commons.math3.distribution.NormalDistribution; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.junit.jupiter.api.Test; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -public class TestParameterSpaces extends BaseDL4JTest { - - - @Test - public void testContinuousParameterSpace() { - - ContinuousParameterSpace cps = new ContinuousParameterSpace(0, 1); - cps.setIndices(0); - - for (int i = 0; i < 10; i++) { - double d = i / 10.0; - assertEquals(d, cps.getValue(new double[]{d}), 0.0); - } - - cps = new ContinuousParameterSpace(10, 20); - cps.setIndices(0); - - for (int i = 0; i < 10; i++) { - double d = i / 10.0; - double exp = d * 10 + 10; - assertEquals(exp, cps.getValue(new double[]{d}), 0.0); - } - - - cps = new ContinuousParameterSpace(new NormalDistribution(0, 1)); - NormalDistribution nd = new NormalDistribution(0, 1); - cps.setIndices(0); - for (int i = 0; i < 11; i++) { - double d = i / 10.0; - assertEquals(nd.inverseCumulativeProbability(d), cps.getValue(new double[]{d}), 1e-4); - } - } - - @Test - public void testDiscreteParameterSpace() { - ParameterSpace dps = new DiscreteParameterSpace<>(0, 1, 2, 3, 4); - dps.setIndices(0); - - for (int i = 0; i < 5; i++) { - double d = i / 5.0 + 0.1; //Center - double dEdgeLower = i / 5.0 + 1e-8; //Edge case: just above split threshold - double dEdgeUpper = (i + 1) / 5.0 - 1e-8; //Edge case: just below split threshold - assertEquals(i, (int) dps.getValue(new double[]{d})); - assertEquals(i, (int) dps.getValue(new double[]{dEdgeLower})); - assertEquals(i, (int) dps.getValue(new double[]{dEdgeUpper})); - } - } - - @Test - public void testIntegerParameterSpace() { - ParameterSpace ips = new IntegerParameterSpace(0, 4); - ips.setIndices(0); - - for (int i = 0; i < 5; i++) { - double d = i / 5.0 + 0.1; //Center - double dEdgeLower = i / 5.0 + 1e-8; //Edge case: just above split threshold - double dEdgeUpper = (i + 1) / 5.0 - 1e-8; //Edge case: just below split threshold - assertEquals(i, (int) ips.getValue(new double[]{d})); - assertEquals(i, (int) ips.getValue(new double[]{dEdgeLower})); - assertEquals(i, (int) ips.getValue(new double[]{dEdgeUpper})); - } - } - - @Test - public void testBooleanSpace() { - ParameterSpace bSpace = new BooleanSpace(); - bSpace.setIndices(1); //randomly setting to non zero - - assertEquals(true, (boolean) bSpace.getValue(new double[]{0.0, 0.0})); - assertEquals(true, (boolean) bSpace.getValue(new double[]{0.1, 0.5})); - assertEquals(false, (boolean) bSpace.getValue(new double[]{0.2, 0.7})); - assertEquals(false, (boolean) bSpace.getValue(new double[]{0.3, 1.0})); - } - -} diff --git a/arbiter/arbiter-core/src/test/resources/logback.xml b/arbiter/arbiter-core/src/test/resources/logback.xml deleted file mode 100644 index 410bdaae9..000000000 --- a/arbiter/arbiter-core/src/test/resources/logback.xml +++ /dev/null @@ -1,51 +0,0 @@ - - - - - - logs/application.log - - %date - [%level] - from %logger in %thread - %n%message%n%xException%n - - - - - - %logger{15} - %message%n%xException{5} - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/arbiter/arbiter-deeplearning4j/pom.xml b/arbiter/arbiter-deeplearning4j/pom.xml deleted file mode 100644 index 2f7e202a3..000000000 --- a/arbiter/arbiter-deeplearning4j/pom.xml +++ /dev/null @@ -1,78 +0,0 @@ - - - - - arbiter - net.brutex.ai - 1.0.0-SNAPSHOT - - 4.0.0 - - arbiter-deeplearning4j - - - - - net.brutex.ai - arbiter-core - ${project.version} - - - - net.brutex.ai - deeplearning4j-core - ${project.version} - - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.databind.version} - - - - com.google.code.gson - gson - ${gson.version} - - - - net.brutex.ai - deeplearning4j-common-tests - ${project.version} - test - - - - diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/BaseNetworkSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/BaseNetworkSpace.java deleted file mode 100644 index 69621330d..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/BaseNetworkSpace.java +++ /dev/null @@ -1,615 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.adapter.ActivationParameterSpaceAdapter; -import org.deeplearning4j.arbiter.conf.dropout.DropoutSpace; -import org.deeplearning4j.arbiter.layers.LayerSpace; -import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.*; -import org.deeplearning4j.nn.conf.distribution.Distribution; -import org.deeplearning4j.nn.conf.dropout.Dropout; -import org.deeplearning4j.nn.conf.dropout.IDropout; -import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; -import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.learning.config.IUpdater; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.core.JsonProcessingException; - -import java.util.*; - -/** - * This is an abstract ParameterSpace for both MultiLayerNetworks (MultiLayerSpace) and ComputationGraph (ComputationGraphSpace) - *

      - * Functionality here should match {@link org.deeplearning4j.nn.conf.NeuralNetConfiguration.Builder} - * - * @param Type of network (MultiLayerNetwork or ComputationGraph) - * @author Alex Black - */ -@EqualsAndHashCode(callSuper = false) -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -@Data -public abstract class BaseNetworkSpace extends AbstractParameterSpace { - - protected Long seed; - protected ParameterSpace optimizationAlgo; - protected ParameterSpace activationFunction; - protected ParameterSpace biasInit; - protected ParameterSpace weightInit; - protected ParameterSpace dist; - protected ParameterSpace maxNumLineSearchIterations; - protected ParameterSpace miniBatch; - protected ParameterSpace minimize; - protected ParameterSpace stepFunction; - protected ParameterSpace l1; - protected ParameterSpace l2; - protected ParameterSpace l1Bias; - protected ParameterSpace l2Bias; - protected ParameterSpace updater; - protected ParameterSpace biasUpdater; - protected ParameterSpace weightNoise; - private ParameterSpace dropout; - protected ParameterSpace gradientNormalization; - protected ParameterSpace gradientNormalizationThreshold; - protected ParameterSpace convolutionMode; - - protected List layerSpaces = new ArrayList<>(); - - //NeuralNetConfiguration.ListBuilder/MultiLayerConfiguration.Builder options: - protected ParameterSpace backpropType; - protected ParameterSpace tbpttFwdLength; - protected ParameterSpace tbpttBwdLength; - - protected ParameterSpace> allParamConstraints; - protected ParameterSpace> weightConstraints; - protected ParameterSpace> biasConstraints; - - protected int numEpochs = 1; - - - static { - JsonMapper.getMapper().registerSubtypes(ComputationGraphSpace.class, MultiLayerSpace.class); - YamlMapper.getMapper().registerSubtypes(ComputationGraphSpace.class, MultiLayerSpace.class); - } - - @SuppressWarnings("unchecked") - protected BaseNetworkSpace(Builder builder) { - this.seed = builder.seed; - this.optimizationAlgo = builder.optimizationAlgo; - this.activationFunction = builder.activationFunction; - this.biasInit = builder.biasInit; - this.weightInit = builder.weightInit; - this.dist = builder.dist; - this.maxNumLineSearchIterations = builder.maxNumLineSearchIterations; - this.miniBatch = builder.miniBatch; - this.minimize = builder.minimize; - this.stepFunction = builder.stepFunction; - this.l1 = builder.l1; - this.l2 = builder.l2; - this.l1Bias = builder.l1Bias; - this.l2Bias = builder.l2Bias; - this.updater = builder.updater; - this.biasUpdater = builder.biasUpdater; - this.weightNoise = builder.weightNoise; - this.dropout = builder.dropout; - this.gradientNormalization = builder.gradientNormalization; - this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold; - this.convolutionMode = builder.convolutionMode; - this.allParamConstraints = builder.allParamConstraints; - this.weightConstraints = builder.weightConstraints; - this.biasConstraints = builder.biasConstraints; - - this.backpropType = builder.backpropType; - this.tbpttFwdLength = builder.tbpttFwdLength; - this.tbpttBwdLength = builder.tbpttBwdLength; - - this.numEpochs = builder.numEpochs; - } - - protected BaseNetworkSpace() { - //Default constructor for Jackson json/yaml serialization - } - - - protected NeuralNetConfiguration.Builder randomGlobalConf(double[] values) { - //Create MultiLayerConfiguration... - NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder(); - if (seed != null) - builder.seed(seed); - if (optimizationAlgo != null) - builder.optimizationAlgo(optimizationAlgo.getValue(values)); - if (activationFunction != null) - builder.activation(activationFunction.getValue(values)); - if (biasInit != null) - builder.biasInit(biasInit.getValue(values)); - if (weightInit != null) - builder.weightInit(weightInit.getValue(values)); - if (dist != null) - builder.dist(dist.getValue(values)); - if (maxNumLineSearchIterations != null) - builder.maxNumLineSearchIterations(maxNumLineSearchIterations.getValue(values)); - if (miniBatch != null) - builder.miniBatch(miniBatch.getValue(values)); - if (minimize != null) - builder.minimize(minimize.getValue(values)); - if (stepFunction != null) - builder.stepFunction(stepFunction.getValue(values)); - if (l1 != null) - builder.l1(l1.getValue(values)); - if (l2 != null) - builder.l2(l2.getValue(values)); - if (l1Bias != null) - builder.l1Bias(l1Bias.getValue(values)); - if (l2Bias != null) - builder.l2Bias(l2Bias.getValue(values)); - if (updater != null) - builder.updater(updater.getValue(values)); - if (biasUpdater != null) - builder.biasUpdater(biasUpdater.getValue(values)); - if (weightNoise != null) - builder.weightNoise(weightNoise.getValue(values)); - if (dropout != null) - builder.dropOut(dropout.getValue(values)); - if (gradientNormalization != null) - builder.gradientNormalization(gradientNormalization.getValue(values)); - if (gradientNormalizationThreshold != null) - builder.gradientNormalizationThreshold(gradientNormalizationThreshold.getValue(values)); - if (convolutionMode != null) - builder.convolutionMode(convolutionMode.getValue(values)); - if (allParamConstraints != null){ - List c = allParamConstraints.getValue(values); - if(c != null){ - builder.constrainAllParameters(c.toArray(new LayerConstraint[c.size()])); - } - } - if (weightConstraints != null){ - List c = weightConstraints.getValue(values); - if(c != null){ - builder.constrainWeights(c.toArray(new LayerConstraint[c.size()])); - } - } - if (biasConstraints != null){ - List c = biasConstraints.getValue(values); - if(c != null){ - builder.constrainBias(c.toArray(new LayerConstraint[c.size()])); - } - } - - return builder; - } - - @Override - public List collectLeaves() { - Map global = getNestedSpaces(); - //Note: Results on previous line does NOT include the LayerSpaces, therefore we need to add these manually... - //This is because the type is a list, not a ParameterSpace - LinkedList stack = new LinkedList<>(); - stack.add(this); - - for (LayerConf layerConf : layerSpaces) { - LayerSpace ls = layerConf.getLayerSpace(); - stack.addAll(ls.collectLeaves()); - } - - List out = new ArrayList<>(); - while (!stack.isEmpty()) { - ParameterSpace next = stack.removeLast(); - if (next.isLeaf()) { - out.add(next); - } else { - Map m = next.getNestedSpaces(); - ParameterSpace[] arr = m.values().toArray(new ParameterSpace[m.size()]); - for (int i = arr.length - 1; i >= 0; i--) { - stack.add(arr[i]); - } - } - } - return out; - } - - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - throw new UnsupportedOperationException("Cannot set indices for non leaf"); - } - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(); - - for (Map.Entry e : getNestedSpaces().entrySet()) { - sb.append(e.getKey()).append(": ").append(e.getValue()).append("\n"); - } - - int i = 0; - for (LayerConf conf : layerSpaces) { - - sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers) - .append(", duplicate: ").append(conf.duplicateConfig).append("), ") - .append(conf.layerSpace.toString()).append("\n"); - } - - - return sb.toString(); - } - - @AllArgsConstructor - @Data - @NoArgsConstructor - public static class LayerConf { - protected LayerSpace layerSpace; - protected String layerName; - protected String[] inputs; - protected ParameterSpace numLayers; - protected boolean duplicateConfig; - protected InputPreProcessor preProcessor; - } - - @SuppressWarnings("unchecked") - protected abstract static class Builder> { - private Long seed; - private ParameterSpace optimizationAlgo; - private ParameterSpace activationFunction; - private ParameterSpace biasInit; - private ParameterSpace weightInit; - private ParameterSpace dist; - private ParameterSpace maxNumLineSearchIterations; - private ParameterSpace miniBatch; - private ParameterSpace minimize; - private ParameterSpace stepFunction; - private ParameterSpace l1; - private ParameterSpace l2; - private ParameterSpace l1Bias; - private ParameterSpace l2Bias; - private ParameterSpace updater; - private ParameterSpace biasUpdater; - private ParameterSpace weightNoise; - private ParameterSpace dropout; - private ParameterSpace gradientNormalization; - private ParameterSpace gradientNormalizationThreshold; - private ParameterSpace convolutionMode; - - private ParameterSpace> allParamConstraints; - private ParameterSpace> weightConstraints; - private ParameterSpace> biasConstraints; - - //NeuralNetConfiguration.ListBuilder/MultiLayerConfiguration.Builder options: - private ParameterSpace backpropType; - private ParameterSpace tbpttFwdLength; - private ParameterSpace tbpttBwdLength; - - //Early stopping configuration / (fixed) number of epochs: - private EarlyStoppingConfiguration earlyStoppingConfiguration; - private int numEpochs = 1; - - protected boolean validateOutputLayerConfig = true; - - public T seed(long seed) { - this.seed = seed; - return (T) this; - } - - public T optimizationAlgo(OptimizationAlgorithm optimizationAlgorithm) { - return optimizationAlgo(new FixedValue<>(optimizationAlgorithm)); - } - - public T optimizationAlgo(ParameterSpace parameterSpace) { - this.optimizationAlgo = parameterSpace; - return (T) this; - } - - - public T activation(Activation activationFunction) { - return activation(new FixedValue<>(activationFunction)); - } - - public T activation(ParameterSpace activationFunction) { - return activationFn(new ActivationParameterSpaceAdapter(activationFunction)); - } - - public T activationFn(ParameterSpace activationFunction) { - this.activationFunction = activationFunction; - return (T) this; - } - - public T biasInit(double biasInit){ - return biasInit(new FixedValue<>(biasInit)); - } - - public T biasInit(ParameterSpace biasInit){ - this.biasInit = biasInit; - return (T) this; - } - - public T weightInit(WeightInit weightInit) { - return weightInit(new FixedValue<>(weightInit)); - } - - public T weightInit(ParameterSpace weightInit) { - this.weightInit = weightInit; - return (T) this; - } - - public T dist(Distribution dist) { - return dist(new FixedValue<>(dist)); - } - - public T dist(ParameterSpace dist) { - this.dist = dist; - return (T) this; - } - - public T maxNumLineSearchIterations(int maxNumLineSearchIterations) { - return maxNumLineSearchIterations(new FixedValue<>(maxNumLineSearchIterations)); - } - - public T maxNumLineSearchIterations(ParameterSpace maxNumLineSearchIterations) { - this.maxNumLineSearchIterations = maxNumLineSearchIterations; - return (T) this; - } - - public T miniBatch(boolean minibatch) { - return miniBatch(new FixedValue<>(minibatch)); - } - - public T miniBatch(ParameterSpace miniBatch) { - this.miniBatch = miniBatch; - return (T) this; - } - - public T minimize(boolean minimize) { - return minimize(new FixedValue<>(minimize)); - } - - public T minimize(ParameterSpace minimize) { - this.minimize = minimize; - return (T) this; - } - - public T stepFunction(StepFunction stepFunction) { - return stepFunction(new FixedValue<>(stepFunction)); - } - - public T stepFunction(ParameterSpace stepFunction) { - this.stepFunction = stepFunction; - return (T) this; - } - - public T l1(double l1) { - return l1(new FixedValue<>(l1)); - } - - public T l1(ParameterSpace l1) { - this.l1 = l1; - return (T) this; - } - - public T l2(double l2) { - return l2(new FixedValue<>(l2)); - } - - public T l2(ParameterSpace l2) { - this.l2 = l2; - return (T) this; - } - public T l1Bias(double l1Bias) { - return l1Bias(new FixedValue<>(l1Bias)); - } - - public T l1Bias(ParameterSpace l1Bias) { - this.l1Bias = l1Bias; - return (T) this; - } - - public T l2Bias(double l2Bias) { - return l2Bias(new FixedValue<>(l2Bias)); - } - - public T l2Bias(ParameterSpace l2Bias) { - this.l2Bias = l2Bias; - return (T) this; - } - - public T updater(IUpdater updater){ - return updater(new FixedValue<>(updater)); - } - - public T updater(ParameterSpace updater) { - this.updater = updater; - return (T) this; - } - - public T biasUpdater(IUpdater biasUpdater){ - return biasUpdater(new FixedValue<>(biasUpdater)); - } - - public T biasUpdater(ParameterSpace biasUpdater){ - this.biasUpdater = biasUpdater; - return (T)this; - } - - public T weightNoise(IWeightNoise weightNoise){ - return weightNoise(new FixedValue<>(weightNoise)); - } - - public T weightNoise(ParameterSpace weightNoise){ - this.weightNoise = weightNoise; - return (T) this; - } - - public T dropOut(double dropout){ - return idropOut(new Dropout(dropout)); - } - - public T dropOut(ParameterSpace dropOut){ - return idropOut(new DropoutSpace(dropOut)); - } - - public T idropOut(IDropout idropOut){ - return idropOut(new FixedValue<>(idropOut)); - } - - public T idropOut(ParameterSpace idropOut){ - this.dropout = idropOut; - return (T) this; - } - - public T gradientNormalization(GradientNormalization gradientNormalization) { - return gradientNormalization(new FixedValue<>(gradientNormalization)); - } - - public T gradientNormalization(ParameterSpace gradientNormalization) { - this.gradientNormalization = gradientNormalization; - return (T) this; - } - - public T gradientNormalizationThreshold(double threshold) { - return gradientNormalizationThreshold(new FixedValue<>(threshold)); - } - - public T gradientNormalizationThreshold(ParameterSpace gradientNormalizationThreshold) { - this.gradientNormalizationThreshold = gradientNormalizationThreshold; - return (T) this; - } - - public T convolutionMode(ConvolutionMode convolutionMode) { - return convolutionMode(new FixedValue(convolutionMode)); - } - - public T convolutionMode(ParameterSpace convolutionMode) { - this.convolutionMode = convolutionMode; - return (T) this; - } - - public T backpropType(BackpropType backpropType) { - return backpropType(new FixedValue<>(backpropType)); - } - - public T backpropType(ParameterSpace backpropType) { - this.backpropType = backpropType; - return (T) this; - } - - public T tbpttFwdLength(int tbpttFwdLength) { - return tbpttFwdLength(new FixedValue<>(tbpttFwdLength)); - } - - public T tbpttFwdLength(ParameterSpace tbpttFwdLength) { - this.tbpttFwdLength = tbpttFwdLength; - return (T) this; - } - - public T tbpttBwdLength(int tbpttBwdLength) { - return tbpttBwdLength(new FixedValue<>(tbpttBwdLength)); - } - - public T tbpttBwdLength(ParameterSpace tbpttBwdLength) { - this.tbpttBwdLength = tbpttBwdLength; - return (T) this; - } - - public T constrainWeights(LayerConstraint... constraints){ - return constrainWeights(new FixedValue>(Arrays.asList(constraints))); - } - - public T constrainWeights(ParameterSpace> constraints){ - this.weightConstraints = constraints; - return (T) this; - } - - public T constrainBias(LayerConstraint... constraints){ - return constrainBias(new FixedValue>(Arrays.asList(constraints))); - } - - public T constrainBias(ParameterSpace> constraints){ - this.biasConstraints = constraints; - return (T) this; - } - - public T constrainAllParams(LayerConstraint... constraints){ - return constrainAllParams(new FixedValue>(Arrays.asList(constraints))); - } - - public T constrainAllParams(ParameterSpace> constraints){ - this.allParamConstraints = constraints; - return (T) this; - } - - public T validateOutputLayerConfig(boolean validate){ - this.validateOutputLayerConfig = validate; - return (T) this; - } - - /** - * Fixed number of training epochs. Default: 1 - * Note if both EarlyStoppingConfiguration and number of epochs is present, early stopping will be used in preference. - */ - public T numEpochs(int numEpochs) { - this.numEpochs = numEpochs; - return (T) this; - } - - - public abstract E build(); - } - - /** - * Return a json configuration of this configuration space. - * - * @return - */ - public String toJson() { - try { - return JsonMapper.getMapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Return a yaml configuration of this configuration space. - * - * @return - */ - public String toYaml() { - try { - return YamlMapper.getMapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/ComputationGraphSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/ComputationGraphSpace.java deleted file mode 100644 index 369300829..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/ComputationGraphSpace.java +++ /dev/null @@ -1,316 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter; - -import lombok.*; -import org.deeplearning4j.arbiter.layers.LayerSpace; -import org.deeplearning4j.arbiter.layers.fixed.FixedLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.TaskCreatorProvider; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper; -import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.deeplearning4j.nn.conf.graph.GraphVertex; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; -import org.deeplearning4j.nn.graph.ComputationGraph; -import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.annotation.JsonTypeName; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** - * ComputationGraphSpace: Defines the space of valid hyperparameters for a ComputationGraph. - * Note that this for fixed graph structures only - * - * @author Alex Black - */ -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON ser/de -@Data -@EqualsAndHashCode(callSuper = true) -@JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "@class") -@JsonTypeName("ComputationGraphSpace") -public class ComputationGraphSpace extends BaseNetworkSpace { - static { - TaskCreatorProvider.registerDefaultTaskCreatorClass(ComputationGraphSpace.class, ComputationGraphTaskCreator.class); - } - - @JsonProperty - protected List layerSpaces = new ArrayList<>(); - @JsonProperty - protected List vertices = new ArrayList<>(); - @JsonProperty - protected String[] networkInputs; - @JsonProperty - protected String[] networkOutputs; - @JsonProperty - protected ParameterSpace inputTypes; - @JsonProperty - protected int numParameters; - @JsonProperty - protected WorkspaceMode trainingWorkspaceMode; - @JsonProperty - protected WorkspaceMode inferenceWorkspaceMode; - @JsonProperty - protected boolean validateOutputLayerConfig = true; - - //Early stopping configuration / (fixed) number of epochs: - protected EarlyStoppingConfiguration earlyStoppingConfiguration; - - protected ComputationGraphSpace(Builder builder) { - super(builder); - - this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration; - this.layerSpaces = builder.layerList; - this.vertices = builder.vertexList; - - this.networkInputs = builder.networkInputs; - this.networkOutputs = builder.networkOutputs; - this.inputTypes = builder.inputTypes; - this.trainingWorkspaceMode = builder.trainingWorkspaceMode; - this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode; - this.validateOutputLayerConfig = builder.validateOutputLayerConfig; - - //Determine total number of parameters: - List list = LeafUtils.getUniqueObjects(collectLeaves()); - for (ParameterSpace ps : list) - numParameters += ps.numParameters(); - } - - - @Override - public GraphConfiguration getValue(double[] values) { - //Create ComputationGraphConfiguration... - NeuralNetConfiguration.Builder builder = randomGlobalConf(values); - - ComputationGraphConfiguration.GraphBuilder graphBuilder = builder.graphBuilder(); - graphBuilder.addInputs(this.networkInputs); - graphBuilder.setOutputs(this.networkOutputs); - if (inputTypes != null) - graphBuilder.setInputTypes(inputTypes.getValue(values)); - - //Build/add our layers and vertices: - for (LayerConf c : layerSpaces) { - org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values); - graphBuilder.addLayer(c.getLayerName(), l, c.getPreProcessor(), c.getInputs()); - } - for (VertexConf gv : vertices) { - graphBuilder.addVertex(gv.getVertexName(), gv.getGraphVertex(), gv.getInputs()); - } - - if (backpropType != null) - graphBuilder.backpropType(backpropType.getValue(values)); - if (tbpttFwdLength != null) - graphBuilder.tBPTTForwardLength(tbpttFwdLength.getValue(values)); - if (tbpttBwdLength != null) - graphBuilder.tBPTTBackwardLength(tbpttBwdLength.getValue(values)); - graphBuilder.validateOutputLayerConfig(validateOutputLayerConfig); - - ComputationGraphConfiguration configuration = graphBuilder.build(); - - if (trainingWorkspaceMode != null) - configuration.setTrainingWorkspaceMode(trainingWorkspaceMode); - if (inferenceWorkspaceMode != null) - configuration.setInferenceWorkspaceMode(inferenceWorkspaceMode); - - return new GraphConfiguration(configuration, earlyStoppingConfiguration, numEpochs); - } - - @Override - public int numParameters() { - return numParameters; - } - - @Override - public List collectLeaves() { - List list = super.collectLeaves(); - for (LayerConf lc : layerSpaces) { - list.addAll(lc.layerSpace.collectLeaves()); - } - if (inputTypes != null) - list.add(inputTypes); - return list; - } - - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(super.toString()); - - for (LayerConf conf : layerSpaces) { - sb.append("Layer config: \"").append(conf.layerName).append("\", ").append(conf.layerSpace) - .append(", inputs: ").append(conf.inputs == null ? "[]" : Arrays.toString(conf.inputs)) - .append("\n"); - } - - for (VertexConf conf : vertices) { - sb.append("GraphVertex: \"").append(conf.vertexName).append("\", ").append(conf.graphVertex) - .append(", inputs: ").append(conf.inputs == null ? "[]" : Arrays.toString(conf.inputs)) - .append("\n"); - } - - if (earlyStoppingConfiguration != null) { - sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n"); - } else { - sb.append("Training # epochs:").append(numEpochs).append("\n"); - } - - if (inputTypes != null) { - sb.append("Input types: ").append(inputTypes).append("\n"); - } - - return sb.toString(); - } - - @AllArgsConstructor - @Data - @NoArgsConstructor //For Jackson JSON - protected static class VertexConf { - protected GraphVertex graphVertex; - protected String vertexName; - protected String[] inputs; - } - - public static class Builder extends BaseNetworkSpace.Builder { - - protected List layerList = new ArrayList<>(); - protected List vertexList = new ArrayList<>(); - protected EarlyStoppingConfiguration earlyStoppingConfiguration; - protected String[] networkInputs; - protected String[] networkOutputs; - protected ParameterSpace inputTypes; - protected WorkspaceMode trainingWorkspaceMode; - protected WorkspaceMode inferenceWorkspaceMode; - - //Need: input types - //Early stopping configuration - //Graph nodes - - /** - * Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is - * present, early stopping will be used in preference. - */ - public Builder earlyStoppingConfiguration( - EarlyStoppingConfiguration earlyStoppingConfiguration) { - this.earlyStoppingConfiguration = earlyStoppingConfiguration; - return this; - } - - public Builder layer(String layerName, LayerSpace layerSpace, String... layerInputs){ - return addLayer(layerName, layerSpace, layerInputs); - } - - public Builder layer(String layerName, LayerSpace layerSpace, InputPreProcessor preProcessor, - String... layerInputs) { - return addLayer(layerName, layerSpace, preProcessor, layerInputs); - } - - public Builder layer(String layerName, Layer layer, String... layerInputs){ - return layer(layerName, new FixedLayerSpace<>(layer), layerInputs); - } - - public Builder addLayer(String layerName, LayerSpace layerSpace, String... layerInputs) { - layerList.add(new LayerConf(layerSpace, layerName, layerInputs, new FixedValue<>(1), false, null)); - return this; - } - - public Builder addLayer(String layerName, LayerSpace layerSpace, InputPreProcessor preProcessor, - String... layerInputs){ - layerList.add(new LayerConf(layerSpace, layerName, layerInputs, new FixedValue<>(1), false, preProcessor)); - return this; - } - - public Builder addVertex(String vertexName, GraphVertex vertex, String... vertexInputs) { - vertexList.add(new VertexConf(vertex, vertexName, vertexInputs)); - return this; - } - - public Builder addInputs(String... networkInputs) { - this.networkInputs = networkInputs; - return this; - } - - public Builder setOutputs(String... networkOutputs) { - this.networkOutputs = networkOutputs; - return this; - } - - public Builder setInputTypes(InputType... inputTypes) { - return setInputTypes(new FixedValue(inputTypes)); - } - - public Builder setInputTypes(ParameterSpace inputTypes) { - this.inputTypes = inputTypes; - return this; - } - - public Builder trainingWorkspaceMode(WorkspaceMode workspaceMode){ - this.trainingWorkspaceMode = workspaceMode; - return this; - } - - public Builder inferenceWorkspaceMode(WorkspaceMode workspaceMode){ - this.inferenceWorkspaceMode = workspaceMode; - return this; - } - - @SuppressWarnings("unchecked") - public ComputationGraphSpace build() { - return new ComputationGraphSpace(this); - } - } - - - /** - * Instantiate a computation graph space from - * a raw json string - * @param json - * @return - */ - public static ComputationGraphSpace fromJson(String json) { - try { - return JsonMapper.getMapper().readValue(json, ComputationGraphSpace.class); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - /** - * Instantiate a computation graph space - * from a raw yaml string - * @param yaml - * @return - */ - public static ComputationGraphSpace fromYaml(String yaml) { - try { - return YamlMapper.getMapper().readValue(yaml, ComputationGraphSpace.class); - } catch (IOException e) { - throw new RuntimeException(e); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/DL4JConfiguration.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/DL4JConfiguration.java deleted file mode 100644 index 15eb7e3ba..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/DL4JConfiguration.java +++ /dev/null @@ -1,73 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter; - -import lombok.AllArgsConstructor; -import lombok.Data; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.fasterxml.jackson.databind.annotation.JsonSerialize; - -import java.io.Serializable; - -/** - * DL4JConfiguration: simple configuration method that contains the following:
      - * - MultiLayerConfiguration
      - * - Early stopping settings, OR number of epochs
      - * Note: if early stopping configuration is absent, a fixed number of epochs (default: 1) will be used. - * If both early stopping and number of epochs is present: early stopping will be used. - */ -@AllArgsConstructor -@Data -public class DL4JConfiguration implements Serializable { - @JsonSerialize - private MultiLayerConfiguration multiLayerConfiguration; - @JsonSerialize - private EarlyStoppingConfiguration earlyStoppingConfiguration; - @JsonSerialize - private Integer numEpochs; - - - /** - * Yaml mapping - * @return - */ - public String toYaml() { - try { - return YamlMapper.getMapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Json mapping - * @return - */ - public String toJson() { - try { - return JsonMapper.getMapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/GraphConfiguration.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/GraphConfiguration.java deleted file mode 100644 index 4cb6cf685..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/GraphConfiguration.java +++ /dev/null @@ -1,67 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter; - - -import lombok.AllArgsConstructor; -import lombok.Data; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.graph.ComputationGraph; -import com.fasterxml.jackson.core.JsonProcessingException; - -import java.io.Serializable; - -/** - * Analogous to {@link DL4JConfiguration}, GraphConfiguration includes a configuration for ComputationGraphs, as well - * as early stopping (or, optionally numEpochs) fields. - */ -@AllArgsConstructor -@Data -public class GraphConfiguration implements Serializable { - private ComputationGraphConfiguration configuration; - private EarlyStoppingConfiguration earlyStoppingConfiguration; - private Integer numEpochs; - - - - /** - * Yaml mapping - * @return - */ - public String toYaml() { - try { - return YamlMapper.getMapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Json mapping - * @return - */ - public String toJson() { - try { - return JsonMapper.getMapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/MultiLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/MultiLayerSpace.java deleted file mode 100644 index beeb0420b..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/MultiLayerSpace.java +++ /dev/null @@ -1,320 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.layers.LayerSpace; -import org.deeplearning4j.arbiter.layers.fixed.FixedLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.TaskCreatorProvider; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import org.deeplearning4j.arbiter.optimize.serde.jackson.YamlMapper; -import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -@Data -@EqualsAndHashCode(callSuper = true) -public class MultiLayerSpace extends BaseNetworkSpace { - - static { - TaskCreatorProvider.registerDefaultTaskCreatorClass(MultiLayerSpace.class, MultiLayerNetworkTaskCreator.class); - } - - @JsonProperty - protected ParameterSpace inputType; - @JsonProperty - protected ParameterSpace> inputPreProcessors; - - //Early stopping configuration / (fixed) number of epochs: - @JsonProperty - protected EarlyStoppingConfiguration earlyStoppingConfiguration; - @JsonProperty - protected int numParameters; - @JsonProperty - protected WorkspaceMode trainingWorkspaceMode; - @JsonProperty - protected WorkspaceMode inferenceWorkspaceMode; - @JsonProperty - protected boolean validateOutputLayerConfig = true; - - - protected MultiLayerSpace(Builder builder) { - super(builder); - this.inputType = builder.inputType; - this.inputPreProcessors = builder.inputPreProcessors; - - this.earlyStoppingConfiguration = builder.earlyStoppingConfiguration; - - this.layerSpaces = builder.layerSpaces; - - //Determine total number of parameters: - //Collect the leaves, and make sure they are unique. - //Note that the *object instances* must be unique - and consequently we don't want to use .equals(), as - // this would incorrectly filter out equal range parameter spaces - List allLeaves = collectLeaves(); - List list = LeafUtils.getUniqueObjects(allLeaves); - - for (ParameterSpace ps : list) { - int n = ps.numParameters(); - numParameters += ps.numParameters(); - } - - this.trainingWorkspaceMode = builder.trainingWorkspaceMode; - this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode; - this.validateOutputLayerConfig = builder.validateOutputLayerConfig; - } - - protected MultiLayerSpace() { - //Default constructor for Jackson json/yaml serialization - } - - @Override - public DL4JConfiguration getValue(double[] values) { - //First: create layer configs - List layers = new ArrayList<>(); - for (LayerConf c : layerSpaces) { - int n = c.numLayers.getValue(values); - if (c.duplicateConfig) { - //Generate N identical configs - org.deeplearning4j.nn.conf.layers.Layer l = c.layerSpace.getValue(values); - for (int i = 0; i < n; i++) { - layers.add(l.clone()); - } - } else { - throw new UnsupportedOperationException("Not yet implemented"); - } - } - - //Create MultiLayerConfiguration... - NeuralNetConfiguration.Builder builder = randomGlobalConf(values); - - NeuralNetConfiguration.ListBuilder listBuilder = builder.list(); - for (int i = 0; i < layers.size(); i++) { - listBuilder.layer(i, layers.get(i)); - } - - if (backpropType != null) - listBuilder.backpropType(backpropType.getValue(values)); - if (tbpttFwdLength != null) - listBuilder.tBPTTForwardLength(tbpttFwdLength.getValue(values)); - if (tbpttBwdLength != null) - listBuilder.tBPTTBackwardLength(tbpttBwdLength.getValue(values)); - if (inputType != null) - listBuilder.setInputType(inputType.getValue(values)); - if (inputPreProcessors != null) - listBuilder.setInputPreProcessors(inputPreProcessors.getValue(values)); - listBuilder.validateOutputLayerConfig(validateOutputLayerConfig); - - MultiLayerConfiguration configuration = listBuilder.build(); - - if (trainingWorkspaceMode != null) - configuration.setTrainingWorkspaceMode(trainingWorkspaceMode); - if (inferenceWorkspaceMode != null) - configuration.setInferenceWorkspaceMode(inferenceWorkspaceMode); - - - return new DL4JConfiguration(configuration, earlyStoppingConfiguration, numEpochs); - } - - @Override - public int numParameters() { - return numParameters; - } - - @Override - public List collectLeaves() { - List list = super.collectLeaves(); - for (LayerConf lc : layerSpaces) { - list.addAll(lc.numLayers.collectLeaves()); - list.addAll(lc.layerSpace.collectLeaves()); - } - if (inputType != null) - list.addAll(inputType.collectLeaves()); - if (inputPreProcessors != null) - list.addAll(inputPreProcessors.collectLeaves()); - return list; - } - - - @Override - public String toString() { - StringBuilder sb = new StringBuilder(super.toString()); - - int i = 0; - for (LayerConf conf : layerSpaces) { - - sb.append("Layer config ").append(i++).append(": (Number layers:").append(conf.numLayers) - .append(", duplicate: ").append(conf.duplicateConfig).append("), ") - .append(conf.layerSpace.toString()).append("\n"); - } - - if (inputType != null) - sb.append("inputType: ").append(inputType).append("\n"); - if (inputPreProcessors != null) - sb.append("inputPreProcessors: ").append(inputPreProcessors).append("\n"); - - if (earlyStoppingConfiguration != null) { - sb.append("Early stopping configuration:").append(earlyStoppingConfiguration.toString()).append("\n"); - } else { - sb.append("Training # epochs:").append(numEpochs).append("\n"); - } - - return sb.toString(); - } - - public LayerSpace getLayerSpace(int layerNumber) { - return layerSpaces.get(layerNumber).getLayerSpace(); - } - - public static class Builder extends BaseNetworkSpace.Builder { - protected List layerSpaces = new ArrayList<>(); - protected ParameterSpace inputType; - protected ParameterSpace> inputPreProcessors; - protected WorkspaceMode trainingWorkspaceMode; - protected WorkspaceMode inferenceWorkspaceMode; - - //Early stopping configuration - protected EarlyStoppingConfiguration earlyStoppingConfiguration; - - - - public Builder setInputType(InputType inputType) { - return setInputType(new FixedValue<>(inputType)); - } - - public Builder setInputType(ParameterSpace inputType) { - this.inputType = inputType; - return this; - } - - public Builder layer(Layer layer){ - return layer(new FixedLayerSpace<>(layer)); - } - - public Builder layer(LayerSpace layerSpace) { - return layer(layerSpace, new FixedValue<>(1)); - } - - public Builder layer(LayerSpace layerSpace, ParameterSpace numLayersDistribution) { - return addLayer(layerSpace, numLayersDistribution); - } - - - public Builder addLayer(LayerSpace layerSpace) { - return addLayer(layerSpace, new FixedValue<>(1)); - } - - /** - * duplicateConfig not supported. Will always be true - * @param layerSpace - * @param numLayersDistribution - * @param duplicateConfig - * @return - */ - @Deprecated - public Builder addLayer(LayerSpace layerSpace, ParameterSpace numLayersDistribution, boolean duplicateConfig) { - if (!duplicateConfig) throw new IllegalArgumentException("Duplicate Config false not supported"); - String layerName = "layer_" + layerSpaces.size(); - duplicateConfig = true; //hard coded to always duplicate layers - layerSpaces.add(new LayerConf(layerSpace, layerName, null, numLayersDistribution, duplicateConfig, null)); - return this; - } - - /** - * @param layerSpace - * @param numLayersDistribution Distribution for number of layers to generate - */ - public Builder addLayer(LayerSpace layerSpace, ParameterSpace numLayersDistribution) { - String layerName = "layer_" + layerSpaces.size(); - boolean duplicateConfig = true; //hard coded to always duplicate layers - layerSpaces.add(new LayerConf(layerSpace, layerName, null, numLayersDistribution, duplicateConfig, null)); - return this; - } - - /** - * Early stopping configuration (optional). Note if both EarlyStoppingConfiguration and number of epochs is - * present, early stopping will be used in preference. - */ - public Builder earlyStoppingConfiguration( - EarlyStoppingConfiguration earlyStoppingConfiguration) { - this.earlyStoppingConfiguration = earlyStoppingConfiguration; - return this; - } - - /** - * @param inputPreProcessors Input preprocessors to set for the model - */ - public Builder setInputPreProcessors(Map inputPreProcessors) { - return setInputPreProcessors(new FixedValue<>(inputPreProcessors)); - } - - /** - * @param inputPreProcessors Input preprocessors to set for the model - */ - public Builder setInputPreProcessors(ParameterSpace> inputPreProcessors) { - this.inputPreProcessors = inputPreProcessors; - return this; - } - - public Builder trainingWorkspaceMode(WorkspaceMode workspaceMode){ - this.trainingWorkspaceMode = workspaceMode; - return this; - } - - public Builder inferenceWorkspaceMode(WorkspaceMode workspaceMode){ - this.inferenceWorkspaceMode = workspaceMode; - return this; - } - - @SuppressWarnings("unchecked") - public MultiLayerSpace build() { - return new MultiLayerSpace(this); - } - } - - public static MultiLayerSpace fromJson(String json) { - try { - return JsonMapper.getMapper().readValue(json, MultiLayerSpace.class); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public static MultiLayerSpace fromYaml(String yaml) { - try { - return YamlMapper.getMapper().readValue(yaml, MultiLayerSpace.class); - } catch (IOException e) { - throw new RuntimeException(e); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/adapter/ActivationParameterSpaceAdapter.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/adapter/ActivationParameterSpaceAdapter.java deleted file mode 100644 index 5e666a00d..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/adapter/ActivationParameterSpaceAdapter.java +++ /dev/null @@ -1,58 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.adapter; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.adapter.ParameterSpaceAdapter; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import com.fasterxml.jackson.annotation.JsonProperty; - -/** - * A simple class to adapt a {@link Activation} parameter space to a {@link IActivation} parameter space - * - * @author Alex Black - */ -@Data -@NoArgsConstructor -@EqualsAndHashCode(callSuper = false) -public class ActivationParameterSpaceAdapter extends ParameterSpaceAdapter { - - private ParameterSpace activation; - - public ActivationParameterSpaceAdapter(@JsonProperty("activation") ParameterSpace activation) { - this.activation = activation; - } - - @Override - public IActivation convertValue(Activation from) { - return from.getActivationFunction(); - } - - @Override - protected ParameterSpace underlying() { - return activation; - } - - @Override - protected String underlyingName() { - return "activation"; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/adapter/LossFunctionParameterSpaceAdapter.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/adapter/LossFunctionParameterSpaceAdapter.java deleted file mode 100644 index 2c4c2899c..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/adapter/LossFunctionParameterSpaceAdapter.java +++ /dev/null @@ -1,60 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.adapter; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.adapter.ParameterSpaceAdapter; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.lossfunctions.LossFunctions; -import com.fasterxml.jackson.annotation.JsonProperty; - -/** - * A simple class to adapt a {@link LossFunctions.LossFunction} parameter space to a {@link ILossFunction} parameter space - * - * @author Alex Black - */ -@Data -@NoArgsConstructor -@EqualsAndHashCode(callSuper = false) -public class LossFunctionParameterSpaceAdapter - extends ParameterSpaceAdapter { - - private ParameterSpace lossFunction; - - public LossFunctionParameterSpaceAdapter( - @JsonProperty("lossFunction") ParameterSpace lossFunction) { - this.lossFunction = lossFunction; - } - - @Override - protected ILossFunction convertValue(LossFunctions.LossFunction from) { - return from.getILossFunction(); - } - - @Override - protected ParameterSpace underlying() { - return lossFunction; - } - - @Override - protected String underlyingName() { - return "lossFunction"; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/dropout/DropoutSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/dropout/DropoutSpace.java deleted file mode 100644 index 76443109e..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/dropout/DropoutSpace.java +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.dropout; - -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.nn.conf.dropout.Dropout; -import org.deeplearning4j.nn.conf.dropout.IDropout; - -import java.util.List; - -@AllArgsConstructor -@NoArgsConstructor -public class DropoutSpace extends AbstractParameterSpace { - - private ParameterSpace dropout; - - @Override - public Dropout getValue(double[] parameterValues) { - double p = dropout.getValue(parameterValues); - if(p == 0){ - //Special case: 0 dropout = "disabled" in DL4J. But Dropout class doesn't support this - return null; - } - return new Dropout(p); - } - - @Override - public int numParameters() { - return dropout.numParameters(); - } - - @Override - public List collectLeaves() { - return dropout.collectLeaves(); - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - dropout.setIndices(indices); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdaGradSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdaGradSpace.java deleted file mode 100644 index ca94a386a..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdaGradSpace.java +++ /dev/null @@ -1,66 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.nd4j.linalg.learning.config.AdaGrad; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.schedule.ISchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -@Data -@EqualsAndHashCode(callSuper = false) -public class AdaGradSpace extends BaseUpdaterSpace { - - private ParameterSpace learningRate; - private ParameterSpace lrSchedule; - - public AdaGradSpace(ParameterSpace learningRate) { - this(learningRate, null); - } - - public static AdaGradSpace withLR(ParameterSpace lr){ - return new AdaGradSpace(lr, null); - } - - public static AdaGradSpace withLRSchedule(ParameterSpace lrSchedule){ - return new AdaGradSpace(null, lrSchedule); - } - - protected AdaGradSpace(@JsonProperty("learningRate") ParameterSpace learningRate, - @JsonProperty("lrSchedule") ParameterSpace lrSchedule){ - this.learningRate = learningRate; - this.lrSchedule = lrSchedule; - } - - @Override - public IUpdater getValue(double[] parameterValues) { - if(lrSchedule != null){ - return new AdaGrad(lrSchedule.getValue(parameterValues)); - } else { - return new AdaGrad(learningRate.getValue(parameterValues)); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdaMaxSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdaMaxSpace.java deleted file mode 100644 index 137e62f8d..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdaMaxSpace.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.nd4j.linalg.learning.config.AdaMax; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.schedule.ISchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -@Data -@EqualsAndHashCode(callSuper = false) -public class AdaMaxSpace extends BaseUpdaterSpace { - - private ParameterSpace learningRate; - private ParameterSpace learningRateSchedule; - private ParameterSpace beta1; - private ParameterSpace beta2; - private ParameterSpace epsilon; - - public AdaMaxSpace(ParameterSpace learningRate) { - this(learningRate, null, null, null); - } - - public AdaMaxSpace(ParameterSpace learningRate, ParameterSpace beta1, - ParameterSpace beta2, ParameterSpace epsilon) { - this(learningRate, null, beta1, beta2, epsilon); - } - - public AdaMaxSpace(@JsonProperty("learningRate") ParameterSpace learningRate, - @JsonProperty("learningRateSchedule") ParameterSpace learningRateSchedule, - @JsonProperty("beta1") ParameterSpace beta1, - @JsonProperty("beta2") ParameterSpace beta2, - @JsonProperty("epsilon") ParameterSpace epsilon){ - this.learningRate = learningRate; - this.learningRateSchedule = learningRateSchedule; - this.beta1 = beta1; - this.beta2 = beta2; - this.epsilon = epsilon; - } - - public static AdaMaxSpace withLR(ParameterSpace lr){ - return new AdaMaxSpace(lr, null, null, null, null); - } - - public static AdaMaxSpace withLRSchedule(ParameterSpace lrSchedule){ - return new AdaMaxSpace(null, lrSchedule, null, null, null); - } - - @Override - public IUpdater getValue(double[] parameterValues) { - double lr = learningRate == null ? AdaMax.DEFAULT_ADAMAX_LEARNING_RATE : learningRate.getValue(parameterValues); - ISchedule lrS = learningRateSchedule == null ? null : learningRateSchedule.getValue(parameterValues); - double b1 = beta1 == null ? AdaMax.DEFAULT_ADAMAX_LEARNING_RATE : beta1.getValue(parameterValues); - double b2 = beta2 == null ? AdaMax.DEFAULT_ADAMAX_LEARNING_RATE : beta2.getValue(parameterValues); - double eps = epsilon == null ? AdaMax.DEFAULT_ADAMAX_LEARNING_RATE : epsilon.getValue(parameterValues); - if(lrS == null){ - return new AdaMax(lr, b1, b2, eps); - } else { - AdaMax a = new AdaMax(lrS); - a.setBeta1(b1); - a.setBeta2(b2); - a.setEpsilon(eps); - return a; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdamSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdamSpace.java deleted file mode 100644 index 638b60782..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/AdamSpace.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.schedule.ISchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -@Data -@EqualsAndHashCode(callSuper = false) -public class AdamSpace extends BaseUpdaterSpace { - - private ParameterSpace learningRate; - private ParameterSpace learningRateSchedule; - private ParameterSpace beta1; - private ParameterSpace beta2; - private ParameterSpace epsilon; - - public AdamSpace(ParameterSpace learningRate) { - this(learningRate, null, null, null); - } - - public AdamSpace(ParameterSpace learningRate, ParameterSpace beta1, - ParameterSpace beta2, ParameterSpace epsilon) { - this(learningRate, null, beta1, beta2, epsilon); - } - - public static AdamSpace withLR(ParameterSpace lr){ - return new AdamSpace(lr, null, null, null, null); - } - - public static AdamSpace withLRSchedule(ParameterSpace lrSchedule){ - return new AdamSpace(null, lrSchedule, null, null, null); - } - - protected AdamSpace(@JsonProperty("learningRate") ParameterSpace learningRate, - @JsonProperty("learningRateSchedule") ParameterSpace learningRateSchedule, - @JsonProperty("beta1") ParameterSpace beta1, - @JsonProperty("beta2") ParameterSpace beta2, - @JsonProperty("epsilon") ParameterSpace epsilon){ - this.learningRate = learningRate; - this.learningRateSchedule = learningRateSchedule; - this.beta1 = beta1; - this.beta2 = beta2; - this.epsilon = epsilon; - } - - @Override - public IUpdater getValue(double[] parameterValues) { - double lr = learningRate == null ? Adam.DEFAULT_ADAM_LEARNING_RATE : learningRate.getValue(parameterValues); - ISchedule lrS = learningRateSchedule == null ? null : learningRateSchedule.getValue(parameterValues); - double b1 = beta1 == null ? Adam.DEFAULT_ADAM_LEARNING_RATE : beta1.getValue(parameterValues); - double b2 = beta2 == null ? Adam.DEFAULT_ADAM_LEARNING_RATE : beta2.getValue(parameterValues); - double eps = epsilon == null ? Adam.DEFAULT_ADAM_LEARNING_RATE : epsilon.getValue(parameterValues); - if(lrS == null){ - return new Adam(lr, b1, b2, eps); - } else { - Adam a = new Adam(lrS); - a.setBeta1(b1); - a.setBeta2(b2); - a.setEpsilon(eps); - return a; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/BaseUpdaterSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/BaseUpdaterSpace.java deleted file mode 100644 index ec1eca996..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/BaseUpdaterSpace.java +++ /dev/null @@ -1,70 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater; - -import lombok.Data; -import lombok.Getter; -import lombok.Setter; -import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.nd4j.linalg.learning.config.IUpdater; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Map; - -@Data -public abstract class BaseUpdaterSpace extends AbstractParameterSpace { - - @Override - public int numParameters() { - int count = 0; - for(ParameterSpace p : collectLeaves()){ - count += p.numParameters(); - } - return count; - } - - @Override - public List collectLeaves() { - Map nested = getNestedSpaces(); - List out = new ArrayList<>(); - for(ParameterSpace p : nested.values()){ - out.addAll(p.collectLeaves()); - } - return out; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices){ - int soFar = 0; - for(ParameterSpace p : collectLeaves()){ - int numParams = p.numParameters(); - if(numParams <= 0){ - continue; - } - int[] subset = Arrays.copyOfRange(indices, soFar, soFar + numParams); - p.setIndices(subset); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/NadamSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/NadamSpace.java deleted file mode 100644 index 16bc09127..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/NadamSpace.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.nd4j.linalg.learning.config.Nadam; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.schedule.ISchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -@Data -@EqualsAndHashCode(callSuper = false) -public class NadamSpace extends BaseUpdaterSpace { - - private ParameterSpace learningRate; - private ParameterSpace learningRateSchedule; - private ParameterSpace beta1; - private ParameterSpace beta2; - private ParameterSpace epsilon; - - public NadamSpace(ParameterSpace learningRate) { - this(learningRate, null, null, null); - } - - public NadamSpace(ParameterSpace learningRate, ParameterSpace beta1, - ParameterSpace beta2, ParameterSpace epsilon) { - this(learningRate, null, beta1, beta2, epsilon); - } - - public NadamSpace(@JsonProperty("learningRate") ParameterSpace learningRate, - @JsonProperty("learningRateSchedule") ParameterSpace learningRateSchedule, - @JsonProperty("beta1") ParameterSpace beta1, - @JsonProperty("beta2") ParameterSpace beta2, - @JsonProperty("epsilon") ParameterSpace epsilon){ - this.learningRate = learningRate; - this.learningRateSchedule = learningRateSchedule; - this.beta1 = beta1; - this.beta2 = beta2; - this.epsilon = epsilon; - } - - public static NadamSpace withLR(ParameterSpace lr){ - return new NadamSpace(lr, null, null, null, null); - } - - public static NadamSpace withLRSchedule(ParameterSpace lrSchedule){ - return new NadamSpace(null, lrSchedule, null, null, null); - } - - @Override - public IUpdater getValue(double[] parameterValues) { - double lr = learningRate == null ? Nadam.DEFAULT_NADAM_LEARNING_RATE : learningRate.getValue(parameterValues); - ISchedule lrS = learningRateSchedule == null ? null : learningRateSchedule.getValue(parameterValues); - double b1 = beta1 == null ? Nadam.DEFAULT_NADAM_LEARNING_RATE : beta1.getValue(parameterValues); - double b2 = beta2 == null ? Nadam.DEFAULT_NADAM_LEARNING_RATE : beta2.getValue(parameterValues); - double eps = epsilon == null ? Nadam.DEFAULT_NADAM_LEARNING_RATE : epsilon.getValue(parameterValues); - if(lrS == null){ - return new Nadam(lr, b1, b2, eps); - } else { - Nadam a = new Nadam(lrS); - a.setBeta1(b1); - a.setBeta2(b2); - a.setEpsilon(eps); - return a; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/NesterovsSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/NesterovsSpace.java deleted file mode 100644 index 6bb059493..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/NesterovsSpace.java +++ /dev/null @@ -1,100 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.learning.config.Nesterovs; -import org.nd4j.linalg.schedule.ISchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -@Data -@EqualsAndHashCode(callSuper = false) -public class NesterovsSpace extends BaseUpdaterSpace { - - protected ParameterSpace learningRate; - protected ParameterSpace learningRateSchedule; - protected ParameterSpace momentum; - protected ParameterSpace momentumSchedule; - - public NesterovsSpace(ParameterSpace learningRate) { - this(learningRate, null); - } - - public NesterovsSpace(ParameterSpace learningRate, ParameterSpace momentum) { - this(learningRate, null, momentum, null); - } - - public NesterovsSpace(@JsonProperty("learningRate") ParameterSpace learningRate, - @JsonProperty("learningRateSchedule") ParameterSpace learningRateSchedule, - @JsonProperty("momentum") ParameterSpace momentum, - @JsonProperty("momentumSchedule") ParameterSpace momentumSchedule) { - this.learningRate = learningRate; - this.learningRateSchedule = learningRateSchedule; - this.momentum = momentum; - this.momentumSchedule = momentumSchedule; - } - - public static NesterovsSpace withLR(ParameterSpace lr){ - return new NesterovsSpace(lr, null, null, null); - } - - public static NesterovsSpace withLR(ParameterSpace lr, double momentum){ - return new NesterovsSpace(lr, null, new FixedValue<>(momentum), null); - } - - public static NesterovsSpace withLR(ParameterSpace lr, ParameterSpace momentum){ - return new NesterovsSpace(lr, null, momentum, null); - } - - public static NesterovsSpace withLRSchedule(ParameterSpace lrSchedule){ - return new NesterovsSpace(null, lrSchedule, null, null); - } - - public static NesterovsSpace withLRSchedule(ParameterSpace lrSchedule, double momentum){ - return new NesterovsSpace(null, lrSchedule, new FixedValue<>(momentum), null); - } - - public static NesterovsSpace withLRSchedule(ParameterSpace lrSchedule, ParameterSpace momentum){ - return new NesterovsSpace(null, lrSchedule, momentum, null); - } - - - @Override - public IUpdater getValue(double[] parameterValues) { - double lr = learningRate == null ? Nesterovs.DEFAULT_NESTEROV_LEARNING_RATE : learningRate.getValue(parameterValues); - ISchedule lrS = learningRateSchedule == null ? null : learningRateSchedule.getValue(parameterValues); - double m = momentum == null ? Nesterovs.DEFAULT_NESTEROV_MOMENTUM : momentum.getValue(parameterValues); - ISchedule mS = momentumSchedule == null ? null : momentumSchedule.getValue(parameterValues); - if(lrS == null){ - if(momentumSchedule == null){ - return new Nesterovs(lr, m); - } else { - return new Nesterovs(lr, mS); - } - } else { - if(momentumSchedule == null){ - return new Nesterovs(lrS, m); - } else { - return new Nesterovs(lrS, mS); - } - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/RmsPropSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/RmsPropSpace.java deleted file mode 100644 index 8590947d6..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/RmsPropSpace.java +++ /dev/null @@ -1,54 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.learning.config.RmsProp; -import org.nd4j.linalg.schedule.ISchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -@Data -@EqualsAndHashCode(callSuper = false) -public class RmsPropSpace extends BaseUpdaterSpace { - - protected ParameterSpace learningRate; - protected ParameterSpace learningRateSchedule; - - public RmsPropSpace(ParameterSpace learningRate) { - this(learningRate, null); - } - - public RmsPropSpace(@JsonProperty("learningRate") ParameterSpace learningRate, - @JsonProperty("learningRateSchedule") ParameterSpace learningRateSchedule){ - this.learningRate = learningRate; - this.learningRateSchedule = learningRateSchedule; - } - - @Override - public IUpdater getValue(double[] parameterValues) { - double lr = learningRate == null ? RmsProp.DEFAULT_RMSPROP_LEARNING_RATE : learningRate.getValue(parameterValues); - ISchedule lrS = learningRateSchedule == null ? null : learningRateSchedule.getValue(parameterValues); - if(lrS == null){ - return new RmsProp(lr); - } else { - return new RmsProp(lrS); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/SgdSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/SgdSpace.java deleted file mode 100644 index 0c136e114..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/SgdSpace.java +++ /dev/null @@ -1,54 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.learning.config.Sgd; -import org.nd4j.linalg.schedule.ISchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -@Data -@EqualsAndHashCode(callSuper = false) -public class SgdSpace extends BaseUpdaterSpace { - - protected ParameterSpace learningRate; - protected ParameterSpace learningRateSchedule; - - public SgdSpace(ParameterSpace learningRate) { - this(learningRate, null); - } - - public SgdSpace(@JsonProperty("learningRate") ParameterSpace learningRate, - @JsonProperty("learningRateSchedule") ParameterSpace learningRateSchedule){ - this.learningRate = learningRate; - this.learningRateSchedule = learningRateSchedule; - } - - @Override - public IUpdater getValue(double[] parameterValues) { - double lr = learningRate == null ? Sgd.DEFAULT_SGD_LR : learningRate.getValue(parameterValues); - ISchedule lrS = learningRateSchedule == null ? null : learningRateSchedule.getValue(parameterValues); - if(lrS == null){ - return new Sgd(lr); - } else { - return new Sgd(lrS); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/ExponentialScheduleSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/ExponentialScheduleSpace.java deleted file mode 100644 index 3977faa25..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/ExponentialScheduleSpace.java +++ /dev/null @@ -1,92 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater.schedule; - -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.NonNull; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.nd4j.linalg.schedule.ExponentialSchedule; -import org.nd4j.linalg.schedule.ISchedule; -import org.nd4j.linalg.schedule.ScheduleType; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.*; - -@NoArgsConstructor //JSON -@Data -public class ExponentialScheduleSpace implements ParameterSpace { - - private ScheduleType scheduleType; - private ParameterSpace initialValue; - private ParameterSpace gamma; - - public ExponentialScheduleSpace(@NonNull ScheduleType scheduleType, - @NonNull ParameterSpace initialValue, double gamma){ - this(scheduleType, initialValue, new FixedValue<>(gamma)); - } - - public ExponentialScheduleSpace(@NonNull @JsonProperty("scheduleType") ScheduleType scheduleType, - @NonNull @JsonProperty("initialValue") ParameterSpace initialValue, - @NonNull @JsonProperty("gamma") ParameterSpace gamma){ - this.scheduleType = scheduleType; - this.initialValue = initialValue; - this.gamma = gamma; - } - - @Override - public ISchedule getValue(double[] parameterValues) { - return new ExponentialSchedule(scheduleType, initialValue.getValue(parameterValues), gamma.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return initialValue.numParameters() + gamma.numParameters(); - } - - @Override - public List collectLeaves() { - return Arrays.asList(initialValue, gamma); - } - - @Override - public Map getNestedSpaces() { - Map out = new LinkedHashMap<>(); - out.put("initialValue", initialValue); - out.put("gamma", gamma); - return out; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - if(initialValue.numParameters() > 0){ - int[] sub = Arrays.copyOfRange(indices, 0, initialValue.numParameters()); - initialValue.setIndices(sub); - } - if(gamma.numParameters() > 0){ - int inp = initialValue.numParameters(); - int[] sub = Arrays.copyOfRange(indices, inp, inp + gamma.numParameters()); - gamma.setIndices(sub); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/InverseScheduleSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/InverseScheduleSpace.java deleted file mode 100644 index a22c640a9..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/InverseScheduleSpace.java +++ /dev/null @@ -1,106 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater.schedule; - -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.NonNull; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.nd4j.linalg.schedule.ExponentialSchedule; -import org.nd4j.linalg.schedule.ISchedule; -import org.nd4j.linalg.schedule.InverseSchedule; -import org.nd4j.linalg.schedule.ScheduleType; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -@NoArgsConstructor //JSON -@Data -public class InverseScheduleSpace implements ParameterSpace { - - private ScheduleType scheduleType; - private ParameterSpace initialValue; - private ParameterSpace gamma; - private ParameterSpace power; - - public InverseScheduleSpace(@NonNull ScheduleType scheduleType, @NonNull ParameterSpace initialValue, - double gamma, double power){ - this(scheduleType, initialValue, new FixedValue<>(gamma), new FixedValue<>(power)); - } - - public InverseScheduleSpace(@NonNull @JsonProperty("scheduleType") ScheduleType scheduleType, - @NonNull @JsonProperty("initialValue") ParameterSpace initialValue, - @NonNull @JsonProperty("gamma") ParameterSpace gamma, - @NonNull @JsonProperty("power") ParameterSpace power){ - this.scheduleType = scheduleType; - this.initialValue = initialValue; - this.gamma = gamma; - this.power = power; - } - - @Override - public ISchedule getValue(double[] parameterValues) { - return new InverseSchedule(scheduleType, initialValue.getValue(parameterValues), - gamma.getValue(parameterValues), power.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return initialValue.numParameters() + gamma.numParameters() + power.numParameters(); - } - - @Override - public List collectLeaves() { - return Arrays.asList(initialValue, gamma, power); - } - - @Override - public Map getNestedSpaces() { - Map out = new LinkedHashMap<>(); - out.put("initialValue", initialValue); - out.put("gamma", gamma); - out.put("power", power); - return out; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - if(initialValue.numParameters() > 0){ - int[] sub = Arrays.copyOfRange(indices, 0, initialValue.numParameters()); - initialValue.setIndices(sub); - } - if(gamma.numParameters() > 0){ - int inp = initialValue.numParameters(); - int[] sub = Arrays.copyOfRange(indices, inp, inp + gamma.numParameters()); - gamma.setIndices(sub); - } - if(power.numParameters() > 0){ - int np = initialValue.numParameters() + gamma.numParameters(); - int[] sub = Arrays.copyOfRange(indices, np, np + power.numParameters()); - power.setIndices(sub); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/PolyScheduleSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/PolyScheduleSpace.java deleted file mode 100644 index 9beff30b5..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/PolyScheduleSpace.java +++ /dev/null @@ -1,106 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater.schedule; - -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.NonNull; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.nd4j.linalg.schedule.ISchedule; -import org.nd4j.linalg.schedule.InverseSchedule; -import org.nd4j.linalg.schedule.PolySchedule; -import org.nd4j.linalg.schedule.ScheduleType; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -@NoArgsConstructor //JSON -@Data -public class PolyScheduleSpace implements ParameterSpace { - - private ScheduleType scheduleType; - private ParameterSpace initialValue; - private ParameterSpace power; - private ParameterSpace maxIter; - - public PolyScheduleSpace(@NonNull ScheduleType scheduleType, @NonNull ParameterSpace initialValue, - double power, int maxIter){ - this(scheduleType, initialValue, new FixedValue<>(power), new FixedValue<>(maxIter)); - } - - public PolyScheduleSpace(@NonNull @JsonProperty("scheduleType") ScheduleType scheduleType, - @NonNull @JsonProperty("initialValue") ParameterSpace initialValue, - @NonNull @JsonProperty("power") ParameterSpace power, - @NonNull @JsonProperty("maxIter") ParameterSpace maxIter){ - this.scheduleType = scheduleType; - this.initialValue = initialValue; - this.power = power; - this.maxIter = maxIter; - } - - @Override - public ISchedule getValue(double[] parameterValues) { - return new PolySchedule(scheduleType, initialValue.getValue(parameterValues), - power.getValue(parameterValues), maxIter.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return initialValue.numParameters() + power.numParameters() + maxIter.numParameters(); - } - - @Override - public List collectLeaves() { - return Arrays.asList(initialValue, power, maxIter); - } - - @Override - public Map getNestedSpaces() { - Map out = new LinkedHashMap<>(); - out.put("initialValue", initialValue); - out.put("power", power); - out.put("maxIter", maxIter); - return out; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - if(initialValue.numParameters() > 0){ - int[] sub = Arrays.copyOfRange(indices, 0, initialValue.numParameters()); - initialValue.setIndices(sub); - } - if(power.numParameters() > 0){ - int np = initialValue.numParameters(); - int[] sub = Arrays.copyOfRange(indices, np, np + power.numParameters()); - power.setIndices(sub); - } - if(maxIter.numParameters() > 0){ - int np = initialValue.numParameters() + power.numParameters(); - int[] sub = Arrays.copyOfRange(indices, np, np + maxIter.numParameters()); - maxIter.setIndices(sub); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/SigmoidScheduleSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/SigmoidScheduleSpace.java deleted file mode 100644 index c8c5e4c3c..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/SigmoidScheduleSpace.java +++ /dev/null @@ -1,106 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater.schedule; - -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.NonNull; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.nd4j.linalg.schedule.ISchedule; -import org.nd4j.linalg.schedule.PolySchedule; -import org.nd4j.linalg.schedule.ScheduleType; -import org.nd4j.linalg.schedule.SigmoidSchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -@NoArgsConstructor //JSON -@Data -public class SigmoidScheduleSpace implements ParameterSpace { - - private ScheduleType scheduleType; - private ParameterSpace initialValue; - private ParameterSpace gamma; - private ParameterSpace stepSize; - - public SigmoidScheduleSpace(@NonNull ScheduleType scheduleType, @NonNull ParameterSpace initialValue, - double gamma, int stepSize){ - this(scheduleType, initialValue, new FixedValue<>(gamma), new FixedValue<>(stepSize)); - } - - public SigmoidScheduleSpace(@NonNull @JsonProperty("scheduleType") ScheduleType scheduleType, - @NonNull @JsonProperty("initialValue") ParameterSpace initialValue, - @NonNull @JsonProperty("gamma") ParameterSpace gamma, - @NonNull @JsonProperty("stepSize") ParameterSpace stepSize){ - this.scheduleType = scheduleType; - this.initialValue = initialValue; - this.gamma = gamma; - this.stepSize = stepSize; - } - - @Override - public ISchedule getValue(double[] parameterValues) { - return new SigmoidSchedule(scheduleType, initialValue.getValue(parameterValues), - gamma.getValue(parameterValues), stepSize.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return initialValue.numParameters() + gamma.numParameters() + stepSize.numParameters(); - } - - @Override - public List collectLeaves() { - return Arrays.asList(initialValue, gamma, stepSize); - } - - @Override - public Map getNestedSpaces() { - Map out = new LinkedHashMap<>(); - out.put("initialValue", initialValue); - out.put("gamma", gamma); - out.put("stepSize", stepSize); - return out; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - if(initialValue.numParameters() > 0){ - int[] sub = Arrays.copyOfRange(indices, 0, initialValue.numParameters()); - initialValue.setIndices(sub); - } - if(gamma.numParameters() > 0){ - int np = initialValue.numParameters(); - int[] sub = Arrays.copyOfRange(indices, np, np + gamma.numParameters()); - gamma.setIndices(sub); - } - if(stepSize.numParameters() > 0){ - int np = initialValue.numParameters() + gamma.numParameters(); - int[] sub = Arrays.copyOfRange(indices, np, np + stepSize.numParameters()); - stepSize.setIndices(sub); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/StepScheduleSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/StepScheduleSpace.java deleted file mode 100644 index d37638e8d..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/conf/updater/schedule/StepScheduleSpace.java +++ /dev/null @@ -1,106 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.conf.updater.schedule; - -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.NonNull; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.nd4j.linalg.schedule.ISchedule; -import org.nd4j.linalg.schedule.InverseSchedule; -import org.nd4j.linalg.schedule.ScheduleType; -import org.nd4j.linalg.schedule.StepSchedule; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.util.Arrays; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; - -@NoArgsConstructor //JSON -@Data -public class StepScheduleSpace implements ParameterSpace { - - private ScheduleType scheduleType; - private ParameterSpace initialValue; - private ParameterSpace decayRate; - private ParameterSpace step; - - public StepScheduleSpace(@NonNull ScheduleType scheduleType, @NonNull ParameterSpace initialValue, - double decayRate, double step){ - this(scheduleType, initialValue, new FixedValue<>(decayRate), new FixedValue<>(step)); - } - - public StepScheduleSpace(@NonNull @JsonProperty("scheduleType") ScheduleType scheduleType, - @NonNull @JsonProperty("initialValue") ParameterSpace initialValue, - @NonNull @JsonProperty("decayRate") ParameterSpace decayRate, - @NonNull @JsonProperty("step") ParameterSpace step){ - this.scheduleType = scheduleType; - this.initialValue = initialValue; - this.decayRate = decayRate; - this.step = step; - } - - @Override - public ISchedule getValue(double[] parameterValues) { - return new StepSchedule(scheduleType, initialValue.getValue(parameterValues), - decayRate.getValue(parameterValues), step.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return initialValue.numParameters() + decayRate.numParameters() + step.numParameters(); - } - - @Override - public List collectLeaves() { - return Arrays.asList(initialValue, decayRate, step); - } - - @Override - public Map getNestedSpaces() { - Map out = new LinkedHashMap<>(); - out.put("initialValue", initialValue); - out.put("decayRate", decayRate); - out.put("step", step); - return out; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - if(initialValue.numParameters() > 0){ - int[] sub = Arrays.copyOfRange(indices, 0, initialValue.numParameters()); - initialValue.setIndices(sub); - } - if(decayRate.numParameters() > 0){ - int inp = initialValue.numParameters(); - int[] sub = Arrays.copyOfRange(indices, inp, inp + decayRate.numParameters()); - decayRate.setIndices(sub); - } - if(step.numParameters() > 0){ - int np = initialValue.numParameters() + decayRate.numParameters(); - int[] sub = Arrays.copyOfRange(indices, np, np + step.numParameters()); - step.setIndices(sub); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/data/DataSetIteratorFactoryProvider.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/data/DataSetIteratorFactoryProvider.java deleted file mode 100644 index e5443c0d3..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/data/DataSetIteratorFactoryProvider.java +++ /dev/null @@ -1,85 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.data; - -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; - -import java.util.Map; - -/** - * This is a {@link DataProvider} for - * an {@link DataSetIteratorFactory} which - * based on a key of {@link DataSetIteratorFactoryProvider#FACTORY_KEY} - * will create {@link org.nd4j.linalg.dataset.api.iterator.DataSetIterator} - * for use with arbiter. - * - * This {@link DataProvider} is mainly meant for use for command line driven - * applications. - * - * @author Adam Gibson - */ -public class DataSetIteratorFactoryProvider implements DataProvider { - - public final static String FACTORY_KEY = "org.deeplearning4j.arbiter.data.data.factory"; - - /** - * Get training data given some parameters for the data. - * Data parameters map is used to specify things like batch - * size data preprocessing - * - * @param dataParameters Parameters for data. May be null or empty for default data - * @return training data - */ - @Override - public DataSetIteratorFactory trainData(Map dataParameters) { - return create(dataParameters); - } - - /** - * Get training data given some parameters for the data. Data parameters map - * is used to specify things like batch - * size data preprocessing - * - * @param dataParameters Parameters for data. May be null or empty for default data - * @return training data - */ - @Override - public DataSetIteratorFactory testData(Map dataParameters) { - return create(dataParameters); - } - - @Override - public Class getDataType() { - return DataSetIteratorFactory.class; - } - - private DataSetIteratorFactory create(Map dataParameters) { - if (!dataParameters.containsKey(FACTORY_KEY)) - throw new IllegalArgumentException( - "No data set iterator factory class found. Please specify a class name with key " - + FACTORY_KEY); - String value = dataParameters.get(FACTORY_KEY).toString(); - try { - Class clazz = - (Class) Class.forName(value); - return clazz.newInstance(); - } catch (Exception e) { - throw new RuntimeException("Could not create DataSetIteratorFactory instance - missing no-arg constructor?", e); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/data/MnistDataProvider.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/data/MnistDataProvider.java deleted file mode 100644 index c42837896..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/data/MnistDataProvider.java +++ /dev/null @@ -1,80 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.data; - -import lombok.Data; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultipleEpochsIterator; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.io.IOException; -import java.util.Map; -import java.util.Random; - -/** - * - * MnistDataProvider - a DataProvider for the MNIST data set, with configurable number of epochs, batch size - * and RNG seed - * - * @author Alex Black - */ -@Data -@NoArgsConstructor -public class MnistDataProvider implements DataProvider{ - - private int numEpochs; - private int batchSize; - private int rngSeed; - - public MnistDataProvider(int numEpochs, int batchSize){ - this(numEpochs, batchSize, new Random().nextInt()); - } - - public MnistDataProvider(@JsonProperty("numEpochs") int numEpochs, @JsonProperty("batchSize") int batchSize, - @JsonProperty("rngSeed") int rngSeed) { - this.numEpochs = numEpochs; - this.batchSize = batchSize; - this.rngSeed = rngSeed; - } - - - @Override - public Object trainData(Map dataParameters) { - try { - return new MultipleEpochsIterator(numEpochs, new MnistDataSetIterator(batchSize, true, rngSeed)); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public Object testData(Map dataParameters) { - try { - return new MnistDataSetIterator(batchSize, false, 12345); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/AlphaDropoutSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/AlphaDropoutSpace.java deleted file mode 100644 index f4e3801f5..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/AlphaDropoutSpace.java +++ /dev/null @@ -1,67 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.dropout; - -import lombok.AllArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.conf.dropout.AlphaDropout; -import org.deeplearning4j.nn.conf.dropout.IDropout; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -@AllArgsConstructor -public class AlphaDropoutSpace implements ParameterSpace { - - private ParameterSpace dropout; - - public AlphaDropoutSpace(double activationRetainProbability){ - this(new FixedValue<>(activationRetainProbability)); - } - - @Override - public IDropout getValue(double[] parameterValues) { - return new AlphaDropout(dropout.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return dropout.numParameters(); - } - - @Override - public List collectLeaves() { - return Collections.singletonList(dropout); - } - - @Override - public Map getNestedSpaces() { - return Collections.singletonMap("dropout", dropout); - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - dropout.setIndices(indices); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/DropoutSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/DropoutSpace.java deleted file mode 100644 index 52dea3155..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/DropoutSpace.java +++ /dev/null @@ -1,67 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.dropout; - -import lombok.AllArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.conf.dropout.Dropout; -import org.deeplearning4j.nn.conf.dropout.IDropout; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -@AllArgsConstructor -public class DropoutSpace implements ParameterSpace { - - private ParameterSpace dropout; - - public DropoutSpace(double activationRetainProbability){ - this(new FixedValue<>(activationRetainProbability)); - } - - @Override - public IDropout getValue(double[] parameterValues) { - return new Dropout(dropout.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return dropout.numParameters(); - } - - @Override - public List collectLeaves() { - return Collections.singletonList(dropout); - } - - @Override - public Map getNestedSpaces() { - return Collections.singletonMap("dropout", dropout); - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - dropout.setIndices(indices); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianDropoutSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianDropoutSpace.java deleted file mode 100644 index 0cb345f40..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianDropoutSpace.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.dropout; - -import lombok.AllArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.conf.dropout.Dropout; -import org.deeplearning4j.nn.conf.dropout.GaussianDropout; -import org.deeplearning4j.nn.conf.dropout.IDropout; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -@AllArgsConstructor -public class GaussianDropoutSpace implements ParameterSpace { - - private ParameterSpace rate; - - public GaussianDropoutSpace(double rate){ - this(new FixedValue<>(rate)); - } - - @Override - public IDropout getValue(double[] parameterValues) { - return new GaussianDropout(rate.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return rate.numParameters(); - } - - @Override - public List collectLeaves() { - return Collections.singletonList(rate); - } - - @Override - public Map getNestedSpaces() { - return Collections.singletonMap("rate", rate); - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - rate.setIndices(indices); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianNoiseSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianNoiseSpace.java deleted file mode 100644 index 706d389ee..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/dropout/GaussianNoiseSpace.java +++ /dev/null @@ -1,67 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.dropout; - -import lombok.AllArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.conf.dropout.GaussianNoise; -import org.deeplearning4j.nn.conf.dropout.IDropout; - -import java.util.Collections; -import java.util.List; -import java.util.Map; - -@AllArgsConstructor -public class GaussianNoiseSpace implements ParameterSpace { - - private ParameterSpace stddev; - - public GaussianNoiseSpace(double stddev){ - this(new FixedValue<>(stddev)); - } - - @Override - public IDropout getValue(double[] parameterValues) { - return new GaussianNoise(stddev.getValue(parameterValues)); - } - - @Override - public int numParameters() { - return stddev.numParameters(); - } - - @Override - public List collectLeaves() { - return Collections.singletonList(stddev); - } - - @Override - public Map getNestedSpaces() { - return Collections.singletonMap("stddev", stddev); - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - stddev.setIndices(indices); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/evaluator/multilayer/ClassificationEvaluator.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/evaluator/multilayer/ClassificationEvaluator.java deleted file mode 100644 index 3d13ea2f1..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/evaluator/multilayer/ClassificationEvaluator.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.evaluator.multilayer; - -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator; -import org.deeplearning4j.arbiter.scoring.util.ScoreUtil; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -import java.util.Arrays; -import java.util.List; -import java.util.Map; - -/** - * A model evaluator for doing additional - * evaluation (classification evaluation) - * for a {@link MultiLayerNetwork} given a {@link DataSetIterator} - * - * @author Alex Black - */ -@NoArgsConstructor -@AllArgsConstructor -public class ClassificationEvaluator implements ModelEvaluator { - private Map params = null; - - - @Override - public Evaluation evaluateModel(Object model, DataProvider dataProvider) { - - if (model instanceof MultiLayerNetwork) { - DataSetIterator iterator = ScoreUtil.getIterator(dataProvider.testData(params)); - return ScoreUtil.getEvaluation((MultiLayerNetwork) model, iterator); - } else { - DataSetIterator iterator = ScoreUtil.getIterator(dataProvider.testData(params)); - return ScoreUtil.getEvaluation((ComputationGraph) model, iterator); - } - } - - @Override - public List> getSupportedModelTypes() { - return Arrays.>asList(MultiLayerNetwork.class, ComputationGraph.class); - } - - @Override - public List> getSupportedDataTypes() { - return Arrays.>asList(DataSetIterator.class, MultiDataSetIterator.class); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/evaluator/multilayer/RegressionDataEvaluator.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/evaluator/multilayer/RegressionDataEvaluator.java deleted file mode 100644 index 7973b11e0..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/evaluator/multilayer/RegressionDataEvaluator.java +++ /dev/null @@ -1,62 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.evaluator.multilayer; - -import lombok.AllArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator; -import org.deeplearning4j.arbiter.scoring.RegressionValue; -import org.deeplearning4j.arbiter.scoring.util.ScoreUtil; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -import java.util.Arrays; -import java.util.List; -import java.util.Map; - -/** - * Created by agibsonccc on 3/12/17. - */ -@AllArgsConstructor -public class RegressionDataEvaluator implements ModelEvaluator { - private RegressionValue regressionValue; - private Map params = null; - - @Override - public Double evaluateModel(Object model, DataProvider dataProvider) { - - if (model instanceof MultiLayerNetwork) { - DataSetIterator iterator = ScoreUtil.getIterator(dataProvider.testData(params)); - return ScoreUtil.score((MultiLayerNetwork) model, iterator, regressionValue); - } else { - DataSetIterator iterator = ScoreUtil.getIterator(dataProvider.testData(params)); - return ScoreUtil.score((ComputationGraph) model, iterator, regressionValue); - } - } - - @Override - public List> getSupportedModelTypes() { - return Arrays.>asList(MultiLayerNetwork.class, ComputationGraph.class); - } - - @Override - public List> getSupportedDataTypes() { - return Arrays.>asList(DataSetIterator.class, MultiDataSetIterator.class); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AbstractLSTMLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AbstractLSTMLayerSpace.java deleted file mode 100644 index 7cad81a82..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AbstractLSTMLayerSpace.java +++ /dev/null @@ -1,108 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.adapter.ActivationParameterSpaceAdapter; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.AbstractLSTM; -import org.deeplearning4j.nn.conf.layers.GravesLSTM; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; - -/** - * Layer space for LSTM layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public abstract class AbstractLSTMLayerSpace extends FeedForwardLayerSpace { - - protected ParameterSpace forgetGateBiasInit; - protected ParameterSpace gateActivationFn; - - protected AbstractLSTMLayerSpace(Builder builder) { - super(builder); - this.forgetGateBiasInit = builder.forgetGateBiasInit; - this.gateActivationFn = builder.gateActivationFn; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - protected void setLayerOptionsBuilder(AbstractLSTM.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (forgetGateBiasInit != null) - builder.forgetGateBiasInit(forgetGateBiasInit.getValue(values)); - if(gateActivationFn != null) - builder.gateActivationFunction(gateActivationFn.getValue(values)); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder(); //"AbstractLSTMLayerSpace("); - if (forgetGateBiasInit != null) - sb.append("forgetGateBiasInit: ").append(forgetGateBiasInit).append(delim); - if (gateActivationFn != null) - sb.append("gateActivationFn: ").append(gateActivationFn).append(delim); - sb.append(super.toString(delim)); - return sb.toString(); - } - - public static abstract class Builder extends FeedForwardLayerSpace.Builder { - - private ParameterSpace forgetGateBiasInit; - private ParameterSpace gateActivationFn; - - public T forgetGateBiasInit(double forgetGateBiasInit) { - return forgetGateBiasInit(new FixedValue<>(forgetGateBiasInit)); - } - - public T forgetGateBiasInit(ParameterSpace forgetGateBiasInit) { - this.forgetGateBiasInit = forgetGateBiasInit; - return (T)this; - } - - public T gateActivationFn(Activation activation){ - return gateActivationFn(activation.getActivationFunction()); - } - - public T gateActivation(ParameterSpace gateActivationFn){ - return gateActivationFn(new ActivationParameterSpaceAdapter(gateActivationFn)); - } - - public T gateActivationFn(IActivation gateActivationFn){ - return gateActivationFn(new FixedValue<>(gateActivationFn)); - } - - public T gateActivationFn(ParameterSpace gateActivationFn){ - this.gateActivationFn = gateActivationFn; - return (T)this; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ActivationLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ActivationLayerSpace.java deleted file mode 100644 index 1d45d23c8..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ActivationLayerSpace.java +++ /dev/null @@ -1,94 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.adapter.ActivationParameterSpaceAdapter; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.ActivationLayer; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; - -/** - * Layer space for {@link ActivationLayer} - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class ActivationLayerSpace extends LayerSpace { - - private ParameterSpace activationFunction; - - protected ActivationLayerSpace(Builder builder) { - super(builder); - this.activationFunction = builder.activationFunction; - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - - @Override - public ActivationLayer getValue(double[] parameterValues) { - ActivationLayer.Builder b = new ActivationLayer.Builder(); - super.setLayerOptionsBuilder(b, parameterValues); - b.activation(activationFunction.getValue(parameterValues)); - return b.build(); - } - - public static class Builder extends LayerSpace.Builder { - - private ParameterSpace activationFunction; - - public Builder activation(Activation activation) { - return activation(new FixedValue<>(activation)); - } - - public Builder activation(IActivation iActivation) { - return activationFn(new FixedValue<>(iActivation)); - } - - public Builder activation(ParameterSpace activationFunction) { - return activationFn(new ActivationParameterSpaceAdapter(activationFunction)); - } - - public Builder activationFn(ParameterSpace activationFunction) { - this.activationFunction = activationFunction; - return this; - } - - @SuppressWarnings("unchecked") - public ActivationLayerSpace build() { - return new ActivationLayerSpace(this); - } - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - return "ActivationLayerSpace(" + super.toString(delim) + ")"; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AutoEncoderLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AutoEncoderLayerSpace.java deleted file mode 100644 index a429a2c96..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/AutoEncoderLayerSpace.java +++ /dev/null @@ -1,107 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.AutoEncoder; -import com.fasterxml.jackson.annotation.JsonProperty; - -/** - * Layer space for autoencoder layers - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class AutoEncoderLayerSpace extends BasePretrainNetworkLayerSpace { - @JsonProperty - private ParameterSpace corruptionLevel; - @JsonProperty - private ParameterSpace sparsity; - - private AutoEncoderLayerSpace(Builder builder) { - super(builder); - this.corruptionLevel = builder.corruptionLevel; - this.sparsity = builder.sparsity; - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public AutoEncoder getValue(double[] values) { - AutoEncoder.Builder b = new AutoEncoder.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(AutoEncoder.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (corruptionLevel != null) - builder.corruptionLevel(corruptionLevel.getValue(values)); - if (sparsity != null) - builder.sparsity(sparsity.getValue(values)); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder("AutoEncoderLayerSpace("); - if (corruptionLevel != null) - sb.append("corruptionLevel: ").append(corruptionLevel).append(delim); - if (sparsity != null) - sb.append("sparsity: ").append(sparsity).append(delim); - sb.append(super.toString(delim)).append(")"); - return sb.toString(); - } - - public static class Builder extends BasePretrainNetworkLayerSpace.Builder { - - private ParameterSpace corruptionLevel; - private ParameterSpace sparsity; - - public Builder corruptionLevel(double corruptionLevel) { - return corruptionLevel(new FixedValue<>(corruptionLevel)); - } - - public Builder corruptionLevel(ParameterSpace corruptionLevel) { - this.corruptionLevel = corruptionLevel; - return this; - } - - public Builder sparsity(double sparsity) { - return sparsity(new FixedValue<>(sparsity)); - } - - public Builder sparsity(ParameterSpace sparsity) { - this.sparsity = sparsity; - return this; - } - - public AutoEncoderLayerSpace build() { - return new AutoEncoderLayerSpace(this); - } - - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseConvolutionLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseConvolutionLayerSpace.java deleted file mode 100644 index 11bf1f274..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseConvolutionLayerSpace.java +++ /dev/null @@ -1,162 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; -import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; - -/** - * Layer space for convolutional layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public abstract class BaseConvolutionLayerSpace extends FeedForwardLayerSpace { - protected ParameterSpace dilation; - protected ParameterSpace kernelSize; - protected ParameterSpace stride; - protected ParameterSpace padding; - protected ParameterSpace convolutionMode; - protected ParameterSpace hasBias; - - protected BaseConvolutionLayerSpace(Builder builder) { - super(builder); - this.dilation = builder.dilation; - this.kernelSize = builder.kernelSize; - this.stride = builder.stride; - this.padding = builder.padding; - this.convolutionMode = builder.convolutionMode; - this.hasBias = builder.hasBias; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - protected void setLayerOptionsBuilder(ConvolutionLayer.BaseConvBuilder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (dilation != null) - builder.dilation(dilation.getValue(values)); - if (kernelSize != null) - builder.kernelSize(kernelSize.getValue(values)); - if (stride != null) - builder.stride(stride.getValue(values)); - if (padding != null) - builder.padding(padding.getValue(values)); - if (convolutionMode != null) - builder.convolutionMode(convolutionMode.getValue(values)); - if (hasBias != null) - builder.hasBias(hasBias.getValue(values)); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder(); - if (dilation != null) - sb.append("dilation: ").append(dilation).append(delim); - if (kernelSize != null) - sb.append("kernelSize: ").append(kernelSize).append(delim); - if (stride != null) - sb.append("stride: ").append(stride).append(delim); - if (padding != null) - sb.append("padding: ").append(padding).append(delim); - if (convolutionMode != null) - sb.append("convolutionMode: ").append(convolutionMode).append(delim); - if (hasBias != null) - sb.append("hasBias: ").append(hasBias).append(delim); - sb.append(super.toString(delim)); - return sb.toString(); - } - - - public static abstract class Builder extends FeedForwardLayerSpace.Builder { - protected ParameterSpace dilation; - protected ParameterSpace kernelSize; - protected ParameterSpace stride; - protected ParameterSpace padding; - protected ParameterSpace convolutionMode; - protected ParameterSpace hasBias; - - public T dilation(int... dilation) { - return dilation(new FixedValue<>(dilation)); - } - - public T dilation(ParameterSpace dilation) { - this.dilation = dilation; - return (T) this; - } - public T kernelSize(int... kernelSize) { - return kernelSize(new FixedValue<>(kernelSize)); - } - - public T kernelSize(ParameterSpace kernelSize) { - this.kernelSize = kernelSize; - return (T)this; - } - - public T stride(int... stride) { - return stride(new FixedValue<>(stride)); - } - - public T stride(ParameterSpace stride) { - this.stride = stride; - return (T)this; - } - - public T padding(int... padding) { - return padding(new FixedValue<>(padding)); - } - - public T padding(ParameterSpace padding) { - this.padding = padding; - return (T)this; - } - - public T convolutionMode(ConvolutionMode convolutionMode) { - return convolutionMode(new FixedValue<>(convolutionMode)); - } - - public T convolutionMode(ParameterSpace convolutionMode) { - this.convolutionMode = convolutionMode; - return (T)this; - } - - public T hasBias(boolean hasBias){ - return hasBias(new FixedValue<>(hasBias)); - } - - public T hasBias(ParameterSpace hasBias){ - this.hasBias = hasBias; - return (T)this; - } - - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java deleted file mode 100644 index 255e76be5..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseLayerSpace.java +++ /dev/null @@ -1,292 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import com.google.common.base.Preconditions; -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.adapter.ActivationParameterSpaceAdapter; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.Updater; -import org.deeplearning4j.nn.conf.distribution.Distribution; -import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.learning.config.IUpdater; -import com.fasterxml.jackson.annotation.JsonInclude; - -import java.util.Map; - -/** - * BaseLayerSpace contains the common Layer hyperparameters; should match {@link BaseLayer} in terms of features - * - * @author Alex Black - */ -@JsonInclude(JsonInclude.Include.NON_NULL) - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public abstract class BaseLayerSpace extends LayerSpace { - protected ParameterSpace activationFunction; - protected ParameterSpace weightInit; - protected ParameterSpace biasInit; - protected ParameterSpace dist; - protected ParameterSpace l1; - protected ParameterSpace l2; - protected ParameterSpace l1Bias; - protected ParameterSpace l2Bias; - protected ParameterSpace updater; - protected ParameterSpace biasUpdater; - protected ParameterSpace weightNoise; - protected ParameterSpace gradientNormalization; - protected ParameterSpace gradientNormalizationThreshold; - protected int numParameters; - - @SuppressWarnings("unchecked") - protected BaseLayerSpace(Builder builder) { - super(builder); - this.activationFunction = builder.activationFunction; - this.weightInit = builder.weightInit; - this.biasInit = builder.biasInit; - this.dist = builder.dist; - this.l1 = builder.l1; - this.l2 = builder.l2; - this.l1Bias = builder.l1Bias; - this.l2Bias = builder.l2Bias; - this.updater = builder.updater; - this.biasUpdater = builder.biasUpdater; - this.weightNoise = builder.weightNoise; - this.gradientNormalization = builder.gradientNormalization; - this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold; - } - - @Override - public int numParameters() { - return numParameters; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - throw new UnsupportedOperationException("Cannot set indices for non-leaf parameter space"); - } - - - protected void setLayerOptionsBuilder(BaseLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (activationFunction != null) - builder.activation(activationFunction.getValue(values)); - if (biasInit != null) - builder.biasInit(biasInit.getValue(values)); - if (weightInit != null) - builder.weightInit(weightInit.getValue(values)); - if (dist != null) - builder.dist(dist.getValue(values)); - if (l1 != null) - builder.l1(l1.getValue(values)); - if (l2 != null) - builder.l2(l2.getValue(values)); - if (l1Bias != null) - builder.l1Bias(l1Bias.getValue(values)); - if (l2Bias != null) - builder.l2Bias(l2Bias.getValue(values)); - if (updater != null) - builder.updater(updater.getValue(values)); - if (biasUpdater != null) - builder.biasUpdater(biasUpdater.getValue(values)); - if (weightNoise != null) - builder.weightNoise(weightNoise.getValue(values)); - if (gradientNormalization != null) - builder.gradientNormalization(gradientNormalization.getValue(values)); - if (gradientNormalizationThreshold != null) - builder.gradientNormalizationThreshold(gradientNormalizationThreshold.getValue(values)); - } - - - @Override - public String toString() { - return toString(", "); - } - - protected String toString(String delim) { - StringBuilder sb = new StringBuilder(); - - for (Map.Entry e : getNestedSpaces().entrySet()) { - sb.append(e.getKey()).append(": ").append(e.getValue()).append("\n"); - } - return sb.toString(); - } - - @SuppressWarnings("unchecked") - public abstract static class Builder extends LayerSpace.Builder { - protected ParameterSpace activationFunction; - protected ParameterSpace weightInit; - protected ParameterSpace biasInit; - protected ParameterSpace dist; - protected ParameterSpace l1; - protected ParameterSpace l2; - protected ParameterSpace l1Bias; - protected ParameterSpace l2Bias; - protected ParameterSpace updater; - protected ParameterSpace biasUpdater; - protected ParameterSpace weightNoise; - protected ParameterSpace gradientNormalization; - protected ParameterSpace gradientNormalizationThreshold; - - public T activation(Activation... activations){ - Preconditions.checkArgument(activations.length > 0, "Activations length must be 1 or more"); - if(activations.length == 1){ - return activation(activations[0]); - } - return activation(new DiscreteParameterSpace<>(activations)); - } - - public T activation(Activation activation) { - return activation(new FixedValue<>(activation)); - } - - public T activation(IActivation iActivation) { - return activationFn(new FixedValue<>(iActivation)); - } - - public T activation(ParameterSpace activationFunction) { - return activationFn(new ActivationParameterSpaceAdapter(activationFunction)); - } - - public T activationFn(ParameterSpace activationFunction) { - this.activationFunction = activationFunction; - return (T) this; - } - - public T weightInit(WeightInit weightInit) { - return (T) weightInit(new FixedValue(weightInit)); - } - - public T weightInit(ParameterSpace weightInit) { - this.weightInit = weightInit; - return (T) this; - } - - public T weightInit(Distribution distribution){ - weightInit(WeightInit.DISTRIBUTION); - return dist(distribution); - } - - public T biasInit(double biasInit){ - return biasInit(new FixedValue<>(biasInit)); - } - - public T biasInit(ParameterSpace biasInit){ - this.biasInit = biasInit; - return (T) this; - } - - public T dist(Distribution dist) { - return dist(new FixedValue<>(dist)); - } - - public T dist(ParameterSpace dist) { - this.dist = dist; - return (T) this; - } - - public T l1(double l1) { - return l1(new FixedValue(l1)); - } - - public T l1(ParameterSpace l1) { - this.l1 = l1; - return (T) this; - } - - public T l2(double l2) { - return l2(new FixedValue(l2)); - } - - public T l2(ParameterSpace l2) { - this.l2 = l2; - return (T) this; - } - - public T l1Bias(double l1Bias) { - return l1Bias(new FixedValue(l1Bias)); - } - - public T l1Bias(ParameterSpace l1Bias) { - this.l1Bias = l1Bias; - return (T) this; - } - - public T l2Bias(double l2Bias) { - return l2Bias(new FixedValue<>(l2Bias)); - } - - public T l2Bias(ParameterSpace l2Bias) { - this.l2Bias = l2Bias; - return (T) this; - } - - public T updater(IUpdater updater) { - return updater(new FixedValue<>(updater)); - } - - public T updater(ParameterSpace updater) { - this.updater = updater; - return (T) this; - } - - public T biasUpdater(IUpdater biasUpdater) { - return biasUpdater(new FixedValue<>(biasUpdater)); - } - - public T biasUpdater(ParameterSpace biasUpdater) { - this.biasUpdater = biasUpdater; - return (T) this; - } - - public T gradientNormalization(GradientNormalization gradientNormalization) { - return gradientNormalization(new FixedValue(gradientNormalization)); - } - - public T gradientNormalization(ParameterSpace gradientNormalization) { - this.gradientNormalization = gradientNormalization; - return (T) this; - } - - public T gradientNormalizationThreshold(double threshold) { - return gradientNormalizationThreshold(new FixedValue<>(threshold)); - } - - public T gradientNormalizationThreshold(ParameterSpace gradientNormalizationThreshold) { - this.gradientNormalizationThreshold = gradientNormalizationThreshold; - return (T) this; - } - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java deleted file mode 100644 index 857f729ad..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BaseOutputLayerSpace.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.adapter.LossFunctionParameterSpaceAdapter; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; - -/** - * @param Type of the (concrete) output layer - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PUBLIC) //For Jackson JSON/YAML deserialization -public abstract class BaseOutputLayerSpace extends FeedForwardLayerSpace { - - protected ParameterSpace lossFunction; - protected ParameterSpace hasBias; - - protected BaseOutputLayerSpace(Builder builder) { - super(builder); - this.lossFunction = builder.lossFunction; - this.hasBias = builder.hasBias; - } - - protected void setLayerOptionsBuilder(BaseOutputLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (lossFunction != null) - builder.lossFunction(lossFunction.getValue(values)); - if (hasBias != null) - builder.hasBias(hasBias.getValue(values)); - } - - @SuppressWarnings("unchecked") - public static abstract class Builder extends FeedForwardLayerSpace.Builder { - - protected ParameterSpace lossFunction; - protected ParameterSpace hasBias; - - public T lossFunction(LossFunction lossFunction) { - return lossFunction(new FixedValue<>(lossFunction)); - } - - public T lossFunction(ParameterSpace lossFunction) { - return iLossFunction(new LossFunctionParameterSpaceAdapter(lossFunction)); - } - - public T iLossFunction(ILossFunction lossFunction) { - return iLossFunction(new FixedValue<>(lossFunction)); - } - - public T iLossFunction(ParameterSpace lossFunction) { - this.lossFunction = lossFunction; - return (T) this; - } - - public T hasBias(boolean hasBias){ - return hasBias(new FixedValue<>(hasBias)); - } - - public T hasBias(ParameterSpace hasBias){ - this.hasBias = hasBias; - return (T)this; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BasePretrainNetworkLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BasePretrainNetworkLayerSpace.java deleted file mode 100644 index 9f554911b..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BasePretrainNetworkLayerSpace.java +++ /dev/null @@ -1,57 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork; -import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; -import com.fasterxml.jackson.annotation.JsonProperty; - - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public abstract class BasePretrainNetworkLayerSpace extends FeedForwardLayerSpace { - @JsonProperty - protected ParameterSpace lossFunction; - - protected BasePretrainNetworkLayerSpace(Builder builder) { - super(builder); - this.lossFunction = builder.lossFunction; - } - - - public static abstract class Builder extends FeedForwardLayerSpace.Builder { - protected ParameterSpace lossFunction; - - public T lossFunction(LossFunction lossFunction) { - return lossFunction(new FixedValue(lossFunction)); - } - - public T lossFunction(ParameterSpace lossFunction) { - this.lossFunction = lossFunction; - return (T) this; - } - - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BatchNormalizationSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BatchNormalizationSpace.java deleted file mode 100644 index 9b55555ed..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/BatchNormalizationSpace.java +++ /dev/null @@ -1,214 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.layers.BatchNormalization; - -import java.util.Arrays; -import java.util.List; - -/** - * LayerSpace for batch normalization layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class BatchNormalizationSpace extends FeedForwardLayerSpace { - - protected ParameterSpace decay; - protected ParameterSpace eps; - protected ParameterSpace isMinibatch; - protected ParameterSpace lockGammaBeta; - protected ParameterSpace gamma; - protected ParameterSpace beta; - protected ParameterSpace> constrainBeta; - protected ParameterSpace> constrainGamma; - - private BatchNormalizationSpace(Builder builder) { - super(builder); - this.decay = builder.decay; - this.eps = builder.eps; - this.isMinibatch = builder.isMinibatch; - this.lockGammaBeta = builder.lockGammaBeta; - this.gamma = builder.gamma; - this.beta = builder.beta; - this.constrainBeta = builder.betaConstraints; - this.constrainGamma = builder.gammaConstraints; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public BatchNormalization getValue(double[] parameterValues) { - BatchNormalization.Builder b = new BatchNormalization.Builder(); - setLayerOptionsBuilder(b, parameterValues); - return b.build(); - } - - protected void setLayerOptionsBuilder(BatchNormalization.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (decay != null) - builder.decay(decay.getValue(values)); - if (eps != null) - builder.eps(eps.getValue(values)); - if (isMinibatch != null) - builder.minibatch(isMinibatch.getValue(values)); - if (lockGammaBeta != null) - builder.lockGammaBeta(lockGammaBeta.getValue(values)); - if (gamma != null) - builder.gamma(gamma.getValue(values)); - if (beta != null) - builder.beta(beta.getValue(values)); - if (constrainBeta != null){ - List c = constrainBeta.getValue(values); - if(c != null){ - builder.constrainBeta(c.toArray(new LayerConstraint[c.size()])); - } - } - if (constrainGamma != null){ - List c = constrainGamma.getValue(values); - if(c != null){ - builder.constrainGamma(c.toArray(new LayerConstraint[c.size()])); - } - } - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder(); - sb.append("BatchNormalizationSpace(").append(super.toString(delim)); - if (decay != null) - sb.append("decay: ").append(decay).append(delim); - if (eps != null) - sb.append("eps: ").append(eps).append(delim); - if (isMinibatch != null) - sb.append("isMinibatch: ").append(isMinibatch).append(delim); - if (lockGammaBeta != null) - sb.append("lockGammaBeta: ").append(lockGammaBeta).append(delim); - if (gamma != null) - sb.append("gamma: ").append(gamma).append(delim); - if (beta != null) - sb.append("beta: ").append(beta).append(delim); - sb.append(")"); - return sb.toString(); - } - - public static class Builder extends FeedForwardLayerSpace.Builder { - - protected ParameterSpace decay; - protected ParameterSpace eps; - protected ParameterSpace isMinibatch; - protected ParameterSpace lockGammaBeta; - protected ParameterSpace gamma; - protected ParameterSpace beta; - protected ParameterSpace> betaConstraints; - protected ParameterSpace> gammaConstraints; - - public Builder minibatch(boolean minibatch) { - return minibatch(new FixedValue<>(minibatch)); - } - - public Builder minibatch(ParameterSpace minibatch) { - this.isMinibatch = minibatch; - return this; - } - - public Builder gamma(double gamma) { - return gamma(new FixedValue<>(gamma)); - } - - public Builder gamma(ParameterSpace gamma) { - this.gamma = gamma; - return this; - } - - public Builder beta(double beta) { - return beta(new FixedValue<>(beta)); - } - - public Builder beta(ParameterSpace beta) { - this.beta = beta; - return this; - } - - public Builder eps(double eps) { - return eps(new FixedValue<>(eps)); - } - - public Builder eps(ParameterSpace eps) { - this.eps = eps; - return this; - } - - public Builder decay(double decay) { - return decay(new FixedValue(decay)); - } - - public Builder decay(ParameterSpace decay) { - this.decay = decay; - return this; - } - - public Builder lockGammaBeta(boolean lockGammaBeta) { - return lockGammaBeta(new FixedValue<>(lockGammaBeta)); - } - - public Builder lockGammaBeta(ParameterSpace lockGammaBeta) { - this.lockGammaBeta = lockGammaBeta; - return this; - } - - public Builder constrainBeta(LayerConstraint... constraints) { - return constrainBeta(new FixedValue<>(Arrays.asList(constraints))); - } - - public Builder constrainBeta(ParameterSpace> constraints) { - this.betaConstraints = constraints; - return this; - } - - public Builder constrainGamma(LayerConstraint... constraints) { - return constrainGamma(new FixedValue<>(Arrays.asList(constraints))); - } - - public Builder constrainGamma(ParameterSpace> constraints) { - this.gammaConstraints = constraints; - return this; - } - - - @Override - public BatchNormalizationSpace build() { - return new BatchNormalizationSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Bidirectional.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Bidirectional.java deleted file mode 100644 index 64cdfd369..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Bidirectional.java +++ /dev/null @@ -1,67 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.Data; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.nn.conf.layers.Layer; - -import java.util.List; - -/** - * Bidirectional layer wrapper. Can be used wrap an existing layer space, in the same way that - * {@link org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional} wraps a DL4J layer - * - * @author Alex Black - */ -@NoArgsConstructor //JSON -@Data -public class Bidirectional extends LayerSpace { - - protected LayerSpace layerSpace; - - public Bidirectional(LayerSpace layerSpace){ - this.layerSpace = layerSpace; - } - - @Override - public Layer getValue(double[] parameterValues) { - Layer underlying = layerSpace.getValue(parameterValues); - return new org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional(underlying); - } - - @Override - public int numParameters() { - return layerSpace.numParameters(); - } - - @Override - public List collectLeaves() { - return layerSpace.collectLeaves(); - } - - @Override - public boolean isLeaf() { - return layerSpace.isLeaf(); - } - - @Override - public void setIndices(int... indices) { - layerSpace.setIndices(indices); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/CenterLossOutputLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/CenterLossOutputLayerSpace.java deleted file mode 100644 index ecba732c3..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/CenterLossOutputLayerSpace.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer; - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public class CenterLossOutputLayerSpace extends BaseOutputLayerSpace { - - ParameterSpace alpha; - ParameterSpace lambda; - - protected CenterLossOutputLayerSpace(Builder builder){ - super(builder); - this.alpha = builder.alpha; - this.lambda = builder.lambda; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public CenterLossOutputLayer getValue(double[] parameterValues) { - CenterLossOutputLayer.Builder b = new CenterLossOutputLayer.Builder(); - setLayerOptionsBuilder(b, parameterValues); - return b.build(); - } - - protected void setLayerBuilderOptions(CenterLossOutputLayer.Builder builder, double[] values){ - super.setLayerOptionsBuilder(builder, values); - if(alpha != null) - builder.alpha(alpha.getValue(values)); - if(lambda != null) - builder.lambda(lambda.getValue(values)); - } - - public static class Builder extends BaseOutputLayerSpace.Builder { - - ParameterSpace alpha; - ParameterSpace lambda; - - public Builder alpha(double alpha){ - return alpha(new FixedValue<>(alpha)); - } - - public Builder alpha(ParameterSpace alpha){ - this.alpha = alpha; - return this; - } - - public Builder lambda(double lambda){ - return lambda(new FixedValue<>(lambda)); - } - - public Builder lambda(ParameterSpace lambda){ - this.lambda = lambda; - return this; - } - - @Override - public CenterLossOutputLayerSpace build() { - return new CenterLossOutputLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ConvolutionLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ConvolutionLayerSpace.java deleted file mode 100644 index 110e5b6e7..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/ConvolutionLayerSpace.java +++ /dev/null @@ -1,172 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; - -/** - * Layer space for convolutional layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class ConvolutionLayerSpace extends FeedForwardLayerSpace { - protected ParameterSpace dilation; - protected ParameterSpace kernelSize; - protected ParameterSpace stride; - protected ParameterSpace padding; - protected ParameterSpace convolutionMode; - protected ParameterSpace hasBias; - - private ConvolutionLayerSpace(Builder builder) { - super(builder); - this.dilation = builder.dilation; - this.kernelSize = builder.kernelSize; - this.stride = builder.stride; - this.padding = builder.padding; - this.convolutionMode = builder.convolutionMode; - this.hasBias = builder.hasBias; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public ConvolutionLayer getValue(double[] values) { - ConvolutionLayer.Builder b = new ConvolutionLayer.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(ConvolutionLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (dilation != null) - builder.dilation(dilation.getValue(values)); - if (kernelSize != null) - builder.kernelSize(kernelSize.getValue(values)); - if (stride != null) - builder.stride(stride.getValue(values)); - if (padding != null) - builder.padding(padding.getValue(values)); - if (convolutionMode != null) - builder.convolutionMode(convolutionMode.getValue(values)); - if (hasBias != null) - builder.hasBias(hasBias.getValue(values)); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder("ConvolutionLayerSpace("); - if (dilation != null) - sb.append("dilation: ").append(dilation).append(delim); - if (kernelSize != null) - sb.append("kernelSize: ").append(kernelSize).append(delim); - if (stride != null) - sb.append("stride: ").append(stride).append(delim); - if (padding != null) - sb.append("padding: ").append(padding).append(delim); - if (convolutionMode != null) - sb.append("convolutionMode: ").append(convolutionMode).append(delim); - if (hasBias != null) - sb.append("hasBias: ").append(hasBias).append(delim); - sb.append(super.toString(delim)).append(")"); - return sb.toString(); - } - - - public static class Builder extends FeedForwardLayerSpace.Builder { - protected ParameterSpace dilation; - protected ParameterSpace kernelSize; - protected ParameterSpace stride; - protected ParameterSpace padding; - protected ParameterSpace convolutionMode; - protected ParameterSpace hasBias; - - public Builder dilation(int... dilation) { - return dilation(new FixedValue<>(dilation)); - } - - public Builder dilation(ParameterSpace dilation) { - this.dilation = dilation; - return this; - } - public Builder kernelSize(int... kernelSize) { - return kernelSize(new FixedValue<>(kernelSize)); - } - - public Builder kernelSize(ParameterSpace kernelSize) { - this.kernelSize = kernelSize; - return this; - } - - public Builder stride(int... stride) { - return stride(new FixedValue<>(stride)); - } - - public Builder stride(ParameterSpace stride) { - this.stride = stride; - return this; - } - - public Builder padding(int... padding) { - return padding(new FixedValue<>(padding)); - } - - public Builder padding(ParameterSpace padding) { - this.padding = padding; - return this; - } - - public Builder convolutionMode(ConvolutionMode convolutionMode) { - return convolutionMode(new FixedValue<>(convolutionMode)); - } - - public Builder convolutionMode(ParameterSpace convolutionMode) { - this.convolutionMode = convolutionMode; - return this; - } - - public Builder hasBias(boolean hasBias){ - return hasBias(new FixedValue<>(hasBias)); - } - - public Builder hasBias(ParameterSpace hasBias){ - this.hasBias = hasBias; - return this; - } - - public ConvolutionLayerSpace build() { - return new ConvolutionLayerSpace(this); - } - - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Deconvolution2DLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Deconvolution2DLayerSpace.java deleted file mode 100644 index 72231f246..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/Deconvolution2DLayerSpace.java +++ /dev/null @@ -1,52 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.nn.conf.layers.Deconvolution2D; - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public class Deconvolution2DLayerSpace extends BaseConvolutionLayerSpace { - - protected Deconvolution2DLayerSpace(Builder builder) { - super(builder); - } - - @Override - public Deconvolution2D getValue(double[] parameterValues) { - Deconvolution2D.Builder b = new Deconvolution2D.Builder(); - setLayerOptionsBuilder(b, parameterValues); - return b.build(); - } - - protected void setLayerOptionsBuilder(Deconvolution2D.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - } - - public static class Builder extends BaseConvolutionLayerSpace.Builder { - @Override - public Deconvolution2DLayerSpace build() { - return new Deconvolution2DLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DenseLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DenseLayerSpace.java deleted file mode 100644 index 4a7ac3f28..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DenseLayerSpace.java +++ /dev/null @@ -1,90 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.DenseLayer; - -/** - * layer hyperparameter configuration space for dense layers (i.e., multi-layer perceptron layers) - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor //For Jackson JSON/YAML deserialization -public class DenseLayerSpace extends FeedForwardLayerSpace { - - protected ParameterSpace hasBias; - - private DenseLayerSpace(Builder builder) { - super(builder); - - this.hasBias = builder.hasBias; - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public DenseLayer getValue(double[] values) { - //Using the builder here, to get default options - DenseLayer.Builder b = new DenseLayer.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(DenseLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if(hasBias != null) - builder.hasBias(hasBias.getValue(values)); - } - - public static class Builder extends FeedForwardLayerSpace.Builder { - - protected ParameterSpace hasBias; - - public Builder hasBias(boolean hasBias){ - return hasBias(new FixedValue<>(hasBias)); - } - - public Builder hasBias(ParameterSpace hasBias){ - this.hasBias = hasBias; - return this; - } - - @Override - @SuppressWarnings("unchecked") - public DenseLayerSpace build() { - return new DenseLayerSpace(this); - } - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - return "DenseLayerSpace(" + super.toString(delim) + ")"; - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DropoutLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DropoutLayerSpace.java deleted file mode 100644 index 1e6ca7157..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/DropoutLayerSpace.java +++ /dev/null @@ -1,89 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.*; -import org.deeplearning4j.arbiter.dropout.DropoutSpace; -import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.conf.dropout.IDropout; -import org.deeplearning4j.nn.conf.layers.DropoutLayer; - -import java.util.Collections; -import java.util.List; - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public class DropoutLayerSpace extends LayerSpace { - - public DropoutLayerSpace(@NonNull ParameterSpace dropout){ - this.dropOut = dropout; - } - - protected DropoutLayerSpace(Builder builder){ - super(builder); - } - - @Override - public DropoutLayer getValue(double[] parameterValues) { - return new DropoutLayer.Builder().dropOut(dropOut.getValue(parameterValues)).build(); - } - - @Override - public int numParameters() { - return dropOut.numParameters(); - } - - @Override - public List collectLeaves() { - return dropOut.collectLeaves(); - } - - - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - dropOut.setIndices(indices); - } - - public static class Builder extends LayerSpace.Builder { - - public Builder dropOut(double d){ - return iDropOut(new DropoutSpace(new FixedValue<>(d))); - } - - public Builder dropOut(ParameterSpace dropOut){ - return iDropOut(new DropoutSpace(dropOut)); - } - - public Builder iDropOut(ParameterSpace dropout){ - this.dropOut = dropout; - return this; - } - - public DropoutLayerSpace build(){ - return new DropoutLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/EmbeddingLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/EmbeddingLayerSpace.java deleted file mode 100644 index 7aa5c5444..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/EmbeddingLayerSpace.java +++ /dev/null @@ -1,88 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; - -/** - * Layer hyperparameter configuration space for {@link org.deeplearning4j.nn.conf.layers.EmbeddingLayer} - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class EmbeddingLayerSpace extends FeedForwardLayerSpace { - private ParameterSpace hasBias; - - private EmbeddingLayerSpace(Builder builder) { - super(builder); - this.hasBias = builder.hasBias; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public EmbeddingLayer getValue(double[] values) { - //Using the builder here, to get default options - EmbeddingLayer.Builder b = new EmbeddingLayer.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(EmbeddingLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if(hasBias != null) - builder.hasBias(hasBias.getValue(values)); - } - - public static class Builder extends FeedForwardLayerSpace.Builder { - protected ParameterSpace hasBias; - - public Builder hasBias(boolean hasBias){ - return hasBias(new FixedValue<>(hasBias)); - } - - public Builder hasBias(ParameterSpace hasBias){ - this.hasBias = hasBias; - return this; - } - - @Override - @SuppressWarnings("unchecked") - public EmbeddingLayerSpace build() { - return new EmbeddingLayerSpace(this); - } - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - return "EmbeddingLayerSpace(" + super.toString(delim) + ")"; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/FeedForwardLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/FeedForwardLayerSpace.java deleted file mode 100644 index 3ba0f3a06..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/FeedForwardLayerSpace.java +++ /dev/null @@ -1,154 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; - -import java.util.Arrays; -import java.util.List; - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor //For Jackson JSON/YAML deserialization -public abstract class FeedForwardLayerSpace extends BaseLayerSpace { - protected ParameterSpace nIn; - protected ParameterSpace nOut; - protected ParameterSpace> constrainWeights; - protected ParameterSpace> constrainBias; - protected ParameterSpace> constrainAll; - - - protected FeedForwardLayerSpace(Builder builder) { - super(builder); - nIn = builder.nIn; - nOut = builder.nOut; - constrainWeights = builder.constrainWeights; - constrainBias = builder.constrainBias; - constrainAll = builder.constrainAll; - } - - protected void setLayerOptionsBuilder(FeedForwardLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (nIn != null) - builder.nIn(nIn.getValue(values)); - if (nOut != null) - builder.nOut(nOut.getValue(values)); - if (constrainWeights != null){ - List c = constrainWeights.getValue(values); - if(c != null){ - builder.constrainWeights(c.toArray(new LayerConstraint[c.size()])); - } - } - if (constrainBias != null){ - List c = constrainBias.getValue(values); - if(c != null){ - builder.constrainBias(c.toArray(new LayerConstraint[c.size()])); - } - } - if (constrainAll != null){ - List c = constrainAll.getValue(values); - if(c != null){ - builder.constrainAllParameters(c.toArray(new LayerConstraint[c.size()])); - } - } - - } - - - public abstract static class Builder extends BaseLayerSpace.Builder { - - protected ParameterSpace nIn; - protected ParameterSpace nOut; - protected ParameterSpace> constrainWeights; - protected ParameterSpace> constrainBias; - protected ParameterSpace> constrainAll; - - public T nIn(int nIn) { - return nIn(new FixedValue<>(nIn)); - } - - public T nIn(ParameterSpace nIn) { - this.nIn = nIn; - return (T) this; - } - - public T nOut(int nOut) { - return nOut(new FixedValue<>(nOut)); - } - - public T nOut(ParameterSpace nOut) { - this.nOut = nOut; - return (T) this; - } - - public T constrainWeights(LayerConstraint... constraints){ - return constrainWeights(new FixedValue>(Arrays.asList(constraints))); - } - - public T constrainWeights(ParameterSpace> constraints){ - this.constrainWeights = constraints; - return (T) this; - } - - public T constrainBias(LayerConstraint... constraints){ - return constrainBias(new FixedValue>(Arrays.asList(constraints))); - } - - public T constrainBias(ParameterSpace> constraints){ - this.constrainBias = constraints; - return (T) this; - } - - public T constrainAllParams(LayerConstraint... constraints){ - return constrainAllParams(new FixedValue>(Arrays.asList(constraints))); - } - - public T constrainAllParams(ParameterSpace> constraints){ - this.constrainAll = constraints; - return (T) this; - } - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - protected String toString(String delim) { - StringBuilder sb = new StringBuilder(); - if (nIn != null) - sb.append("nIn: ").append(nIn).append(delim); - if (nOut != null) - sb.append("nOut: ").append(nOut).append(delim); - if (constrainWeights != null) - sb.append("constrainWeights: ").append(constrainWeights).append(delim); - if (constrainBias != null) - sb.append("constrainBias: ").append(constrainBias).append(delim); - if (constrainAll != null) - sb.append("constrainAllParams: ").append(constrainAll).append(delim); - sb.append(super.toString(delim)); - return sb.toString(); - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GlobalPoolingLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GlobalPoolingLayerSpace.java deleted file mode 100644 index 17bd22103..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GlobalPoolingLayerSpace.java +++ /dev/null @@ -1,135 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; -import org.deeplearning4j.nn.conf.layers.PoolingType; - -/** - * Layer space for a {@link GlobalPoolingLayer} - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class GlobalPoolingLayerSpace extends LayerSpace { - - protected ParameterSpace poolingDimensions; - protected ParameterSpace collapseDimensions; - protected ParameterSpace poolingType; - protected ParameterSpace pNorm; - - private int numParameters; - - private GlobalPoolingLayerSpace(Builder builder) { - super(builder); - this.poolingDimensions = builder.poolingDimensions; - this.collapseDimensions = builder.collapseDimensions; - this.poolingType = builder.poolingType; - this.pNorm = builder.pNorm; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public GlobalPoolingLayer getValue(double[] parameterValues) { - GlobalPoolingLayer.Builder builder = new GlobalPoolingLayer.Builder(); - super.setLayerOptionsBuilder(builder, parameterValues); - if (poolingDimensions != null) - builder.poolingDimensions(poolingDimensions.getValue(parameterValues)); - if (collapseDimensions != null) - builder.collapseDimensions(collapseDimensions.getValue(parameterValues)); - if (poolingType != null) - builder.poolingType(poolingType.getValue(parameterValues)); - if (pNorm != null) - builder.pnorm(pNorm.getValue(parameterValues)); - return builder.build(); - } - - @Override - public int numParameters() { - return numParameters; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - throw new UnsupportedOperationException("Cannot set indices for non-leaf parameter space"); - } - - - - public static class Builder extends LayerSpace.Builder { - - protected ParameterSpace poolingDimensions; - protected ParameterSpace collapseDimensions; - protected ParameterSpace poolingType; - protected ParameterSpace pNorm; - - public Builder poolingDimensions(int... poolingDimensions) { - return poolingDimensions(new FixedValue<>(poolingDimensions)); - } - - public Builder poolingDimensions(ParameterSpace poolingDimensions) { - this.poolingDimensions = poolingDimensions; - return this; - } - - public Builder collapseDimensions(boolean collapseDimensions) { - return collapseDimensions(new FixedValue<>(collapseDimensions)); - } - - public Builder collapseDimensions(ParameterSpace collapseDimensions) { - this.collapseDimensions = collapseDimensions; - return this; - } - - public Builder poolingType(PoolingType poolingType) { - return poolingType(new FixedValue<>(poolingType)); - } - - public Builder poolingType(ParameterSpace poolingType) { - this.poolingType = poolingType; - return this; - } - - public Builder pNorm(int pNorm) { - return pNorm(new FixedValue<>(pNorm)); - } - - public Builder pNorm(ParameterSpace pNorm) { - this.pNorm = pNorm; - return this; - } - - public GlobalPoolingLayerSpace build() { - return new GlobalPoolingLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesBidirectionalLSTMLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesBidirectionalLSTMLayerSpace.java deleted file mode 100644 index e42deacbe..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesBidirectionalLSTMLayerSpace.java +++ /dev/null @@ -1,97 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; - -import java.util.List; - -/** - * Layer space for Bidirectional LSTM layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class GravesBidirectionalLSTMLayerSpace extends FeedForwardLayerSpace { - - private ParameterSpace forgetGateBiasInit; - - private GravesBidirectionalLSTMLayerSpace(Builder builder) { - super(builder); - this.forgetGateBiasInit = builder.forgetGateBiasInit; - - List l = collectLeaves(); - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - - @Override - public GravesBidirectionalLSTM getValue(double[] values) { - GravesBidirectionalLSTM.Builder b = new GravesBidirectionalLSTM.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(GravesBidirectionalLSTM.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (forgetGateBiasInit != null) - builder.forgetGateBiasInit(forgetGateBiasInit.getValue(values)); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder("GravesBidirectionalLSTMLayerSpace("); - if (forgetGateBiasInit != null) - sb.append("forgetGateBiasInit: ").append(forgetGateBiasInit).append(delim); - sb.append(super.toString(delim)).append(")"); - return sb.toString(); - } - - public static class Builder extends FeedForwardLayerSpace.Builder { - - private ParameterSpace forgetGateBiasInit; - - public Builder forgetGateBiasInit(double forgetGateBiasInit) { - return forgetGateBiasInit(new FixedValue<>(forgetGateBiasInit)); - } - - public Builder forgetGateBiasInit(ParameterSpace forgetGateBiasInit) { - this.forgetGateBiasInit = forgetGateBiasInit; - return this; - } - - @Override - @SuppressWarnings("unchecked") - public GravesBidirectionalLSTMLayerSpace build() { - return new GravesBidirectionalLSTMLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesLSTMLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesLSTMLayerSpace.java deleted file mode 100644 index 9707836fa..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/GravesLSTMLayerSpace.java +++ /dev/null @@ -1,76 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.GravesLSTM; - -/** - * Layer space for LSTM layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class GravesLSTMLayerSpace extends AbstractLSTMLayerSpace { - - private GravesLSTMLayerSpace(Builder builder) { - super(builder); - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - - @Override - public GravesLSTM getValue(double[] values) { - GravesLSTM.Builder b = new GravesLSTM.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(GravesLSTM.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder("GravesLSTMLayerSpace("); - sb.append(super.toString(delim)).append(")"); - return sb.toString(); - } - - public static class Builder extends AbstractLSTMLayerSpace.Builder { - - @Override - @SuppressWarnings("unchecked") - public GravesLSTMLayerSpace build() { - return new GravesLSTMLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LSTMLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LSTMLayerSpace.java deleted file mode 100644 index 10e297134..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LSTMLayerSpace.java +++ /dev/null @@ -1,77 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.GravesLSTM; -import org.deeplearning4j.nn.conf.layers.LSTM; - -/** - * Layer space for LSTM layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class LSTMLayerSpace extends AbstractLSTMLayerSpace { - - private LSTMLayerSpace(Builder builder) { - super(builder); - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - - @Override - public LSTM getValue(double[] values) { - LSTM.Builder b = new LSTM.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(LSTM.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder("LSTMLayerSpace("); - sb.append(super.toString(delim)).append(")"); - return sb.toString(); - } - - public static class Builder extends AbstractLSTMLayerSpace.Builder { - - @Override - @SuppressWarnings("unchecked") - public LSTMLayerSpace build() { - return new LSTMLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LayerSpace.java deleted file mode 100644 index eb77196d2..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LayerSpace.java +++ /dev/null @@ -1,138 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.dropout.DropoutSpace; -import org.deeplearning4j.arbiter.optimize.api.AbstractParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.conf.dropout.IDropout; -import org.deeplearning4j.nn.conf.layers.Layer; -import com.fasterxml.jackson.annotation.JsonInclude; - -import java.util.ArrayList; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; - -/** - * LayerSpace contains common Layer hyperparameters; should match {@link Layer} in terms of features - * - * @author Alex Black - */ -@JsonInclude(JsonInclude.Include.NON_NULL) -@Data -@EqualsAndHashCode(callSuper = false) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public abstract class LayerSpace extends AbstractParameterSpace { - protected ParameterSpace dropOut; - protected int numParameters; - - protected LayerSpace(Builder builder) { - this.dropOut = builder.dropOut; - } - - @Override - public List collectLeaves() { - //To avoid manually coding EVERY parameter, in every layer: - // Do a depth-first search of nested spaces - LinkedList stack = new LinkedList<>(); - stack.add(this); - - List out = new ArrayList<>(); - while (!stack.isEmpty()) { - ParameterSpace next = stack.removeLast(); - if (next.isLeaf()) { - out.add(next); - } else { - Map m = next.getNestedSpaces(); - ParameterSpace[] arr = m.values().toArray(new ParameterSpace[m.size()]); - for (int i = arr.length - 1; i >= 0; i--) { - stack.add(arr[i]); - } - } - } - - return out; - } - - @Override - public int numParameters() { - return numParameters; - } - - @Override - public boolean isLeaf() { - return false; - } - - @Override - public void setIndices(int... indices) { - throw new UnsupportedOperationException("Cannot set indices for non-leaf parameter space"); - } - - - protected void setLayerOptionsBuilder(Layer.Builder builder, double[] values) { - if (dropOut != null) - builder.dropOut(dropOut.getValue(values)); - } - - - @Override - public String toString() { - return toString(", "); - } - - protected String toString(String delim) { - StringBuilder sb = new StringBuilder(); - if (dropOut != null) - sb.append("dropOut: ").append(dropOut).append(delim); - String s = sb.toString(); - - if (s.endsWith(delim)) { - //Remove final delimiter - int last = s.lastIndexOf(delim); - return s.substring(0, last); - } else - return s; - } - - @SuppressWarnings("unchecked") - public abstract static class Builder { - protected ParameterSpace dropOut; - - public T dropOut(double dropOut) { - return dropOut(new FixedValue<>(dropOut)); - } - - public T dropOut(ParameterSpace dropOut) { - return iDropOut(new DropoutSpace(dropOut)); - } - - public T iDropOut(ParameterSpace dropOut){ - this.dropOut = dropOut; - return (T) this; - } - - public abstract E build(); - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LocalResponseNormalizationLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LocalResponseNormalizationLayerSpace.java deleted file mode 100644 index eeeb5837f..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LocalResponseNormalizationLayerSpace.java +++ /dev/null @@ -1,119 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization; - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class LocalResponseNormalizationLayerSpace extends LayerSpace { - - private ParameterSpace n; - private ParameterSpace k; - private ParameterSpace alpha; - private ParameterSpace beta; - - - private LocalResponseNormalizationLayerSpace(Builder builder) { - super(builder); - this.n = builder.n; - this.k = builder.k; - this.alpha = builder.alpha; - this.beta = builder.beta; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public LocalResponseNormalization getValue(double[] values) { - LocalResponseNormalization.Builder b = new LocalResponseNormalization.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(LocalResponseNormalization.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (n != null) - builder.n(n.getValue(values)); - if (k != null) - builder.k(k.getValue(values)); - if (alpha != null) - builder.alpha(alpha.getValue(values)); - if (beta != null) - builder.beta(beta.getValue(values)); - } - - - public static class Builder extends LayerSpace.Builder { - - private ParameterSpace n; - private ParameterSpace k; - private ParameterSpace alpha; - private ParameterSpace beta; - - - public Builder n(double n) { - return n(new FixedValue<>(n)); - } - - public Builder n(ParameterSpace n) { - this.n = n; - return this; - } - - public Builder k(double k) { - return k(new FixedValue<>(k)); - } - - public Builder k(ParameterSpace k) { - this.k = k; - return this; - } - - public Builder alpha(double alpha) { - return alpha(new FixedValue<>(alpha)); - } - - public Builder alpha(ParameterSpace alpha) { - this.alpha = alpha; - return this; - } - - public Builder beta(double beta) { - return beta(new FixedValue<>(beta)); - } - - public Builder beta(ParameterSpace beta) { - this.beta = beta; - return this; - } - - public LocalResponseNormalizationLayerSpace build() { - return new LocalResponseNormalizationLayerSpace(this); - } - - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LossLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LossLayerSpace.java deleted file mode 100644 index fc0b8c4d1..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/LossLayerSpace.java +++ /dev/null @@ -1,105 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.adapter.ActivationParameterSpaceAdapter; -import org.deeplearning4j.arbiter.adapter.LossFunctionParameterSpaceAdapter; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.LossLayer; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public class LossLayerSpace extends LayerSpace { - - private ParameterSpace activationFunction; - protected ParameterSpace lossFunction; - - public LossLayerSpace(Builder builder){ - super(builder); - this.activationFunction = builder.activationFunction; - this.lossFunction = builder.lossFunction; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public LossLayer getValue(double[] parameterValues) { - LossLayer.Builder b = new LossLayer.Builder(); - if(activationFunction != null) - b.activation(activationFunction.getValue(parameterValues)); - if(lossFunction != null) - b.lossFunction(lossFunction.getValue(parameterValues)); - return b.build(); - } - - - public static class Builder extends LayerSpace.Builder{ - - private ParameterSpace activationFunction; - protected ParameterSpace lossFunction; - - public Builder lossFunction(LossFunctions.LossFunction lossFunction) { - return lossFunction(new FixedValue<>(lossFunction)); - } - - public Builder lossFunction(ParameterSpace lossFunction) { - return iLossFunction(new LossFunctionParameterSpaceAdapter(lossFunction)); - } - - public Builder iLossFunction(ILossFunction lossFunction) { - return iLossFunction(new FixedValue<>(lossFunction)); - } - - public Builder iLossFunction(ParameterSpace lossFunction) { - this.lossFunction = lossFunction; - return this; - } - - public Builder activation(Activation activation) { - return activation(new FixedValue<>(activation)); - } - - public Builder activation(IActivation iActivation) { - return activationFn(new FixedValue<>(iActivation)); - } - - public Builder activation(ParameterSpace activationFunction) { - return activationFn(new ActivationParameterSpaceAdapter(activationFunction)); - } - - public Builder activationFn(ParameterSpace activationFunction) { - this.activationFunction = activationFunction; - return this; - } - - @Override - public LossLayerSpace build() { - return new LossLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OCNNLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OCNNLayerSpace.java deleted file mode 100644 index d4fc9553b..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OCNNLayerSpace.java +++ /dev/null @@ -1,153 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer; - - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class OCNNLayerSpace extends BaseOutputLayerSpace { - - - protected ParameterSpace nuSpace; - protected ParameterSpace initialRValue; - protected ParameterSpace hiddenLayerSize; - protected ParameterSpace windowSize; - protected ParameterSpace configureR; - - private OCNNLayerSpace(Builder builder) { - super(builder); - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - this.nuSpace = builder.nuSpace; - this.initialRValue = builder.initialRValue; - this.hiddenLayerSize = builder.hiddenLayerSize; - this.configureR = builder.configureR; - } - - - @Override - public OCNNOutputLayer getValue(double[] parameterValues) { - OCNNOutputLayer.Builder o = new OCNNOutputLayer.Builder(); - setLayerOptionsBuilder(o, parameterValues); - return o.build(); - } - - protected void setLayerOptionsBuilder(OCNNOutputLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - builder.nu(nuSpace.getValue(values)); - builder.hiddenLayerSize(hiddenLayerSize.getValue(values)); - builder.initialRValue(initialRValue.getValue(values)); - builder.configureR(configureR.getValue(values)); - builder.windowSize(windowSize.getValue(values)); - } - - - public static class Builder extends BaseOutputLayerSpace.Builder { - protected ParameterSpace nuSpace; - protected ParameterSpace initialRValue; - protected ParameterSpace hiddenLayerSize; - protected ParameterSpace windowSize; - protected ParameterSpace configureR; - - public Builder nu(ParameterSpace nuSpace) { - this.nuSpace = nuSpace; - return this; - } - - /** - * Use hiddenLayerSize instead - * @param numHiddenSpace - * @return - */ - @Deprecated - public Builder numHidden(ParameterSpace numHiddenSpace) { - return hiddenLayerSize(numHiddenSpace); - } - - /** - * Use hiddenLayerSize instead - * @param numHidden - * @return - */ - @Deprecated - public Builder numHidden(int numHidden) { - return hiddenLayerSize(numHidden); - } - - public Builder hiddenLayerSize(ParameterSpace hiddenLayerSize) { - this.hiddenLayerSize = hiddenLayerSize; - return this; - } - - public Builder hiddenLayerSize(int hiddenLayerSize) { - this.hiddenLayerSize = new FixedValue<>(hiddenLayerSize); - return this; - } - - public Builder nu(double nu) { - this.nuSpace = new FixedValue<>(nu); - return this; - } - - public Builder initialRValue(double initialRValue) { - this.initialRValue = new FixedValue<>(initialRValue); - return this; - } - - public Builder initialRValue(ParameterSpace initialRValue) { - this.initialRValue = initialRValue; - return this; - } - - public Builder windowSize(int windowSize) { - this.windowSize = new FixedValue<>(windowSize); - return this; - } - - public Builder windowSize(ParameterSpace windowSize) { - this.windowSize = windowSize; - return this; - } - - public Builder configureR(boolean configureR) { - this.configureR = new FixedValue<>(configureR); - return this; - } - - public Builder configureR(ParameterSpace configureR) { - this.configureR = configureR; - return this; - } - - - @Override - @SuppressWarnings("unchecked") - public OCNNLayerSpace build() { - return new OCNNLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OutputLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OutputLayerSpace.java deleted file mode 100644 index 5e6479fce..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/OutputLayerSpace.java +++ /dev/null @@ -1,71 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.OutputLayer; - -/** - * Layer hyperparameter configuration space for output layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class OutputLayerSpace extends BaseOutputLayerSpace { - - private OutputLayerSpace(Builder builder) { - super(builder); - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public OutputLayer getValue(double[] values) { - OutputLayer.Builder o = new OutputLayer.Builder(); - setLayerOptionsBuilder(o, values); - return o.build(); - } - - protected void setLayerOptionsBuilder(OutputLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - } - - public static class Builder extends BaseOutputLayerSpace.Builder { - - @Override - @SuppressWarnings("unchecked") - public OutputLayerSpace build() { - return new OutputLayerSpace(this); - } - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - return "OutputLayerSpace(" + super.toString(delim) + ")"; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/RnnOutputLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/RnnOutputLayerSpace.java deleted file mode 100644 index 4fba80d81..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/RnnOutputLayerSpace.java +++ /dev/null @@ -1,71 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; - -/** - * Layer hyperparametor configuration space for RnnOutputLayer - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class RnnOutputLayerSpace extends BaseOutputLayerSpace { - - private RnnOutputLayerSpace(Builder builder) { - super(builder); - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public RnnOutputLayer getValue(double[] values) { - RnnOutputLayer.Builder b = new RnnOutputLayer.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(RnnOutputLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - return "RnnOutputLayerSpace(" + super.toString(delim) + ")"; - } - - public static class Builder extends BaseOutputLayerSpace.Builder { - - @Override - @SuppressWarnings("unchecked") - public RnnOutputLayerSpace build() { - return new RnnOutputLayerSpace(this); - } - } - - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SeparableConvolution2DLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SeparableConvolution2DLayerSpace.java deleted file mode 100644 index 64a0d26a6..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SeparableConvolution2DLayerSpace.java +++ /dev/null @@ -1,101 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D; - -import java.util.Arrays; -import java.util.List; - -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For Jackson JSON/YAML deserialization -public class SeparableConvolution2DLayerSpace extends BaseConvolutionLayerSpace { - - private ParameterSpace depthMultiplier; - protected ParameterSpace> pointWiseConstraints; - - protected SeparableConvolution2DLayerSpace(Builder builder){ - super(builder); - this.depthMultiplier = builder.depthMultiplier; - this.pointWiseConstraints = builder.pointWiseConstraints; - } - - @Override - public SeparableConvolution2D getValue(double[] parameterValues) { - SeparableConvolution2D.Builder b = new SeparableConvolution2D.Builder(); - setLayerOptionsBuilder(b, parameterValues); - return b.build(); - } - - protected void setLayerOptionsBuilder(SeparableConvolution2D.Builder builder, double[] values){ - super.setLayerOptionsBuilder(builder, values); - if (kernelSize != null) - builder.kernelSize(kernelSize.getValue(values)); - if (stride != null) - builder.stride(stride.getValue(values)); - if (padding != null) - builder.padding(padding.getValue(values)); - if (convolutionMode != null) - builder.convolutionMode(convolutionMode.getValue(values)); - if (hasBias != null) - builder.hasBias(hasBias.getValue(values)); - if (depthMultiplier != null) - builder.depthMultiplier(depthMultiplier.getValue(values)); - if (pointWiseConstraints != null){ - List c = pointWiseConstraints.getValue(values); - if(c != null){ - builder.constrainPointWise(c.toArray(new LayerConstraint[c.size()])); - } - } - } - - - public static class Builder extends BaseConvolutionLayerSpace.Builder{ - private ParameterSpace depthMultiplier; - protected ParameterSpace> pointWiseConstraints; - - public Builder constrainPointWise(LayerConstraint... constraints){ - return constrainPointWise(new FixedValue>(Arrays.asList(constraints))); - } - - public Builder constrainPointWise(ParameterSpace> constraints){ - this.pointWiseConstraints = constraints; - return this; - } - - public Builder depthMultiplier(int depthMultiplier){ - return depthMultiplier(new FixedValue<>(depthMultiplier)); - } - - public Builder depthMultiplier(ParameterSpace depthMultiplier){ - this.depthMultiplier = depthMultiplier; - return this; - } - - public SeparableConvolution2DLayerSpace build(){ - return new SeparableConvolution2DLayerSpace(this); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SubsamplingLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SubsamplingLayerSpace.java deleted file mode 100644 index 5f1e32dab..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/SubsamplingLayerSpace.java +++ /dev/null @@ -1,208 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; - -/** - * Layer hyperparameter configuration space for subsampling layers - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class SubsamplingLayerSpace extends LayerSpace { - - protected ParameterSpace convolutionMode; - protected ParameterSpace poolingType; - protected ParameterSpace dilation; - protected ParameterSpace kernelSize; - protected ParameterSpace stride; - protected ParameterSpace padding; - protected ParameterSpace pnorm; - protected ParameterSpace eps; - - private SubsamplingLayerSpace(Builder builder) { - super(builder); - this.convolutionMode = builder.convolutionMode; - this.poolingType = builder.poolingType; - this.kernelSize = builder.kernelSize; - this.dilation = builder.dilation; - this.stride = builder.stride; - this.padding = builder.padding; - this.pnorm = builder.pnorm; - this.eps = builder.eps; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public SubsamplingLayer getValue(double[] values) { - SubsamplingLayer.Builder b = new SubsamplingLayer.Builder(); - setLayerOptionsBuilder(b, values); - return b.build(); - } - - protected void setLayerOptionsBuilder(SubsamplingLayer.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (convolutionMode != null) - builder.convolutionMode(convolutionMode.getValue(values)); - if (poolingType != null) - builder.poolingType(poolingType.getValue(values)); - if (dilation !=null) - builder.dilation(dilation.getValue(values)); - if (kernelSize != null) - builder.kernelSize(kernelSize.getValue(values)); - if (stride != null) - builder.stride(stride.getValue(values)); - if (padding != null) - builder.padding(padding.getValue(values)); - if(pnorm != null) - builder.pnorm(pnorm.getValue(values)); - if(eps != null) - builder.eps(eps.getValue(values)); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder("SubsamplingLayerSpace("); - if (convolutionMode != null) - sb.append("convolutionMode: ").append(convolutionMode).append(delim); - if (poolingType != null) - sb.append("poolingType: ").append(poolingType).append(delim); - if (dilation != null) - sb.append("dilation: ").append(dilation).append(delim); - if (kernelSize != null) - sb.append("kernelSize: ").append(kernelSize).append(delim); - if (stride != null) - sb.append("stride: ").append(stride).append(delim); - if (padding != null) - sb.append("padding: ").append(padding).append(delim); - if (pnorm != null) - sb.append("pnorm: ").append(pnorm).append(delim); - if (eps != null) - sb.append("eps: ").append(eps).append(delim); - sb.append(super.toString(delim)).append(")"); - return sb.toString(); - } - - - public static class Builder extends FeedForwardLayerSpace.Builder { - - protected ParameterSpace convolutionMode; - protected ParameterSpace poolingType; - protected ParameterSpace dilation; - protected ParameterSpace kernelSize; - protected ParameterSpace stride; - protected ParameterSpace padding; - protected ParameterSpace pnorm; - protected ParameterSpace eps; - - public Builder convolutionMode(ConvolutionMode convolutionMode){ - return convolutionMode(new FixedValue<>(convolutionMode)); - } - - public Builder convolutionMode(ParameterSpace convolutionMode){ - this.convolutionMode = convolutionMode; - return this; - } - - public Builder poolingType(SubsamplingLayer.PoolingType poolingType) { - return poolingType(new FixedValue<>(poolingType)); - } - - public Builder poolingType(ParameterSpace poolingType) { - this.poolingType = poolingType; - return this; - } - - public Builder dilation(int... dilation) { - return dilation(new FixedValue<>(dilation)); - } - - public Builder dilation(ParameterSpace dilation) { - this.dilation = dilation; - return this; - } - - public Builder kernelSize(int... kernelSize) { - return kernelSize(new FixedValue<>(kernelSize)); - } - - public Builder kernelSize(ParameterSpace kernelSize) { - this.kernelSize = kernelSize; - return this; - } - - public Builder stride(int... stride) { - return stride(new FixedValue(stride)); - } - - public Builder stride(ParameterSpace stride) { - this.stride = stride; - return this; - } - - public Builder padding(int... padding) { - return padding(new FixedValue(padding)); - } - - public Builder padding(ParameterSpace padding) { - this.padding = padding; - return this; - } - - public Builder pnorm(int pnorm){ - return pnorm(new FixedValue<>(pnorm)); - } - - public Builder pnorm(ParameterSpace pnorm){ - this.pnorm = pnorm; - return this; - } - - public Builder eps(double eps){ - return eps(new FixedValue<>(eps)); - } - - public Builder eps(ParameterSpace eps){ - this.eps = eps; - return this; - } - - @SuppressWarnings("unchecked") - public SubsamplingLayerSpace build() { - return new SubsamplingLayerSpace(this); - } - - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/VariationalAutoencoderLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/VariationalAutoencoderLayerSpace.java deleted file mode 100644 index 2138ea8ec..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/VariationalAutoencoderLayerSpace.java +++ /dev/null @@ -1,182 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper; -import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution; -import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -/** - * Layer space for {@link VariationalAutoencoder} - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PRIVATE) //For Jackson JSON/YAML deserialization -public class VariationalAutoencoderLayerSpace extends BasePretrainNetworkLayerSpace { - - private ParameterSpace encoderLayerSizes; - private ParameterSpace decoderLayerSizes; - private ParameterSpace outputDistribution; - private ParameterSpace pzxActivationFn; - private ParameterSpace numSamples; - - protected VariationalAutoencoderLayerSpace(Builder builder) { - super(builder); - - this.encoderLayerSizes = builder.encoderLayerSizes; - this.decoderLayerSizes = builder.decoderLayerSizes; - this.outputDistribution = builder.outputDistribution; - this.pzxActivationFn = builder.pzxActivationFn; - this.numSamples = builder.numSamples; - - this.numParameters = LeafUtils.countUniqueParameters(collectLeaves()); - } - - @Override - public VariationalAutoencoder getValue(double[] parameterValues) { - VariationalAutoencoder.Builder b = new VariationalAutoencoder.Builder(); - setLayerOptionsBuilder(b, parameterValues); - return b.build(); - } - - protected void setLayerOptionsBuilder(VariationalAutoencoder.Builder builder, double[] values) { - super.setLayerOptionsBuilder(builder, values); - if (encoderLayerSizes != null) - builder.encoderLayerSizes(encoderLayerSizes.getValue(values)); - if (decoderLayerSizes != null) - builder.decoderLayerSizes(decoderLayerSizes.getValue(values)); - if (outputDistribution != null) - builder.reconstructionDistribution(outputDistribution.getValue(values)); - if (pzxActivationFn != null) - builder.pzxActivationFn(pzxActivationFn.getValue(values)); - if (numSamples != null) - builder.numSamples(numSamples.getValue(values)); - } - - @Override - public String toString() { - return toString(", "); - } - - @Override - public String toString(String delim) { - StringBuilder sb = new StringBuilder("VariationalAutoencoderLayerSpace("); - if (encoderLayerSizes != null) - sb.append("encoderLayerSizes: ").append(encoderLayerSizes).append(delim); - if (decoderLayerSizes != null) - sb.append("decoderLayerSizes: ").append(decoderLayerSizes).append(delim); - if (outputDistribution != null) - sb.append("reconstructionDistribution: ").append(outputDistribution).append(delim); - if (pzxActivationFn != null) - sb.append("pzxActivationFn: ").append(pzxActivationFn).append(delim); - if (numSamples != null) - sb.append("numSamples: ").append(numSamples).append(delim); - sb.append(super.toString(delim)).append(")"); - return sb.toString(); - } - - public static class Builder extends BasePretrainNetworkLayerSpace.Builder { - - private ParameterSpace encoderLayerSizes; - private ParameterSpace decoderLayerSizes; - private ParameterSpace outputDistribution; - private ParameterSpace pzxActivationFn; - private ParameterSpace numSamples; - - - public Builder encoderLayerSizes(int... encoderLayerSizes) { - return encoderLayerSizes(new FixedValue<>(encoderLayerSizes)); - } - - public Builder encoderLayerSizes(ParameterSpace encoderLayerSizes) { - this.encoderLayerSizes = encoderLayerSizes; - return this; - } - - public Builder decoderLayerSizes(int... decoderLayerSizes) { - return decoderLayerSizes(new FixedValue<>(decoderLayerSizes)); - } - - public Builder decoderLayerSizes(ParameterSpace decoderLayerSizes) { - this.decoderLayerSizes = decoderLayerSizes; - return this; - } - - public Builder reconstructionDistribution(ReconstructionDistribution distribution) { - return reconstructionDistribution(new FixedValue<>(distribution)); - } - - public Builder reconstructionDistribution(ParameterSpace distribution) { - this.outputDistribution = distribution; - return this; - } - - public Builder lossFunction(IActivation outputActivationFn, LossFunctions.LossFunction lossFunction) { - return lossFunction(outputActivationFn, lossFunction.getILossFunction()); - } - - public Builder lossFunction(Activation outputActivationFn, LossFunctions.LossFunction lossFunction) { - return lossFunction(outputActivationFn.getActivationFunction(), lossFunction.getILossFunction()); - } - - public Builder lossFunction(IActivation outputActivationFn, ILossFunction lossFunction) { - return reconstructionDistribution(new LossFunctionWrapper(outputActivationFn, lossFunction)); - } - - public Builder pzxActivationFn(IActivation activationFunction) { - return pzxActivationFn(new FixedValue<>(activationFunction)); - } - - public Builder pzxActivationFn(ParameterSpace activationFunction) { - this.pzxActivationFn = activationFunction; - return this; - } - - public Builder pzxActivationFunction(Activation activation) { - return pzxActivationFn(activation.getActivationFunction()); - } - - public Builder numSamples(int numSamples) { - return numSamples(new FixedValue<>(numSamples)); - } - - public Builder numSamples(ParameterSpace numSamples) { - this.numSamples = numSamples; - return this; - } - - - @Override - public E build() { - return (E) new VariationalAutoencoderLayerSpace(this); - } - - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/fixed/FixedLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/fixed/FixedLayerSpace.java deleted file mode 100644 index fb1afc299..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/layers/fixed/FixedLayerSpace.java +++ /dev/null @@ -1,71 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.layers.fixed; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.layers.LayerSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.nn.conf.layers.Layer; - -import java.util.Collections; -import java.util.List; - -/** - * A layer space that wraps a DL4J layer, without any optimizable hyperparameters - * - * @param Type of layer - * - * @author Alex Black - */ -@AllArgsConstructor -@NoArgsConstructor -@Data -@EqualsAndHashCode(callSuper = false) -public class FixedLayerSpace extends LayerSpace { - - protected T layer; - - @Override - public T getValue(double[] parameterValues) { - return (T)layer.clone(); - } - - @Override - public int numParameters() { - return 0; - } - - @Override - public boolean isLeaf() { - return true; - } - - @Override - public void setIndices(int[] idxs){ - if(idxs != null && idxs.length > 0){ - throw new IllegalStateException("Cannot set indices: no parameters"); - } - } - - @Override - public List collectLeaves() { - return Collections.singletonList(this); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/listener/DL4JArbiterStatusReportingListener.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/listener/DL4JArbiterStatusReportingListener.java deleted file mode 100644 index 0c89984c9..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/listener/DL4JArbiterStatusReportingListener.java +++ /dev/null @@ -1,49 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.listener; - -import lombok.AllArgsConstructor; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; -import org.deeplearning4j.nn.api.Model; -import org.deeplearning4j.optimize.api.BaseTrainingListener; -import org.deeplearning4j.optimize.api.IterationListener; - -import java.util.List; - -/** - * A simple DL4J Iteration listener that calls Arbiter's status listeners - * - * @author Alex Black - */ -@AllArgsConstructor -public class DL4JArbiterStatusReportingListener extends BaseTrainingListener { - - private List statusListeners; - private CandidateInfo candidateInfo; - - @Override - public void iterationDone(Model model, int iteration, int epoch) { - if (statusListeners == null) { - return; - } - - for (StatusListener sl : statusListeners) { - sl.onCandidateIteration(candidateInfo, model, iteration); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/saver/local/FileModelSaver.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/saver/local/FileModelSaver.java deleted file mode 100644 index 167f1f9d1..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/saver/local/FileModelSaver.java +++ /dev/null @@ -1,147 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.saver.local; - -import lombok.AllArgsConstructor; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; -import org.apache.commons.io.FilenameUtils; -import org.deeplearning4j.arbiter.DL4JConfiguration; -import org.deeplearning4j.arbiter.GraphConfiguration; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.nn.api.Model; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.util.ModelSerializer; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.io.*; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -/** - * Basic MultiLayerNetwork saver. Saves config, parameters and score to: baseDir/0/, baseDir/1/, etc - * where index is given by OptimizationResult.getIndex() - * - * @author Alex Black - */ -@Slf4j -@NoArgsConstructor -@AllArgsConstructor -@EqualsAndHashCode -public class FileModelSaver implements ResultSaver { - @JsonProperty - private String path; - private File fPath; - - @JsonCreator - public FileModelSaver(@NonNull String path) { - this(new File(path)); - } - - public FileModelSaver(@NonNull File file){ - this.path = file.getPath(); - this.fPath = file; - - if(!fPath.exists()){ - fPath.mkdirs(); - } else if (!fPath.isDirectory()) { - throw new IllegalArgumentException("Invalid path: exists and is not directory. " + path); - } - - log.info("FileModelSaver saving networks to local directory: {}", path); - } - - @Override - public ResultReference saveModel(OptimizationResult result, Object modelResult) throws IOException { - String dir = new File(path, result.getIndex() + "/").getAbsolutePath(); - - File f = new File(dir); - f.mkdir(); - - File modelFile = new File(FilenameUtils.concat(dir, "model.bin")); - File scoreFile = new File(FilenameUtils.concat(dir, "score.txt")); - File additionalResultsFile = new File(FilenameUtils.concat(dir, "additionalResults.bin")); - File esConfigFile = new File(FilenameUtils.concat(dir, "earlyStoppingConfig.bin")); - File numEpochsFile = new File(FilenameUtils.concat(dir, "numEpochs.txt")); - - FileUtils.writeStringToFile(scoreFile, String.valueOf(result.getScore())); - - Model m = (Model) modelResult; - ModelSerializer.writeModel(m, modelFile, true); - - - Object additionalResults = result.getModelSpecificResults(); - if (additionalResults != null && additionalResults instanceof Serializable) { - try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(additionalResultsFile))) { - oos.writeObject(additionalResults); - } - } - - //Write early stopping configuration (if present) to file: - int nEpochs; - EarlyStoppingConfiguration esc; - if (result.getCandidate().getValue() instanceof DL4JConfiguration) { - DL4JConfiguration c = ((DL4JConfiguration) result.getCandidate().getValue()); - esc = c.getEarlyStoppingConfiguration(); - nEpochs = c.getNumEpochs(); - } else { - GraphConfiguration c = ((GraphConfiguration) result.getCandidate().getValue()); - esc = c.getEarlyStoppingConfiguration(); - nEpochs = c.getNumEpochs(); - } - - - if (esc != null) { - try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(esConfigFile))) { - oos.writeObject(esc); - } - } else { - FileUtils.writeStringToFile(numEpochsFile, String.valueOf(nEpochs)); - } - - log.debug("Deeplearning4j model result (id={}, score={}) saved to directory: {}", result.getIndex(), - result.getScore(), dir); - - boolean isGraph = m instanceof ComputationGraph; - return new LocalFileNetResultReference(result.getIndex(), dir, isGraph, modelFile, scoreFile, - additionalResultsFile, esConfigFile, numEpochsFile, result.getCandidate()); - } - - @Override - public List> getSupportedCandidateTypes() { - return Collections.>singletonList(Object.class); - } - - @Override - public List> getSupportedModelTypes() { - return Arrays.>asList(MultiLayerNetwork.class, ComputationGraph.class); - } - - @Override - public String toString() { - return "FileModelSaver(path=" + path + ")"; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/saver/local/LocalFileNetResultReference.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/saver/local/LocalFileNetResultReference.java deleted file mode 100644 index db46e011e..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/saver/local/LocalFileNetResultReference.java +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.saver.local; - -import lombok.AllArgsConstructor; -import org.apache.commons.io.FileUtils; -import org.deeplearning4j.arbiter.DL4JConfiguration; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.nn.api.Model; -import org.deeplearning4j.util.ModelSerializer; - -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.ObjectInputStream; - -/** - * Result reference for MultiLayerNetworks and ComputationGraphs saved to local file system - */ -@AllArgsConstructor -public class LocalFileNetResultReference implements ResultReference { - - private int index; - private String dir; - private boolean isGraph; - private File modelFile; - private File scoreFile; - private File additionalResultsFile; - private File esConfigFile; - private File numEpochsFile; - private Candidate candidate; - - @Override - public OptimizationResult getResult() throws IOException { - - - String scoreStr = FileUtils.readFileToString(scoreFile); - //TODO: properly parsing. Probably want to store additional info other than just score... - double d = Double.parseDouble(scoreStr); - - EarlyStoppingConfiguration earlyStoppingConfiguration = null; - if (esConfigFile != null && esConfigFile.exists()) { - try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(esConfigFile))) { - earlyStoppingConfiguration = (EarlyStoppingConfiguration) ois.readObject(); - } catch (ClassNotFoundException e) { - throw new RuntimeException("Error loading early stopping configuration", e); - } - } - int nEpochs = 1; - if (numEpochsFile != null && numEpochsFile.exists()) { - String numEpochs = FileUtils.readFileToString(numEpochsFile); - nEpochs = Integer.parseInt(numEpochs); - } - - - - Object additionalResults; - if (additionalResultsFile.exists()) { - try (ObjectInputStream ois = new ObjectInputStream(new FileInputStream(additionalResultsFile))) { - additionalResults = ois.readObject(); - } catch (ClassNotFoundException e) { - throw new RuntimeException("Error loading additional results", e); - } - } else { - additionalResults = null; - } - - return new OptimizationResult(candidate, d, index, additionalResults, null, this); - } - - @Override - public Object getResultModel() throws IOException { - Model m; - if (isGraph) { - m = ModelSerializer.restoreComputationGraph(modelFile, false); - } else { - m = ModelSerializer.restoreMultiLayerNetwork(modelFile, false); - } - return m; - } - - @Override - public String toString() { - return "LocalFileNetResultReference(" + dir + ")"; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/RegressionValue.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/RegressionValue.java deleted file mode 100644 index 304750dc8..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/RegressionValue.java +++ /dev/null @@ -1,32 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring; - -/** - * Enumeration used to select the type of regression statistics to optimize on, with the various regression score functions - * - MSE: mean squared error
      - * - MAE: mean absolute error
      - * - RMSE: root mean squared error
      - * - RSE: relative squared error
      - * - CorrCoeff: correlation coefficient
      - * - * @deprecated Use {@link org.deeplearning4j.eval.RegressionEvaluation.Metric} - */ -@Deprecated -public enum RegressionValue { - MSE, MAE, RMSE, RSE, CorrCoeff -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/ScoreFunctions.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/ScoreFunctions.java deleted file mode 100644 index f9e57a597..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/ScoreFunctions.java +++ /dev/null @@ -1,66 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring; - - -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.scoring.impl.TestSetAccuracyScoreFunction; -import org.deeplearning4j.arbiter.scoring.impl.TestSetF1ScoreFunction; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.arbiter.scoring.impl.TestSetRegressionScoreFunction; - -/** - * ScoreFunctions provides static methods for getting score functions for DL4J MultiLayerNetwork and ComputationGraph - * - * @author Alex Black - */ -public class ScoreFunctions { - - private ScoreFunctions() {} - - /** - * Calculate the loss (score/loss function value) on a test set, for a MultiLayerNetwork - * - * @param average Average (divide by number of examples) - */ - public static ScoreFunction testSetLoss(boolean average) { - return new TestSetLossScoreFunction(average); - } - - /** - * Calculate the accuracy on a test set, for a MultiLayerNetwork - */ - public static ScoreFunction testSetAccuracy() { - return new TestSetAccuracyScoreFunction(); - } - - - /** - * Calculate the f1 score on a test set - */ - public static ScoreFunction testSetF1() { - return new TestSetF1ScoreFunction(); - } - - /** - * Calculate a regression value (MSE, MAE etc) on a test set - */ - public static ScoreFunction testSetRegression(RegressionValue regressionValue) { - return new TestSetRegressionScoreFunction(regressionValue); - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/BaseNetScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/BaseNetScoreFunction.java deleted file mode 100644 index 1d38ada7c..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/BaseNetScoreFunction.java +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.impl; - -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; -import java.util.Map; -import java.util.Properties; - -/** - * Created by Alex on 23/07/2017. - */ -@EqualsAndHashCode -public abstract class BaseNetScoreFunction implements ScoreFunction { - - - @Override - public double score(Object model, DataProvider dataProvider, Map dataParameters) { - Object testData = dataProvider.testData(dataParameters); - return score(model, testData); - } - - @Override - public double score(Object model, Class dataSource, Properties dataSourceProperties) { - DataSource ds; - try{ - ds = dataSource.newInstance(); - if (dataSourceProperties != null) { - ds.configure(dataSourceProperties); - } - } catch (Exception e){ - throw new RuntimeException("Error creating DataSource instance - missing no-arg constructor?", e); - } - return score(model, ds.testData()); - } - - protected double score(Object model, Object testData){ - if (model instanceof MultiLayerNetwork) { - if (testData instanceof DataSetIterator) { - return score((MultiLayerNetwork) model, (DataSetIterator) testData); - } else if(testData instanceof MultiDataSetIterator){ - return score((MultiLayerNetwork) model, (MultiDataSetIterator) testData); - } else if(testData instanceof DataSetIteratorFactory){ - return score((MultiLayerNetwork)model, ((DataSetIteratorFactory)testData).create()); - } else { - throw new RuntimeException("Unknown type of data: " + testData.getClass()); - } - } else { - if (testData instanceof DataSetIterator) { - return score((ComputationGraph) model, (DataSetIterator) testData); - } else if(testData instanceof DataSetIteratorFactory){ - return score((ComputationGraph) model, ((DataSetIteratorFactory)testData).create()); - } else if(testData instanceof MultiDataSetIterator) { - return score((ComputationGraph) model, (MultiDataSetIterator) testData); - } else { - throw new RuntimeException("Unknown type of data: " + testData.getClass()); - } - } - } - - @Override - public List> getSupportedModelTypes() { - return Arrays.>asList(MultiLayerNetwork.class, ComputationGraph.class); - } - - @Override - public List> getSupportedDataTypes() { - return Arrays.>asList(DataSetIterator.class, MultiDataSetIterator.class); - } - - public abstract double score(MultiLayerNetwork net, DataSetIterator iterator); - - public abstract double score(MultiLayerNetwork net, MultiDataSetIterator iterator); - - public abstract double score(ComputationGraph graph, DataSetIterator iterator); - - public abstract double score(ComputationGraph graph, MultiDataSetIterator iterator); -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/EvaluationScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/EvaluationScoreFunction.java deleted file mode 100644 index 7e71425d5..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/EvaluationScoreFunction.java +++ /dev/null @@ -1,86 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.impl; - -import lombok.*; -import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.evaluation.classification.Evaluation; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -/** - * Score function that calculates an evaluation {@link Evaluation.Metric} on the test set for a - * {@link MultiLayerNetwork} or {@link ComputationGraph} - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //JSON -public class EvaluationScoreFunction extends BaseNetScoreFunction { - - protected Evaluation.Metric metric; - - /** - * @param metric Evaluation metric to calculate - */ - public EvaluationScoreFunction(@NonNull org.deeplearning4j.eval.Evaluation.Metric metric) { - this(metric.toNd4j()); - } - - /** - * @param metric Evaluation metric to calculate - */ - public EvaluationScoreFunction(@NonNull Evaluation.Metric metric) { - this.metric = metric; - } - - @Override - public String toString() { - return "EvaluationScoreFunction(metric=" + metric + ")"; - } - - @Override - public double score(MultiLayerNetwork net, DataSetIterator iterator) { - Evaluation e = net.evaluate(iterator); - return e.scoreForMetric(metric); - } - - @Override - public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) { - return score(net, new MultiDataSetWrapperIterator(iterator)); - } - - @Override - public double score(ComputationGraph graph, DataSetIterator iterator) { - Evaluation e = graph.evaluate(iterator); - return e.scoreForMetric(metric); - } - - @Override - public double score(ComputationGraph graph, MultiDataSetIterator iterator) { - Evaluation e = graph.evaluate(iterator); - return e.scoreForMetric(metric); - } - - @Override - public boolean minimize() { - return false; //Want to maximize all evaluation metrics: Accuracy, F1, precision, recall, g-measure, mcc - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java deleted file mode 100644 index 9203963e3..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/ROCScoreFunction.java +++ /dev/null @@ -1,122 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.impl; - -import lombok.*; -import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.evaluation.classification.ROC; -import org.nd4j.evaluation.classification.ROCBinary; -import org.nd4j.evaluation.classification.ROCMultiClass; -import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -/** - * Score function that calculates AUC (area under ROC curve) or AUPRC (area under precision/recall curve) on a test set - * for a {@link MultiLayerNetwork} or {@link ComputationGraph} - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //JSON -public class ROCScoreFunction extends BaseNetScoreFunction { - - /** - * Type of ROC evaluation to perform:
      - * ROC: use {@link ROC} to perform evaluation (single output binary classification)
      - * BINARY: use {@link ROCBinary} to perform evaluation (multi-output/multi-task binary classification)
      - * MULTICLASS: use {@link ROCMultiClass} to perform evaluation (1 vs. all multi-class classification) - * - */ - public enum ROCType {ROC, BINARY, MULTICLASS} - - /** - * Metric to calculate.
      - * AUC: Area under ROC curve
      - * AUPRC: Area under precision/recall curve - */ - public enum Metric {AUC, AUPRC}; - - protected ROCType type; - protected Metric metric; - - /** - * @param type ROC type to use for evaluation - * @param metric Evaluation metric to calculate - */ - public ROCScoreFunction(@NonNull ROCType type, @NonNull Metric metric) { - this.type = type; - this.metric = metric; - } - - @Override - public String toString() { - return "ROCScoreFunction(type=" + type + ",metric=" + metric + ")"; - } - - @Override - public double score(MultiLayerNetwork net, DataSetIterator iterator) { - switch (type){ - case ROC: - ROC r = net.evaluateROC(iterator); - return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR(); - case BINARY: - ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0]; - return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR(); - case MULTICLASS: - ROCMultiClass r3 = net.evaluateROCMultiClass(iterator); - return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR(); - default: - throw new RuntimeException("Unknown type: " + type); - } - } - - @Override - public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) { - return score(net, new MultiDataSetWrapperIterator(iterator)); - } - - @Override - public double score(ComputationGraph graph, DataSetIterator iterator) { - return score(graph, new MultiDataSetIteratorAdapter(iterator)); - } - - @Override - public double score(ComputationGraph net, MultiDataSetIterator iterator) { - switch (type){ - case ROC: - ROC r = net.evaluateROC(iterator); - return metric == Metric.AUC ? r.calculateAUC() : r.calculateAUCPR(); - case BINARY: - ROCBinary r2 = net.doEvaluation(iterator, new ROCBinary())[0]; - return metric == Metric.AUC ? r2.calculateAverageAuc() : r2.calculateAverageAUCPR(); - case MULTICLASS: - ROCMultiClass r3 = net.evaluateROCMultiClass(iterator, 0); - return metric == Metric.AUC ? r3.calculateAverageAUC() : r3.calculateAverageAUCPR(); - default: - throw new RuntimeException("Unknown type: " + type); - } - } - - @Override - public boolean minimize() { - return false; //Want to maximize all evaluation metrics: Accuracy, F1, precision, recall, g-measure, mcc - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java deleted file mode 100644 index 51fcd9898..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/RegressionScoreFunction.java +++ /dev/null @@ -1,92 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.impl; - -import lombok.*; -import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.evaluation.regression.RegressionEvaluation; -import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -/** - * Score function for regression (including multi-label regression) for a MultiLayerNetwork or ComputationGraph - * on a test set. Supports all regression metrics: {@link Metric} - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For JSON -public class RegressionScoreFunction extends BaseNetScoreFunction { - - protected Metric metric; - - public RegressionScoreFunction(@NonNull org.deeplearning4j.eval.RegressionEvaluation.Metric metric) { - this(metric.toNd4j()); - } - - public RegressionScoreFunction(@NonNull Metric metric) { - this.metric = metric; - } - - @Override - public boolean minimize() { - switch (metric) { - case MSE: - case MAE: - case RMSE: - case RSE: - return true; - case PC: - case R2: - return false; - default: - throw new IllegalStateException("Unknown metric: " + metric); - } - } - - @Override - public String toString() { - return "RegressionScoreFunction(metric=" + metric + ")"; - } - - @Override - public double score(MultiLayerNetwork net, DataSetIterator iterator) { - RegressionEvaluation e = net.evaluateRegression(iterator); - return e.scoreForMetric(metric); - } - - @Override - public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) { - return score(net, new MultiDataSetWrapperIterator(iterator)); - } - - @Override - public double score(ComputationGraph graph, DataSetIterator iterator) { - RegressionEvaluation e = graph.evaluateRegression(iterator); - return e.scoreForMetric(metric); - } - - @Override - public double score(ComputationGraph graph, MultiDataSetIterator iterator) { - RegressionEvaluation e = graph.evaluateRegression(iterator); - return e.scoreForMetric(metric); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetAccuracyScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetAccuracyScoreFunction.java deleted file mode 100644 index 34b051663..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetAccuracyScoreFunction.java +++ /dev/null @@ -1,72 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.impl; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -/** - * Score function that calculates the accuracy on a - * test set for a {@link MultiLayerNetwork} or {@link ComputationGraph} - * - * @author Alex Black - * @deprecated Use {@link EvaluationScoreFunction} - */ -@Data -@EqualsAndHashCode(callSuper = true) -@Deprecated -public class TestSetAccuracyScoreFunction extends BaseNetScoreFunction { - - - @Override - public String toString() { - return "TestSetAccuracyScoreFunction()"; - } - - @Override - public double score(MultiLayerNetwork net, DataSetIterator iterator) { - Evaluation e = net.evaluate(iterator); - return e.accuracy(); - } - - @Override - public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) { - throw new UnsupportedOperationException("Cannot evaluate MultiLayerNetwork on MultiDataSetIterator"); - } - - @Override - public double score(ComputationGraph graph, DataSetIterator iterator) { - Evaluation e = graph.evaluate(iterator); - return e.accuracy(); - } - - @Override - public double score(ComputationGraph graph, MultiDataSetIterator iterator) { - Evaluation e = graph.evaluate(iterator); - return e.accuracy(); - } - - @Override - public boolean minimize() { - return false; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetF1ScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetF1ScoreFunction.java deleted file mode 100644 index 24516a1d7..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetF1ScoreFunction.java +++ /dev/null @@ -1,72 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.impl; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -/** - * Score function that calculates the F1 score - * on a test set for a {@link MultiLayerNetwork} or {@link ComputationGraph} - * - * @author Alex Black - * @deprecated Use {@link EvaluationScoreFunction} - */ -@Data -@EqualsAndHashCode(callSuper = true) -@Deprecated -public class TestSetF1ScoreFunction extends BaseNetScoreFunction { - - @Override - public boolean minimize() { - return false; //false -> maximize - } - - - @Override - public String toString() { - return "TestSetF1ScoreFunction"; - } - - @Override - public double score(MultiLayerNetwork net, DataSetIterator iterator) { - Evaluation e = net.evaluate(iterator); - return e.f1(); - } - - @Override - public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) { - throw new UnsupportedOperationException("Cannot evaluate MultiLayerNetwork on MultiDataSetIterator"); - } - - @Override - public double score(ComputationGraph graph, DataSetIterator iterator) { - Evaluation e = graph.evaluate(iterator); - return e.f1(); - } - - @Override - public double score(ComputationGraph graph, MultiDataSetIterator iterator) { - Evaluation e = graph.evaluate(iterator); - return e.f1(); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetLossScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetLossScoreFunction.java deleted file mode 100644 index f44df800e..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetLossScoreFunction.java +++ /dev/null @@ -1,78 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.impl; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.deeplearning4j.arbiter.scoring.util.ScoreUtil; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import com.fasterxml.jackson.annotation.JsonProperty; - -/** - * Score function that calculates the test set loss - * on a test set for a {@link MultiLayerNetwork} or {@link ComputationGraph} - * - * @author Alex Black - */ -@Data -@EqualsAndHashCode(callSuper = false) -public class TestSetLossScoreFunction extends BaseNetScoreFunction { - @JsonProperty - private final boolean average; - - public TestSetLossScoreFunction() { - this(true); - } - - public TestSetLossScoreFunction(boolean average) { - this.average = average; - } - - - @Override - public boolean minimize() { - return true; - } - - @Override - public String toString() { - return "TestSetLossScoreFunction()"; - } - - @Override - public double score(MultiLayerNetwork net, DataSetIterator iterator) { - return ScoreUtil.score(net, iterator, average); - } - - @Override - public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) { - throw new UnsupportedOperationException("Cannot evaluate MultiLayerNetwork on MultiDataSetIterator"); - } - - @Override - public double score(ComputationGraph graph, DataSetIterator iterator) { - return ScoreUtil.score(graph, iterator, average); - } - - @Override - public double score(ComputationGraph graph, MultiDataSetIterator iterator) { - return ScoreUtil.score(graph, iterator, average); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetRegressionScoreFunction.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetRegressionScoreFunction.java deleted file mode 100644 index 0a27cea4e..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/impl/TestSetRegressionScoreFunction.java +++ /dev/null @@ -1,85 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.impl; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; -import org.deeplearning4j.arbiter.scoring.RegressionValue; -import org.deeplearning4j.arbiter.scoring.util.ScoreUtil; -import org.deeplearning4j.eval.RegressionEvaluation; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -/** - * Score function for regression (including multi-label regression) for a MultiLayerNetwork or ComputationGraph - * on a test set - * - * @author Alex Black - * @deprecated Use {@link RegressionScoreFunction} - */ -@Data -@EqualsAndHashCode(callSuper = true) -@NoArgsConstructor(access = AccessLevel.PROTECTED) //For JSON -@Deprecated -public class TestSetRegressionScoreFunction extends BaseNetScoreFunction { - private RegressionValue regressionValue; - - /** - * @param regressionValue The type of evaluation to do: MSE, MAE, RMSE, etc - */ - public TestSetRegressionScoreFunction(RegressionValue regressionValue) { - this.regressionValue = regressionValue; - } - - - @Override - public boolean minimize() { - return regressionValue != RegressionValue.CorrCoeff; //Maximize correlation coefficient, minimize the remaining ones - } - - @Override - public String toString() { - return "TestSetRegressionScoreFunction(type=" + regressionValue + ")"; - } - - @Override - public double score(MultiLayerNetwork net, DataSetIterator iterator) { - RegressionEvaluation e = net.evaluateRegression(iterator); - return ScoreUtil.getScoreFromRegressionEval(e, regressionValue); - } - - @Override - public double score(MultiLayerNetwork net, MultiDataSetIterator iterator) { - throw new UnsupportedOperationException("Cannot evaluate MultiLayerNetwork on MultiDataSetIterator"); - } - - @Override - public double score(ComputationGraph graph, DataSetIterator iterator) { - RegressionEvaluation e = graph.evaluateRegression(iterator); - return ScoreUtil.getScoreFromRegressionEval(e, regressionValue); - } - - @Override - public double score(ComputationGraph graph, MultiDataSetIterator iterator) { - RegressionEvaluation e = graph.evaluateRegression(iterator); - return ScoreUtil.getScoreFromRegressionEval(e, regressionValue); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/util/ScoreUtil.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/util/ScoreUtil.java deleted file mode 100644 index 303defe35..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/scoring/util/ScoreUtil.java +++ /dev/null @@ -1,328 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.scoring.util; - -import org.deeplearning4j.arbiter.scoring.RegressionValue; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.eval.RegressionEvaluation; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; -import org.nd4j.linalg.dataset.api.MultiDataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIteratorFactory; - - - -/** - * Various utilities for functions used in arbiter. - * - * @author Adam Gibson - */ -public class ScoreUtil { - - - - /** - * Get a {@link DataSetIterator} - * from the given object whether it's a {@link DataSetIterator} - * or {@link DataSetIteratorFactory}, any other type will throw - * an {@link IllegalArgumentException} - * @param o the object to get the iterator from - * @return the datasetiterator from the given objects - */ - public static MultiDataSetIterator getMultiIterator(Object o) { - if (o instanceof MultiDataSetIterator) { - return (MultiDataSetIterator) o; - } else if (o instanceof MultiDataSetIteratorFactory) { - MultiDataSetIteratorFactory factory = (MultiDataSetIteratorFactory) o; - return factory.create(); - } else if( o instanceof DataSetIterator ){ - return new MultiDataSetIteratorAdapter((DataSetIterator)o); - } else if( o instanceof DataSetIteratorFactory ){ - return new MultiDataSetIteratorAdapter(((DataSetIteratorFactory)o).create()); - } - - throw new IllegalArgumentException("Type must either be DataSetIterator or DataSetIteratorFactory"); - } - - - /** - * Get a {@link DataSetIterator} - * from the given object whether it's a {@link DataSetIterator} - * or {@link DataSetIteratorFactory}, any other type will throw - * an {@link IllegalArgumentException} - * @param o the object to get the iterator from - * @return the datasetiterator from the given objects - */ - public static DataSetIterator getIterator(Object o) { - if (o instanceof DataSetIterator) - return (DataSetIterator) o; - else if (o instanceof DataSetIteratorFactory) { - DataSetIteratorFactory factory = (DataSetIteratorFactory) o; - return factory.create(); - } - - throw new IllegalArgumentException("Type must either be DataSetIterator or DataSetIteratorFactory"); - } - - /** - * - * @param model - * @param testData - * @return - */ - public static Evaluation getEvaluation(MultiLayerNetwork model, DataSetIterator testData) { - return model.evaluate(testData); - } - - /** - * Get the evaluation - * for the given model and test dataset - * @param model the model to get the evaluation from - * @param testData the test data to do the evaluation on - * @return the evaluation object with accumulated statistics - * for the current test data - */ - public static Evaluation getEvaluation(ComputationGraph model, MultiDataSetIterator testData) { - if (model.getNumOutputArrays() != 1) - throw new IllegalStateException("GraphSetSetAccuracyScoreFunction cannot be " - + "applied to ComputationGraphs with more than one output. NumOutputs = " - + model.getNumOutputArrays()); - - return model.evaluate(testData); - } - - - /** - * Get the evaluation - * for the given model and test dataset - * @param model the model to get the evaluation from - * @param testData the test data to do the evaluation on - * @return the evaluation object with accumulated statistics - * for the current test data - */ - public static Evaluation getEvaluation(ComputationGraph model, DataSetIterator testData) { - if (model.getNumOutputArrays() != 1) - throw new IllegalStateException("GraphSetSetAccuracyScoreFunctionDataSet cannot be " - + "applied to ComputationGraphs with more than one output. NumOutputs = " - + model.getNumOutputArrays()); - - return model.evaluate(testData); - } - - - - /** - * Score based on the loss function - * @param model the model to score with - * @param testData the test data to score - * @param average whether to average the score - * for the whole batch or not - * @return the score for the given test set - */ - public static double score(ComputationGraph model, MultiDataSetIterator testData, boolean average) { - //TODO: do this properly taking into account division by N, L1/L2 etc - double sumScore = 0.0; - int totalExamples = 0; - while (testData.hasNext()) { - MultiDataSet ds = testData.next(); - long numExamples = ds.getFeatures(0).size(0); - sumScore += numExamples * model.score(ds); - totalExamples += numExamples; - } - - if (!average) - return sumScore; - return sumScore / totalExamples; - } - - /** - * Score based on the loss function - * @param model the model to score with - * @param testData the test data to score - * @param average whether to average the score - * for the whole batch or not - * @return the score for the given test set - */ - public static double score(ComputationGraph model, DataSetIterator testData, boolean average) { - //TODO: do this properly taking into account division by N, L1/L2 etc - double sumScore = 0.0; - int totalExamples = 0; - while (testData.hasNext()) { - DataSet ds = testData.next(); - int numExamples = ds.numExamples(); - - sumScore += numExamples * model.score(ds); - totalExamples += numExamples; - } - - if (!average) - return sumScore; - return sumScore / totalExamples; - } - - - /** - * - * @param model - * @param testSet - * @param regressionValue - * @return - */ - public static double score(ComputationGraph model, MultiDataSetIterator testSet, RegressionValue regressionValue) { - int nOutputs = model.getNumOutputArrays(); - - RegressionEvaluation[] evaluations = new RegressionEvaluation[nOutputs]; - for (int i = 0; i < evaluations.length; i++) - evaluations[i] = new RegressionEvaluation(); - - while (testSet.hasNext()) { - MultiDataSet next = testSet.next(); - INDArray[] labels = next.getLabels(); - - if (next.hasMaskArrays()) { - INDArray[] fMasks = next.getFeaturesMaskArrays(); - INDArray[] lMasks = next.getLabelsMaskArrays(); - - model.setLayerMaskArrays(fMasks, lMasks); - - INDArray[] outputs = model.output(false, next.getFeatures()); - for (int i = 0; i < evaluations.length; i++) { - if (lMasks != null && lMasks[i] != null) { - evaluations[i].evalTimeSeries(labels[i], outputs[i], lMasks[i]); - } else { - evaluations[i].evalTimeSeries(labels[i], outputs[i]); - } - } - - model.clearLayerMaskArrays(); - } else { - INDArray[] outputs = model.output(false, next.getFeatures()); - for (int i = 0; i < evaluations.length; i++) { - if (labels[i].rank() == 3) { - evaluations[i].evalTimeSeries(labels[i], outputs[i]); - } else { - evaluations[i].eval(labels[i], outputs[i]); - } - } - } - } - - double sum = 0.0; - int totalColumns = 0; - for (int i = 0; i < evaluations.length; i++) { - int nColumns = evaluations[i].numColumns(); - totalColumns += nColumns; - sum += getScoreFromRegressionEval(evaluations[i], regressionValue); - } - if (regressionValue == RegressionValue.CorrCoeff) - sum /= totalColumns; - - return sum; - } - - - /** - * Run a {@link RegressionEvaluation} - * over a {@link DataSetIterator} - * @param model the model to use - * @param testSet the test set iterator - * @param regressionValue the regression type to use - * @return - */ - public static double score(ComputationGraph model, DataSetIterator testSet, RegressionValue regressionValue) { - RegressionEvaluation evaluation = model.evaluateRegression(testSet); - return getScoreFromRegressionEval(evaluation, regressionValue); - } - - - /** - * Score the given test data - * with the given multi layer network - * @param model model to use - * @param testData the test data to test with - * @param average whether to average the score or not - * @return the score for the given test data given the model - */ - public static double score(MultiLayerNetwork model, DataSetIterator testData, boolean average) { - //TODO: do this properly taking into account division by N, L1/L2 etc - double sumScore = 0.0; - int totalExamples = 0; - while (testData.hasNext()) { - DataSet ds = testData.next(); - int numExamples = ds.numExamples(); - - sumScore += numExamples * model.score(ds); - totalExamples += numExamples; - } - - if (!average) - return sumScore; - return sumScore / totalExamples; - } - - - /** - * Score the given multi layer network - * @param model the model to score - * @param testSet the test set - * @param regressionValue the regression function to use - * @return the score from the given test set - */ - public static double score(MultiLayerNetwork model, DataSetIterator testSet, RegressionValue regressionValue) { - RegressionEvaluation eval = model.evaluateRegression(testSet); - return getScoreFromRegressionEval(eval, regressionValue); - } - - - @Deprecated - public static double getScoreFromRegressionEval(RegressionEvaluation eval, RegressionValue regressionValue) { - double sum = 0.0; - int nColumns = eval.numColumns(); - switch (regressionValue) { - case MSE: - for (int i = 0; i < nColumns; i++) - sum += eval.meanSquaredError(i); - break; - case MAE: - for (int i = 0; i < nColumns; i++) - sum += eval.meanAbsoluteError(i); - break; - case RMSE: - for (int i = 0; i < nColumns; i++) - sum += eval.rootMeanSquaredError(i); - break; - case RSE: - for (int i = 0; i < nColumns; i++) - sum += eval.relativeSquaredError(i); - break; - case CorrCoeff: - for (int i = 0; i < nColumns; i++) - sum += eval.correlationR2(i); - sum /= nColumns; - break; - } - - return sum; - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/ComputationGraphTaskCreator.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/ComputationGraphTaskCreator.java deleted file mode 100644 index 53a9fe0aa..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/ComputationGraphTaskCreator.java +++ /dev/null @@ -1,267 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.task; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.exception.ExceptionUtils; -import org.deeplearning4j.arbiter.GraphConfiguration; -import org.deeplearning4j.arbiter.listener.DL4JArbiterStatusReportingListener; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.api.TaskCreator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; -import org.deeplearning4j.arbiter.scoring.util.ScoreUtil; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.earlystopping.EarlyStoppingResult; -import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.IOException; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.Callable; - -/** - * Task creator for ComputationGraph - * - * @author Alex Black - */ -@AllArgsConstructor -@NoArgsConstructor -@Slf4j -public class ComputationGraphTaskCreator implements TaskCreator { - - private ModelEvaluator modelEvaluator; - @Getter - @Setter - private TaskListener taskListener; - - public ComputationGraphTaskCreator(ModelEvaluator modelEvaluator){ - this(modelEvaluator, null); - } - - @Override - public Callable create(Candidate candidate, DataProvider dataProvider, - ScoreFunction scoreFunction, List statusListener, - IOptimizationRunner runner) { - - return new GraphLearningTask(candidate, dataProvider, scoreFunction, modelEvaluator, statusListener, - taskListener, runner); - } - - @Override - public Callable create(Candidate candidate, Class dataSource, Properties dataSourceProperties, - ScoreFunction scoreFunction, List statusListeners, IOptimizationRunner runner) { - return new GraphLearningTask(candidate, dataSource, dataSourceProperties, scoreFunction, modelEvaluator, statusListeners, - taskListener, runner); - } - - @AllArgsConstructor - private static class GraphLearningTask implements Callable { - - private Candidate candidate; - private DataProvider dataProvider; - private Class dataSource; - private Properties dataSourceProperties; - private ScoreFunction scoreFunction; - private ModelEvaluator modelEvaluator; - private List listeners; - private TaskListener taskListener; - private IOptimizationRunner runner; - - private long startTime; - - public GraphLearningTask(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction, - ModelEvaluator modelEvaluator, List listeners, - TaskListener taskListener, IOptimizationRunner runner) { - this.candidate = candidate; - this.dataProvider = dataProvider; - this.scoreFunction = scoreFunction; - this.modelEvaluator = modelEvaluator; - this.listeners = listeners; - this.taskListener = taskListener; - this.runner = runner; - } - - public GraphLearningTask(Candidate candidate, Class dataSource, Properties dataSourceProperties, - ScoreFunction scoreFunction, ModelEvaluator modelEvaluator, List listeners, - TaskListener taskListener, IOptimizationRunner runner) { - this.candidate = candidate; - this.dataSource = dataSource; - this.dataSourceProperties = dataSourceProperties; - this.scoreFunction = scoreFunction; - this.modelEvaluator = modelEvaluator; - this.listeners = listeners; - this.taskListener = taskListener; - this.runner = runner; - } - - - @Override - public OptimizationResult call() throws Exception { - - try { - OptimizationResult result = callHelper(); - if(listeners != null && !listeners.isEmpty()){ - CandidateInfo ci = new CandidateInfo(candidate.getIndex(), CandidateStatus.Complete, result.getScore(), - startTime, startTime, System.currentTimeMillis(), candidate.getFlatParameters(), null); - for(StatusListener sl : listeners){ - try{ - sl.onCandidateStatusChange(ci, runner, result); - } catch (Exception e){ - log.error("Error in status listener for candidate {}", candidate.getIndex(), e); - } - } - } - return result; - } catch (Throwable e) { - String stackTrace = ExceptionUtils.getStackTrace(e); - log.warn("Execution failed for task {}", candidate.getIndex(), e); - - CandidateInfo ci = new CandidateInfo(candidate.getIndex(), CandidateStatus.Failed, null, startTime, - null, null, candidate.getFlatParameters(), stackTrace); - return new OptimizationResult(candidate, null, candidate.getIndex(), null, ci, null); - } finally { - //Destroy workspaces to free memory - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - System.gc(); - try { - //Sleep for a few seconds - workspace destruction and memory deallocation happens quickly but doesn't - // happen instantly; if we didn't have this, we may run into a situation where the next thread/task - // tries to allocate before WS memory is fully deallocated, resulting in an OOM in memory constrained - // environments - Thread.sleep(2000L); - } catch (Exception e){ } - } - } - - private OptimizationResult callHelper() throws Exception { - startTime = System.currentTimeMillis(); - CandidateInfo ci = new CandidateInfo(candidate.getIndex(), CandidateStatus.Running, null, startTime, startTime, - null, candidate.getFlatParameters(), null); - - //Create network - ComputationGraph net = new ComputationGraph(((GraphConfiguration) candidate.getValue()).getConfiguration()); - net.init(); - - if(taskListener != null){ - net = taskListener.preProcess(net, candidate); - } - - if (listeners != null) { - net.addListeners(new DL4JArbiterStatusReportingListener(listeners, ci)); - } - - //For DataSetIterator: wraps in a MultiDataSetIterator, hence method can be used for both - MultiDataSetIterator iterator; - if(dataSource != null){ - try { - DataSource dsInstance = dataSource.newInstance(); - if (dataSourceProperties != null) - dsInstance.configure(dataSourceProperties); - iterator = ScoreUtil.getMultiIterator(dsInstance.trainData()); - } catch (Exception e){ - throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() + - " - no zero-arg constructor?",e); - } - } else { - iterator = ScoreUtil.getMultiIterator(dataProvider.trainData(candidate.getDataParameters())); - } - - - EarlyStoppingConfiguration esConfig = - ((GraphConfiguration) candidate.getValue()).getEarlyStoppingConfiguration(); - EarlyStoppingResult esResult = null; - if (esConfig != null) { - EarlyStoppingGraphTrainer trainer = new EarlyStoppingGraphTrainer(esConfig, net, iterator, null); - esResult = trainer.fit(); - net = esResult.getBestModel(); //Can return null if failed OR if - - switch (esResult.getTerminationReason()) { - case Error: - ci.setCandidateStatus(CandidateStatus.Failed); - ci.setExceptionStackTrace(esResult.getTerminationDetails()); - break; - case IterationTerminationCondition: - case EpochTerminationCondition: - ci.setCandidateStatus(CandidateStatus.Complete); - break; - } - - } else { - //Fixed number of epochs - int nEpochs = ((GraphConfiguration) candidate.getValue()).getNumEpochs(); - for (int i = 0; i < nEpochs; i++) { - net.fit(iterator); - } - ci.setCandidateStatus(CandidateStatus.Complete); - } - Nd4j.getExecutioner().commit(); - - Object additionalEvaluation = null; - if (esConfig != null && esResult.getTerminationReason() != EarlyStoppingResult.TerminationReason.Error) { - additionalEvaluation = - (modelEvaluator != null ? modelEvaluator.evaluateModel(net, dataProvider) : null); - } - - Double score = null; - if (net != null) { - if(dataSource != null){ - score = scoreFunction.score(net, dataSource, dataSourceProperties); - } else { - score = scoreFunction.score(net, dataProvider, candidate.getDataParameters()); - } - ci.setScore(score); - } - - if(taskListener != null){ - taskListener.postProcess(net, candidate); - } - - OptimizationResult result = new OptimizationResult(candidate, score, candidate.getIndex(), additionalEvaluation, ci, null); - - //Save the model: - ResultSaver saver = runner.getConfiguration().getResultSaver(); - ResultReference resultReference = null; - if (saver != null) { - try { - resultReference = saver.saveModel(result, net); - } catch (IOException e) { - //TODO: Do we want ta warn or fail on IOException? - log.warn("Error saving model (id={}): IOException thrown. ", result.getIndex(), e); - } - } - result.setResultReference(resultReference); - return result; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/MultiLayerNetworkTaskCreator.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/MultiLayerNetworkTaskCreator.java deleted file mode 100644 index 5c2fb0703..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/MultiLayerNetworkTaskCreator.java +++ /dev/null @@ -1,265 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.task; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.exception.ExceptionUtils; -import org.deeplearning4j.arbiter.DL4JConfiguration; -import org.deeplearning4j.arbiter.listener.DL4JArbiterStatusReportingListener; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.api.TaskCreator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.evaluation.ModelEvaluator; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; -import org.deeplearning4j.arbiter.scoring.util.ScoreUtil; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.earlystopping.EarlyStoppingResult; -import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.IOException; -import java.util.List; -import java.util.Properties; -import java.util.concurrent.Callable; - -/** - * Task creator for MultiLayerNetworks - * - * @author Alex Black - */ -@AllArgsConstructor -@NoArgsConstructor -@Slf4j -public class MultiLayerNetworkTaskCreator implements TaskCreator { - - private ModelEvaluator modelEvaluator; - @Getter - @Setter - private TaskListener taskListener; - - public MultiLayerNetworkTaskCreator(ModelEvaluator modelEvaluator){ - this(modelEvaluator, null); - } - - @Override - public Callable create(Candidate candidate, DataProvider dataProvider, - ScoreFunction scoreFunction, List statusListeners, - IOptimizationRunner runner) { - - return new DL4JLearningTask(candidate, dataProvider, scoreFunction, modelEvaluator, statusListeners, taskListener, runner); - } - - @Override - public Callable create(Candidate candidate, Class dataSource, Properties dataSourceProperties, - ScoreFunction scoreFunction, List statusListeners, IOptimizationRunner runner) { - return new DL4JLearningTask(candidate, dataSource, dataSourceProperties, scoreFunction, modelEvaluator, statusListeners, taskListener, runner); - } - - - private static class DL4JLearningTask implements Callable { - - private Candidate candidate; - private DataProvider dataProvider; - private Class dataSource; - private Properties dataSourceProperties; - private ScoreFunction scoreFunction; - private ModelEvaluator modelEvaluator; - private List listeners; - private TaskListener taskListener; - private IOptimizationRunner runner; - - private long startTime; - - public DL4JLearningTask(Candidate candidate, DataProvider dataProvider, ScoreFunction scoreFunction, - ModelEvaluator modelEvaluator, List listeners, TaskListener taskListener, - IOptimizationRunner runner) { - this.candidate = candidate; - this.dataProvider = dataProvider; - this.scoreFunction = scoreFunction; - this.modelEvaluator = modelEvaluator; - this.listeners = listeners; - this.taskListener = taskListener; - this.runner = runner; - } - - public DL4JLearningTask(Candidate candidate, Class dataSource, Properties dataSourceProperties, - ScoreFunction scoreFunction, ModelEvaluator modelEvaluator, List listeners, TaskListener taskListener, - IOptimizationRunner runner) { - this.candidate = candidate; - this.dataSource = dataSource; - this.dataSourceProperties = dataSourceProperties; - this.scoreFunction = scoreFunction; - this.modelEvaluator = modelEvaluator; - this.listeners = listeners; - this.taskListener = taskListener; - this.runner = runner; - } - - - @Override - public OptimizationResult call() { - - try { - OptimizationResult result = callHelper(); - if(listeners != null && !listeners.isEmpty()){ - CandidateInfo ci = new CandidateInfo(candidate.getIndex(), CandidateStatus.Complete, result.getScore(), - startTime, startTime, System.currentTimeMillis(), candidate.getFlatParameters(), null); - for(StatusListener sl : listeners){ - try{ - sl.onCandidateStatusChange(ci, runner, result); - } catch (Exception e){ - log.error("Error in status listener for candidate {}", candidate.getIndex(), e); - } - } - } - return result; - } catch (Throwable e) { - String stackTrace = ExceptionUtils.getStackTrace(e); - log.warn( "Execution failed for task {}", candidate.getIndex(), e ); - - CandidateInfo ci = new CandidateInfo(candidate.getIndex(), CandidateStatus.Failed, null, startTime, - null, null, candidate.getFlatParameters(), stackTrace); - return new OptimizationResult(candidate, null, candidate.getIndex(), null, ci, null); - } finally { - //Destroy workspaces to free memory - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - System.gc(); - try { - //Sleep for a few seconds - workspace destruction and memory deallocation happens quickly but doesn't - // happen instantly; if we didn't have this, we may run into a situation where the next thread/task - // tries to allocate before WS memory is fully deallocated, resulting in an OOM in memory constrained - // environments - Thread.sleep(2000L); - } catch (Exception e){ } - } - } - - private OptimizationResult callHelper() { - startTime = System.currentTimeMillis(); - CandidateInfo ci = new CandidateInfo(candidate.getIndex(), CandidateStatus.Running, null, - startTime, startTime, null, candidate.getFlatParameters(), null); - - //Create network - MultiLayerNetwork net = new MultiLayerNetwork( - ((DL4JConfiguration) candidate.getValue()).getMultiLayerConfiguration()); - net.init(); - - if(taskListener != null){ - net = taskListener.preProcess(net, candidate); - } - - if (listeners != null) { - net.addListeners(new DL4JArbiterStatusReportingListener(listeners, ci)); - } - - //Early stopping or fixed number of epochs: - DataSetIterator dataSetIterator; - if(dataSource != null){ - DataSource dsInstance; - try{ - dsInstance = dataSource.newInstance(); - } catch (Exception e){ - throw new RuntimeException("Error instantiating instance of DataSource for class " + dataSource.getName() + - " - no zero-arg constructor?",e); - } - if(dataSourceProperties != null) - dsInstance.configure(dataSourceProperties); - dataSetIterator = ScoreUtil.getIterator(dsInstance.trainData()); - } else { - dataSetIterator = ScoreUtil.getIterator(dataProvider.trainData(candidate.getDataParameters())); - } - - - EarlyStoppingConfiguration esConfig = - ((DL4JConfiguration) candidate.getValue()).getEarlyStoppingConfiguration(); - EarlyStoppingResult esResult = null; - if (esConfig != null) { - EarlyStoppingTrainer trainer = new EarlyStoppingTrainer(esConfig, net, dataSetIterator, null); - esResult = trainer.fit(); - net = esResult.getBestModel(); //Can return null if failed OR if - - switch (esResult.getTerminationReason()) { - case Error: - ci.setCandidateStatus(CandidateStatus.Failed); - ci.setExceptionStackTrace(esResult.getTerminationDetails()); - break; - case IterationTerminationCondition: - case EpochTerminationCondition: - ci.setCandidateStatus(CandidateStatus.Complete); - break; - } - - } else { - //Fixed number of epochs - int nEpochs = ((DL4JConfiguration) candidate.getValue()).getNumEpochs(); - for (int i = 0; i < nEpochs; i++) { - net.fit(dataSetIterator); - } - ci.setCandidateStatus(CandidateStatus.Complete); - } - - Object additionalEvaluation = null; - if (esConfig != null && esResult.getTerminationReason() != EarlyStoppingResult.TerminationReason.Error) { - additionalEvaluation = - (modelEvaluator != null ? modelEvaluator.evaluateModel(net, dataProvider) : null); - } - - Double score = null; - if (net != null) { - if(dataSource != null){ - score = scoreFunction.score(net, dataSource, dataSourceProperties); - } else { - score = scoreFunction.score(net, dataProvider, candidate.getDataParameters()); - } - ci.setScore(score); - } - - if(taskListener != null){ - taskListener.postProcess(net, candidate); - } - - OptimizationResult result = new OptimizationResult(candidate, score, candidate.getIndex(), additionalEvaluation, ci, null); - //Save the model: - ResultSaver saver = runner.getConfiguration().getResultSaver(); - ResultReference resultReference = null; - if (saver != null) { - try { - resultReference = saver.saveModel(result, net); - } catch (IOException e) { - //TODO: Do we want ta warn or fail on IOException? - log.warn("Error saving model (id={}): IOException thrown. ", result.getIndex(), e); - } - } - result.setResultReference(resultReference); - return result; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/TaskListener.java b/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/TaskListener.java deleted file mode 100644 index ecf262548..000000000 --- a/arbiter/arbiter-deeplearning4j/src/main/java/org/deeplearning4j/arbiter/task/TaskListener.java +++ /dev/null @@ -1,49 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.task; - -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.nn.api.Model; - -import java.io.Serializable; - -/** - * TaskListener: can be used to preprocess and post process a model (MultiLayerNetwork or ComputationGraph) before/after - * training, in a {@link MultiLayerNetworkTaskCreator} or {@link ComputationGraphTaskCreator} - * - * @author Alex Black - */ -public interface TaskListener extends Serializable { - - /** - * Preprocess the model, before any training has taken place. - *
      - * Can be used to (for example) set listeners on a model before training starts - * @param model Model to preprocess - * @param candidate Candidate information, for the current model - * @return The updated model (usually the same one as the input, perhaps with modifications) - */ - T preProcess(T model, Candidate candidate); - - /** - * Post process the model, after any training has taken place - * @param model Model to postprocess - * @param candidate Candidate information, for the current model - */ - void postProcess(Model model, Candidate candidate); - -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java deleted file mode 100644 index 06e00219f..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,50 +0,0 @@ -/* ****************************************************************************** - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.arbiter; - -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.BaseDL4JTest; -import org.nd4j.common.tests.AbstractAssertTestsClass; - -import java.util.*; - -/** - * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) - * extends BaseDl4jTest - either directly or indirectly. - * Other than a small set of exceptions, all tests must extend this - * - * @author Alex Black - */ - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.deeplearning4j.arbiter"; - } - - @Override - protected Class getBaseClass() { - return BaseDL4JTest.class; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/TestUtils.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/TestUtils.java deleted file mode 100644 index ea5e0eddd..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/TestUtils.java +++ /dev/null @@ -1,243 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter; - -import org.apache.commons.compress.utils.IOUtils; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.util.ModelSerializer; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.learning.regularization.L1Regularization; -import org.nd4j.linalg.learning.regularization.L2Regularization; -import org.nd4j.linalg.learning.regularization.Regularization; -import org.nd4j.linalg.learning.regularization.WeightDecay; - -import java.io.*; -import java.util.List; -import java.util.Random; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; - -public class TestUtils { - - public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ - - MultiLayerNetwork restored; - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ModelSerializer.writeModel(net, baos, true); - byte[] bytes = baos.toByteArray(); - - ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - - assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); - assertEquals(net.params(), restored.params()); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - //Also check the MultiLayerConfiguration is serializable (required by Spark etc) - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - serializeDeserializeJava(conf); - - return restored; - } - - public static ComputationGraph testModelSerialization(ComputationGraph net){ - - ComputationGraph restored; - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ModelSerializer.writeModel(net, baos, true); - byte[] bytes = baos.toByteArray(); - - ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - restored = ModelSerializer.restoreComputationGraph(bais, true); - - assertEquals(net.getConfiguration(), restored.getConfiguration()); - assertEquals(net.params(), restored.params()); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) - ComputationGraphConfiguration conf = net.getConfiguration(); - serializeDeserializeJava(conf); - - return restored; - } - - private static T serializeDeserializeJava(T object){ - byte[] bytes; - try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ - oos.writeObject(object); - oos.close(); - bytes = baos.toByteArray(); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - T out; - try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){ - out = (T)ois.readObject(); - } catch (IOException | ClassNotFoundException e){ - throw new RuntimeException(e); - } - - assertEquals(object, out); - return out; - } - - public static INDArray randomOneHot(long examples, long nOut){ - return randomOneHot(examples, nOut, new Random(12345)); - } - - public static INDArray randomOneHot(long examples, long nOut, long rngSeed){ - return randomOneHot(examples, nOut, new Random(rngSeed)); - } - - public static INDArray randomOneHot(long examples, long nOut, Random rng){ - INDArray arr = Nd4j.create(examples, nOut); - for( int i=0; i l){ - for(Regularization r : l){ - if(r instanceof L1Regularization){ - return (L1Regularization) r; - } - } - return null; - } - - public static L2Regularization getL2Reg(BaseLayer baseLayer){ - return getL2Reg(baseLayer.getRegularization()); - } - - public static L2Regularization getL2Reg(List l){ - for(Regularization r : l){ - if(r instanceof L2Regularization){ - return (L2Regularization) r; - } - } - return null; - } - - public static WeightDecay getWeightDecayReg(BaseLayer bl){ - return getWeightDecayReg(bl.getRegularization()); - } - - public static WeightDecay getWeightDecayReg(List l){ - for(Regularization r : l){ - if(r instanceof WeightDecay){ - return (WeightDecay) r; - } - } - return null; - } - - public static double getL1(BaseLayer layer) { - List l = layer.getRegularization(); - return getL1(l); - } - - public static double getL1(List l){ - L1Regularization l1Reg = null; - for(Regularization reg : l){ - if(reg instanceof L1Regularization) - l1Reg = (L1Regularization) reg; - } - assertNotNull(l1Reg); - return l1Reg.getL1().valueAt(0,0); - } - - public static double getL2(BaseLayer layer) { - List l = layer.getRegularization(); - return getL2(l); - } - - public static double getL2(List l){ - L2Regularization l2Reg = null; - for(Regularization reg : l){ - if(reg instanceof L2Regularization) - l2Reg = (L2Regularization) reg; - } - assertNotNull(l2Reg); - return l2Reg.getL2().valueAt(0,0); - } - - public static double getL1(AbstractSameDiffLayer layer){ - return getL1(layer.getRegularization()); - } - - public static double getL2(AbstractSameDiffLayer layer){ - return getL2(layer.getRegularization()); - } - - public static double getWeightDecay(BaseLayer layer) { - return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java deleted file mode 100644 index b34280911..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestComputationGraphSpace.java +++ /dev/null @@ -1,168 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.computationgraph; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.ComputationGraphSpace; -import org.deeplearning4j.arbiter.TestUtils; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.graph.LayerVertex; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.junit.jupiter.api.Test; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.learning.config.Sgd; -import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; - -import java.util.List; -import java.util.Random; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -public class TestComputationGraphSpace extends BaseDL4JTest { - - @Test - public void testBasic() { - - ComputationGraphConfiguration expected = new NeuralNetConfiguration.Builder() - .updater(new Sgd(0.005)) - .seed(12345) - .graphBuilder().addInputs("in") - .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") - .addLayer("1", new DenseLayer.Builder().nIn(10).nOut(10).build(), "0").addLayer("2", - new OutputLayer.Builder().lossFunction(LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .nIn(10).nOut(5) - .build(), - "1") - .setOutputs("2").build(); - - ComputationGraphSpace cgs = new ComputationGraphSpace.Builder() - .updater(new Sgd(0.005)) - .seed(12345).addInputs("in") - .addLayer("0", new DenseLayerSpace.Builder().nIn(10).nOut(10).build(), "in") - .addLayer("1", new DenseLayerSpace.Builder().nIn(10).nOut(10).build(), "0") - .addLayer("2", new OutputLayerSpace.Builder().lossFunction(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(5) - .build(), "1") - .setOutputs("2").setInputTypes(InputType.feedForward(10)) - .build(); - - int nParams = cgs.numParameters(); - assertEquals(0, nParams); - - ComputationGraphConfiguration conf = cgs.getValue(new double[0]).getConfiguration(); - - assertEquals(expected, conf); - } - - @Test - public void testBasic2() { - - ComputationGraphSpace mls = new ComputationGraphSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.2, 0.5)) - .addInputs("in").addLayer("0", - new DenseLayerSpace.Builder().nIn(10).nOut(10) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), - "in") - .addLayer("1", new OutputLayerSpace.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX) - .build(), "0") - .setOutputs("1").setInputTypes(InputType.feedForward(10)).build(); - - int nParams = mls.numParameters(); - assertEquals(3, nParams); - - //Assign numbers to each leaf ParameterSpace object (normally done by candidate generator) - List noDuplicatesList = LeafUtils.getUniqueObjects(mls.collectLeaves()); - - //Second: assign each a number - int c = 0; - for (ParameterSpace ps : noDuplicatesList) { - int np = ps.numParameters(); - if (np == 1) { - ps.setIndices(c++); - } else { - int[] values = new int[np]; - for (int j = 0; j < np; j++) - values[c++] = j; - ps.setIndices(values); - } - } - - int reluCount = 0; - int tanhCount = 0; - - Random r = new Random(12345); - - for (int i = 0; i < 50; i++) { - - double[] rvs = new double[nParams]; - for (int j = 0; j < rvs.length; j++) - rvs[j] = r.nextDouble(); - - - ComputationGraphConfiguration conf = mls.getValue(rvs).getConfiguration(); - - int nLayers = conf.getVertexInputs().size(); - assertEquals(2, nLayers); - - for (int j = 0; j < nLayers; j++) { - NeuralNetConfiguration layerConf = - ((LayerVertex) conf.getVertices().get(String.valueOf(j))).getLayerConf(); - - double lr = ((Sgd)((BaseLayer) layerConf.getLayer()).getIUpdater()).getLearningRate(); - assertTrue(lr >= 0.0001 && lr <= 0.1); - double l2 = TestUtils.getL2(((BaseLayer) layerConf.getLayer())); - assertTrue(l2 >= 0.2 && l2 <= 0.5); - - if (j == nLayers - 1) { //Output layer - assertEquals(Activation.SOFTMAX.getActivationFunction(), - ((BaseLayer) layerConf.getLayer()).getActivationFn()); - } else { - IActivation actFn = ((BaseLayer) layerConf.getLayer()).getActivationFn(); - assertTrue(Activation.RELU.getActivationFunction().equals(actFn) || - Activation.TANH.getActivationFunction().equals(actFn)); - if (Activation.RELU.getActivationFunction().equals(actFn)) - reluCount++; - else - tanhCount++; - } - } - } - -// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount); - assertTrue(reluCount > 0); - assertTrue(tanhCount > 0); - - } - - -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java deleted file mode 100644 index e9b8a7f73..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecution.java +++ /dev/null @@ -1,373 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.computationgraph; - -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.ComputationGraphSpace; -import org.deeplearning4j.arbiter.conf.updater.AdamSpace; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.multilayernetwork.TestDL4JLocalExecution; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.ScoreFunctions; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator; -import org.deeplearning4j.arbiter.util.TestDataFactoryProviderMnist; -import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; -import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG; -import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; -import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.function.Supplier; -import org.nd4j.linalg.lossfunctions.LossFunctions; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.io.File; -import java.io.IOException; -import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.concurrent.TimeUnit; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -@Slf4j -public class TestGraphLocalExecution extends BaseDL4JTest { - - @TempDir - public File testDir; - - @BeforeAll - public static void before(){ - Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - } - - @Override - public long getTimeoutMilliseconds() { - return 120_000L; - } - - @Test - public void testLocalExecutionDataSources() throws Exception { - - for( int dataApproach = 0; dataApproach<3; dataApproach++ ) { - log.info("////////////////// Starting Test: {} ///////////////////", dataApproach); - - //Define: network config (hyperparameter space) - ComputationGraphSpace mls = new ComputationGraphSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)) - .addInputs("in") - .addLayer("0", - new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(10, 20)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), "in") //1-2 identical layers (except nIn) - .addLayer("1", new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "0") - .setOutputs("1") - .setInputTypes(InputType.feedForward(784)) - .numEpochs(3).build(); - - DataProvider dp = null; - Class ds = null; - Properties dsP = null; - CandidateGenerator candidateGenerator; - - if(dataApproach == 0){ - ds = TestDL4JLocalExecution.MnistDataSource.class; - dsP = new Properties(); - dsP.setProperty("minibatch", "2"); - candidateGenerator = new RandomSearchGenerator(mls); - } else if(dataApproach == 1) { - //DataProvider approach - dp = new TestDL4JLocalExecution.MnistDataProvider(); - candidateGenerator = new RandomSearchGenerator(mls); - } else { - //Factory approach - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - candidateGenerator = new RandomSearchGenerator(mls, commands); - dp = new DataSetIteratorFactoryProvider(); - } - - File f = testDir; - File modelSave = new File(f, "modelSaveDir"); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dp) - .dataSource(ds, dsP) - .modelSaver(new FileModelSaver(modelSave)) - .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), - new MaxCandidatesCondition(3)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator())); - - runner.execute(); - - List results = runner.getResults(); - assertTrue(results.size() > 0); - -// System.out.println("----- COMPLETE - " + results.size() + " results -----"); - } - } - - - @Test - public void testLocalExecution() throws Exception { - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - //Define: network config (hyperparameter space) - ComputationGraphSpace mls = new ComputationGraphSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in") - .setInputTypes(InputType.feedForward(4)) - .addLayer("layer0", - new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), - "in") - .addLayer("out", new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "layer0") - .setOutputs("out").numEpochs(3).build(); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new DataSetIteratorFactoryProvider(); - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterDL4JTest\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - f.deleteOnExit(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) - .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS), - new MaxCandidatesCondition(3)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, - new ComputationGraphTaskCreator(new ClassificationEvaluator())); - - runner.execute(); - - assertEquals(0, runner.numCandidatesFailed()); - assertTrue(runner.numCandidatesCompleted() > 0); - } - - @Test - public void testLocalExecutionMDS() throws Exception { - //Define: network config (hyperparameter space) - ComputationGraphSpace mls = new ComputationGraphSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in") - .setInputTypes(InputType.feedForward(784)) - .addLayer("layer0", - new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), - "in") - .addLayer("out", new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "layer0") - .setOutputs("out").numEpochs(3).build(); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, null); - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterDL4JTest\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - f.deleteOnExit(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(new TestMdsDataProvider(1, 32)) - .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(ScoreFunctions.testSetLoss(true)) - .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS), - new MaxCandidatesCondition(3)) - .scoreFunction(ScoreFunctions.testSetAccuracy()) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator()); - - runner.execute(); - - assertEquals(0, runner.numCandidatesFailed()); - assertTrue(runner.numCandidatesCompleted() > 0); - } - - public static class TestMdsDataProvider implements DataProvider { - private int numEpochs; - private int batchSize; - - public TestMdsDataProvider(@JsonProperty("numEpochs") int numEpochs, @JsonProperty("batchSize") int batchSize) { - this.numEpochs = numEpochs; - this.batchSize = batchSize; - } - - private TestMdsDataProvider() { - } - - - @Override - public Object trainData(Map dataParameters) { - try { - DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 3 * batchSize), false, true, true, 12345); - return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying)); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public Object testData(Map dataParameters) { - try { - DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 2 * batchSize), false, false, false, 12345); - return new MultiDataSetIteratorAdapter(underlying); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return MultiDataSetIterator.class; - } - } - - @Test - public void testLocalExecutionEarlyStopping() throws Exception { - EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() - .epochTerminationConditions(new MaxEpochsTerminationCondition(2)) - .scoreCalculator(new ScoreProvider()) - .modelSaver(new InMemoryModelSaver()).build(); - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - //Define: network config (hyperparameter space) - ComputationGraphSpace cgs = new ComputationGraphSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new AdamSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in") - .setInputTypes(InputType.feedForward(784)) - .addLayer("first", - new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), - "in") //1-2 identical layers (except nIn) - .addLayer("out", new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "first") - .setOutputs("out").earlyStoppingConfiguration(esConf).build(); - - //Define configuration: - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, commands); - DataProvider dataProvider = new DataSetIteratorFactoryProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterDL4JTest2CG\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - f.deleteOnExit(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dataProvider) - .scoreFunction(ScoreFunctions.testSetF1()) - .modelSaver(new FileModelSaver(modelSavePath)) - .terminationConditions(new MaxTimeCondition(15, TimeUnit.SECONDS), - new MaxCandidatesCondition(3)) - .build(); - - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator()); - runner.execute(); - - assertEquals(0, runner.numCandidatesFailed()); - assertTrue(runner.numCandidatesCompleted() > 0); - } - - private static class ScoreProvider implements Supplier, Serializable { - @Override - public ScoreCalculator get() { - try { - return new DataSetLossCalculatorCG(new MnistDataSetIterator(4, 8), true); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java deleted file mode 100644 index 05815a020..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/computationgraph/TestGraphLocalExecutionGenetic.java +++ /dev/null @@ -1,212 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.computationgraph; - -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.ComputationGraphSpace; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.multilayernetwork.TestDL4JLocalExecution; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.generator.GeneticSearchCandidateGenerator; -import org.deeplearning4j.arbiter.optimize.generator.genetic.population.PopulationModel; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator; -import org.deeplearning4j.arbiter.util.TestDataFactoryProviderMnist; -import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG; -import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.common.function.Supplier; -import org.nd4j.linalg.lossfunctions.LossFunctions; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.io.File; -import java.io.IOException; -import java.io.Serializable; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.concurrent.TimeUnit; - -import static org.junit.jupiter.api.Assertions.assertTrue; - -@Slf4j -public class TestGraphLocalExecutionGenetic extends BaseDL4JTest { - - @TempDir - public File testDir; - - @Override - public long getTimeoutMilliseconds() { - return 120_000L; - } - - @Test - public void testLocalExecutionDataSources() throws Exception { - for (int dataApproach = 0; dataApproach < 3; dataApproach++) { - log.info("////////////////// Starting Test: {} ///////////////////", dataApproach); - - //Define: network config (hyperparameter space) - ComputationGraphSpace mls = new ComputationGraphSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)) - .addInputs("in") - .addLayer("0", - new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(5, 32)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH, Activation.LEAKYRELU)) - .build(), "in") //1-2 identical layers (except nIn) - .addLayer("1", new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "0") - .setOutputs("1") - .setInputTypes(InputType.feedForward(784)) - .numEpochs(3).build(); - - DataProvider dp = null; - Class ds = null; - Properties dsP = null; - CandidateGenerator candidateGenerator; - - TestSetLossScoreFunction scoreFunction = new TestSetLossScoreFunction(); - - if (dataApproach == 0) { - ds = TestDL4JLocalExecution.MnistDataSource.class; - dsP = new Properties(); - dsP.setProperty("minibatch", "2"); - - candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction) - .populationModel(new PopulationModel.Builder().populationSize(5).build()) - .build(); - } else if (dataApproach == 1) { - //DataProvider approach - dp = new TestDL4JLocalExecution.MnistDataProvider(); - - candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction) - .populationModel(new PopulationModel.Builder().populationSize(5).build()) - .build(); - } else { - //Factory approach - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - candidateGenerator = new GeneticSearchCandidateGenerator.Builder(mls, scoreFunction) - .dataParameters(commands) - .populationModel(new PopulationModel.Builder().populationSize(5).build()) - .build(); - dp = new DataSetIteratorFactoryProvider(); - } - - File f = testDir; - File modelSave = new File(f, "modelSaveDir"); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dp) - .dataSource(ds, dsP) - .modelSaver(new FileModelSaver(modelSave)) - .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(20, TimeUnit.SECONDS), - new MaxCandidatesCondition(3)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator(new ClassificationEvaluator())); - - runner.execute(); - - List results = runner.getResults(); - assertTrue(results.size() > 0); - -// System.out.println("----- COMPLETE - " + results.size() + " results -----"); - } - } - - public static class TestMdsDataProvider implements DataProvider { - private int numEpochs; - private int batchSize; - - public TestMdsDataProvider(@JsonProperty("numEpochs") int numEpochs, @JsonProperty("batchSize") int batchSize) { - this.numEpochs = numEpochs; - this.batchSize = batchSize; - } - - private TestMdsDataProvider() { - } - - - @Override - public Object trainData(Map dataParameters) { - try { - DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(60000, 10 * batchSize), false, true, true, 12345); - return new MultiDataSetIteratorAdapter(new MultipleEpochsIterator(numEpochs, underlying)); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public Object testData(Map dataParameters) { - try { - DataSetIterator underlying = new MnistDataSetIterator(batchSize, Math.min(10000, 5 * batchSize), false, false, false, 12345); - return new MultiDataSetIteratorAdapter(underlying); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return MultiDataSetIterator.class; - } - } - - private static class ScoreProvider implements Supplier, Serializable { - @Override - public ScoreCalculator get() { - try { - return new DataSetLossCalculatorCG(new MnistDataSetIterator(128, 1280), true); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java deleted file mode 100644 index f4d539089..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/json/TestJson.java +++ /dev/null @@ -1,268 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.json; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.ComputationGraphSpace; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.conf.updater.AdaMaxSpace; -import org.deeplearning4j.arbiter.conf.updater.AdamSpace; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.multilayernetwork.MnistDataSetIteratorFactory; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import org.deeplearning4j.arbiter.scoring.RegressionValue; -import org.deeplearning4j.arbiter.scoring.ScoreFunctions; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.arbiter.util.TestDataFactoryProviderMnist; -import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; -import org.deeplearning4j.earlystopping.scorecalc.ClassificationScoreCalculator; -import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculatorCG; -import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator; -import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; -import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.eval.IEvaluation; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.junit.jupiter.api.Test; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.util.HashMap; -import java.util.Map; -import java.util.Properties; -import java.util.concurrent.TimeUnit; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; - -/** - * Created by Alex on 14/02/2017. - */ -public class TestJson extends BaseDL4JTest { - - @Test - public void testMultiLayerSpaceJson() { - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .addLayer(new DenseLayerSpace.Builder().nIn(1).nOut(new IntegerParameterSpace(5, 30)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.SOFTPLUS, - Activation.LEAKYRELU)) - .build(), new IntegerParameterSpace(1, 2), true) //1-2 identical layers - .addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .iLossFunction(LossFunctions.LossFunction.MCXENT.getILossFunction()).build()) - .setInputType(InputType.convolutional(28, 28, 1)).build(); - - String asJson = mls.toJson(); - // System.out.println(asJson); - - MultiLayerSpace fromJson = MultiLayerSpace.fromJson(asJson); - - assertEquals(mls, fromJson); - } - - - - @Test - public void testOptimizationFromJson() { - EarlyStoppingConfiguration esConf = - new EarlyStoppingConfiguration.Builder() - .epochTerminationConditions(new MaxEpochsTerminationCondition(100)) - .scoreCalculator(new DataSetLossCalculatorCG(new IrisDataSetIterator(150, 150), - true)) - .modelSaver(new InMemoryModelSaver()).build(); - - //Define: network config (hyperparameter space) - ComputationGraphSpace cgs = new ComputationGraphSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new AdaMaxSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in") - .setInputTypes(InputType.feedForward(4)) - .addLayer("first", - new DenseLayerSpace.Builder().nIn(4).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), - "in") //1-2 identical layers (except nIn) - .addLayer("out", new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "first") - .setOutputs("out").earlyStoppingConfiguration(esConf).build(); - - //Define configuration: - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, commands); - DataProvider dataProvider = new DataSetIteratorFactoryProvider(); - - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder().candidateGenerator(candidateGenerator) - .dataProvider(dataProvider).scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - String json = configuration.toJson(); - OptimizationConfiguration loadConf = OptimizationConfiguration.fromJson(json); - assertEquals(configuration, loadConf); - } - - @Test - public void testOptimizationFromJsonDataSource() { - for(boolean withProperties : new boolean[]{false, true}) { - //Define: network config (hyperparameter space) - ComputationGraphSpace cgs = new ComputationGraphSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new AdaMaxSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)).addInputs("in") - .setInputTypes(InputType.feedForward(4)) - .addLayer("first", - new DenseLayerSpace.Builder().nIn(4).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), - "in") //1-2 identical layers (except nIn) - .addLayer("out", new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "first") - .setOutputs("out").build(); - - //Define configuration: - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs, commands); - - Properties p = new Properties(); - p.setProperty("minibatch", "16"); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder().candidateGenerator(candidateGenerator) - .dataSource(MnistDataSource.class, (withProperties ? p : null)) - .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - String json = configuration.toJson(); - OptimizationConfiguration loadConf = OptimizationConfiguration.fromJson(json); - assertEquals(configuration, loadConf); - assertNotNull(loadConf.getDataSource()); - if(withProperties){ - assertNotNull(loadConf.getDataSourceProperties()); - } - } - } - - @Test - public void testComputationGraphSpaceJson() { - ParameterSpace p = new IntegerParameterSpace(10, 100); - ComputationGraphSpace cgs = - new ComputationGraphSpace.Builder() - .updater(new AdamSpace(new DiscreteParameterSpace<>(0.1, 0.5, 1.0))) - .seed(12345).addInputs("in") - .addLayer("0", new DenseLayerSpace.Builder() - .nIn(new IntegerParameterSpace(1, 100)).nOut(p).build(), "in") - .addLayer("1", new DenseLayerSpace.Builder().nIn(p).nOut(10).build(), "0") - .addLayer("2", new OutputLayerSpace.Builder().iLossFunction( - LossFunctions.LossFunction.MCXENT.getILossFunction()).nIn(10) - .nOut(5).build(), "1") - .setOutputs("2").build(); - - String asJson = cgs.toJson(); - ComputationGraphSpace fromJson = ComputationGraphSpace.fromJson(asJson); - - assertEquals(cgs, fromJson); - } - - @Test - public void testScoreFunctionJson() throws Exception { - - ScoreFunction[] scoreFunctions = new ScoreFunction[]{ - ScoreFunctions.testSetAccuracy(), ScoreFunctions.testSetF1(), - ScoreFunctions.testSetLoss(true), ScoreFunctions.testSetRegression(RegressionValue.MAE), - ScoreFunctions.testSetRegression(RegressionValue.RMSE)}; - - for(ScoreFunction sc : scoreFunctions){ - String json = JsonMapper.getMapper().writeValueAsString(sc); - ScoreFunction fromJson = JsonMapper.getMapper().readValue(json, ScoreFunction.class); - - assertEquals(sc, fromJson); - } - } - - - public static class MnistDataSource implements DataSource { - private int minibatch; - - public MnistDataSource(){ - - } - - @Override - public void configure(Properties properties) { - this.minibatch = Integer.parseInt(properties.getProperty("minibatch", "16")); - } - - @Override - public Object trainData() { - try { - return new MnistDataSetIterator(minibatch, true, 12345); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - - @Override - public Object testData() { - try { - return new MnistDataSetIterator(minibatch, true, 12345); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java deleted file mode 100644 index ea754990a..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MNISTOptimizationTest.java +++ /dev/null @@ -1,166 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.multilayernetwork; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; -import org.deeplearning4j.arbiter.util.TestDataFactoryProviderMnist; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; -import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator; -import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; -import org.deeplearning4j.earlystopping.termination.MaxScoreIterationTerminationCondition; -import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.File; -import java.util.HashMap; -import java.util.Map; -import java.util.concurrent.TimeUnit; - -// import org.deeplearning4j.arbiter.optimize.ui.ArbiterUIServer; -// import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener; - -/** Not strictly a unit test. Rather: part example, part debugging on MNIST */ -public class MNISTOptimizationTest extends BaseDL4JTest { - - public static void main(String[] args) throws Exception { - EarlyStoppingConfiguration esConf = - new EarlyStoppingConfiguration.Builder() - .epochTerminationConditions(new MaxEpochsTerminationCondition(3)) - .iterationTerminationConditions( - new MaxTimeIterationTerminationCondition(5, TimeUnit.MINUTES), - new MaxScoreIterationTerminationCondition(4.6) //Random score: -log_e(0.1) ~= 2.3 - ).scoreCalculator(new DataSetLossCalculator(new MnistDataSetIterator(64, 2000, false, false, true, 123), true)).modelSaver(new InMemoryModelSaver()).build(); - - //Define: network config (hyperparameter space) - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .addLayer( - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 30)) - .kernelSize(new DiscreteParameterSpace<>(new int[] {3, 3}, - new int[] {4, 4}, new int[] {5, 5})) - .stride(new DiscreteParameterSpace<>(new int[] {1, 1}, - new int[] {2, 2})) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build(), - new IntegerParameterSpace(1, 2)) //1-2 identical layers - .addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), new IntegerParameterSpace(0, 1)) //0 to 1 layers - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .earlyStoppingConfiguration(esConf).build(); - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new MnistDataSetProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterMNISTSmall\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - - // ArbiterUIServer server = ArbiterUIServer.getInstance(); - // runner.addListeners(new UIOptimizationRunnerStatusListener(server)); - - runner.execute(); - - - System.out.println("----- COMPLETE -----"); - } - - - private static class MnistDataSetProvider implements DataProvider { - - @Override - public DataSetIterator trainData(Map dataParameters) { - try { - if (dataParameters == null || dataParameters.isEmpty()) { - return new MnistDataSetIterator(64, 10000, false, true, true, 123); - } - if (dataParameters.containsKey("batchsize")) { - int b = (Integer) dataParameters.get("batchsize"); - return new MnistDataSetIterator(b, 10000, false, true, true, 123); - } - return new MnistDataSetIterator(64, 10000, false, true, true, 123); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public DataSetIterator testData(Map dataParameters) { - return trainData(dataParameters); - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - - @Override - public String toString() { - return "MnistDataSetProvider()"; - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MnistDataSetIteratorFactory.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MnistDataSetIteratorFactory.java deleted file mode 100644 index 55c2643a9..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/MnistDataSetIteratorFactory.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.multilayernetwork; - -import lombok.Data; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; - -import java.io.IOException; - -/** - * Created by agibsonccc on 3/13/17. - */ -@Data -public class MnistDataSetIteratorFactory implements DataSetIteratorFactory { - /** - * @return - */ - @Override - public DataSetIterator create() { - try { - return new MnistDataSetIterator(1000, 1000); - } catch (IOException e) { - throw new RuntimeException(e); - } - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java deleted file mode 100644 index aee1d022c..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestDL4JLocalExecution.java +++ /dev/null @@ -1,381 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.multilayernetwork; - -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OCNNLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; -import org.deeplearning4j.arbiter.util.TestDataFactoryProviderMnist; -import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; -import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; -import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; -import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator; -import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.File; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Properties; -import java.util.concurrent.TimeUnit; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -@Slf4j -public class TestDL4JLocalExecution extends BaseDL4JTest { - - @TempDir - public File testDir; - - @BeforeAll - public static void before(){ - Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - } - - @Test - public void testLocalExecution() throws Exception { - - for( int dataApproach = 0; dataApproach<3; dataApproach++ ) { - log.info("////////////////// Starting Test: {} ///////////////////", dataApproach); - - //Define: network config (hyperparameter space) - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)) - .addLayer( - new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(10, 20)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build()) //1-2 identical layers (except nIn) - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .numEpochs(3).build(); - - DataProvider dp = null; - Class ds = null; - Properties dsP = null; - CandidateGenerator candidateGenerator; - - if(dataApproach == 0){ - ds = MnistDataSource.class; - dsP = new Properties(); - dsP.setProperty("minibatch", "2"); - candidateGenerator = new RandomSearchGenerator(mls); - } else if(dataApproach == 1) { - //DataProvider approach - dp = new MnistDataProvider(); - candidateGenerator = new RandomSearchGenerator(mls); - } else { - //Factory approach - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - candidateGenerator = new RandomSearchGenerator(mls, commands); - dp = new DataSetIteratorFactoryProvider(); - } - - File f = testDir; - File modelSave = new File(f, "modelSaveDir"); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dp) - .dataSource(ds, dsP) - .modelSaver(new FileModelSaver(modelSave)) - .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(5, TimeUnit.SECONDS), - new MaxCandidatesCondition(5)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, - new MultiLayerNetworkTaskCreator(new ClassificationEvaluator())); - - runner.execute(); - - List results = runner.getResults(); - assertTrue(results.size() > 0); - - System.out.println("----- COMPLETE - " + results.size() + " results -----"); - } - } - - public static class MnistDataSource implements DataSource { - private int minibatch; - - public MnistDataSource(){ - - } - - @Override - public void configure(Properties properties) { - this.minibatch = Integer.parseInt(properties.getProperty("minibatch", "16")); - } - - @Override - public Object trainData() { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(minibatch, true, 12345), 3); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - - @Override - public Object testData() { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(minibatch, true, 12345), 3); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - } - - public static class MnistDataProvider implements DataProvider { - private int minibatch = 8; - - @Override - public Object trainData(Map dataParameters) { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(minibatch, true, 12345), 3); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - - @Override - public Object testData(Map dataParameters) { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(minibatch, true, 12345), 3); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - } - - @Test - //@org.junit.Ignore - public void testLocalExecutionGridSearch() throws Exception { - - //Define: network config (hyperparameter space) - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)) - .addLayer( - new DenseLayerSpace.Builder().nIn(4).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), - new IntegerParameterSpace(1, 2)) //1-2 identical layers (except nIn) - .addLayer(new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .numEpochs(3).build(); - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(mls, 5, - GridSearchCandidateGenerator.Mode.Sequential, commands); - DataProvider dataProvider = new DataSetIteratorFactoryProvider(); - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterDL4JTest/").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - f.deleteOnExit(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, - new MultiLayerNetworkTaskCreator(new ClassificationEvaluator())); - - runner.execute(); - - System.out.println("----- COMPLETE -----"); - } - - @Test - //@Ignore - public void testLocalExecutionEarlyStopping() throws Exception { - EarlyStoppingConfiguration esConf = new EarlyStoppingConfiguration.Builder() - .epochTerminationConditions(new MaxEpochsTerminationCondition(100)) - .scoreCalculator(new DataSetLossCalculator(new IrisDataSetIterator(150, 150), true)) - .modelSaver(new InMemoryModelSaver()).build(); - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - - //Define: network config (hyperparameter space) - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)) - .addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(new IntegerParameterSpace(2, 10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), - new IntegerParameterSpace(1, 2)) //1-2 identical layers (except nIn) - .addLayer(new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .earlyStoppingConfiguration(esConf).build(); - - //Define configuration: - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new DataSetIteratorFactoryProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterDL4JTest2\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - f.deleteOnExit(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, - new MultiLayerNetworkTaskCreator(new ClassificationEvaluator())); - - runner.execute(); - System.out.println("----- COMPLETE -----"); - } - - - @Test - public void testOcnn() { - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - - //Define: network config (hyperparameter space) - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)) - .addLayer( - new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(250, 500)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), - new IntegerParameterSpace(1, 2)) //1-2 identical layers (except nIn) - .addLayer(new OCNNLayerSpace.Builder().nu(new ContinuousParameterSpace(0.0001, 0.1)) - .numHidden(new DiscreteParameterSpace(784 / 2,784 / 4)) - .activation(Activation.HARDSIGMOID) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutionalFlat(28,28,1)) - .build(); - - //Define configuration: - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new DataSetIteratorFactoryProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterDL4JTest3\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - f.deleteOnExit(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)).scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(2, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - - //candidate generation: uncomment execute if you want to run - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, - new MultiLayerNetworkTaskCreator(new ClassificationEvaluator())); - - Candidate candidate = candidateGenerator.getCandidate(); - - // runner.execute(); - System.out.println("----- COMPLETE -----"); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java deleted file mode 100644 index a3d4e3657..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestErrors.java +++ /dev/null @@ -1,158 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.multilayernetwork; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.ComputationGraphSpace; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; -import org.deeplearning4j.arbiter.util.TestDataProviderMnist; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.Timeout; -import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.File; - -@Timeout(20) -public class TestErrors extends BaseDL4JTest { - - @TempDir - public File temp; - - @Test - public void testAllInvalidConfig() throws Exception { - //Invalid config - basically check that this actually terminates - - File f = temp; - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(new FixedValue<>(0)) //INVALID: nOut of 0 - .activation(Activation.TANH) - .build()) - .addLayer(new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .build(); - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) - .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions( - new MaxCandidatesCondition(5)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration); - runner.execute(); - } - - - @Test - public void testAllInvalidDataConfigMismatch() throws Exception { - //Valid config - but mismatched with provided data - - File f = temp; - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .addLayer(new DenseLayerSpace.Builder().nIn(4).nOut(10) //INVALID: nOut of 0 - .activation(Activation.TANH) - .build()) - .addLayer(new OutputLayerSpace.Builder().nIn(10).nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .build(); - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) - .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions( - new MaxCandidatesCondition(5)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration); - runner.execute(); - } - - - @Test - public void testAllInvalidConfigCG() throws Exception { - //Invalid config - basically check that this actually terminates - - File f = temp; - ComputationGraphSpace mls = new ComputationGraphSpace.Builder() - .addInputs("in") - .layer("0", new DenseLayerSpace.Builder().nIn(4).nOut(new FixedValue<>(0)) //INVALID: nOut of 0 - .activation(Activation.TANH) - .build(), "in") - .layer("1", new OutputLayerSpace.Builder().nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "0") - .setOutputs("1") - .build(); - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) - .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxCandidatesCondition(5)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration); - runner.execute(); - } - - - @Test - public void testAllInvalidDataConfigMismatchCG() throws Exception { - //Valid config - but mismatched with provided data - - File f = temp; - ComputationGraphSpace mls = new ComputationGraphSpace.Builder() - .addInputs("in") - .layer("0", new DenseLayerSpace.Builder().nIn(4).nOut(10) - .activation(Activation.TANH).build(), "in") - .addLayer("1", new OutputLayerSpace.Builder().nIn(10).nOut(3).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "0") - .setOutputs("1") - .build(); - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); - - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(new TestDataProviderMnist(32, 3)) - .modelSaver(new FileModelSaver(f)).scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions( - new MaxCandidatesCondition(5)) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - runner.execute(); - } - -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java deleted file mode 100644 index 3f25c66db..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestLayerSpace.java +++ /dev/null @@ -1,314 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.multilayernetwork; - -import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.TestUtils; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.layers.*; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.BooleanSpace; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.constraint.MaxNormConstraint; -import org.deeplearning4j.nn.conf.constraint.MinMaxNormConstraint; -import org.deeplearning4j.nn.conf.constraint.NonNegativeConstraint; -import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint; -import org.deeplearning4j.nn.conf.layers.*; -import org.junit.jupiter.api.Test; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.learning.config.Sgd; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Random; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -public class TestLayerSpace extends BaseDL4JTest { - - @Test - public void testBasic1() { - - DenseLayer expected = new DenseLayer.Builder().nOut(13).activation(Activation.RELU).build(); - - DenseLayerSpace space = new DenseLayerSpace.Builder().nOut(13).activation(Activation.RELU).build(); - - int nParam = space.numParameters(); - assertEquals(0, nParam); - DenseLayer actual = space.getValue(new double[nParam]); - - assertEquals(expected, actual); - } - - @Test - public void testBasic2() { - - Activation[] actFns = new Activation[]{Activation.SOFTSIGN, Activation.RELU, Activation.LEAKYRELU}; - Random r = new Random(12345); - - for (int i = 0; i < 20; i++) { - - new DenseLayer.Builder().build(); - - DenseLayerSpace ls = - new DenseLayerSpace.Builder().nOut(20) - .updater(new SgdSpace(new ContinuousParameterSpace(0.3, 0.4))) - .l2(new ContinuousParameterSpace(0.01, 0.1)) - .activation(new DiscreteParameterSpace<>(actFns)).build(); - - //Set the parameter numbers... - List list = ls.collectLeaves(); - int k = 0; - for (int j = 0; j < list.size(); j++) { - if (list.get(j).numParameters() > 0) { - list.get(j).setIndices(k++); - } - } - - int nParam = ls.numParameters(); - assertEquals(3, nParam); - - double[] d = new double[nParam]; - for (int j = 0; j < d.length; j++) { - d[j] = r.nextDouble(); - } - - DenseLayer l = ls.getValue(d); - - assertEquals(20, l.getNOut()); - double lr = ((Sgd) l.getIUpdater()).getLearningRate(); - double l2 = TestUtils.getL2(l); - IActivation activation = l.getActivationFn(); - -// System.out.println(lr + "\t" + l2 + "\t" + activation); - - assertTrue(lr >= 0.3 && lr <= 0.4); - assertTrue(l2 >= 0.01 && l2 <= 0.1); - assertTrue(containsActivationFunction(actFns, activation)); - } - } - - @Test - public void testBatchNorm() { - BatchNormalizationSpace sp = new BatchNormalizationSpace.Builder().gamma(1.5) - .beta(new ContinuousParameterSpace(2, 3)).lockGammaBeta(true).build(); - - //Set the parameter numbers... - List list = sp.collectLeaves(); - int k = 0; - for (int j = 0; j < list.size(); j++) { - if (list.get(j).numParameters() > 0) { - list.get(j).setIndices(k++); - } - } - - BatchNormalization bn = sp.getValue(new double[]{0.6}); - assertTrue(bn.isLockGammaBeta()); - assertEquals(1.5, bn.getGamma(), 0.0); - assertEquals(0.6 * (3 - 2) + 2, bn.getBeta(), 1e-4); - } - - @Test - public void testBatchNormConstrain() { - - ArrayList> constrainListOptions = new ArrayList>(); - constrainListOptions.add(Collections.singletonList((LayerConstraint) new MaxNormConstraint(0.5, 1))); - constrainListOptions.add(Collections.singletonList((LayerConstraint) new MinMaxNormConstraint(0.3, 0.4, 1.0, 1))); - constrainListOptions.add(Collections.singletonList((LayerConstraint) new NonNegativeConstraint())); - constrainListOptions.add(Collections.singletonList((LayerConstraint) new UnitNormConstraint(1))); - - DiscreteParameterSpace> constrainParamSpace = new DiscreteParameterSpace<>(constrainListOptions); - BatchNormalizationSpace sp = new BatchNormalizationSpace.Builder().gamma(1.5) - .beta(0.6).lockGammaBeta(true).constrainBeta(constrainParamSpace).constrainGamma(new NonNegativeConstraint()).build(); - - BatchNormalization bnExpected = new BatchNormalization.Builder().gamma(1.5) - .beta(0.6).lockGammaBeta(true).constrainBeta(new NonNegativeConstraint()).constrainGamma(new NonNegativeConstraint()).build(); - //Set the parameter numbers... - List list = sp.collectLeaves(); - int k = 0; - for( - int j = 0; j 0) { - list.get(j).setIndices(k++); - } - } - - assertEquals(1,sp.getNumParameters()); - BatchNormalization bn = sp.getValue(new double[]{0.6}); - assertEquals(bnExpected,bn); //0.6 should pick the 3rd value in discrete param space - - //assertEquals(bn.getConstraints().size(),2); This throws an NPE but I believe this is an issue with actual impl of BatchNormalization not arbiter -} - - @Test - public void testActivationLayer() { - Activation[] actFns = new Activation[]{Activation.SOFTSIGN, Activation.RELU, Activation.LEAKYRELU}; - - ActivationLayerSpace als = - new ActivationLayerSpace.Builder().activation(new DiscreteParameterSpace<>(actFns)).build(); - //Set the parameter numbers... - List list = als.collectLeaves(); - for (int j = 0; j < list.size(); j++) { - list.get(j).setIndices(j); - } - - int nParam = als.numParameters(); - assertEquals(1, nParam); - - Random r = new Random(12345); - - for (int i = 0; i < 20; i++) { - - double[] d = new double[nParam]; - for (int j = 0; j < d.length; j++) { - d[j] = r.nextDouble(); - } - - ActivationLayer al = als.getValue(d); - IActivation activation = al.getActivationFn(); - -// System.out.println(activation); - - assertTrue(containsActivationFunction(actFns, activation)); - } - } - - @Test - public void testEmbeddingLayer() { - - Activation[] actFns = new Activation[]{Activation.SOFTSIGN, Activation.RELU, Activation.LEAKYRELU}; - - EmbeddingLayerSpace els = new EmbeddingLayerSpace.Builder().activation(new DiscreteParameterSpace<>(actFns)) - .nIn(10).nOut(new IntegerParameterSpace(10, 20)).build(); - //Set the parameter numbers... - List list = els.collectLeaves(); - int k = 0; - for (int j = 0; j < list.size(); j++) { - if (list.get(j).numParameters() > 0) { - list.get(j).setIndices(k++); - } - } - - int nParam = els.numParameters(); - assertEquals(2, nParam); - - Random r = new Random(12345); - - for (int i = 0; i < 20; i++) { - - double[] d = new double[nParam]; - for (int j = 0; j < d.length; j++) { - d[j] = r.nextDouble(); - } - - EmbeddingLayer el = els.getValue(d); - IActivation activation = el.getActivationFn(); - long nOut = el.getNOut(); - -// System.out.println(activation + "\t" + nOut); - - assertTrue(containsActivationFunction(actFns, activation)); - assertTrue(nOut >= 10 && nOut <= 20); - } - } - - @Test - public void testSimpleConv() { - ConvolutionLayer conv2d = new Convolution2D.Builder().dilation(1,2).kernelSize(2,2).nIn(2).nOut(3).build(); - ConvolutionLayerSpace conv2dSpace = new ConvolutionLayerSpace.Builder().dilation(1,2).kernelSize(2,2).nIn(2).nOut(3).build(); - assertEquals(0,conv2dSpace.getNumParameters()); - assertEquals(conv2d, conv2dSpace.getValue(new double[0])); - - Deconvolution2DLayerSpace deconvd2dls = new Deconvolution2DLayerSpace.Builder().dilation(2,1).nIn(2).nOut(2).hasBias(new BooleanSpace()).build(); - assertEquals(1, deconvd2dls.getNumParameters()); - //Set the parameter numbers... - List list = deconvd2dls.collectLeaves(); - int k = 0; - for( - int j = 0; j 0) { - list.get(j).setIndices(k++); - } - } - Deconvolution2D actual = deconvd2dls.getValue(new double[]{0.9}); - assertTrue(!actual.hasBias()); - assertEquals(ArrayUtils.toString(new int[] {2,1} ),ArrayUtils.toString(actual.getDilation())); - } - - @Test - public void testGravesBidirectionalLayer() { - - Activation[] actFns = new Activation[]{Activation.SOFTSIGN, Activation.RELU, Activation.LEAKYRELU}; - - GravesBidirectionalLSTMLayerSpace ls = - new GravesBidirectionalLSTMLayerSpace.Builder().activation(new DiscreteParameterSpace<>(actFns)) - .forgetGateBiasInit(new ContinuousParameterSpace(0.5, 0.8)).nIn(10) - .nOut(new IntegerParameterSpace(10, 20)).build(); - //Set the parameter numbers... - List list = ls.collectLeaves(); - int k = 0; - for (int j = 0; j < list.size(); j++) { - if (list.get(j).numParameters() > 0) { - list.get(j).setIndices(k++); - } - } - - int nParam = ls.numParameters(); - assertEquals(3, nParam); //Excluding fixed value for nIn - - Random r = new Random(12345); - - for (int i = 0; i < 20; i++) { - - double[] d = new double[nParam]; - for (int j = 0; j < d.length; j++) { - d[j] = r.nextDouble(); - } - - GravesBidirectionalLSTM el = ls.getValue(d); - IActivation activation = el.getActivationFn(); - long nOut = el.getNOut(); - double forgetGate = el.getForgetGateBiasInit(); - -// System.out.println(activation + "\t" + nOut + "\t" + forgetGate); - - assertTrue(containsActivationFunction(actFns, activation)); - assertTrue(nOut >= 10 && nOut <= 20); - assertTrue(forgetGate >= 0.5 && forgetGate <= 0.8); - } - } - - private static boolean containsActivationFunction(Activation[] activationFunctions, - IActivation activationFunction) { - for (Activation af : activationFunctions) { - if (activationFunction.equals(af.getActivationFunction())) - return true; - } - return false; - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java deleted file mode 100644 index 784df4628..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestMultiLayerSpace.java +++ /dev/null @@ -1,819 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.multilayernetwork; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.DL4JConfiguration; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.TestUtils; -import org.deeplearning4j.arbiter.conf.updater.AdamSpace; -import org.deeplearning4j.arbiter.conf.updater.NesterovsSpace; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.layers.*; -import org.deeplearning4j.arbiter.optimize.api.Candidate; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.parameter.FixedValue; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.math.MathOp; -import org.deeplearning4j.arbiter.optimize.parameter.math.Op; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.impl.TestSetAccuracyScoreFunction; -import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; -import org.deeplearning4j.arbiter.util.LeafUtils; -import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator; -import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.constraint.NonNegativeConstraint; -import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint; -import org.deeplearning4j.nn.conf.dropout.Dropout; -import org.deeplearning4j.nn.conf.dropout.IDropout; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.*; -import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution; -import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; -import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution; -import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.io.TempDir; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.learning.config.Nesterovs; -import org.nd4j.linalg.learning.config.Sgd; -import org.nd4j.linalg.lossfunctions.ILossFunction; -import org.nd4j.linalg.lossfunctions.LossFunctions; -import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; -import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; -import org.nd4j.linalg.lossfunctions.impl.LossMSE; -import org.nd4j.common.primitives.Pair; - -import java.io.File; -import java.lang.reflect.Field; -import java.util.*; - -import static org.junit.jupiter.api.Assertions.*; - -public class TestMultiLayerSpace extends BaseDL4JTest { - - @TempDir - public File testDir; - - @BeforeAll - public static void before(){ - Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - } - - @Test - public void testBasic() { - - MultiLayerConfiguration expected = - new NeuralNetConfiguration.Builder() - .updater(new Sgd(0.005)).seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(2, - new OutputLayer.Builder().lossFunction(LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(10).nOut(5).build()) - - .build(); - - MultiLayerSpace mls = - new MultiLayerSpace.Builder() - .updater(new Sgd(0.005)).seed(12345) - .addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(10).build(), - new FixedValue<>(2)) //2 identical layers - .addLayer(new OutputLayerSpace.Builder().lossFunction(LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .nIn(10).nOut(5).build()).build(); - - int nParams = mls.numParameters(); - assertEquals(0, nParams); - - MultiLayerConfiguration conf = mls.getValue(new double[0]).getMultiLayerConfiguration(); - - assertEquals(expected, conf); - } - - @Test - public void testBasic0() { - MultiLayerConfiguration expected = - new NeuralNetConfiguration.Builder() - .l1Bias(0.4) - .l2Bias(0.5) - .constrainBias(new NonNegativeConstraint()) - .updater(new Sgd(0.005)).seed(12345).list() - .layer(0, new DenseLayer.Builder().l1Bias(0.6).nIn(10).nOut(10).build()) - .layer(1, new DenseLayer.Builder().l2Bias(0.7).constrainBias(new UnitNormConstraint()).nIn(10).nOut(10).build()).layer(2, - new OutputLayer.Builder().lossFunction(LossFunction.MCXENT).activation(Activation.SOFTMAX) - .nIn(10).nOut(5).build()) - .build(); - - MultiLayerSpace mls = - new MultiLayerSpace.Builder() - .l1Bias(0.4) - .l2Bias(0.5) - .constrainBias(new NonNegativeConstraint()) - .updater(new Sgd(0.005)).seed(12345) - .addLayer(new DenseLayerSpace.Builder().l1Bias(new ContinuousParameterSpace(0,1)).nIn(10).nOut(10).build()) - .addLayer(new DenseLayerSpace.Builder().l2Bias(0.7).constrainBias(new UnitNormConstraint()).nIn(10).nOut(10).build()) - .addLayer(new OutputLayerSpace.Builder().lossFunction(LossFunction.MCXENT).activation(Activation.SOFTMAX) - .nIn(10).nOut(5).build()) - .build(); - - int nParams = mls.numParameters(); - assertEquals(1, nParams); - - //Assign numbers to each leaf ParameterSpace object (normally done by candidate generator - manual here for testing) - List noDuplicatesList = LeafUtils.getUniqueObjects(mls.collectLeaves()); - - //Second: assign each a number - int c = 0; - for (ParameterSpace ps : noDuplicatesList) { - int np = ps.numParameters(); - if (np == 1) { - ps.setIndices(c++); - } else { - int[] values = new int[np]; - for (int j = 0; j < np; j++) - values[c++] = j; - ps.setIndices(values); - } - } - MultiLayerConfiguration conf = mls.getValue(new double[] {0.6}).getMultiLayerConfiguration(); - - assertEquals(expected, conf); - } - - @Test - public void testILossFunctionGetsSet() { - ILossFunction lossFunction = new LossMCXENT(Nd4j.create(new float[] {1f, 2f}, new long[]{1,2})); - - MultiLayerConfiguration expected = - new NeuralNetConfiguration.Builder().updater(new Sgd(0.005)).seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(2, - new OutputLayer.Builder().lossFunction(lossFunction) - .activation(Activation.SOFTMAX).nIn(10).nOut(5).build()) - .build(); - - MultiLayerSpace mls = new MultiLayerSpace.Builder().updater(new Sgd(0.005)).seed(12345) - .addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(10).build(), new FixedValue<>(2)) //2 identical layers - .addLayer(new OutputLayerSpace.Builder().iLossFunction(lossFunction).activation(Activation.SOFTMAX).nIn(10).nOut(5).build()) - .build(); - - int nParams = mls.numParameters(); - assertEquals(0, nParams); - - MultiLayerConfiguration conf = mls.getValue(new double[0]).getMultiLayerConfiguration(); - - assertEquals(expected, conf); - } - - @Test - public void testBasic2() { - - MultiLayerSpace mls = - new MultiLayerSpace.Builder().updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.2, 0.5)) - .convolutionMode(ConvolutionMode.Same) - .addLayer(new ConvolutionLayerSpace.Builder().nIn(3).nOut(3).kernelSize(2, 2) - .stride(1, 1).build()) - .addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(10) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.TANH)) - .build(), new IntegerParameterSpace(1, 3)) //1-3 identical layers - .addLayer(new OutputLayerSpace.Builder().nIn(10).nOut(10) - .activation(Activation.SOFTMAX).build()) - .build(); - - int nParams = mls.numParameters(); - assertEquals(4, nParams); - - //Assign numbers to each leaf ParameterSpace object (normally done by candidate generator - manual here for testing) - List noDuplicatesList = LeafUtils.getUniqueObjects(mls.collectLeaves()); - - //Second: assign each a number - int c = 0; - for (ParameterSpace ps : noDuplicatesList) { - int np = ps.numParameters(); - if (np == 1) { - ps.setIndices(c++); - } else { - int[] values = new int[np]; - for (int j = 0; j < np; j++) - values[c++] = j; - ps.setIndices(values); - } - } - - - int[] nLayerCounts = new int[3]; - int reluCount = 0; - int tanhCount = 0; - - Random r = new Random(12345); - - for (int i = 0; i < 50; i++) { - - double[] rvs = new double[nParams]; - for (int j = 0; j < rvs.length; j++) - rvs[j] = r.nextDouble(); - - - MultiLayerConfiguration conf = mls.getValue(rvs).getMultiLayerConfiguration(); - - int nLayers = conf.getConfs().size(); - assertTrue(nLayers >= 3 && nLayers <= 5); //1 conv + 1-3 dense layers + 1 output layer: 2 to 4 - - int nLayersExOutputLayer = nLayers - 1; - nLayerCounts[nLayersExOutputLayer - 2]++; - - for (int j = 0; j < nLayers; j++) { - NeuralNetConfiguration layerConf = conf.getConf(j); - - double lr = ((Sgd)((BaseLayer) layerConf.getLayer()).getIUpdater()).getLearningRate(); - assertTrue(lr >= 0.0001 && lr <= 0.1); - double l2 = TestUtils.getL2((BaseLayer) layerConf.getLayer()); - assertTrue(l2 >= 0.2 && l2 <= 0.5); - - if (j == nLayers - 1) { //Output layer - assertEquals(Activation.SOFTMAX.getActivationFunction(), ((BaseLayer) layerConf.getLayer()).getActivationFn()); - } else if (j == 0) { - //Conv layer - ConvolutionLayer cl = (ConvolutionLayer) layerConf.getLayer(); - assertEquals(3, cl.getNIn()); - assertEquals(3, cl.getNOut()); - assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); - } else { - IActivation actFn = ((BaseLayer) layerConf.getLayer()).getActivationFn(); - assertTrue(Activation.RELU.getActivationFunction().equals(actFn) || - Activation.TANH.getActivationFunction().equals(actFn)); - if (Activation.RELU.getActivationFunction().equals(actFn)) - reluCount++; - else - tanhCount++; - } - } - } - - for (int i = 0; i < 3; i++) { - assertTrue(nLayerCounts[i] >= 5); //Expect approx equal (50/3 each), but some variation randomly - } - -// System.out.println("Number of layers: " + Arrays.toString(nLayerCounts)); -// System.out.println("ReLU vs. Tanh: " + reluCount + "\t" + tanhCount); - - } - - @Test - public void testGlobalPoolingBasic() { - - MultiLayerConfiguration expected = new NeuralNetConfiguration.Builder().updater(new Sgd(0.005)).seed(12345).list() - .layer(0, new GravesLSTM.Builder().nIn(10).nOut(10).build()) - .layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.SUM).pnorm(7).build()) - .layer(2, new OutputLayer.Builder().lossFunction(LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(5).build()) - .build(); - - MultiLayerSpace mls = - new MultiLayerSpace.Builder().updater(new Sgd(0.005)).seed(12345) - .addLayer(new GravesLSTMLayerSpace.Builder().nIn(10).nOut(10).build()) - .addLayer(new GlobalPoolingLayerSpace.Builder().poolingType(PoolingType.SUM) - .pNorm(7).build()) - .addLayer(new OutputLayerSpace.Builder().lossFunction(LossFunction.MCXENT) - .activation(Activation.SOFTMAX) - .nIn(10).nOut(5).build()) - .build(); - - int nParams = mls.numParameters(); - assertEquals(0, nParams); - - MultiLayerConfiguration conf = mls.getValue(new double[0]).getMultiLayerConfiguration(); - - assertEquals(expected, conf); - } - - - @Test - public void testVariationalAutoencoderLayerSpaceBasic() { - MultiLayerSpace mls = - new MultiLayerSpace.Builder() - .updater(new Sgd(0.005)).seed( - 12345) - .addLayer(new VariationalAutoencoderLayerSpace.Builder() - .nIn(new IntegerParameterSpace(50, 75)).nOut(200) - .encoderLayerSizes(234, 567).decoderLayerSizes(123, 456) - .reconstructionDistribution( - new DiscreteParameterSpace( - new GaussianReconstructionDistribution(), - new BernoulliReconstructionDistribution())) - .build()) - .build(); - - int numParams = mls.numParameters(); - - //Assign numbers to each leaf ParameterSpace object (normally done by candidate generator - manual here for testing) - List noDuplicatesList = LeafUtils.getUniqueObjects(mls.collectLeaves()); - - //Second: assign each a number - int c = 0; - for (ParameterSpace ps : noDuplicatesList) { - int np = ps.numParameters(); - if (np == 1) { - ps.setIndices(c++); - } else { - int[] values = new int[np]; - for (int j = 0; j < np; j++) - values[c++] = j; - ps.setIndices(values); - } - } - - double[] zeros = new double[numParams]; - - DL4JConfiguration configuration = mls.getValue(zeros); - - MultiLayerConfiguration conf = configuration.getMultiLayerConfiguration(); - assertEquals(1, conf.getConfs().size()); - - NeuralNetConfiguration nnc = conf.getConf(0); - VariationalAutoencoder vae = (VariationalAutoencoder) nnc.getLayer(); - - assertEquals(50, vae.getNIn()); - assertEquals(200, vae.getNOut()); - - assertArrayEquals(new int[] {234, 567}, vae.getEncoderLayerSizes()); - assertArrayEquals(new int[] {123, 456}, vae.getDecoderLayerSizes()); - - assertTrue(vae.getOutputDistribution() instanceof GaussianReconstructionDistribution); - - - - double[] ones = new double[numParams]; - for (int i = 0; i < ones.length; i++) - ones[i] = 1.0; - - configuration = mls.getValue(ones); - - conf = configuration.getMultiLayerConfiguration(); - assertEquals(1, conf.getConfs().size()); - - nnc = conf.getConf(0); - vae = (VariationalAutoencoder) nnc.getLayer(); - - assertEquals(75, vae.getNIn()); - assertEquals(200, vae.getNOut()); - - assertArrayEquals(new int[] {234, 567}, vae.getEncoderLayerSizes()); - assertArrayEquals(new int[] {123, 456}, vae.getDecoderLayerSizes()); - - assertTrue(vae.getOutputDistribution() instanceof BernoulliReconstructionDistribution); - } - - @Test - public void testInputTypeBasic() throws Exception { - - ParameterSpace layerSizeHyperparam = new IntegerParameterSpace(20, 60); - - MultiLayerSpace hyperparameterSpace = new MultiLayerSpace.Builder().l2(0.0001) - .weightInit(WeightInit.XAVIER).updater(new Nesterovs()) - .addLayer(new ConvolutionLayerSpace.Builder().kernelSize(5, 5).nIn(1).stride(1, 1) - .nOut(layerSizeHyperparam).activation(Activation.IDENTITY).build()) - .addLayer(new SubsamplingLayerSpace.Builder().poolingType(SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2).stride(2, 2).build()) - .addLayer(new ConvolutionLayerSpace.Builder().kernelSize(5, 5) - //Note that nIn need not be specified in later layers - .stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()) - .addLayer(new SubsamplingLayerSpace.Builder().poolingType(SubsamplingLayer.PoolingType.MAX) - .kernelSize(2, 2).stride(2, 2).build()) - .addLayer(new DenseLayerSpace.Builder().activation(Activation.RELU).nOut(500).build()) - .addLayer(new OutputLayerSpace.Builder() - .lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10) - .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); - - - DataProvider dataProvider = new TestDataSetProvider(); - - File f = testDir; - if (f.exists()) - f.delete(); - f.mkdir(); - ResultSaver modelSaver = new FileModelSaver(f.getAbsolutePath()); - - ScoreFunction scoreFunction = new TestSetAccuracyScoreFunction(); - - int maxCandidates = 4; - TerminationCondition[] terminationConditions; - terminationConditions = new TerminationCondition[] {new MaxCandidatesCondition(maxCandidates)}; - - //Given these configuration options, let's put them all together: - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(new RandomSearchGenerator(hyperparameterSpace, null)) - .dataProvider(dataProvider).modelSaver(modelSaver).scoreFunction(scoreFunction) - .terminationConditions(terminationConditions).build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - runner.execute(); - - assertEquals(maxCandidates, runner.getResults().size()); - } - - - @Test - public void testSameRanges() { - - ParameterSpace l1Hyperparam = new ContinuousParameterSpace(0.001, 0.1); - ParameterSpace l2Hyperparam = new ContinuousParameterSpace(0.001, 0.1); - - MultiLayerSpace hyperparameterSpace = - new MultiLayerSpace.Builder().addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(10).build()) - .l1(l1Hyperparam).l2(l2Hyperparam).build(); - - CandidateGenerator c = new RandomSearchGenerator(hyperparameterSpace, null); - - Candidate candidate = c.getCandidate(); - } - - @Test - public void testWeightedLossFunction() { - - MultiLayerConfiguration expected = - new NeuralNetConfiguration.Builder().updater(new Sgd(0.005)).seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(2, - new OutputLayer.Builder() - .lossFunction(new LossMSE(Nd4j.create( - new double[] {1, 2, 3, 4, 5}, new long[]{1,5}))) - .nIn(10).nOut(5).build()) - .build(); - - MultiLayerSpace mls = - new MultiLayerSpace.Builder().updater(new Sgd(0.005)).seed(12345) - .addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(10).build(), - new FixedValue<>(2)) //2 identical layers - .addLayer(new OutputLayerSpace.Builder() - .iLossFunction(new LossMSE(Nd4j.create(new double[] {1, 2, 3, 4, 5}, new long[]{1,5}))) - .nIn(10).nOut(5).build()) - .build(); - - int nParams = mls.numParameters(); - assertEquals(0, nParams); - - MultiLayerConfiguration conf = mls.getValue(new double[0]).getMultiLayerConfiguration(); - - assertEquals(expected, conf); - - String json = mls.toJson(); - MultiLayerSpace fromJson = MultiLayerSpace.fromJson(json); - - assertEquals(mls, fromJson); - } - - - @Test - public void testBidirectional() throws Exception { - - MultiLayerSpace mls = - new MultiLayerSpace.Builder().updater(new Sgd(0.005)) - .seed(12345) - .layer(new Bidirectional(new LSTMLayerSpace.Builder() - .nIn(10).nOut(10).build())) - .build(); - - DL4JConfiguration conf = mls.getValue(new double[0]); - MultiLayerConfiguration c2 = conf.getMultiLayerConfiguration(); - - MultiLayerNetwork net = new MultiLayerNetwork(c2); - net.init(); - - assertEquals(1, net.getnLayers()); - assertTrue(net.getLayer(0) instanceof BidirectionalLayer); - BidirectionalLayer bl = (BidirectionalLayer)net.getLayer(0); - - Field f = BidirectionalLayer.class.getDeclaredField("fwd"); - Field b = BidirectionalLayer.class.getDeclaredField("bwd"); - f.setAccessible(true); - b.setAccessible(true); - org.deeplearning4j.nn.layers.recurrent.LSTM lstmFwd = (org.deeplearning4j.nn.layers.recurrent.LSTM) f.get(bl); - org.deeplearning4j.nn.layers.recurrent.LSTM lstmBwd = (org.deeplearning4j.nn.layers.recurrent.LSTM) b.get(bl); - - assertEquals(10, ((LSTM)lstmFwd.conf().getLayer()).getNIn()); - assertEquals(10, ((LSTM)lstmFwd.conf().getLayer()).getNOut()); - assertEquals(10, ((LSTM)lstmBwd.conf().getLayer()).getNIn()); - assertEquals(10, ((LSTM)lstmBwd.conf().getLayer()).getNOut()); - } - - - @Test - public void testMathOps() { - - ParameterSpace firstLayerSize = new IntegerParameterSpace(10,30); - ParameterSpace secondLayerSize = new MathOp<>(firstLayerSize, Op.MUL, 3); - ParameterSpace firstLayerLR = new ContinuousParameterSpace(0.01, 0.1); - ParameterSpace secondLayerLR = new MathOp<>(firstLayerLR, Op.ADD, 0.2); - - MultiLayerSpace mls = - new MultiLayerSpace.Builder().updater(new Sgd(0.005)) - .seed(12345) - .layer(new DenseLayerSpace.Builder().nOut(firstLayerSize) - .updater(new AdamSpace(firstLayerLR)) - .build()) - .layer(new OutputLayerSpace.Builder().nOut(secondLayerSize) - .updater(new AdamSpace(secondLayerLR)) - .activation(Activation.SOFTMAX) - .build()) - .setInputType(InputType.feedForward(10)) - .build(); - - int nParams = mls.numParameters(); - assertEquals(2, nParams); - - new RandomSearchGenerator(mls, null); //Initializes the indices - - Random r = new Random(12345); - for( int i=0; i<10; i++ ){ - double[] d = new double[nParams]; - for( int j=0; j dropout = new DiscreteParameterSpace<>(0.0, 0.5); - - MultiLayerSpace mls = - new MultiLayerSpace.Builder().updater(new Sgd(0.005)) - .dropOut(dropout) - .seed(12345) - .layer(new DenseLayerSpace.Builder().nOut(10) - .build()) - .layer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .build()) - .setInputType(InputType.feedForward(10)) - .build(); - - int nParams = mls.numParameters(); - assertEquals(1, nParams); - - new RandomSearchGenerator(mls, null); //Initializes the indices - - Random r = new Random(12345); - int countNull = 0; - int count05 = 0; - for( int i=0; i<10; i++ ){ - double[] d = new double[nParams]; - for( int j=0; j 0); - assertTrue(count05 > 0); - } - - - private static class TestDataSetProvider implements DataProvider { - - @Override - public Object trainData(Map dataParameters) { - return new ExistingDataSetIterator( - Collections.singletonList(new DataSet(Nd4j.create(1, 1, 28, 28), Nd4j.create(1,10)))); - } - - @Override - public Object testData(Map dataParameters) { - return new ExistingDataSetIterator( - Collections.singletonList(new DataSet(Nd4j.create(1, 1, 28, 28), Nd4j.create(1,10)))); - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - } - - - @Test - public void testDropout(){ - - MultiLayerSpace mls = new MultiLayerSpace.Builder().updater(new Sgd(0.005)).seed(12345) - .addLayer(new ConvolutionLayerSpace.Builder().nOut(2) - .dropOut(new ContinuousParameterSpace(0.4,0.6)) - .build()) - .addLayer(new GlobalPoolingLayerSpace.Builder().dropOut(new ContinuousParameterSpace(0.4,0.6)).build()) - .addLayer(new OutputLayerSpace.Builder().activation(Activation.SOFTMAX).nIn(10).nOut(5).build()) - .setInputType(InputType.convolutional(28, 28, 1)) - .build(); - - int nParams = mls.numParameters(); - List l = LeafUtils.getUniqueObjects(mls.collectLeaves()); - int x=0; - for( ParameterSpace p : l){ - int n = p.numParameters(); - int[] arr = new int[n]; - for(int i=0; i l = LeafUtils.getUniqueObjects(mls.collectLeaves()); - int x=0; - for( ParameterSpace p : l){ - int n = p.numParameters(); - int[] arr = new int[n]; - for(int i=0; i learningRateHyperparam = new DiscreteParameterSpace<>(0.003, 0.005, 0.01, 0.05); - ParameterSpace layerSizeHyperparam1 = new DiscreteParameterSpace<>(32, 64, 96, 128); - ParameterSpace layerSizeHyperparam2 = new DiscreteParameterSpace<>(32, 64, 96, 128); - ParameterSpace dropoutHyperparam = new DiscreteParameterSpace<>(0.8, 0.9); - - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .updater(new AdamSpace(learningRateHyperparam)) - .weightInit(WeightInit.XAVIER) - .l2(0.0001) - .addLayer(new DenseLayerSpace.Builder() - .nIn(10) - .nOut(layerSizeHyperparam1) - .build()) - .addLayer(new BatchNormalizationSpace.Builder() - .nOut(layerSizeHyperparam1) - .activation(Activation.RELU) - .build()) - .addLayer(new DropoutLayerSpace.Builder() - .dropOut(dropoutHyperparam) - .build()) - .addLayer(new DenseLayerSpace.Builder() - .nOut(layerSizeHyperparam2) - .build()) - .addLayer(new BatchNormalizationSpace.Builder() - .nOut(layerSizeHyperparam2) - .activation(Activation.RELU) - .build()) - .addLayer(new DropoutLayerSpace.Builder() - .dropOut(dropoutHyperparam) - .build()) - .addLayer(new OutputLayerSpace.Builder() - .nOut(10) - .activation(Activation.SOFTMAX) - .lossFunction(LossFunction.MCXENT) - .build()) - .build(); - - assertEquals(4, mls.getNumParameters()); - - for( int discreteCount : new int[]{1, 5}) { - GridSearchCandidateGenerator generator = new GridSearchCandidateGenerator(mls, discreteCount, GridSearchCandidateGenerator.Mode.Sequential, null); - - int expCandidates = 4 * 4 * 4 * 2; - assertEquals(expCandidates, generator.getTotalNumCandidates()); - - int count = 0; - while (generator.hasMoreCandidates()) { - generator.getCandidate(); - count++; - } - - - assertEquals(expCandidates, count); - } - } - - - @Test - public void testGridCandidateGenerator(){ - ParameterSpace layerSizeParam = new DiscreteParameterSpace<>(32, 48, 64); - ParameterSpace learningRateParam = new DiscreteParameterSpace<>(0.005, 0.007, 0.01); - - MultiLayerSpace hyperParamaterSpace = new MultiLayerSpace.Builder() - .seed(12345) - .biasInit(1) - .l2(1e-4) - .updater(new NesterovsSpace(learningRateParam)) - .addLayer(new DenseLayerSpace.Builder().nIn(10).nOut(layerSizeParam) - .weightInit(WeightInit.XAVIER) - .activation(Activation.RELU) - .build()) - .addLayer(new DenseLayerSpace.Builder().nIn(layerSizeParam).nOut(layerSizeParam) - .weightInit(WeightInit.XAVIER) - .activation(Activation.RELU) - .build()) - .addLayer(new OutputLayerSpace.Builder() - .lossFunction(LossFunctions.LossFunction.MSE) - .weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX) - .nIn(layerSizeParam).nOut(10).build()) - .build(); - - CandidateGenerator candidateGenerator = new GridSearchCandidateGenerator(hyperParamaterSpace, 30, GridSearchCandidateGenerator.Mode.Sequential, null); -// CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperParamaterSpace); - - Set> expCandidates = new HashSet<>(); - for(Double d : new double[]{0.005, 0.007, 0.01}){ - for(int i : new int[]{32, 48, 64}){ - expCandidates.add(new Pair<>(d, i)); - } - } - - Set> actCandidates = new HashSet<>(); - while(candidateGenerator.hasMoreCandidates()) { - Candidate conf = candidateGenerator.getCandidate(); - MultiLayerConfiguration mlc = conf.getValue().getMultiLayerConfiguration(); - FeedForwardLayer ffl = ((FeedForwardLayer) mlc.getConf(0).getLayer()); -// System.out.println(ffl.getIUpdater() + ", " + ffl.getNOut()); - actCandidates.add(new Pair<>(ffl.getIUpdater().getLearningRate(0,0), (int)ffl.getNOut())); - } - - assertEquals(expCandidates, actCandidates); - } -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java deleted file mode 100644 index e9c34c947..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/multilayernetwork/TestScoreFunctions.java +++ /dev/null @@ -1,220 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.multilayernetwork; - -import lombok.AllArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.conf.updater.AdamSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.saving.InMemoryResultSaver; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; -import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.scoring.impl.ROCScoreFunction; -import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.eval.ROC; -import org.deeplearning4j.eval.ROCBinary; -import org.deeplearning4j.eval.ROCMultiClass; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.weights.WeightInit; -import org.junit.jupiter.api.Test; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.DataSet; -import org.nd4j.linalg.dataset.api.DataSetPreProcessor; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.IOException; -import java.util.List; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -@Slf4j -public class TestScoreFunctions extends BaseDL4JTest { - - - @Override - public long getTimeoutMilliseconds() { - return 60000L; - } - - @Test - public void testROCScoreFunctions() throws Exception { - - - for (boolean auc : new boolean[]{true, false}) { - for (ROCScoreFunction.ROCType rocType : ROCScoreFunction.ROCType.values()) { - String msg = (auc ? "AUC" : "AUPRC") + " - " + rocType; - log.info("Starting: " + msg); - - ParameterSpace lr = new ContinuousParameterSpace(1e-5, 1e-3); - - int nOut = (rocType == ROCScoreFunction.ROCType.ROC ? 2 : 10); - LossFunctions.LossFunction lf = (rocType == ROCScoreFunction.ROCType.BINARY ? - LossFunctions.LossFunction.XENT : LossFunctions.LossFunction.MCXENT); - Activation a = (rocType == ROCScoreFunction.ROCType.BINARY ? Activation.SIGMOID : Activation.SOFTMAX); - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .trainingWorkspaceMode(WorkspaceMode.NONE) - .inferenceWorkspaceMode(WorkspaceMode.NONE) - .updater(new AdamSpace(lr)) - .weightInit(WeightInit.XAVIER) - .layer(new OutputLayerSpace.Builder().nIn(784).nOut(nOut) - .activation(a) - .lossFunction(lf).build()) - .build(); - - CandidateGenerator cg = new RandomSearchGenerator(mls); - ResultSaver rs = new InMemoryResultSaver(); - ScoreFunction sf = new ROCScoreFunction(rocType, (auc ? ROCScoreFunction.Metric.AUC : ROCScoreFunction.Metric.AUPRC)); - - - OptimizationConfiguration oc = new OptimizationConfiguration.Builder() - .candidateGenerator(cg) - .dataProvider(new DP(rocType)) - .modelSaver(rs) - .scoreFunction(sf) - .terminationConditions(new MaxCandidatesCondition(3)) - .rngSeed(12345) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(oc, new MultiLayerNetworkTaskCreator()); - runner.execute(); - - List list = runner.getResults(); - - for (ResultReference rr : list) { - DataSetIterator testIter = new MnistDataSetIterator(4, 16, false, false, false, 12345); - testIter.setPreProcessor(new PreProc(rocType)); - - OptimizationResult or = rr.getResult(); - MultiLayerNetwork net = (MultiLayerNetwork) or.getResultReference().getResultModel(); - - double expScore; - switch (rocType){ - case ROC: - if(auc){ - expScore = net.doEvaluation(testIter, new ROC())[0].calculateAUC(); - } else { - expScore = net.doEvaluation(testIter, new ROC())[0].calculateAUCPR(); - } - break; - case BINARY: - if(auc){ - expScore = net.doEvaluation(testIter, new ROCBinary())[0].calculateAverageAuc(); - } else { - expScore = net.doEvaluation(testIter, new ROCBinary())[0].calculateAverageAUCPR(); - } - break; - case MULTICLASS: - if(auc){ - expScore = net.doEvaluation(testIter, new ROCMultiClass())[0].calculateAverageAUC(); - } else { - expScore = net.doEvaluation(testIter, new ROCMultiClass())[0].calculateAverageAUCPR(); - } - break; - default: - throw new RuntimeException(); - } - - - DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); - iter.setPreProcessor(new PreProc(rocType)); - - assertEquals(expScore, or.getScore(), 1e-4, msg); - } - } - } - } - - @AllArgsConstructor - public static class DP implements DataProvider { - - protected ROCScoreFunction.ROCType rocType; - - @Override - public Object trainData(Map dataParameters) { - try { - DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); - iter.setPreProcessor(new PreProc(rocType)); - return iter; - } catch (IOException e){ - throw new RuntimeException(e); - } - } - - @Override - public Object testData(Map dataParameters) { - try { - DataSetIterator iter = new MnistDataSetIterator(4, 16, false, false, false, 12345); - iter.setPreProcessor(new PreProc(rocType)); - return iter; - } catch (IOException e){ - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - } - - @AllArgsConstructor - public static class PreProc implements DataSetPreProcessor { - protected ROCScoreFunction.ROCType rocType; - - @Override - public void preProcess(DataSet toPreProcess) { - switch (rocType){ - case ROC: - //Convert to binary - long mb = toPreProcess.getLabels().size(0); - INDArray argMax = Nd4j.argMax(toPreProcess.getLabels(), 1); - INDArray newLabel = Nd4j.create(mb, 2); - for( int i=0; i dataParameters) { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), terminationIter); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - - @Override - public Object testData(Map dataParameters) { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, false, 12345), terminationIter); - } catch (Exception e){ - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - - -} diff --git a/arbiter/arbiter-deeplearning4j/src/test/resources/logback.xml b/arbiter/arbiter-deeplearning4j/src/test/resources/logback.xml deleted file mode 100644 index 410bdaae9..000000000 --- a/arbiter/arbiter-deeplearning4j/src/test/resources/logback.xml +++ /dev/null @@ -1,51 +0,0 @@ - - - - - - logs/application.log - - %date - [%level] - from %logger in %thread - %n%message%n%xException%n - - - - - - %logger{15} - %message%n%xException{5} - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/arbiter/arbiter-server/pom.xml b/arbiter/arbiter-server/pom.xml deleted file mode 100644 index c38549354..000000000 --- a/arbiter/arbiter-server/pom.xml +++ /dev/null @@ -1,63 +0,0 @@ - - - - - arbiter - net.brutex.ai - 1.0.0-SNAPSHOT - - 4.0.0 - - arbiter-server - jar - - arbiter-server - - - UTF-8 - - - - - com.beust - jcommander - 1.27 - - - net.brutex.ai - arbiter-deeplearning4j - ${project.version} - - - - net.brutex.ai - deeplearning4j-common-tests - ${project.version} - test - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliGenerator.java b/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliGenerator.java deleted file mode 100644 index af19a81f7..000000000 --- a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliGenerator.java +++ /dev/null @@ -1,286 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.server; - -import com.beust.jcommander.JCommander; -import com.beust.jcommander.Parameter; -import com.beust.jcommander.ParameterException; -import org.apache.commons.io.FileUtils; -import org.deeplearning4j.arbiter.ComputationGraphSpace; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; -import org.deeplearning4j.arbiter.optimize.generator.GridSearchCandidateGenerator; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.RegressionValue; -import org.deeplearning4j.arbiter.scoring.ScoreFunctions; - -import java.io.File; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.TimeUnit; - -/** - * Generate an {@link OptimizationConfiguration} - * via the command line interface. - * You can then use this configuration json file from - * {@link ArbiterCliRunner} - * - * @author Adam Gibson - */ -public class ArbiterCliGenerator { - @Parameter(names = {"--searchSpacePath"}) - private String searchSpacePath = null; - @Parameter(names = {"--candidateType"},required = true) - private String candidateType = null; - @Parameter(names = {"--discretizationCount"}) - private int discretizationCount = 5; - @Parameter(names = {"--gridSearchOrder"}) - private String gridSearchOrder = null; - @Parameter(names = {"--neuralNetType"},required = true) - private String neuralNetType = null; - @Parameter(names = {"--dataSetIteratorClass"},required = true) - private String dataSetIteratorClass = null; - @Parameter(names = {"--modelOutputPath"},required = true) - private String modelOutputPath = null; - @Parameter(names = {"--score"},required = true) - private String score = null; - @Parameter(names = {"--problemType"},required = true) - private String problemType = CLASSIFICIATION; - @Parameter(names = {"--configSavePath"},required = true) - private String configSavePath = null; - - @Parameter(names = {"--duration"},description = "The number of minutes to run for. Default is -1 which means run till convergence.") - private long duration = -1; - @Parameter(names = {"--numCandidates"},description = "The number of candidates to generate. Default is 1.") - private int numCandidates = 1; - - public final static String REGRESSION_MULTI = "regression"; - public final static String REGRESSION = "regression"; - public final static String CLASSIFICIATION = "classification"; - - public final static String RANDOM_CANDIDATE = "random"; - public final static String GRID_SEARCH_CANDIDATE = "gridsearch"; - - public final static String SEQUENTIAL_ORDER = "sequence"; - public final static String RANDOM_ORDER = "random"; - - public final static String COMP_GRAPH = "compgraph"; - public final static String MULTI_LAYER = "multilayer"; - - public final static String ACCURACY = "accuracy"; - public final static String F1 = "f1"; - - public final static String ACCURACY_MULTI = "accuracy_multi"; - public final static String F1_MULTI = "f1_multi"; - - - public final static String REGRESSION_SCORE = "regression_score"; - public final static String REGRESSION_SCORE_MULTI = "regression_score_multi"; - - public void runMain(String...args) throws Exception { - JCommander jcmdr = new JCommander(this); - - try { - jcmdr.parse(args); - } catch(ParameterException e) { - System.err.println(e.getMessage()); - //User provides invalid input -> print the usage info - jcmdr.usage(); - try{ Thread.sleep(500); } catch(Exception e2){ } - System.exit(1); - } - - - DataProvider dataProvider = new DataSetIteratorFactoryProvider(); - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY,dataSetIteratorClass); - - - if(neuralNetType.equals(MULTI_LAYER)) { - MultiLayerSpace multiLayerSpace = loadMultiLayer(); - CandidateGenerator candidateGenerator = null; - if(candidateType.equals(GRID_SEARCH_CANDIDATE)) { - candidateGenerator = new RandomSearchGenerator(multiLayerSpace,commands); - - - - } - else if(candidateType.equals(RANDOM_CANDIDATE)) { - candidateGenerator = new RandomSearchGenerator(multiLayerSpace,commands); - - } - - if(problemType.equals(CLASSIFICIATION)) { - OptimizationConfiguration configuration - = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelOutputPath)) - .scoreFunction(scoreFunctionMultiLayerNetwork()) - .terminationConditions(getConditions()) - .build(); - FileUtils.writeStringToFile(new File(configSavePath),configuration.toJson()); - - } - else if(problemType.equals(REGRESSION)) { - OptimizationConfiguration configuration - = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelOutputPath)) - .scoreFunction(scoreFunctionMultiLayerNetwork()) - .terminationConditions(getConditions()) - .build(); - FileUtils.writeStringToFile(new File(configSavePath),configuration.toJson()); - - } - - - } - else if(neuralNetType.equals(COMP_GRAPH)) { - ComputationGraphSpace computationGraphSpace = loadCompGraph(); - CandidateGenerator candidateGenerator = null; - if(candidateType.equals(GRID_SEARCH_CANDIDATE)) { - candidateGenerator = new RandomSearchGenerator(computationGraphSpace,commands); - - } - else if(candidateType.equals(RANDOM_CANDIDATE)) { - candidateGenerator = new RandomSearchGenerator(computationGraphSpace,commands); - - } - - - if(problemType.equals(CLASSIFICIATION)) { - OptimizationConfiguration configuration - = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelOutputPath)) - .scoreFunction(scoreFunctionCompGraph()) - .terminationConditions(getConditions()) - .build(); - - FileUtils.writeStringToFile(new File(configSavePath),configuration.toJson()); - } - else { - OptimizationConfiguration configuration - = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelOutputPath)) - .scoreFunction(scoreFunctionCompGraph()) - .terminationConditions(getConditions()) - .build(); - FileUtils.writeStringToFile(new File(configSavePath),configuration.toJson()); - - - } - - - } - - - } - - public static void main(String...args) throws Exception { - new ArbiterCliGenerator().runMain(args); - } - - private List getConditions() { - List ret = new ArrayList<>(); - if(duration > 0) { - ret.add(new MaxTimeCondition(duration,TimeUnit.MINUTES)); - } - - if(numCandidates > 0) - ret.add(new MaxCandidatesCondition(numCandidates)); - if(ret.isEmpty()) { - ret.add(new MaxCandidatesCondition(1)); - } - return ret; - } - - - private GridSearchCandidateGenerator.Mode getMode() { - if(gridSearchOrder.equals(RANDOM_ORDER)) - return GridSearchCandidateGenerator.Mode.RandomOrder; - else if(gridSearchOrder.equals(SEQUENTIAL_ORDER)) { - return GridSearchCandidateGenerator.Mode.Sequential; - } - else throw new IllegalArgumentException("Illegal mode " + gridSearchOrder); - } - - private ScoreFunction scoreFunctionCompGraph() { - if(problemType.equals(CLASSIFICIATION)) { - switch(score) { - case ACCURACY: return ScoreFunctions.testSetAccuracy(); - case F1: return ScoreFunctions.testSetF1(); - case F1_MULTI : return ScoreFunctions.testSetF1(); - case ACCURACY_MULTI: return ScoreFunctions.testSetAccuracy(); - - default: throw new IllegalArgumentException("Score " + score + " not valid for type " + problemType); - } - } - else if(problemType.equals(REGRESSION)) { - switch(score) { - case REGRESSION_SCORE: return ScoreFunctions.testSetRegression(RegressionValue.valueOf(score)); - case REGRESSION_SCORE_MULTI: return ScoreFunctions.testSetRegression(RegressionValue.valueOf(score)); - default: throw new IllegalArgumentException("Score " + score + " not valid for type " + problemType); - } - } - throw new IllegalStateException("Illegal problem type " + problemType); - } - - private ScoreFunction scoreFunctionMultiLayerNetwork() { - if(problemType.equals(CLASSIFICIATION)) { - switch(score) { - case ACCURACY: return ScoreFunctions.testSetAccuracy(); - case F1: return ScoreFunctions.testSetF1(); - - default: throw new IllegalArgumentException("Score " + score + " not valid for type " + problemType); - } - } - else if(problemType.equals(REGRESSION)) { - switch(score) { - case REGRESSION_SCORE: return ScoreFunctions.testSetRegression(RegressionValue.valueOf(score)); - default: throw new IllegalArgumentException("Score " + score + " not valid for type " + problemType); - - } - } - throw new IllegalStateException("Illegal problem type " + problemType); - } - - private ComputationGraphSpace loadCompGraph() throws Exception { - ComputationGraphSpace multiLayerSpace = ComputationGraphSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath))); - return multiLayerSpace; - } - - private MultiLayerSpace loadMultiLayer() throws Exception { - MultiLayerSpace multiLayerSpace = MultiLayerSpace.fromJson(FileUtils.readFileToString(new File(searchSpacePath))); - return multiLayerSpace; - } -} diff --git a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliRunner.java b/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliRunner.java deleted file mode 100644 index c845828cf..000000000 --- a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/ArbiterCliRunner.java +++ /dev/null @@ -1,152 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.server; - -import com.beust.jcommander.JCommander; -import com.beust.jcommander.Parameter; -import com.beust.jcommander.ParameterException; -import org.apache.commons.io.FileUtils; -import org.deeplearning4j.arbiter.evaluator.multilayer.ClassificationEvaluator; -import org.deeplearning4j.arbiter.evaluator.multilayer.RegressionDataEvaluator; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.scoring.RegressionValue; -import org.deeplearning4j.arbiter.server.cli.NeuralNetTypeValidator; -import org.deeplearning4j.arbiter.server.cli.ProblemTypeValidator; -import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator; -import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; - -import java.io.File; -import java.util.HashMap; -import java.util.Map; - -/** - * Options: - * --dataSetIteratorClass - --modelSavePath - Default: /tmp - * --neuralNetType - --optimizationConfigPath - --problemType - Default: classification - --regressionType - - - - @author Adam Gibson - */ -public class ArbiterCliRunner { - @Parameter(names = {"--modelSavePath"}) - private String modelSavePath = System.getProperty("java.io.tmpdir"); - @Parameter(names = {"--optimizationConfigPath"}) - private String optimizationConfigPath = null; - @Parameter(names = {"--problemType"},validateWith = ProblemTypeValidator.class) - private String problemType = CLASSIFICATION; - @Parameter(names = {"--regressionType"}) - private String regressionType = null; - @Parameter(names = {"--dataSetIteratorClass"},required = true) - private String dataSetIteratorClass = null; - @Parameter(names = {"--neuralNetType"},required = true,validateWith = NeuralNetTypeValidator.class) - private String neuralNetType = null; - - public final static String CLASSIFICATION = "classification"; - public final static String REGRESSION = "regression"; - - - public final static String COMP_GRAPH = "compgraph"; - public final static String MULTI_LAYER_NETWORK = "multilayernetwork"; - - public void runMain(String...args) throws Exception { - JCommander jcmdr = new JCommander(this); - - try { - jcmdr.parse(args); - } catch(ParameterException e) { - System.err.println(e.getMessage()); - //User provides invalid input -> print the usage info - jcmdr.usage(); - try{ Thread.sleep(500); } catch(Exception e2){ } - System.exit(1); - } - - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY,dataSetIteratorClass); - - File f = new File(modelSavePath); - - if(f.exists()) f.delete(); - f.mkdir(); - f.deleteOnExit(); - - if(problemType.equals(REGRESSION)) { - if(neuralNetType.equals(COMP_GRAPH)) { - OptimizationConfiguration configuration - = OptimizationConfiguration.fromJson( - FileUtils.readFileToString(new File(optimizationConfigPath))); - - IOptimizationRunner runner - = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator( - new RegressionDataEvaluator(RegressionValue.valueOf(regressionType),commands))); - runner.execute(); - } - else if(neuralNetType.equals(MULTI_LAYER_NETWORK)) { - OptimizationConfiguration configuration = OptimizationConfiguration. - fromJson(FileUtils.readFileToString(new File(optimizationConfigPath))); - - IOptimizationRunner runner - = new LocalOptimizationRunner( - configuration, - new MultiLayerNetworkTaskCreator( - new RegressionDataEvaluator( - RegressionValue.valueOf(regressionType), - commands))); - runner.execute(); - } - } - - else if(problemType.equals(CLASSIFICATION)) { - if(neuralNetType.equals(COMP_GRAPH)) { - OptimizationConfiguration configuration - = OptimizationConfiguration.fromJson(FileUtils.readFileToString(new File(optimizationConfigPath))); - - IOptimizationRunner runner - = new LocalOptimizationRunner( - configuration,new ComputationGraphTaskCreator(new ClassificationEvaluator())); - - runner.execute(); - } - else if(neuralNetType.equals(MULTI_LAYER_NETWORK)) { - OptimizationConfiguration configuration = OptimizationConfiguration - .fromJson(FileUtils.readFileToString(new File(optimizationConfigPath))); - - IOptimizationRunner runner - = new LocalOptimizationRunner(configuration, - new MultiLayerNetworkTaskCreator( - new ClassificationEvaluator()) - ); - - runner.execute(); - } - } - } - public static void main(String...args) throws Exception { - new ArbiterCliRunner().runMain(args); - } - -} diff --git a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/cli/NeuralNetTypeValidator.java b/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/cli/NeuralNetTypeValidator.java deleted file mode 100644 index 1a338bdc0..000000000 --- a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/cli/NeuralNetTypeValidator.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.server.cli; - -import com.beust.jcommander.IParameterValidator; -import com.beust.jcommander.ParameterException; -import org.deeplearning4j.arbiter.server.ArbiterCliRunner; - -/** - * Created by agibsonccc on 3/13/17. - */ -public class NeuralNetTypeValidator implements IParameterValidator { - /** - * Validate the parameter. - * - * @param name The name of the parameter (e.g. "-host"). - * @param value The value of the parameter that we need to validate - * @throws ParameterException Thrown if the value of the parameter is invalid. - */ - @Override - public void validate(String name, String value) throws ParameterException { - if(!value.equals(ArbiterCliRunner.MULTI_LAYER_NETWORK) || value.equals(ArbiterCliRunner.COMP_GRAPH)) { - throw new ParameterException("Neural net type can only be " + ArbiterCliRunner.COMP_GRAPH + " or " + ArbiterCliRunner.MULTI_LAYER_NETWORK); - - } - } -} diff --git a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/cli/ProblemTypeValidator.java b/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/cli/ProblemTypeValidator.java deleted file mode 100644 index 3df2f6449..000000000 --- a/arbiter/arbiter-server/src/main/java/org/deeplearning4j/arbiter/server/cli/ProblemTypeValidator.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.server.cli; - -import com.beust.jcommander.IParameterValidator; -import com.beust.jcommander.ParameterException; -import org.deeplearning4j.arbiter.server.ArbiterCliGenerator; - -/** - * Created by agibsonccc on 3/13/17. - */ -public class ProblemTypeValidator implements IParameterValidator { - /** - * Validate the parameter. - * - * @param name The name of the parameter (e.g. "-host"). - * @param value The value of the parameter that we need to validate - * @throws ParameterException Thrown if the value of the parameter is invalid. - */ - @Override - public void validate(String name, String value) throws ParameterException { - if(!value.equals(ArbiterCliGenerator.REGRESSION) || value.equals(ArbiterCliGenerator.CLASSIFICIATION)) { - throw new ParameterException("Problem type can only be " + ArbiterCliGenerator.REGRESSION + " or " + ArbiterCliGenerator.CLASSIFICIATION); - - } - } -} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java deleted file mode 100644 index 5efcd9657..000000000 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/ArbiterCLIRunnerTest.java +++ /dev/null @@ -1,121 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.server; - -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSetIteratorFactoryProvider; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.junit.jupiter.api.Test; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.learning.config.Sgd; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.File; -import java.util.HashMap; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.TimeUnit; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** - * Created by agibsonccc on 3/12/17. - */ -@Slf4j -public class ArbiterCLIRunnerTest extends BaseDL4JTest { - - @Override - public long getTimeoutMilliseconds() { - return 90000; - } - - @Test - public void testCliRunner() throws Exception { - ArbiterCliRunner cliRunner = new ArbiterCliRunner(); - - //Define: network config (hyperparameter space) - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.1))) - .l2(new ContinuousParameterSpace(0.0001, 0.01)) - .addLayer(new DenseLayerSpace.Builder().nIn(784).nOut(new IntegerParameterSpace(2,10)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build()) - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .numEpochs(3).build(); - assertEquals(mls,MultiLayerSpace.fromJson(mls.toJson())); - //Define configuration: - Map commands = new HashMap<>(); - commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY,TestDataFactoryProviderMnist.class.getCanonicalName()); - - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls,commands); - DataProvider dataProvider = new DataSetIteratorFactoryProvider(); - - -// String modelSavePath = FilenameUtils.concat(System.getProperty("java.io.tmpdir"),"ArbiterDL4JTest/"); - String modelSavePath = new File(System.getProperty("java.io.tmpdir"),"ArbiterDL4JTest/").getAbsolutePath(); - File dir = new File(modelSavePath); - if(!dir.exists()) - dir.mkdirs(); - String configPath = System.getProperty("java.io.tmpdir") + File.separator + UUID.randomUUID().toString() + ".json"; - OptimizationConfiguration configuration - = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator) - .dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction()) - .terminationConditions(new MaxTimeCondition(30, TimeUnit.SECONDS), - new MaxCandidatesCondition(5)) - .build(); - assertEquals(configuration,OptimizationConfiguration.fromJson(configuration.toJson())); - - FileUtils.writeStringToFile(new File(configPath),configuration.toJson()); -// System.out.println(configuration.toJson()); - configuration.toJson(); - - log.info("Starting test"); - cliRunner.runMain( - "--dataSetIteratorClass", - TestDataFactoryProviderMnist.class.getCanonicalName(), - "--neuralNetType", - ArbiterCliRunner.MULTI_LAYER_NETWORK, - "--optimizationConfigPath", - configPath - ); - } - - - -} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java deleted file mode 100644 index 256a8af9b..000000000 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,50 +0,0 @@ -/* ****************************************************************************** - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.arbiter.server; - -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.BaseDL4JTest; -import org.nd4j.common.tests.AbstractAssertTestsClass; - -import java.util.*; - -/** - * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) - * extends BaseDl4jTest - either directly or indirectly. - * Other than a small set of exceptions, all tests must extend this - * - * @author Alex Black - */ - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.deeplearning4j.arbiter.server"; - } - - @Override - protected Class getBaseClass() { - return BaseDL4JTest.class; - } -} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java deleted file mode 100644 index 57bef758d..000000000 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java +++ /dev/null @@ -1,43 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.server; - -import lombok.Data; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; - -import java.io.IOException; - -/** - * Created by agibsonccc on 3/13/17. - */ -@Data -public class MnistDataSetIteratorFactory extends BaseDL4JTest implements DataSetIteratorFactory { - /** - * @return - */ - @Override - public DataSetIterator create() { - try { - return new MnistDataSetIterator(1000,1000); - } catch (IOException e) { - throw new RuntimeException(e); - } - } -} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java deleted file mode 100644 index c4a75ffb4..000000000 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.server; - -import lombok.AllArgsConstructor; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; - -@AllArgsConstructor -public class TestDataFactoryProviderMnist extends BaseDL4JTest implements DataSetIteratorFactory { - - private int batchSize; - private int terminationIter; - - public TestDataFactoryProviderMnist(){ - this(16, 10); - } - - @Override - public DataSetIterator create() { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(batchSize, true, 12345), terminationIter); - } catch (Exception e){ - throw new RuntimeException(e); - } - } -} diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml deleted file mode 100644 index 86f4530a9..000000000 --- a/arbiter/arbiter-ui/pom.xml +++ /dev/null @@ -1,73 +0,0 @@ - - - - - - arbiter - net.brutex.ai - 1.0.0-SNAPSHOT - - - 4.0.0 - - arbiter-ui - arbiter-ui - - - - net.brutex.ai - arbiter-core - ${project.version} - - - - net.brutex.ai - deeplearning4j-ui - ${project.version} - - - - net.brutex.ai - deeplearning4j-common-tests - ${project.version} - test - - - - net.brutex.ai - arbiter-deeplearning4j - ${project.version} - - - com.fasterxml.jackson.core - jackson-core - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-annotations - ${jackson.version} - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.databind.version} - - - diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/UpdateStatus.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/UpdateStatus.java deleted file mode 100644 index a92b4f0e7..000000000 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/UpdateStatus.java +++ /dev/null @@ -1,33 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.ui; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; - -@AllArgsConstructor -@NoArgsConstructor -@EqualsAndHashCode -@Data -public class UpdateStatus { - - private long statusUpdateTime; - private long settingsUpdateTime; - private long resultsUpdateTime; -} diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/BaseJavaPersistable.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/BaseJavaPersistable.java deleted file mode 100644 index 1fb699e0b..000000000 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/BaseJavaPersistable.java +++ /dev/null @@ -1,159 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.ui.data; - -import lombok.AllArgsConstructor; -import org.apache.commons.compress.utils.IOUtils; -import org.deeplearning4j.core.storage.Persistable; -import org.deeplearning4j.arbiter.ui.module.ArbiterModule; - -import java.io.*; -import java.lang.reflect.Field; -import java.lang.reflect.Modifier; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; - -/** - * Common implementation - * - * @author Alex Black - */ -@AllArgsConstructor -public abstract class BaseJavaPersistable implements Persistable { - - private String sessionId; - private long timestamp; - - public BaseJavaPersistable(Builder builder){ - this.sessionId = builder.sessionId; - this.timestamp = builder.timestamp; - } - - protected BaseJavaPersistable(){ - //No-arg costructor for Pesistable encoding/decoding - } - - @Override - public String getTypeID() { - return ArbiterModule.ARBITER_UI_TYPE_ID; - } - - @Override - public long getTimeStamp() { - return timestamp; - } - - @Override - public String getSessionID() { - return sessionId; - } - - @Override - public int encodingLengthBytes() { - //TODO - presumably a more efficient way to do this - byte[] encoded = encode(); - return encoded.length; - } - - @Override - public byte[] encode() { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - try (ObjectOutputStream oos = new ObjectOutputStream(baos)) { - oos.writeObject(this); - } catch (IOException e) { - throw new RuntimeException(e); //Should never happen - } - return baos.toByteArray(); - } - - @Override - public void encode(ByteBuffer buffer) { - buffer.put(encode()); - } - - @Override - public void encode(OutputStream outputStream) throws IOException { - try (ObjectOutputStream oos = new ObjectOutputStream(outputStream)) { - oos.writeObject(this); - } - } - - @Override - public void decode(byte[] decode) { - BaseJavaPersistable r; - try (ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(decode))) { - r = (BaseJavaPersistable) ois.readObject(); - } catch (IOException | ClassNotFoundException e) { - throw new RuntimeException(e); //Should never happen - } - - //Need to manually build and walk the class heirarchy... - Class currClass = this.getClass(); - List> classHeirarchy = new ArrayList<>(); - while (currClass != Object.class) { - classHeirarchy.add(currClass); - currClass = currClass.getSuperclass(); - } - - for (int i = classHeirarchy.size() - 1; i >= 0; i--) { - //Use reflection here to avoid a mass of boilerplate code... - Field[] allFields = classHeirarchy.get(i).getDeclaredFields(); - - for (Field f : allFields) { - if (Modifier.isStatic(f.getModifiers())) { - //Skip static fields - continue; - } - f.setAccessible(true); - try { - f.set(this, f.get(r)); - } catch (IllegalAccessException e) { - throw new RuntimeException(e); //Should never happen - } - } - } - } - - @Override - public void decode(ByteBuffer buffer) { - byte[] bytes = new byte[buffer.remaining()]; - buffer.get(bytes); - decode(bytes); - } - - @Override - public void decode(InputStream inputStream) throws IOException { - decode(IOUtils.toByteArray(inputStream)); - } - - public static abstract class Builder> { - protected String sessionId; - protected long timestamp; - - public T sessionId(String sessionId){ - this.sessionId = sessionId; - return (T) this; - } - - public T timestamp(long timestamp){ - this.timestamp = timestamp; - return (T) this; - } - - } -} diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java deleted file mode 100644 index 9a6c3faa9..000000000 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/GlobalConfigPersistable.java +++ /dev/null @@ -1,119 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.ui.data; - -import lombok.Getter; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import org.deeplearning4j.arbiter.ui.module.ArbiterModule; -import org.deeplearning4j.core.storage.Persistable; - -import java.io.IOException; - -/** - * - * A {@link Persistable} implemention for global settings - * @author Alex Black - */ -@Getter -public class GlobalConfigPersistable extends BaseJavaPersistable { - public static final String GLOBAL_WORKER_ID = "global"; - - private String optimizationConfigJson; - private int[] candidateCounts; //queued, completed, failed, total - private String optimizationRunner; - - public GlobalConfigPersistable(String sessionId, long timestamp){ - super(sessionId, timestamp); - } - - public GlobalConfigPersistable(Builder builder){ - super(builder); - this.optimizationConfigJson = builder.optimizationConfigJson; - this.candidateCounts = builder.candidateCounts; - if(this.candidateCounts == null){ - this.candidateCounts = new int[4]; - } - this.optimizationRunner = builder.optimizationRunner; - } - - public GlobalConfigPersistable(){ - //No-arg costructor for Pesistable encoding/decoding - } - - @Override - public String getTypeID() { - return ArbiterModule.ARBITER_UI_TYPE_ID; - } - - @Override - public String getWorkerID() { - return GLOBAL_WORKER_ID; - } - - - public OptimizationConfiguration getOptimizationConfiguration(){ - try { - return JsonMapper.getMapper().readValue(optimizationConfigJson, OptimizationConfiguration.class); - } catch (IOException e){ - throw new RuntimeException(e); - } - } - - public int getCandidatesQueued(){ - return candidateCounts[0]; - } - - public int getCandidatesCompleted(){ - return candidateCounts[1]; - } - - public int getCandidatesFailed(){ - return candidateCounts[2]; - } - - public int getCandidatesTotal(){ - return candidateCounts[3]; - } - - public static class Builder extends BaseJavaPersistable.Builder{ - - private String optimizationConfigJson; - private int[] candidateCounts; //queued, completed, failed, total - private String optimizationRunner; - - public Builder optimizationConfigJson(String optimizationConfigJson){ - this.optimizationConfigJson = optimizationConfigJson; - return this; - } - - public Builder candidateCounts(int queued, int completed, int failed, int total){ - this.candidateCounts = new int[]{queued, completed, failed, total}; - return this; - } - - public Builder optimizationRunner(String optimizationRunner){ - this.optimizationRunner = optimizationRunner; - return this; - } - - public GlobalConfigPersistable build(){ - return new GlobalConfigPersistable(this); - } - - } -} diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/ModelInfoPersistable.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/ModelInfoPersistable.java deleted file mode 100644 index 4d1ee4e5f..000000000 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/data/ModelInfoPersistable.java +++ /dev/null @@ -1,163 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.ui.data; - -import lombok.Data; -import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus; -import org.deeplearning4j.core.storage.Persistable; - -/** - * A {@link Persistable} implemention for model results - i.e., results for - * each model - * - * @author Alex BLack - */ -@Data -public class ModelInfoPersistable extends BaseJavaPersistable { - - private String workerId; - private Integer modelIdx; - private Double score; - private CandidateStatus status; - private long lastUpdateTime; - private long numParameters; - private int numLayers; - //From candidate generator - this + model hyperparam space means we can work out specific hyperparam - // settings for this model - private double[] paramSpaceValues; - private int totalNumUpdates; - //Values for score vs. iteration chart - private int[] iter; - private float[] scoreVsIter; - private String modelConfigJson; - private String exceptionStackTrace; - - public ModelInfoPersistable(String sessionId, String workerId, long timeStamp){ - super(sessionId, timeStamp); - - this.workerId = workerId; - } - - private ModelInfoPersistable(Builder builder){ - super(builder); - this.workerId = builder.workerId; - this.modelIdx = builder.modelIdx; - this.score = builder.score; - this.status = builder.status; - this.iter = builder.iter; - this.scoreVsIter = builder.scoreVsIter; - this.lastUpdateTime = builder.lastUpdateTime; - this.numParameters = builder.numParameters; - this.numLayers = builder.numLayers; - this.paramSpaceValues = builder.paramSpaceValues; - this.modelConfigJson = builder.modelConfigJson; - this.totalNumUpdates = builder.totalNumUpdates; - this.exceptionStackTrace = builder.exceptionStackTrace; - } - - public ModelInfoPersistable(){ - //No-arg costructor for Pesistable encoding/decoding - } - - @Override - public String getWorkerID() { - return workerId; - } - - - public static class Builder extends BaseJavaPersistable.Builder { - - private String workerId; - private Integer modelIdx; - private Double score; - private CandidateStatus status; - private long lastUpdateTime;; - private long numParameters; - private int numLayers; - private int totalNumUpdates; - private double[] paramSpaceValues; - private int[] iter; - private float[] scoreVsIter; - private String modelConfigJson; - private String exceptionStackTrace; - - public Builder workerId(String workerId){ - this.workerId = workerId; - return this; - } - - public Builder modelIdx(Integer idx){ - this.modelIdx = idx; - return this; - } - - public Builder score(Double score){ - this.score = score; - return this; - } - - public Builder status(CandidateStatus status){ - this.status = status; - return this; - } - - public Builder scoreVsIter(int[] iter, float[] scoreVsIter){ - this.iter = iter; - this.scoreVsIter = scoreVsIter; - return this; - } - - public Builder lastUpdateTime(long lastUpdateTime){ - this.lastUpdateTime = lastUpdateTime; - return this; - } - - public Builder numParameters(long numParameters){ - this.numParameters = numParameters; - return this; - } - - public Builder numLayers(int numLayers){ - this.numLayers = numLayers; - return this; - } - - public Builder totalNumUpdates(int totalNumUpdates){ - this.totalNumUpdates = totalNumUpdates; - return this; - } - - public Builder paramSpaceValues(double[] paramSpaceValues){ - this.paramSpaceValues = paramSpaceValues; - return this; - } - - public Builder modelConfigJson(String modelConfigJson){ - this.modelConfigJson = modelConfigJson; - return this; - } - - public Builder exceptionStackTrace(String exceptionStackTrace){ - this.exceptionStackTrace = exceptionStackTrace; - return this; - } - - public ModelInfoPersistable build(){ - return new ModelInfoPersistable(this); - } - } -} diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java deleted file mode 100644 index c14258be2..000000000 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/listener/ArbiterStatusListener.java +++ /dev/null @@ -1,238 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.ui.listener; - -import it.unimi.dsi.fastutil.floats.FloatArrayList; -import it.unimi.dsi.fastutil.ints.IntArrayList; -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.core.storage.Persistable; -import org.deeplearning4j.core.storage.StatsStorageRouter; -import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; -import org.deeplearning4j.arbiter.optimize.runner.CandidateInfo; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; -import org.deeplearning4j.arbiter.optimize.serde.jackson.JsonMapper; -import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable; -import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.nd4j.common.primitives.Pair; - -import java.io.IOException; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; - -/** - * A {@link StatusListener} for reporting Arbiter/DL4J optimization results to a {@link StatsStorageRouter} - * - * @author Alex Black - */ -@Slf4j -public class ArbiterStatusListener implements StatusListener { - - public static final int MAX_SCORE_VS_ITER_PTS = 1024; //Above this: subsample... every 2nd, 4th, 8th etc - - private final String sessionId; - private final StatsStorageRouter statsStorage; - - private String ocJson; - private long startTime = 0; - - private Map candidateScoreVsIterSubsampleFreq = new ConcurrentHashMap<>(); - private Map> candidateScoreVsIter = new ConcurrentHashMap<>(); - - private Map lastModelInfoPersistable = new ConcurrentHashMap<>(); - - public ArbiterStatusListener(@NonNull StatsStorageRouter statsStorage) { - this(UUID.randomUUID().toString(), statsStorage); - } - - public ArbiterStatusListener(@NonNull String sessionId, @NonNull StatsStorageRouter statsStorage){ - this.sessionId = sessionId; - this.statsStorage = statsStorage; - } - - @Override - public void onInitialization(IOptimizationRunner r) { - Persistable p = getNewStatusPersistable(r); - statsStorage.putStaticInfo(p); - } - - @Override - public void onShutdown(IOptimizationRunner runner) { - //No op? - - } - - @Override - public void onRunnerStatusChange(IOptimizationRunner r) { - Persistable p = getNewStatusPersistable(r); - statsStorage.putStaticInfo(p); - } - - @Override - public void onCandidateStatusChange(CandidateInfo candidateInfo, IOptimizationRunner runner, OptimizationResult result) { - ModelInfoPersistable p = lastModelInfoPersistable.get(candidateInfo.getIndex()); - if(p == null){ - p = new ModelInfoPersistable.Builder() - .timestamp(candidateInfo.getCreatedTime()) - .sessionId(sessionId) - .workerId(String.valueOf(candidateInfo.getIndex())) - .modelIdx(candidateInfo.getIndex()) - .score(candidateInfo.getScore()) - .status(candidateInfo.getCandidateStatus()) - .exceptionStackTrace(candidateInfo.getExceptionStackTrace()) - .build(); - - lastModelInfoPersistable.put(candidateInfo.getIndex(), p); - } - - if(p.getScore() == null){ - p.setScore(candidateInfo.getScore()); - } - - if(result != null && p.getExceptionStackTrace() == null && result.getCandidateInfo().getExceptionStackTrace() != null){ - //Update exceptions that may have occurred since earlier model info instance - p.setExceptionStackTrace(result.getCandidateInfo().getExceptionStackTrace()); - } - - p.setStatus(candidateInfo.getCandidateStatus()); - - statsStorage.putUpdate(p); - } - - @Override - public void onCandidateIteration(CandidateInfo candidateInfo, Object candidate, int iteration) { - double score; - long numParams; - int numLayers; - String modelConfigJson; - int totalNumUpdates; - if(candidate instanceof MultiLayerNetwork){ - MultiLayerNetwork m = (MultiLayerNetwork)candidate; - score = m.score(); - numParams = m.numParams(); - numLayers = m.getnLayers(); - modelConfigJson = m.getLayerWiseConfigurations().toJson(); - totalNumUpdates = m.getLayerWiseConfigurations().getIterationCount(); - } else if(candidate instanceof ComputationGraph) { - ComputationGraph cg = (ComputationGraph)candidate; - score = cg.score(); - numParams = cg.numParams(); - numLayers = cg.getNumLayers(); - modelConfigJson = cg.getConfiguration().toJson(); - totalNumUpdates = cg.getConfiguration().getIterationCount(); - } else { - score = 0; - numParams = 0; - numLayers = 0; - totalNumUpdates = 0; - modelConfigJson = ""; - } - - int idx = candidateInfo.getIndex(); - - Pair pair = candidateScoreVsIter.computeIfAbsent(idx, k -> new Pair<>(new IntArrayList(), new FloatArrayList())); - - IntArrayList iter = pair.getFirst(); - FloatArrayList scores = pair.getSecond(); - - //Do we need subsampling to avoid having too many data points? - int subsamplingFreq = candidateScoreVsIterSubsampleFreq.computeIfAbsent(idx, k -> 1); - if(iteration / subsamplingFreq > MAX_SCORE_VS_ITER_PTS){ - //Double subsampling frequency and re-parse data - subsamplingFreq *= 2; - candidateScoreVsIterSubsampleFreq.put(idx, subsamplingFreq); - - IntArrayList newIter = new IntArrayList(); - FloatArrayList newScores = new FloatArrayList(); - for( int i=0; i(iter, scores)); - } - - if(iteration % subsamplingFreq == 0) { - iter.add(iteration); - scores.add((float) score); - } - - - int[] iters = iter.toIntArray(); - float[] fScores = new float[iters.length]; - for( int i=0; i T fromJson(String json, Class type){ - try{ - return getMapper().readValue(json, type); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public static ObjectMapper getInstance(){ - return MAPPER; - } - -} diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/UIUtils.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/UIUtils.java deleted file mode 100644 index 8ea969c82..000000000 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/misc/UIUtils.java +++ /dev/null @@ -1,112 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.ui.misc; - -import org.joda.time.Period; -import org.joda.time.PeriodType; -import org.joda.time.format.PeriodFormatter; -import org.joda.time.format.PeriodFormatterBuilder; - -/** - * Created by Alex on 20/07/2017. - */ -public class UIUtils { - - /** - * Convert the "messy" min/max values on a dataset to something clean. For example, 0.895732 becomes 1.0 - * - * @param max Maximum data point value - * @param min Minimum data point value - * @param nTick Number of tick marks desired on chart (good setting: 5) - * @return double[] of length 2 - with new minimum and maximum - */ - public static double[] graphNiceRange(double max, double min, int nTick){ - if(max == min || !Double.isFinite(max)){ - if(max == 0.0 || !Double.isFinite(max)){ - return new double[]{0.0, 1.0}; - } - - return graphNiceRange(1.5 * max, 0.5 * max, nTick); - } - - double range = niceNum(max-min, false); - double d = niceNum(range / (nTick-1), true ); - double graphMin = Math.floor(min/d)*d; - double graphMax = Math.ceil(max/d)*d; - - - return new double[]{graphMin, graphMax}; - } - - public static double niceNum(double x, boolean round){ - double exp = Math.floor(Math.log10(x)); - double f = x / Math.pow(10, exp); - - double nf; - if(round){ - if(f < 1.5 ){ - nf = 1; - } else if( f < 3){ - nf = 2; - } else if( f < 7){ - nf = 5; - } else { - nf = 10; - } - } else { - if(f <= 1 ){ - nf = 1; - } else if( f <= 2){ - nf = 2; - } else if( f <= 5){ - nf = 5; - } else { - nf = 10; - } - } - return nf * Math.pow(10, exp); - } - - /** - * Format the duration in milliseconds to a human readable String, with "yr", "days", "hr" etc prefixes - * - * - * @param durationMs Duration in milliseconds - * @return Human readable string - */ - public static String formatDuration(long durationMs){ - Period period = Period.seconds((int)(durationMs/1000L)); - Period p2 = period.normalizedStandard(PeriodType.yearMonthDayTime()); - - PeriodFormatter formatter = new PeriodFormatterBuilder() - .appendYears() - .appendSuffix(" yr ") - .appendMonths() - .appendSuffix(" months ") - .appendDays() - .appendSuffix(" days ") - .appendHours() - .appendSuffix(" hr ") - .appendMinutes() - .appendSuffix(" min ") - .appendSeconds() - .appendSuffix(" sec") - .toFormatter(); - - return formatter.print(p2); - } -} diff --git a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java b/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java deleted file mode 100644 index 1ee0ce729..000000000 --- a/arbiter/arbiter-ui/src/main/java/org/deeplearning4j/arbiter/ui/module/ArbiterModule.java +++ /dev/null @@ -1,943 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * Copyright (c) 2019 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.ui.module; - -import com.fasterxml.jackson.core.JsonProcessingException; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.vertx.ext.web.RoutingContext; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.core.storage.Persistable; -import org.deeplearning4j.core.storage.StatsStorage; -import org.deeplearning4j.core.storage.StatsStorageEvent; -import org.deeplearning4j.core.storage.StatsStorageListener; -import org.deeplearning4j.arbiter.BaseNetworkSpace; -import org.deeplearning4j.arbiter.layers.LayerSpace; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.runner.CandidateStatus; -import org.deeplearning4j.arbiter.ui.UpdateStatus; -import org.deeplearning4j.arbiter.ui.data.GlobalConfigPersistable; -import org.deeplearning4j.arbiter.ui.data.ModelInfoPersistable; -import org.deeplearning4j.arbiter.ui.misc.UIUtils; -import org.deeplearning4j.arbiter.util.ObjectUtils; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.deeplearning4j.ui.VertxUIServer; -import org.deeplearning4j.ui.api.Component; -import org.deeplearning4j.ui.api.*; -import org.deeplearning4j.ui.components.chart.ChartLine; -import org.deeplearning4j.ui.components.chart.ChartScatter; -import org.deeplearning4j.ui.components.chart.style.StyleChart; -import org.deeplearning4j.ui.components.component.ComponentDiv; -import org.deeplearning4j.ui.components.component.style.StyleDiv; -import org.deeplearning4j.ui.components.table.ComponentTable; -import org.deeplearning4j.ui.components.table.style.StyleTable; -import org.deeplearning4j.ui.components.text.ComponentText; -import org.deeplearning4j.ui.components.text.style.StyleText; -import org.deeplearning4j.ui.i18n.I18NResource; -import org.joda.time.format.DateTimeFormat; -import org.joda.time.format.DateTimeFormatter; -import org.nd4j.common.function.Function; -import org.nd4j.common.primitives.Pair; - -import java.awt.*; -import java.text.DecimalFormat; -import java.util.List; -import java.util.*; -import java.util.concurrent.atomic.AtomicBoolean; - -/** - * A Deeplearning4j {@link UIModule}, for integration with DL4J's user interface - * - * @author Alex Black - */ -@Slf4j -public class ArbiterModule implements UIModule { - - private static final DecimalFormat DECIMAL_FORMAT_2DP = new DecimalFormat("#.00"); - private static final DateTimeFormatter TIME_FORMATTER = DateTimeFormat.forPattern("YYYY-MM-dd HH:mm ZZ"); - public static final String ARBITER_UI_TYPE_ID = "ArbiterUI"; - - private AtomicBoolean loggedArbiterAddress = new AtomicBoolean(false); - private Map knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>()); - private String currentSessionID; - - private Map lastUpdateForSession = Collections.synchronizedMap(new HashMap<>()); - - //Styles for UI: - private static final StyleTable STYLE_TABLE = new StyleTable.Builder() - .width(100, LengthUnit.Percent) - .backgroundColor(Color.WHITE) - .borderWidth(1) - .columnWidths(LengthUnit.Percent, 30, 70) - .build(); - - private static final StyleTable STYLE_TABLE3_25_25_50 = new StyleTable.Builder() - .width(100, LengthUnit.Percent) - .backgroundColor(Color.WHITE) - .borderWidth(1) - .columnWidths(LengthUnit.Percent, 25, 25, 50) - .build(); - - private static final StyleDiv STYLE_DIV_WIDTH_100_PC = new StyleDiv.Builder() - .width(100, LengthUnit.Percent) - .build(); - - private static final ComponentDiv DIV_SPACER_20PX = new ComponentDiv(new StyleDiv.Builder() - .width(100,LengthUnit.Percent) - .height(20, LengthUnit.Px).build()); - - private static final ComponentDiv DIV_SPACER_60PX = new ComponentDiv(new StyleDiv.Builder() - .width(100,LengthUnit.Percent) - .height(60, LengthUnit.Px).build()); - - private static final StyleChart STYLE_CHART_560_320 = new StyleChart.Builder() - .width(560, LengthUnit.Px) - .height(320, LengthUnit.Px) - .build(); - - private static final StyleChart STYLE_CHART_800_400 = new StyleChart.Builder() - .width(800, LengthUnit.Px) - .height(400, LengthUnit.Px) - .build(); - - - private StyleText STYLE_TEXT_SZ12 = new StyleText.Builder() - .fontSize(12) - .build(); - - //Set whitespacePre(true) to avoid losing new lines, tabs, multiple spaces etc - private StyleText STYLE_TEXT_SZ10_WHITESPACE_PRE = new StyleText.Builder() - .fontSize(10) - .whitespacePre(true) - .build(); - - - @Override - public List getCallbackTypeIDs() { - return Collections.singletonList(ARBITER_UI_TYPE_ID); - } - - @Override - public List getRoutes() { - boolean multiSession = VertxUIServer.getMultiSession().get(); - List r = new ArrayList<>(); - r.add(new Route("/arbiter/multisession", HttpMethod.GET, - (path, rc) -> rc.response().end(multiSession ? "true" : "false"))); - if (multiSession) { - r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> this.listSessions(rc))); - r.add(new Route("/arbiter/:sessionId", HttpMethod.GET, (path, rc) -> { - if (knownSessionIDs.containsKey(path.get(0))) { - rc.response() - .putHeader("content-type", "text/html; charset=utf-8") - .sendFile("templates/ArbiterUI.html"); - } else { - sessionNotFound(path.get(0), rc.request().path(), rc); - } - })); - - r.add(new Route("/arbiter/:sessionId/lastUpdate", HttpMethod.GET, (path, rc) -> { - if (knownSessionIDs.containsKey(path.get(0))) { - this.getLastUpdateTime(path.get(0), rc); - } else { - sessionNotFound(path.get(0), rc.request().path(), rc); - } - })); - r.add(new Route("/arbiter/:sessionId/candidateInfo/:id", HttpMethod.GET, (path, rc) -> { - if (knownSessionIDs.containsKey(path.get(0))) { - this.getCandidateInfo(path.get(0), path.get(1), rc); - } else { - sessionNotFound(path.get(0), rc.request().path(), rc); - } - })); - r.add(new Route("/arbiter/:sessionId/config", HttpMethod.GET, (path, rc) -> { - if (knownSessionIDs.containsKey(path.get(0))) { - this.getOptimizationConfig(path.get(0), rc); - } else { - sessionNotFound(path.get(0), rc.request().path(), rc); - } - })); - r.add(new Route("/arbiter/:sessionId/results", HttpMethod.GET, (path, rc) -> { - if (knownSessionIDs.containsKey(path.get(0))) { - this.getSummaryResults(path.get(0), rc); - } else { - sessionNotFound(path.get(0), rc.request().path(), rc); - } - })); - r.add(new Route("/arbiter/:sessionId/summary", HttpMethod.GET, (path, rc) -> { - if (knownSessionIDs.containsKey(path.get(0))) { - this.getSummaryStatus(path.get(0), rc); - } else { - sessionNotFound(path.get(0), rc.request().path(), rc); - } - })); - } else { - r.add(new Route("/arbiter", HttpMethod.GET, (path, rc) -> rc.response() - .putHeader("content-type", "text/html; charset=utf-8") - .sendFile("templates/ArbiterUI.html"))); - r.add(new Route("/arbiter/lastUpdate", HttpMethod.GET, (path, rc) -> this.getLastUpdateTime(null, rc))); - r.add(new Route("/arbiter/candidateInfo/:id", HttpMethod.GET, - (path, rc) -> this.getCandidateInfo(null, path.get(0), rc))); - r.add(new Route("/arbiter/config", HttpMethod.GET, (path, rc) -> this.getOptimizationConfig(null, rc))); - r.add(new Route("/arbiter/results", HttpMethod.GET, (path, rc) -> this.getSummaryResults(null, rc))); - r.add(new Route("/arbiter/summary", HttpMethod.GET, (path, rc) -> this.getSummaryStatus(null, rc))); - - r.add(new Route("/arbiter/sessions/current", HttpMethod.GET, (path, rc) -> this.currentSession(rc))); - r.add(new Route("/arbiter/sessions/set/:to", HttpMethod.GET, - (path, rc) -> this.setSession(path.get(0), rc))); - } - // common for single- and multi-session mode - r.add(new Route("/arbiter/sessions/all", HttpMethod.GET, (path, rc) -> this.sessionInfo(rc))); - - return r; - } - - - /** - * Load StatsStorage via provider, or return "not found" - * - * @param sessionId session ID to look fo with provider - * @param targetPath one of overview / model / system, or null - * @param rc routing context - */ - private void sessionNotFound(String sessionId, String targetPath, RoutingContext rc) { - Function loader = VertxUIServer.getInstance().getStatsStorageLoader(); - if (loader != null && loader.apply(sessionId)) { - if (targetPath != null) { - rc.reroute(targetPath); - } else { - rc.response().end(); - } - } else { - rc.response().setStatusCode(HttpResponseStatus.NOT_FOUND.code()) - .end("Unknown session ID: " + sessionId); - } - } - - - /** - * List optimization sessions. Returns a HTML list of arbiter sessions - */ - private synchronized void listSessions(RoutingContext rc) { - StringBuilder sb = new StringBuilder("\n" + - "\n" + - "\n" + - " \n" + - " Optimization sessions - DL4J Arbiter UI\n" + - " \n" + - "\n" + - " \n" + - "

      DL4J Arbiter UI

      \n" + - "

      UI server is in multi-session mode." + - " To visualize an optimization session, please select one from the following list.

      \n" + - "

      List of attached optimization sessions

      \n"); - if (!knownSessionIDs.isEmpty()) { - sb.append(" "); - } else { - sb.append("No optimization session attached."); - } - - sb.append(" \n" + - "\n"); - - rc.response() - .putHeader("content-type", "text/html; charset=utf-8") - .end(sb.toString()); - } - - @Override - public void reportStorageEvents(Collection events) { - boolean attachedArbiter = false; - for (StatsStorageEvent sse : events) { - if (ARBITER_UI_TYPE_ID.equals(sse.getTypeID())) { - if (sse.getEventType() == StatsStorageListener.EventType.PostStaticInfo) { - knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage()); - } - - Long lastUpdate = lastUpdateForSession.get(sse.getSessionID()); - if (lastUpdate == null) { - lastUpdateForSession.put(sse.getSessionID(), sse.getTimestamp()); - } else if (sse.getTimestamp() > lastUpdate) { - lastUpdateForSession.put(sse.getSessionID(), sse.getTimestamp()); //Should be thread safe - read only elsewhere - } - attachedArbiter = true; - } - } - - if(currentSessionID == null){ - getDefaultSession(); - } - - if(attachedArbiter && !loggedArbiterAddress.getAndSet(true)){ - String address = UIServer.getInstance().getAddress(); - address += "/arbiter"; - log.info("DL4J Arbiter Hyperparameter Optimization UI: {}", address); - } - } - - @Override - public synchronized void onAttach(StatsStorage statsStorage) { - for (String sessionID : statsStorage.listSessionIDs()) { - for (String typeID : statsStorage.listTypeIDsForSession(sessionID)) { - if (!ARBITER_UI_TYPE_ID.equals(typeID)) - continue; - knownSessionIDs.put(sessionID, statsStorage); - } - } - - if (currentSessionID == null) - getDefaultSession(); - } - - private void currentSession(RoutingContext rc) { - String sid = currentSessionID == null ? "" : currentSessionID; - rc.response() - .putHeader("content-type", "application/json") - .end(asJson(sid)); - } - - private void sessionInfo(RoutingContext rc) { - rc.response() - .putHeader("content-type", "application/json") - .end(asJson(knownSessionIDs.keySet())); - } - - private void setSession(String newSessionID, RoutingContext rc) { - log.debug("Arbiter UI: Set to session {}", newSessionID); - - if (knownSessionIDs.containsKey(newSessionID)) { - currentSessionID = newSessionID; - rc.response().end(); - } else { - rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()).end("Unknown session ID: " + newSessionID); - } - } - - private void getDefaultSession() { - if (currentSessionID != null) - return; - - long mostRecentTime = Long.MIN_VALUE; - String sessionID = null; - for (Map.Entry entry : knownSessionIDs.entrySet()) { - List staticInfos = entry.getValue().getAllStaticInfos(entry.getKey(), ARBITER_UI_TYPE_ID); - if (staticInfos == null || staticInfos.size() == 0) - continue; - Persistable p = staticInfos.get(0); - long thisTime = p.getTimeStamp(); - if (thisTime > mostRecentTime) { - mostRecentTime = thisTime; - sessionID = entry.getKey(); - } - } - - if (sessionID != null) { - currentSessionID = sessionID; - } - } - - @Override - public void onDetach(StatsStorage statsStorage) { - for (String s : knownSessionIDs.keySet()) { - if (knownSessionIDs.get(s) == statsStorage) { - knownSessionIDs.remove(s); - } - } - } - - @Override - public List getInternationalizationResources() { - return Collections.emptyList(); - } - - /** - * Return the last update time for the page - * @param sessionId session ID (optional, for multi-session mode) - * @param rc routing context - */ - private void getLastUpdateTime(String sessionId, RoutingContext rc){ - if (sessionId == null) { - sessionId = currentSessionID; - } - StatsStorage ss = knownSessionIDs.get(sessionId); - List latestUpdates = ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID); - long t = 0; - if (latestUpdates.isEmpty()) { - t = System.currentTimeMillis(); - } else { - for (Persistable update : latestUpdates) { - if (update.getTimeStamp() > t) { - t = update.getTimeStamp(); - } - } - } - UpdateStatus us = new UpdateStatus(t, t, t); - - rc.response().putHeader("content-type", "application/json").end(asJson(us)); - } - - private String asJson(Object o){ - try{ - return JsonMappers.getMapper().writeValueAsString(o); - } catch (JsonProcessingException e){ - throw new RuntimeException("Error converting object to JSON", e); - } - } - - /** - * Get the info for a specific candidate - last section in the UI - * @param sessionId session ID (optional, for multi-session mode) - * @param candidateId ID for the candidate - * @param rc routing context - */ - private void getCandidateInfo(String sessionId, String candidateId, RoutingContext rc){ - if (sessionId == null) { - sessionId = currentSessionID; - } - StatsStorage ss = knownSessionIDs.get(sessionId); - if(ss == null){ - log.debug("getModelLastUpdateTimes(): Session ID is unknown: {}", sessionId); - rc.response().end(); - return; - } - - GlobalConfigPersistable gcp = (GlobalConfigPersistable)ss - .getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); - OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); - - Persistable p = ss.getLatestUpdate(sessionId, ARBITER_UI_TYPE_ID, candidateId); - if(p == null){ - String title = "No results found for model " + candidateId + "."; - ComponentText ct = new ComponentText.Builder(title,STYLE_TEXT_SZ12).build(); - rc.response() - .putHeader("content-type", "application/json") - .end(asJson(ct)); - return; - } - - ModelInfoPersistable mip = (ModelInfoPersistable)p; - - //First: static info - // Hyperparameter configuration/settings - // Number of parameters - // Maybe memory info in the future? - - //Second: dynamic info - //Runtime - // Performance stats (total minibatches, total time, - // Score vs. time - - List components = new ArrayList<>(); - - //First table: mix of static + dynamic in a table - long runtimeDurationMs = mip.getLastUpdateTime() - mip.getTimeStamp(); - double avgMinibatchesPerSec = mip.getTotalNumUpdates() / (runtimeDurationMs/1000.0); - String avgMinibatchesPerSecStr = DECIMAL_FORMAT_2DP.format(avgMinibatchesPerSec); - String runtimeStr = UIUtils.formatDuration(runtimeDurationMs); - - if(mip.getStatus() == CandidateStatus.Failed){ - runtimeStr = ""; - avgMinibatchesPerSecStr = ""; - } - - String[][] table = new String[][]{ - {"Model Index", String.valueOf(mip.getModelIdx())}, - {"Status", mip.getStatus().toString()}, - {"Model Score", mip.getScore() == null ? "" : String.valueOf(mip.getScore())}, - {"Created", TIME_FORMATTER.print(mip.getTimeStamp())}, - {"Runtime", runtimeStr}, - {"Total Number of Model Updates", String.valueOf(mip.getTotalNumUpdates())}, - {"Average # Updates / Sec", avgMinibatchesPerSecStr}, - {"Number of Parameters", String.valueOf(mip.getNumParameters())}, - {"Number of Layers", String.valueOf(mip.getNumLayers())} - }; - - ComponentTable cTable = new ComponentTable.Builder(STYLE_TABLE) - .content(table) - .header("Model Information", "") - .build(); - components.add(cTable); - - - //Second: parameter space values, in multiple tables - double[] paramSpaceValues = mip.getParamSpaceValues(); - if(paramSpaceValues != null){ - BaseNetworkSpace bns = (BaseNetworkSpace)oc.getCandidateGenerator().getParameterSpace(); - Map m = bns.getNestedSpaces(); - - String[][] hSpaceTable = new String[m.size()][3]; - int i=0; - for(Map.Entry e : m.entrySet()){ - hSpaceTable[i][0] = e.getKey(); - Object currCandidateValue = e.getValue().getValue(paramSpaceValues); - hSpaceTable[i][1] = ObjectUtils.valueToString(currCandidateValue); - hSpaceTable[i][2] = e.getValue().toString(); - i++; - } - - String[] hSpaceTableHeader = new String[]{"Hyperparameter", "Model Value", "Hyperparameter Space"}; - - ComponentTable ct2 = new ComponentTable.Builder(STYLE_TABLE3_25_25_50) - .content(hSpaceTable) - .header(hSpaceTableHeader) - .build(); - - - String title = "Global Network Configuration"; - components.add(DIV_SPACER_20PX); - components.add(new ComponentText.Builder(title, STYLE_TEXT_SZ12).build()); - components.add(ct2); - - List layerConfs = bns.getLayerSpaces(); - - for(BaseNetworkSpace.LayerConf l : layerConfs){ - LayerSpace ls = l.getLayerSpace(); - Map lpsm = ls.getNestedSpaces(); - - String[][] t = new String[lpsm.size()][3]; - i=0; - for(Map.Entry e : lpsm.entrySet()){ - t[i][0] = e.getKey(); - Object currCandidateValue = e.getValue().getValue(paramSpaceValues); - t[i][1] = ObjectUtils.valueToString(currCandidateValue); - t[i][2] = e.getValue().toString(); - i++; - } - - ComponentTable ct3 = new ComponentTable.Builder(STYLE_TABLE3_25_25_50) - .content(t) - .header(hSpaceTableHeader) - .build(); - - title = "Layer Space: " + ls.getClass().getSimpleName() + ", Name: " + l.getLayerName(); - - components.add(DIV_SPACER_20PX); - components.add(new ComponentText.Builder(title, STYLE_TEXT_SZ12).build()); - components.add(ct3); - } - } - - - //Third: Score vs. time chart - int[] iters = mip.getIter(); - float[] scores = mip.getScoreVsIter(); - - if(iters != null) { - double[] si = new double[iters.length]; - double[] scoresD = new double[iters.length]; - - double minScore = Double.MAX_VALUE; - double maxScore = -Double.MAX_VALUE; - for( int i=0; i components = new ArrayList<>(); - - GlobalConfigPersistable gcp = (GlobalConfigPersistable)p; - OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); - - //Report optimization settings/configuration. - String[] tableHeader = {"Configuration", "Value"}; - String [] dataSourceOrProvider; - if (oc.getDataProvider() != null) { - dataSourceOrProvider = new String[] {"Data Provider", oc.getDataProvider().toString()}; - } - else { - dataSourceOrProvider = new String[] {"Data Source", oc.getDataSource().getCanonicalName()}; - } - String[][] table = new String[][]{ - {"Candidate Generator", oc.getCandidateGenerator().getClass().getSimpleName()}, - dataSourceOrProvider, - {"Score Function", oc.getScoreFunction().toString()}, - {"Result Saver", oc.getResultSaver().toString()}, - }; - - ComponentTable ct = new ComponentTable.Builder(STYLE_TABLE) - .content(table) - .header(tableHeader) - .build(); - components.add(ct); - - - String title = "Global Network Configuration"; - components.add(DIV_SPACER_20PX); - components.add(new ComponentText.Builder(title, STYLE_TEXT_SZ12).build()); - BaseNetworkSpace ps = (BaseNetworkSpace)oc.getCandidateGenerator().getParameterSpace(); - Map m = ps.getNestedSpaces(); - - String[][] hSpaceTable = new String[m.size()][2]; - int i=0; - for(Map.Entry e : m.entrySet()){ - hSpaceTable[i][0] = e.getKey(); - hSpaceTable[i][1] = e.getValue().toString(); - i++; - } - - components.add(DIV_SPACER_20PX); - String[] hSpaceTableHeader = new String[]{"Hyperparameter", "Hyperparameter Configuration"}; - - ComponentTable ct2 = new ComponentTable.Builder(STYLE_TABLE) - .content(hSpaceTable) - .header(hSpaceTableHeader) - .build(); - components.add(ct2); - - //Configuration for each layer: - List layerConfs = ps.getLayerSpaces(); - for(BaseNetworkSpace.LayerConf l : layerConfs){ - LayerSpace ls = l.getLayerSpace(); - Map lpsm = ls.getNestedSpaces(); - - String[][] t = new String[lpsm.size()][2]; - i=0; - for(Map.Entry e : lpsm.entrySet()){ - t[i][0] = e.getKey(); - t[i][1] = e.getValue().toString(); - i++; - } - - ComponentTable ct3 = new ComponentTable.Builder(STYLE_TABLE) - .content(t) - .header(hSpaceTableHeader) - .build(); - - title = "Layer Space: " + ls.getClass().getSimpleName() + ", Name: " + l.getLayerName(); - - components.add(DIV_SPACER_20PX); - components.add(new ComponentText.Builder(title, STYLE_TEXT_SZ12).build()); - components.add(ct3); - } - - ComponentDiv cd = new ComponentDiv(STYLE_DIV_WIDTH_100_PC, components); - - rc.response().putHeader("content-type", "application/json").end(asJson(cd)); - } - - /** - * Get candidates summary results list - third section on the page: Results table - * @param sessionId session ID (optional, for multi-session mode) - * @param rc routing context - */ - private void getSummaryResults(String sessionId, RoutingContext rc){ - if (sessionId == null) { - sessionId = currentSessionID; - } - StatsStorage ss = knownSessionIDs.get(sessionId); - if(ss == null){ - log.debug("getSummaryResults(): Session ID is unknown: {}", sessionId); - rc.response().end(); - return; - } - - List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID)); - List table = new ArrayList<>(); - for(Persistable per : allModelInfoTemp){ - ModelInfoPersistable mip = (ModelInfoPersistable)per; - String score = (mip.getScore() == null ? "" : mip.getScore().toString()); - table.add(new String[]{mip.getModelIdx().toString(), score, mip.getStatus().toString()}); - } - - rc.response().putHeader("content-type", "application/json").end(asJson(table)); - } - - /** - * Get summary status information: first section in the page - * @param sessionId session ID (optional, for multi-session mode) - * @param rc routing context - */ - private void getSummaryStatus(String sessionId, RoutingContext rc){ - if (sessionId == null) { - sessionId = currentSessionID; - } - StatsStorage ss = knownSessionIDs.get(sessionId); - if(ss == null){ - log.debug("getOptimizationConfig(): Session ID is unknown: {}", sessionId); - rc.response().end(); - return; - } - - Persistable p = ss.getStaticInfo(sessionId, ARBITER_UI_TYPE_ID, GlobalConfigPersistable.GLOBAL_WORKER_ID); - - if(p == null){ - log.info("No static info"); - rc.response().end(); - return; - } - - GlobalConfigPersistable gcp = (GlobalConfigPersistable)p; - OptimizationConfiguration oc = gcp.getOptimizationConfiguration(); - long execStartTime = oc.getExecutionStartTime(); - - - - //Charts: - //Best model score vs. time - //All candidate scores (scatter plot vs. time) - - //How to get this? query all model infos... - - List allModelInfoTemp = new ArrayList<>(ss.getLatestUpdateAllWorkers(sessionId, ARBITER_UI_TYPE_ID)); - List allModelInfo = new ArrayList<>(); - for(Persistable per : allModelInfoTemp){ - ModelInfoPersistable mip = (ModelInfoPersistable)per; - if(mip.getStatus() == CandidateStatus.Complete && mip.getScore() != null && Double.isFinite(mip.getScore())){ - allModelInfo.add(mip); - } - } - - allModelInfo.sort(Comparator.comparingLong(Persistable::getTimeStamp)); - - Pair, ModelInfoPersistable> chartsAndBest = getSummaryChartsAndBest(allModelInfo, oc.getScoreFunction().minimize(), execStartTime ); - - //First: table - number completed, queued, running, failed, total - //Best model index, score, and time - //Total runtime - //Termination conditions - List components = new ArrayList<>(); - - - - List tcs = oc.getTerminationConditions(); - - //TODO: I18N - - long bestTime; - Double bestScore = null; - String bestModelString = null; - if(chartsAndBest.getSecond() != null){ - bestTime = chartsAndBest.getSecond().getTimeStamp(); - bestScore = chartsAndBest.getSecond().getScore(); - String sinceBest = UIUtils.formatDuration(System.currentTimeMillis() - bestTime); - - bestModelString = "Model " + chartsAndBest.getSecond().getModelIdx() + ", Found at " + - TIME_FORMATTER.print(bestTime) + " (" + sinceBest + " ago)"; - } - - String execStartTimeStr = ""; - String execTotalRuntimeStr = ""; - if(execStartTime > 0){ - execStartTimeStr = TIME_FORMATTER.print(execStartTime); - // allModelInfo is sorted by Persistable::getTimeStamp - long lastCompleteTime = execStartTime; - if (!allModelInfo.isEmpty()) { - lastCompleteTime = allModelInfo.get(allModelInfo.size() - 1).getTimeStamp(); - } - execTotalRuntimeStr = UIUtils.formatDuration(lastCompleteTime - execStartTime); - } - - - String[][] table = new String[][]{ - {"Models Completed", String.valueOf(gcp.getCandidatesCompleted())}, - {"Models Queued/Running", String.valueOf(gcp.getCandidatesQueued())}, - {"Models Failed", String.valueOf(gcp.getCandidatesFailed())}, - {"Models Total", String.valueOf(gcp.getCandidatesTotal())}, - {"Best Score", (bestScore != null ? String.valueOf(bestScore) : "")}, - {"Best Scoring Model", bestModelString != null ? bestModelString : ""}, - {"Optimization Runner", gcp.getOptimizationRunner()}, - {"Execution Start Time", execStartTimeStr}, - {"Total Runtime", execTotalRuntimeStr} - }; - - - - ComponentTable ct = new ComponentTable.Builder(STYLE_TABLE) - .content(table) - .header("Status", "") - .build(); - - components.add(ct); - - String[][] tcTable = new String[tcs.size()][2]; - for( int i=0; i,ModelInfoPersistable> getSummaryChartsAndBest(List allModelInfo, - boolean minimize, long execStartTime){ - List bestX = new ArrayList<>(); - List bestY = new ArrayList<>(); - - double[] allX = new double[allModelInfo.size()]; - double[] allY = new double[allModelInfo.size()]; - - double bestScore = (minimize ? Double.MAX_VALUE : -Double.MAX_VALUE); - double worstScore = (minimize ? -Double.MAX_VALUE : Double.MAX_VALUE); - double lastTime = -1L; - ModelInfoPersistable bestModel = null; - for(int i=0; i bestScore) || (minimize && currScore < bestScore)){ - bestX.add(t); - bestY.add(bestScore); - bestX.add(t); //TODO non-real time rendering support... - bestY.add(currScore); - - bestScore = currScore; - bestModel = mip; - } - - if((!minimize && currScore < worstScore) || (minimize && currScore > worstScore)){ - worstScore = currScore; - } - - if(t > lastTime){ - lastTime = t; - } - } - - - double[] scatterGraphMinMax = UIUtils.graphNiceRange(Math.max(bestScore, worstScore), Math.min(bestScore, worstScore), 5); - double[] lineGraphMinMax = UIUtils.graphNiceRange( - bestY.stream().mapToDouble(s -> s).max().orElse(0),bestY.stream().mapToDouble(s -> s).min().orElse(0), 5 - ); - - if(bestX.size() > 0) { - bestX.add(lastTime); - bestY.add(bestY.get(bestY.size() - 1)); - } - - - double[] bestXd = new double[bestX.size()]; - double[] bestYd = new double[bestXd.length]; - for( int i=0; i components = new ArrayList<>(2); - - ChartLine cl = new ChartLine.Builder("Best Model Score vs. Time (Minutes)", STYLE_CHART_560_320) - .addSeries("Best Score vs. Time", bestXd, bestYd) - .setYMin(lineGraphMinMax[0]) - .setYMax(lineGraphMinMax[1]) - .build(); - components.add(cl); - - ChartScatter cs = new ChartScatter.Builder("All Candidate Scores vs. Time (Minutes)", STYLE_CHART_560_320) - .addSeries("Candidates", allX, allY) - .setYMin(scatterGraphMinMax[0]) - .setYMax(scatterGraphMinMax[1]) - .build(); - - components.add(cs); - - return new Pair<>(components, bestModel); - } -} diff --git a/arbiter/arbiter-ui/src/main/resources/META-INF/services/org.deeplearning4j.ui.api.UIModule b/arbiter/arbiter-ui/src/main/resources/META-INF/services/org.deeplearning4j.ui.api.UIModule deleted file mode 100644 index 083fd24c9..000000000 --- a/arbiter/arbiter-ui/src/main/resources/META-INF/services/org.deeplearning4j.ui.api.UIModule +++ /dev/null @@ -1,17 +0,0 @@ -################################################################################ -# Copyright (c) 2015-2018 Skymind, Inc. -# -# This program and the accompanying materials are made available under the -# terms of the Apache License, Version 2.0 which is available at -# https://www.apache.org/licenses/LICENSE-2.0. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -org.deeplearning4j.arbiter.ui.module.ArbiterModule \ No newline at end of file diff --git a/arbiter/arbiter-ui/src/main/resources/deeplearning4jUiAssets/dl4j-ui.js b/arbiter/arbiter-ui/src/main/resources/deeplearning4jUiAssets/dl4j-ui.js deleted file mode 100644 index 4c99517d0..000000000 --- a/arbiter/arbiter-ui/src/main/resources/deeplearning4jUiAssets/dl4j-ui.js +++ /dev/null @@ -1,1319 +0,0 @@ -var __extends = (this && this.__extends) || function (d, b) { - for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; - function __() { this.constructor = d; } - d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); -}; -var Style = (function () { - function Style(jsonObj) { - var _this = this; - this.getWidth = function () { return _this.width; }; - this.getHeight = function () { return _this.height; }; - this.getWidthUnit = function () { return _this.widthUnit; }; - this.getHeightUnit = function () { return _this.heightUnit; }; - this.getMarginTop = function () { return _this.marginTop; }; - this.getMarginBottom = function () { return _this.marginBottom; }; - this.getMarginLeft = function () { return _this.marginLeft; }; - this.getMarginRight = function () { return _this.marginRight; }; - this.getBackgroundColor = function () { return _this.backgroundColor; }; - this.width = jsonObj['width']; - this.height = jsonObj['height']; - this.widthUnit = TSUtils.normalizeLengthUnit(jsonObj['widthUnit']); - this.heightUnit = TSUtils.normalizeLengthUnit(jsonObj['heightUnit']); - this.marginTop = jsonObj['marginTop']; - this.marginBottom = jsonObj['marginBottom']; - this.marginLeft = jsonObj['marginLeft']; - this.marginRight = jsonObj['marginRight']; - this.backgroundColor = jsonObj['backgroundColor']; - } - Style.getMargins = function (s) { - var mTop = (s ? s.getMarginTop() : 0); - var mBottom = (s ? s.getMarginBottom() : 0); - var mLeft = (s ? s.getMarginLeft() : 0); - var mRight = (s ? s.getMarginRight() : 0); - return { top: mTop, - right: mRight, - bottom: mBottom, - left: mLeft, - widthExMargins: s.getWidth() - mLeft - mRight, - heightExMargins: s.getHeight() - mTop - mBottom }; - }; - return Style; -}()); -var ComponentType; -(function (ComponentType) { - ComponentType[ComponentType["ComponentText"] = 0] = "ComponentText"; - ComponentType[ComponentType["ComponentTable"] = 1] = "ComponentTable"; - ComponentType[ComponentType["ComponentDiv"] = 2] = "ComponentDiv"; - ComponentType[ComponentType["ChartHistogram"] = 3] = "ChartHistogram"; - ComponentType[ComponentType["ChartHorizontalBar"] = 4] = "ChartHorizontalBar"; - ComponentType[ComponentType["ChartLine"] = 5] = "ChartLine"; - ComponentType[ComponentType["ChartScatter"] = 6] = "ChartScatter"; - ComponentType[ComponentType["ChartStackedArea"] = 7] = "ChartStackedArea"; - ComponentType[ComponentType["ChartTimeline"] = 8] = "ChartTimeline"; - ComponentType[ComponentType["DecoratorAccordion"] = 9] = "DecoratorAccordion"; -})(ComponentType || (ComponentType = {})); -var Component = (function () { - function Component(componentType) { - this.componentType = componentType; - } - Component.prototype.getComponentType = function () { - return this.componentType; - }; - Component.getComponent = function (jsonStr) { - var json = JSON.parse(jsonStr); - var key; - if (json["componentType"]) - key = json["componentType"]; - else - key = Object.keys(json)[0]; - switch (key) { - case ComponentType[ComponentType.ComponentText]: - return new ComponentText(jsonStr); - case ComponentType[ComponentType.ComponentTable]: - return new ComponentTable(jsonStr); - case ComponentType[ComponentType.ChartHistogram]: - return new ChartHistogram(jsonStr); - case ComponentType[ComponentType.ChartHorizontalBar]: - throw new Error("Horizontal bar chart: not yet implemented"); - case ComponentType[ComponentType.ChartLine]: - return new ChartLine(jsonStr); - case ComponentType[ComponentType.ChartScatter]: - return new ChartScatter(jsonStr); - case ComponentType[ComponentType.ChartStackedArea]: - return new ChartStackedArea(jsonStr); - case ComponentType[ComponentType.ChartTimeline]: - return new ChartTimeline(jsonStr); - case ComponentType[ComponentType.DecoratorAccordion]: - return new DecoratorAccordion(jsonStr); - case ComponentType[ComponentType.ComponentDiv]: - return new ComponentDiv(jsonStr); - default: - throw new Error("Unknown component type \"" + key + "\" or invalid JSON: \"" + jsonStr + "\""); - } - }; - return Component; -}()); -var ChartConstants = (function () { - function ChartConstants() { - } - ChartConstants.DEFAULT_CHART_STROKE_WIDTH = 1.0; - ChartConstants.DEFAULT_CHART_POINT_SIZE = 3.0; - ChartConstants.DEFAULT_AXIS_STROKE_WIDTH = 1.0; - ChartConstants.DEFAULT_TITLE_COLOR = "#000000"; - return ChartConstants; -}()); -var TSUtils = (function () { - function TSUtils() { - } - TSUtils.max = function (input) { - var max = -Number.MAX_VALUE; - for (var i = 0; i < input.length; i++) { - for (var j = 0; j < input[i].length; j++) { - max = Math.max(max, input[i][j]); - } - } - return max; - }; - TSUtils.min = function (input) { - var min = Number.MAX_VALUE; - for (var i = 0; i < input.length; i++) { - for (var j = 0; j < input[i].length; j++) { - min = Math.min(min, input[i][j]); - } - } - return min; - }; - TSUtils.normalizeLengthUnit = function (input) { - if (input == null) - return input; - switch (input.toLowerCase()) { - case "px": - return "px"; - case "percent": - case "%": - return "%"; - case "cm": - return "cm"; - case "mm": - return "mm"; - case "in": - return "in"; - default: - return input; - } - }; - return TSUtils; -}()); -var Chart = (function (_super) { - __extends(Chart, _super); - function Chart(componentType, jsonStr) { - _super.call(this, componentType); - var jsonOrig = JSON.parse(jsonStr); - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[componentType]]; - this.suppressAxisHorizontal = json['suppressAxisHorizontal']; - this.suppressAxisVertical = json['suppressAxisVertical']; - this.showLegend = json['showLegend']; - this.title = json['title']; - this.setXMin = json['setXMin']; - this.setXMax = json['setXMax']; - this.setYMin = json['setYMin']; - this.setYMax = json['setYMax']; - this.gridVerticalStrokeWidth = json['gridVerticalStrokeWidth']; - this.gridHorizontalStrokeWidth = json['gridHorizontalStrokeWidth']; - if (json['style']) - this.style = new StyleChart(json['style']); - } - Chart.prototype.getStyle = function () { - return this.style; - }; - Chart.appendTitle = function (svg, title, margin, titleStyle) { - var text = svg.append("text") - .text(title) - .attr("x", (margin.widthExMargins / 2)) - .attr("y", 0 - ((margin.top - 30) / 2)) - .attr("text-anchor", "middle"); - if (titleStyle) { - if (titleStyle.getFont()) - text.attr("font-family", titleStyle.getFont); - if (titleStyle.getFontSize() != null) - text.attr("font-size", titleStyle.getFontSize() + "pt"); - if (titleStyle.getUnderline() != null) - text.style("text-decoration", "underline"); - if (titleStyle.getColor()) - text.style("fill", titleStyle.getColor); - else - text.style("fill", ChartConstants.DEFAULT_TITLE_COLOR); - } - else { - text.style("text-decoration", "underline"); - text.style("fill", ChartConstants.DEFAULT_TITLE_COLOR); - } - }; - return Chart; -}(Component)); -var ChartHistogram = (function (_super) { - __extends(ChartHistogram, _super); - function ChartHistogram(jsonStr) { - _super.call(this, ComponentType.ChartHistogram, jsonStr); - this.render = function (appendToObject) { - var s = this.getStyle(); - var margin = Style.getMargins(s); - var xMin; - var xMax; - var yMin; - var yMax; - if (this.setXMin) - xMin = this.setXMin; - else - xMin = (this.lowerBounds ? d3.min(this.lowerBounds) : 0); - if (this.setXMax) - xMax = this.setXMax; - else - xMax = (this.upperBounds ? d3.max(this.upperBounds) : 1); - if (this.setYMin) - yMin = this.setYMin; - else - yMin = 0; - if (this.setYMax) - yMax = this.setYMax; - else - yMax = (this.yValues ? d3.max(this.yValues) : 1); - var xScale = d3.scale.linear() - .domain([xMin, xMax]) - .range([0, margin.widthExMargins]); - var xAxis = d3.svg.axis().scale(xScale) - .orient("bottom").ticks(5); - if (this.gridVerticalStrokeWidth && this.gridVerticalStrokeWidth > 0) { - xAxis.innerTickSize(-margin.heightExMargins); - } - var yScale = d3.scale.linear() - .domain([0, yMax]) - .range([margin.heightExMargins, 0]); - var yAxis = d3.svg.axis().scale(yScale) - .orient("left").ticks(5); - if (this.gridHorizontalStrokeWidth && this.gridHorizontalStrokeWidth > 0) { - yAxis.innerTickSize(-margin.widthExMargins); - } - if (this.suppressAxisHorizontal === true) - xAxis.tickValues([]); - if (this.suppressAxisVertical === true) - yAxis.tickValues([]); - var lowerBounds = this.lowerBounds; - var upperBounds = this.upperBounds; - var yValues = this.yValues; - var data = lowerBounds.map(function (d, i) { - return { 'width': upperBounds[i] - lowerBounds[i], 'height': yValues[i], 'offset': lowerBounds[i] }; - }); - var svg = d3.select("#" + appendToObject.attr("id")) - .append("svg") - .style("fill", "none") - .attr("width", s.getWidth()) - .attr("height", s.getHeight()) - .attr("padding", "20px") - .append("g") - .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); - svg.selectAll(".bin") - .data(data) - .enter().append("rect") - .attr("class", "bin") - .style("fill", "steelblue") - .attr("x", function (d) { return xScale(d.offset); }) - .attr("width", function (d) { return xScale(xMin + d.width) - 1; }) - .attr("y", function (d) { return yScale(d.height); }) - .attr("height", function (d) { return margin.heightExMargins - yScale(d.height); }); - var xAxisNode = svg.append("g") - .attr("class", "x axis") - .attr("transform", "translate(0," + margin.heightExMargins + ")") - .style("stroke", "#000") - .style("stroke-width", (s != null && s.getAxisStrokeWidth() != null ? s.getAxisStrokeWidth() : ChartConstants.DEFAULT_AXIS_STROKE_WIDTH)) - .style("fill", "none") - .call(xAxis); - xAxisNode.selectAll('text').style("stroke-width", 0).style("fill", "#000000"); - if (this.gridVerticalStrokeWidth != null) - xAxisNode.selectAll('.axis line').style({ 'stroke-width': this.gridVerticalStrokeWidth }); - var yAxisNode = svg.append("g") - .attr("class", "y axis") - .style("stroke", "#000") - .style("stroke-width", (s != null && s.getAxisStrokeWidth() != null ? s.getAxisStrokeWidth() : ChartConstants.DEFAULT_AXIS_STROKE_WIDTH)) - .style("fill", "none") - .call(yAxis); - yAxisNode.selectAll('text').style("stroke-width", 0).style("fill", "#000000"); - if (this.gridHorizontalStrokeWidth != null) - yAxisNode.selectAll('.axis line').style({ 'stroke-width': this.gridHorizontalStrokeWidth }); - if (this.title) { - var titleStyle; - if (this.style) - titleStyle = this.style.getTitleStyle(); - Chart.appendTitle(svg, this.title, margin, titleStyle); - } - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.ChartHistogram]]; - this.lowerBounds = json['lowerBounds']; - this.upperBounds = json['upperBounds']; - this.yValues = json['yvalues']; - } - return ChartHistogram; -}(Chart)); -var ChartLine = (function (_super) { - __extends(ChartLine, _super); - function ChartLine(jsonStr) { - _super.call(this, ComponentType.ChartLine, jsonStr); - this.render = function (appendToObject) { - var nSeries = (!this.xData ? 0 : this.xData.length); - var s = this.getStyle(); - var margin = Style.getMargins(s); - var xScale = d3.scale.linear().range([0, margin.widthExMargins]); - var yScale = d3.scale.linear().range([margin.heightExMargins, 0]); - var xAxis = d3.svg.axis().scale(xScale) - .orient("bottom").ticks(5); - if (this.gridVerticalStrokeWidth != null && this.gridVerticalStrokeWidth > 0) { - xAxis.innerTickSize(-margin.heightExMargins); - } - var yAxis = d3.svg.axis().scale(yScale) - .orient("left").ticks(5); - if (this.gridHorizontalStrokeWidth != null && this.gridHorizontalStrokeWidth > 0) { - yAxis.innerTickSize(-margin.widthExMargins); - } - if (this.suppressAxisHorizontal === true) - xAxis.tickValues([]); - if (this.suppressAxisVertical === true) - yAxis.tickValues([]); - var valueline = d3.svg.line() - .x(function (d) { - return xScale(d.xPos); - }) - .y(function (d) { - return yScale(d.yPos); - }); - var svg = d3.select("#" + appendToObject.attr("id")) - .append("svg") - .style("stroke-width", (s && s.getStrokeWidth() ? s.getStrokeWidth() : ChartConstants.DEFAULT_CHART_STROKE_WIDTH)) - .style("fill", "none") - .attr("width", s.getWidth()) - .attr("height", s.getHeight()) - .append("g") - .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); - var xMin; - var xMax; - var yMin; - var yMax; - if (this.setXMin != null) - xMin = this.setXMin; - else - xMin = (this.xData ? TSUtils.min(this.xData) : 0); - if (this.setXMax != null) - xMax = this.setXMax; - else - xMax = (this.xData ? TSUtils.max(this.xData) : 1); - if (this.setYMin != null) - yMin = this.setYMin; - else - yMin = (this.yData ? TSUtils.min(this.yData) : 0); - if (this.setYMax != null) - yMax = this.setYMax; - else - yMax = (this.yData ? TSUtils.max(this.yData) : 1); - xScale.domain([xMin, xMax]); - yScale.domain([yMin, yMax]); - var defaultColor = d3.scale.category10(); - for (var i = 0; i < nSeries; i++) { - var xVals = this.xData[i]; - var yVals = this.yData[i]; - var data = xVals.map(function (d, i) { - return { 'xPos': xVals[i], 'yPos': yVals[i] }; - }); - svg.append("path") - .attr("class", "line") - .style("stroke", (s && s.getSeriesColor(i) ? s.getSeriesColor(i) : defaultColor(String(i)))) - .attr("d", valueline(data)); - } - var xAxisNode = svg.append("g") - .attr("class", "x axis") - .attr("transform", "translate(0," + margin.heightExMargins + ")") - .style("stroke", "#000") - .style("stroke-width", (s != null && s.getAxisStrokeWidth() != null ? s.getAxisStrokeWidth() : ChartConstants.DEFAULT_AXIS_STROKE_WIDTH)) - .style("fill", "none") - .call(xAxis); - xAxisNode.selectAll('text').style("stroke-width", 0).style("fill", "#000000"); - if (this.gridVerticalStrokeWidth != null) - xAxisNode.selectAll('.axis line').style({ 'stroke-width': this.gridVerticalStrokeWidth }); - var yAxisNode = svg.append("g") - .attr("class", "y axis") - .style("stroke", "#000") - .style("stroke-width", (s != null && s.getAxisStrokeWidth() != null ? s.getAxisStrokeWidth() : ChartConstants.DEFAULT_AXIS_STROKE_WIDTH)) - .style("fill", "none") - .call(yAxis); - yAxisNode.selectAll('text').style("stroke-width", 0).style("fill", "#000000"); - if (this.gridHorizontalStrokeWidth != null) - yAxisNode.selectAll('.axis line').style({ 'stroke-width': this.gridHorizontalStrokeWidth }); - if (this.seriesNames && this.showLegend === true) { - var legendSpace = margin.widthExMargins / i; - for (var i = 0; i < nSeries; i++) { - var values = this.xData[i]; - var yValues = this.yData[i]; - var lastX = values[values.length - 1]; - var lastY = yValues[yValues.length - 1]; - var toDisplay = this.seriesNames[i]; - svg.append("text") - .attr("x", (legendSpace / 2) + i * legendSpace) - .attr("y", margin.heightExMargins + (margin.bottom / 2) + 5) - .attr("class", "legend") - .style("fill", (s && s.getSeriesColor(i) ? s.getSeriesColor(i) : defaultColor(String(i)))) - .text(toDisplay); - } - } - if (this.title) { - var titleStyle; - if (this.style) - titleStyle = this.style.getTitleStyle(); - Chart.appendTitle(svg, this.title, margin, titleStyle); - } - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.ChartLine]]; - this.xData = json['x']; - this.yData = json['y']; - this.seriesNames = json['seriesNames']; - } - return ChartLine; -}(Chart)); -var ChartScatter = (function (_super) { - __extends(ChartScatter, _super); - function ChartScatter(jsonStr) { - _super.call(this, ComponentType.ChartScatter, jsonStr); - this.render = function (appendToObject) { - var nSeries = (!this.xData ? 0 : this.xData.length); - var s = this.getStyle(); - var margin = Style.getMargins(s); - var xScale = d3.scale.linear().range([0, margin.widthExMargins]); - var yScale = d3.scale.linear().range([margin.heightExMargins, 0]); - var xAxis = d3.svg.axis().scale(xScale) - .innerTickSize(-margin.heightExMargins) - .orient("bottom").ticks(5); - var yAxis = d3.svg.axis().scale(yScale) - .innerTickSize(-margin.widthExMargins) - .orient("left").ticks(5); - if (this.suppressAxisHorizontal === true) - xAxis.tickValues([]); - if (this.suppressAxisVertical === true) - yAxis.tickValues([]); - var svg = d3.select("#" + appendToObject.attr("id")) - .append("svg") - .style("stroke-width", (s && s.getStrokeWidth() ? s.getStrokeWidth() : 1)) - .style("fill", "none") - .attr("width", s.getWidth()) - .attr("height", s.getHeight()) - .attr("padding", "20px") - .append("g") - .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); - var xMin; - var xMax; - var yMin; - var yMax; - if (this.setXMin) - xMin = this.setXMin; - else - xMin = (this.xData ? TSUtils.min(this.xData) : 0); - if (this.setXMax) - xMax = this.setXMax; - else - xMax = (this.xData ? TSUtils.max(this.xData) : 1); - if (this.setYMin) - yMin = this.setYMin; - else - yMin = (this.yData ? TSUtils.min(this.yData) : 0); - if (this.setYMax) - yMax = this.setYMax; - else - yMax = (this.yData ? TSUtils.max(this.yData) : 1); - xScale.domain([xMin, xMax]); - yScale.domain([yMin, yMax]); - var defaultColor = d3.scale.category10(); - for (var i = 0; i < nSeries; i++) { - var xVals = this.xData[i]; - var yVals = this.yData[i]; - var data = xVals.map(function (d, i) { - return { 'xPos': xVals[i], 'yPos': yVals[i] }; - }); - svg.selectAll("circle") - .data(data) - .enter() - .append("circle") - .style("fill", (s && s.getSeriesColor(i) ? s.getSeriesColor(i) : defaultColor(String(i)))) - .attr("r", (s && s.getPointSize() ? s.getPointSize() : ChartConstants.DEFAULT_CHART_POINT_SIZE)) - .attr("cx", function (d) { - return xScale(d['xPos']); - }) - .attr("cy", function (d) { - return yScale(d['yPos']); - }); - } - var xAxisNode = svg.append("g") - .attr("class", "x axis") - .attr("transform", "translate(0," + margin.heightExMargins + ")") - .style("stroke", "#000") - .style("stroke-width", (s != null && s.getAxisStrokeWidth() != null ? s.getAxisStrokeWidth() : ChartConstants.DEFAULT_AXIS_STROKE_WIDTH)) - .style("fill", "none") - .call(xAxis); - xAxisNode.selectAll('text').style("stroke-width", 0).style("fill", "#000000"); - if (this.gridVerticalStrokeWidth != null) - xAxisNode.selectAll('.axis line').style({ 'stroke-width': this.gridVerticalStrokeWidth }); - var yAxisNode = svg.append("g") - .attr("class", "y axis") - .style("stroke", "#000") - .style("stroke-width", (s != null && s.getAxisStrokeWidth() != null ? s.getAxisStrokeWidth() : ChartConstants.DEFAULT_AXIS_STROKE_WIDTH)) - .style("fill", "none") - .call(yAxis); - yAxisNode.selectAll('text').style("stroke-width", 0).style("fill", "#000000"); - if (this.gridHorizontalStrokeWidth != null) - yAxisNode.selectAll('.axis line').style({ 'stroke-width': this.gridHorizontalStrokeWidth }); - if (this.seriesNames && this.showLegend === true) { - var legendSpace = margin.widthExMargins / i; - for (var i = 0; i < nSeries; i++) { - var values = this.xData[i]; - var yValues = this.yData[i]; - var lastX = values[values.length - 1]; - var lastY = yValues[yValues.length - 1]; - var toDisplay; - if (!lastX || !lastY) - toDisplay = this.seriesNames[i] + " (no data)"; - else - toDisplay = this.seriesNames[i] + " (" + lastX.toPrecision(5) + "," + lastY.toPrecision(5) + ")"; - svg.append("text") - .attr("x", (legendSpace / 2) + i * legendSpace) - .attr("y", margin.heightExMargins + (margin.bottom / 2) + 5) - .attr("class", "legend") - .style("fill", (s && s.getSeriesColor(i) ? s.getSeriesColor(i) : defaultColor(String(i)))) - .text(toDisplay); - } - } - if (this.title) { - var titleStyle; - if (this.style) - titleStyle = this.style.getTitleStyle(); - Chart.appendTitle(svg, this.title, margin, titleStyle); - } - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.ChartScatter]]; - this.xData = json['x']; - this.yData = json['y']; - this.seriesNames = json['seriesNames']; - } - return ChartScatter; -}(Chart)); -var Legend = (function () { - function Legend() { - } - Legend.offsetX = 15; - Legend.offsetY = 15; - Legend.padding = 8; - Legend.separation = 12; - Legend.boxSize = 10; - Legend.fillColor = "#FFFFFF"; - Legend.legendOpacity = 0.75; - Legend.borderStrokeColor = "#000000"; - Legend.legendFn = (function (g) { - var svg = d3.select(g.property("nearestViewportElement")); - var legendBox = g.selectAll(".outerRect").data([true]); - var legendItems = g.selectAll(".legendElement").data([true]); - legendBox.enter().append("rect").attr("class", "outerRect"); - legendItems.enter().append("g").attr("class", "legendElement"); - var legendElements = []; - svg.selectAll("[data-legend]").each(function () { - var thisVar = d3.select(this); - legendElements.push({ - label: thisVar.attr("data-legend"), - color: thisVar.style("fill") - }); - }); - legendItems.selectAll("rect") - .data(legendElements, function (d) { return d.label; }) - .call(function (d) { d.enter().append("rect"); }) - .call(function (d) { d.exit().remove(); }) - .attr("x", 0) - .attr("y", function (d, i) { return i * Legend.separation - Legend.boxSize + "px"; }) - .attr("width", Legend.boxSize) - .attr("height", Legend.boxSize) - .style("fill", function (d) { return d.color; }); - legendItems.selectAll("text") - .data(legendElements, function (d) { return d.label; }) - .call(function (d) { d.enter().append("text"); }) - .call(function (d) { d.exit().remove(); }) - .attr("y", function (d, i) { return i * Legend.separation + "px"; }) - .attr("x", (Legend.padding + Legend.boxSize) + "px") - .text(function (d) { return d.label; }); - var legendBoundingBox = legendItems[0][0].getBBox(); - legendBox.attr("x", (legendBoundingBox.x - Legend.padding)) - .attr("y", (legendBoundingBox.y - Legend.padding)) - .attr("height", (legendBoundingBox.height + 2 * Legend.padding)) - .attr("width", (legendBoundingBox.width + 2 * Legend.padding)) - .style("fill", Legend.fillColor) - .style("stroke", Legend.borderStrokeColor) - .style("opacity", Legend.legendOpacity); - svg.selectAll(".legend").attr("transform", "translate(" + Legend.offsetX + "," + Legend.offsetY + ")"); - }); - return Legend; -}()); -var ChartStackedArea = (function (_super) { - __extends(ChartStackedArea, _super); - function ChartStackedArea(jsonStr) { - _super.call(this, ComponentType.ChartStackedArea, jsonStr); - this.render = function (appendToObject) { - var nSeries = (!this.xData ? 0 : this.xData.length); - var s = this.getStyle(); - var margin = Style.getMargins(s); - var xScale = d3.scale.linear().range([0, margin.widthExMargins]); - var yScale = d3.scale.linear().range([margin.heightExMargins, 0]); - var xAxis = d3.svg.axis().scale(xScale) - .orient("bottom").ticks(5); - if (this.gridVerticalStrokeWidth != null && this.gridVerticalStrokeWidth > 0) { - xAxis.innerTickSize(-margin.heightExMargins); - } - var yAxis = d3.svg.axis().scale(yScale) - .orient("left").ticks(5); - if (this.gridHorizontalStrokeWidth != null && this.gridHorizontalStrokeWidth > 0) { - yAxis.innerTickSize(-margin.widthExMargins); - } - if (this.suppressAxisHorizontal === true) - xAxis.tickValues([]); - if (this.suppressAxisVertical === true) - yAxis.tickValues([]); - var data = []; - for (var i = 0; i < this.xData.length; i++) { - var obj = {}; - for (var j = 0; j < this.labels.length; j++) { - obj[this.labels[j]] = this.yData[j][i]; - obj['xValue'] = this.xData[i]; - } - data.push(obj); - } - var area = d3.svg.area() - .x(function (d) { return xScale(d.xValue); }) - .y0(function (d) { return yScale(d.y0); }) - .y1(function (d) { return yScale(d.y0 + d.y); }); - var stack = d3.layout.stack() - .values(function (d) { return d.values; }); - var svg = d3.select("#" + appendToObject.attr("id")).append("svg") - .attr("width", margin.widthExMargins + margin.left + margin.right) - .attr("height", margin.heightExMargins + margin.top + margin.bottom) - .append("g") - .attr("transform", "translate(" + margin.left + "," + margin.top + ")"); - var color = d3.scale.category20(); - color.domain(d3.keys(data[0]).filter(function (key) { - return key !== "xValue"; - })); - var browsers = stack(color.domain().map(function (name) { - return { - name: name, - values: data.map(function (d) { - return { xValue: d.xValue, y: d[name] * 1 }; - }) - }; - })); - var maxX = d3.max(data, function (d) { - var vals = d3.keys(d).map(function (key) { - return key !== "xValue" ? d[key] : 0; - }); - return d3.sum(vals); - }); - xScale.domain(d3.extent(data, function (d) { - return d.xValue; - })); - yScale.domain([0, maxX]); - var browser = svg.selectAll(".browser") - .data(browsers) - .enter().append("g") - .attr("class", "browser"); - var tempLabels = this.labels; - var defaultColor = d3.scale.category20(); - browser.append("path") - .attr("class", "area") - .attr("data-legend", function (d) { return d.name; }) - .attr("d", function (d) { - return area(d.values); - }) - .style("fill", function (d) { - if (s && s.getSeriesColor(tempLabels.indexOf(d.name))) { - return s.getSeriesColor(tempLabels.indexOf(d.name)); - } - else { - return defaultColor(String(tempLabels.indexOf(d.name))); - } - }) - .style({ "stroke-width": "0px" }); - var xAxisNode = svg.append("g") - .attr("class", "x axis") - .style("stroke", "#000") - .style("stroke-width", (s != null && s.getAxisStrokeWidth() != null ? s.getAxisStrokeWidth() : ChartConstants.DEFAULT_AXIS_STROKE_WIDTH)) - .style("fill", "none") - .attr("transform", "translate(0," + margin.heightExMargins + ")") - .call(xAxis); - xAxisNode.selectAll('text').style("stroke-width", 0).style("fill", "#000000"); - var yAxisNode = svg.append("g") - .attr("class", "y axis") - .style("stroke", "#000") - .style("stroke-width", (s != null && s.getAxisStrokeWidth() != null ? s.getAxisStrokeWidth() : ChartConstants.DEFAULT_AXIS_STROKE_WIDTH)) - .style("fill", "none") - .call(yAxis); - yAxisNode.selectAll('text').style("stroke-width", 0).style("fill", "#000000"); - if (this.title) { - var titleStyle; - if (this.style) - titleStyle = this.style.getTitleStyle(); - Chart.appendTitle(svg, this.title, margin, titleStyle); - } - var legend = svg.append("g") - .attr("class", "legend") - .attr("transform", "translate(40,40)") - .style("font-size", "12px") - .call(Legend.legendFn); - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.ChartStackedArea]]; - this.xData = json['x']; - this.yData = json['y']; - this.labels = json['labels']; - } - return ChartStackedArea; -}(Chart)); -var ChartTimeline = (function (_super) { - __extends(ChartTimeline, _super); - function ChartTimeline(jsonStr) { - _super.call(this, ComponentType.ChartTimeline, jsonStr); - this.render = function (appendToObject) { - var instance = this; - var s = this.getStyle(); - var margin = Style.getMargins(s); - this.itemData = []; - var count = 0; - for (var i = 0; i < this.laneData.length; i++) { - for (var j = 0; j < this.laneData[i].length; j++) { - var obj = {}; - obj["start"] = this.laneData[i][j]["startTimeMs"]; - obj["end"] = this.laneData[i][j]["endTimeMs"]; - obj["id"] = count++; - obj["lane"] = i; - obj["color"] = this.laneData[i][j]["color"]; - obj["label"] = this.laneData[i][j]["entryLabel"]; - this.itemData.push(obj); - } - } - this.lanes = []; - for (var i = 0; i < this.laneNames.length; i++) { - var obj = {}; - obj["label"] = this.laneNames[i]; - obj["id"] = i; - this.lanes.push(obj); - } - var svg = d3.select("#" + appendToObject.attr("id")) - .append("svg") - .style("stroke-width", (s && s.getStrokeWidth() ? s.getStrokeWidth() : ChartConstants.DEFAULT_CHART_STROKE_WIDTH)) - .style("fill", "none") - .attr("width", s.getWidth()) - .attr("height", s.getHeight()) - .append("g"); - var heightExMargins = s.getHeight() - margin.top - margin.bottom; - var widthExMargins = s.getWidth() - margin.left - margin.right; - var miniHeight = this.laneNames.length * ChartTimeline.MINI_LANE_HEIGHT_PX; - var mainHeight = s.getHeight() - miniHeight - margin.top - margin.bottom - 25; - var minTime = d3.min(this.itemData, function (d) { return d.start; }); - var maxTime = d3.max(this.itemData, function (d) { return d.end; }); - this.x = d3.time.scale() - .domain([minTime, maxTime]) - .range([0, widthExMargins]); - this.x1 = d3.time.scale().range([0, widthExMargins]); - this.y1 = d3.scale.linear().domain([0, this.laneNames.length]).range([0, mainHeight]); - this.y2 = d3.scale.linear().domain([0, this.laneNames.length]).range([0, miniHeight]); - this.rect = svg.append('defs').append('clipPath') - .attr('id', 'clip') - .append('rect') - .attr('width', widthExMargins) - .attr('height', s.getHeight() - 100); - this.mainView = svg.append('g') - .attr('transform', 'translate(' + margin.left + ',' + margin.top + ')') - .attr('width', widthExMargins) - .attr('height', mainHeight) - .attr('font-size', '12px') - .attr('font', 'sans-serif'); - this.miniView = svg.append('g') - .attr('transform', 'translate(' + margin.left + ',' + (mainHeight + margin.top + 25) + ')') - .attr('width', widthExMargins) - .attr('height', miniHeight) - .attr('font-size', '10px') - .attr('font', 'sans-serif'); - this.mainView.append('g').selectAll('.laneLines') - .data(this.lanes) - .enter().append('line') - .attr('x1', 0) - .attr('y1', function (d) { - return d3.round(instance.y1(d.id)) + 0.5; - }) - .attr('x2', widthExMargins) - .attr('y2', function (d) { - return d3.round(instance.y1(d.id)) + 0.5; - }) - .attr('stroke', 'lightgray') - .attr('stroke-width', 1); - this.mainView.append('g').selectAll('.laneText') - .data(this.lanes) - .enter().append('text') - .text(function (d) { - if (d.label) - return d.label; - return ""; - }) - .attr('x', -10) - .attr('y', function (d) { - return instance.y1(d.id + .5); - }) - .attr('text-anchor', 'end') - .attr("font", "8pt sans-serif") - .attr('fill', 'black'); - this.miniView.append('g').selectAll('.laneLines') - .data(this.lanes) - .enter().append('line') - .attr('x1', 0) - .attr('y1', function (d) { return d3.round(instance.y2(d.id)) + 0.5; }) - .attr('x2', widthExMargins) - .attr('y2', function (d) { return d3.round(instance.y2(d.id)) + 0.5; }) - .attr('stroke', 'gray') - .attr('stroke-width', 1.0); - this.miniView.append('g').selectAll('.laneText') - .data(this.lanes) - .enter().append('text') - .text(function (d) { - if (d.label) - return d.label; - return ""; - }) - .attr('x', -10) - .attr('y', function (d) { - return instance.y2(d.id + .5); - }) - .attr('dy', '0.5ex') - .attr('text-anchor', 'end') - .attr('fill', 'black'); - this.xTimeAxis = d3.svg.axis() - .scale(this.x1) - .orient('bottom') - .ticks(d3.time.days, 1) - .tickFormat(d3.time.format('%a %d')) - .tickSize(6, 0); - var temp = this.mainView.append('g') - .attr('transform', 'translate(0,' + mainHeight + ')') - .attr('class', 'timeAxis') - .attr('fill', 'black') - .style("stroke", "black").style("stroke-width", 1.0).style("fill", "black") - .attr("font", "10px sans-serif") - .call(this.xTimeAxis); - temp.selectAll('text').style("stroke-width", 0.0).attr('stroke-width', 0.0); - this.itemRects = this.mainView.append('g') - .attr('clip-path', 'url(#clip)'); - this.miniView.append('g').selectAll('miniItems') - .data(this.getMiniViewPaths(this.itemData)) - .enter().append('path') - .attr('class', function (d) { - return 'miniItem ' + d.class; - }) - .attr('d', function (d) { - return d.path; - }) - .attr('stroke', 'black') - .attr('stroke-width', 'black'); - this.miniView.append('rect') - .attr('pointer-events', 'painted') - .attr('width', widthExMargins) - .attr('height', miniHeight) - .attr('visibility', 'hidden') - .on('mouseup', this.moveBrush); - this.brush = d3.svg.brush() - .x(this.x) - .extent([minTime, maxTime]) - .on("brush", this.renderChart); - this.miniView.append('g') - .attr('class', 'x brush') - .call(this.brush) - .selectAll('rect') - .attr('y', 1) - .attr('height', miniHeight - 1) - .style('fill', 'gray') - .style('fill-opacity', '0.2') - .style('stroke', 'DarkSlateGray') - .style('stroke-width', 1); - this.miniView.selectAll('rect.background').remove(); - this.renderChart(); - if (this.title) { - var titleStyle; - if (this.style) - titleStyle = this.style.getTitleStyle(); - var text = svg.append("text") - .text(this.title) - .attr("x", (s.getWidth() / 2)) - .attr("y", ((margin.top - 30) / 2)) - .attr("text-anchor", "middle"); - if (titleStyle) { - if (titleStyle.getFont()) - text.attr("font-family", titleStyle.getFont); - if (titleStyle.getFontSize() != null) - text.attr("font-size", titleStyle.getFontSize() + "pt"); - if (titleStyle.getUnderline() != null) - text.style("text-decoration", "underline"); - if (titleStyle.getColor()) - text.style("fill", titleStyle.getColor); - else - text.style("fill", ChartConstants.DEFAULT_TITLE_COLOR); - } - else { - text.style("text-decoration", "underline"); - text.style("fill", ChartConstants.DEFAULT_TITLE_COLOR); - } - } - }; - this.renderChart = function () { - var instance = this; - var extent = this.brush.extent(); - var minExtent = extent[0]; - var maxExtent = extent[1]; - var visibleItems = this.itemData.filter(function (d) { - return d.start < maxExtent && d.end > minExtent; - }); - this.miniView.select('.brush').call(this.brush.extent([minExtent, maxExtent])); - this.x1.domain([minExtent, maxExtent]); - var range = maxExtent - minExtent; - if (range > 2 * ChartTimeline.MILLISEC_PER_WEEK) { - this.xTimeAxis.ticks(d3.time.mondays, 1).tickFormat(d3.time.format('%a %d')); - } - else if (range > 2 * ChartTimeline.MILLISEC_PER_DAY) { - this.xTimeAxis.ticks(d3.time.days, 1).tickFormat(d3.time.format('%a %d')); - } - else if (range > 2 * ChartTimeline.MILLISEC_PER_HOUR) { - this.xTimeAxis.ticks(d3.time.hours, 4).tickFormat(d3.time.format('%H %p')); - } - else if (range > 2 * ChartTimeline.MILLISEC_PER_MINUTE) { - this.xTimeAxis.ticks(d3.time.minutes, 1).tickFormat(d3.time.format('%H:%M')); - } - else if (range >= 30000) { - this.xTimeAxis.ticks(d3.time.seconds, 10).tickFormat(d3.time.format('%H:%M:%S')); - } - else { - this.xTimeAxis.ticks(d3.time.seconds, 1).tickFormat(d3.time.format('%H:%M:%S')); - } - this.mainView.select('.timeAxis').call(this.xTimeAxis); - var rects = this.itemRects.selectAll('rect') - .data(visibleItems, function (d) { return d.id; }) - .attr('x', function (d) { return instance.x1(d.start); }) - .attr('width', function (d) { return instance.x1(d.end) - instance.x1(d.start); }); - rects.enter().append('rect') - .attr('x', function (d) { return instance.x1(d.start); }) - .attr('y', function (d) { return instance.y1(d.lane) + ChartTimeline.ENTRY_LANE_HEIGHT_OFFSET_FRACTION * instance.y1(1) + 0.5; }) - .attr('width', function (d) { return instance.x1(d.end) - instance.x1(d.start); }) - .attr('height', function (d) { return ChartTimeline.ENTRY_LANE_HEIGHT_TOTAL_FRACTION * instance.y1(1); }) - .attr('stroke', 'black') - .attr('fill', function (d) { - if (d.color) - return d.color; - return ChartTimeline.DEFAULT_COLOR; - }) - .attr('stroke-width', 1); - rects.exit().remove(); - var labels = this.itemRects.selectAll('text') - .data(visibleItems, function (d) { - return d.id; - }) - .attr('x', function (d) { - return instance.x1(Math.max(d.start, minExtent)) + 2; - }) - .attr('fill', 'black'); - labels.enter().append('text') - .text(function (d) { - if (instance.x1(d.end) - instance.x1(d.start) <= 30) - return ""; - if (d.label) - return d.label; - return ""; - }) - .attr('x', function (d) { - return instance.x1(Math.max(d.start, minExtent)) + 2; - }) - .attr('y', function (d) { - return instance.y1(d.lane) + .4 * instance.y1(1) + 0.5; - }) - .attr('text-anchor', 'start') - .attr('class', 'itemLabel') - .attr('fill', 'black'); - labels.exit().remove(); - }; - this.moveBrush = function () { - var origin = d3.mouse(this.rect[0]); - var time = this.x.invert(origin[0]).getTime(); - var halfExtent = (this.brush.extent()[1].getTime() - this.brush.extent()[0].getTime()) / 2; - this.brush.extent([new Date(time - halfExtent), new Date(time + halfExtent)]); - this.renderChart(); - }; - this.getMiniViewPaths = function (items) { - var paths = {}, d, offset = .5 * this.y2(1) + 0.5, result = []; - for (var i = 0; i < items.length; i++) { - d = items[i]; - if (!paths[d.class]) - paths[d.class] = ''; - paths[d.class] += ['M', this.x(d.start), (this.y2(d.lane) + offset), 'H', this.x(d.end)].join(' '); - } - for (var className in paths) { - result.push({ class: className, path: paths[className] }); - } - return result; - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.ChartTimeline]]; - this.laneNames = json['laneNames']; - this.laneData = json['laneData']; - } - ChartTimeline.MINI_LANE_HEIGHT_PX = 12; - ChartTimeline.ENTRY_LANE_HEIGHT_OFFSET_FRACTION = 0.05; - ChartTimeline.ENTRY_LANE_HEIGHT_TOTAL_FRACTION = 0.90; - ChartTimeline.MILLISEC_PER_MINUTE = 60 * 1000; - ChartTimeline.MILLISEC_PER_HOUR = 60 * ChartTimeline.MILLISEC_PER_MINUTE; - ChartTimeline.MILLISEC_PER_DAY = 24 * ChartTimeline.MILLISEC_PER_HOUR; - ChartTimeline.MILLISEC_PER_WEEK = 7 * ChartTimeline.MILLISEC_PER_DAY; - ChartTimeline.DEFAULT_COLOR = "LightGrey"; - return ChartTimeline; -}(Chart)); -var StyleChart = (function (_super) { - __extends(StyleChart, _super); - function StyleChart(jsonObj) { - var _this = this; - _super.call(this, jsonObj['StyleChart']); - this.getStrokeWidth = function () { return _this.strokeWidth; }; - this.getPointSize = function () { return _this.pointSize; }; - this.getSeriesColors = function () { return _this.seriesColors; }; - this.getSeriesColor = function (idx) { - if (!this.seriesColors || idx < 0 || idx > this.seriesColors.length) - return null; - return _this.seriesColors[idx]; - }; - this.getAxisStrokeWidth = function () { return _this.axisStrokeWidth; }; - this.getTitleStyle = function () { return _this.titleStyle; }; - var style = jsonObj['StyleChart']; - if (style) { - this.strokeWidth = style['strokeWidth']; - this.pointSize = style['pointSize']; - this.seriesColors = style['seriesColors']; - if (style['titleStyle']) - this.titleStyle = new StyleText(style['titleStyle']); - } - } - return StyleChart; -}(Style)); -var ComponentDiv = (function (_super) { - __extends(ComponentDiv, _super); - function ComponentDiv(jsonStr) { - _super.call(this, ComponentType.ComponentDiv); - this.render = function (appendToObject) { - var newDiv = $('
      '); - newDiv.uniqueId(); - if (this.style) { - if (this.style.getWidth()) { - var unit = this.style.getWidthUnit(); - newDiv.width(this.style.getWidth() + (unit ? unit : "")); - } - if (this.style.getHeight()) { - var unit = this.style.getHeightUnit(); - newDiv.height(this.style.getHeight() + (unit ? unit : "")); - } - if (this.style.getBackgroundColor()) - newDiv.css("background-color", this.style.getBackgroundColor()); - if (this.style.getFloatValue()) - newDiv.css("float", this.style.getFloatValue()); - } - appendToObject.append(newDiv); - if (this.components) { - for (var i = 0; i < this.components.length; i++) { - this.components[i].render(newDiv); - } - } - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.ComponentDiv]]; - var components = json['components']; - if (components) { - this.components = []; - for (var i = 0; i < components.length; i++) { - var asStr = JSON.stringify(components[i]); - this.components.push(Component.getComponent(asStr)); - } - } - if (json['style']) - this.style = new StyleDiv(json['style']); - } - return ComponentDiv; -}(Component)); -var StyleDiv = (function (_super) { - __extends(StyleDiv, _super); - function StyleDiv(jsonObj) { - var _this = this; - _super.call(this, jsonObj['StyleDiv']); - this.getFloatValue = function () { return _this.floatValue; }; - if (jsonObj && jsonObj['StyleDiv']) - this.floatValue = jsonObj['StyleDiv']['floatValue']; - } - return StyleDiv; -}(Style)); -var DecoratorAccordion = (function (_super) { - __extends(DecoratorAccordion, _super); - function DecoratorAccordion(jsonStr) { - _super.call(this, ComponentType.DecoratorAccordion); - this.render = function (appendToObject) { - var s = this.style; - var outerDiv = $('
      '); - outerDiv.uniqueId(); - var titleDiv; - if (this.title) - titleDiv = $('
      ' + this.title + '
      '); - else - titleDiv = $('
      '); - titleDiv.uniqueId(); - outerDiv.append(titleDiv); - var innerDiv = $('
      '); - innerDiv.uniqueId(); - outerDiv.append(innerDiv); - if (this.innerComponents) { - for (var i = 0; i < this.innerComponents.length; i++) { - this.innerComponents[i].render(innerDiv); - } - } - appendToObject.append(outerDiv); - if (this.defaultCollapsed) - outerDiv.accordion({ collapsible: true, heightStyle: "content", active: false }); - else - outerDiv.accordion({ collapsible: true, heightStyle: "content" }); - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.DecoratorAccordion]]; - this.title = json['title']; - this.defaultCollapsed = json['defaultCollapsed']; - var innerCs = json['innerComponents']; - if (innerCs) { - this.innerComponents = []; - for (var i = 0; i < innerCs.length; i++) { - var asStr = JSON.stringify(innerCs[i]); - this.innerComponents.push(Component.getComponent(asStr)); - } - } - if (json['style']) - this.style = new StyleAccordion(json['style']); - } - return DecoratorAccordion; -}(Component)); -var StyleAccordion = (function (_super) { - __extends(StyleAccordion, _super); - function StyleAccordion(jsonObj) { - _super.call(this, jsonObj['StyleAccordion']); - } - return StyleAccordion; -}(Style)); -var ComponentTable = (function (_super) { - __extends(ComponentTable, _super); - function ComponentTable(jsonStr) { - _super.call(this, ComponentType.ComponentTable); - this.render = function (appendToObject) { - var s = this.style; - var margin = Style.getMargins(s); - var tbl = document.createElement('table'); - tbl.style.width = '100%'; - if (s && s.getBorderWidthPx() != null) - tbl.setAttribute('border', String(s.getBorderWidthPx())); - if (s && s.getBackgroundColor()) - tbl.style.backgroundColor = s.getBackgroundColor(); - if (s && s.getWhitespaceMode()) - tbl.style.whiteSpace = s.getWhitespaceMode(); - if (s && s.getColumnWidths()) { - var colWidths = s.getColumnWidths(); - var unit = TSUtils.normalizeLengthUnit(s.getColumnWidthUnit()); - for (var i = 0; i < colWidths.length; i++) { - var col = document.createElement('col'); - col.setAttribute('width', colWidths[i] + unit); - tbl.appendChild(col); - } - } - var padTop = 1; - var padRight = 1; - var padBottom = 1; - var padLeft = 1; - if (this.header) { - var theader = document.createElement('thead'); - var headerRow = document.createElement('tr'); - if (s && s.getHeaderColor()) - headerRow.style.backgroundColor = s.getHeaderColor(); - for (var i = 0; i < this.header.length; i++) { - var headerd = document.createElement('th'); - headerd.style.padding = padTop + 'px ' + padRight + 'px ' + padBottom + 'px ' + padLeft + 'px'; - headerd.appendChild(document.createTextNode(this.header[i])); - headerRow.appendChild(headerd); - } - tbl.appendChild(headerRow); - } - if (this.content) { - var tbdy = document.createElement('tbody'); - for (var i = 0; i < this.content.length; i++) { - var tr = document.createElement('tr'); - for (var j = 0; j < this.content[i].length; j++) { - var td = document.createElement('td'); - td.style.padding = padTop + 'px ' + padRight + 'px ' + padBottom + 'px ' + padLeft + 'px'; - td.appendChild(document.createTextNode(this.content[i][j])); - tr.appendChild(td); - } - tbdy.appendChild(tr); - } - tbl.appendChild(tbdy); - } - appendToObject.append(tbl); - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.ComponentTable]]; - this.header = json['header']; - this.content = json['content']; - if (json['style']) - this.style = new StyleTable(json['style']); - } - return ComponentTable; -}(Component)); -var StyleTable = (function (_super) { - __extends(StyleTable, _super); - function StyleTable(jsonObj) { - var _this = this; - _super.call(this, jsonObj['StyleTable']); - this.getColumnWidths = function () { return _this.columnWidths; }; - this.getColumnWidthUnit = function () { return _this.columnWidthUnit; }; - this.getBorderWidthPx = function () { return _this.borderWidthPx; }; - this.getHeaderColor = function () { return _this.headerColor; }; - this.getWhitespaceMode = function () { return _this.whitespaceMode; }; - var style = jsonObj['StyleTable']; - if (style) { - this.columnWidths = jsonObj['StyleTable']['columnWidths']; - this.borderWidthPx = jsonObj['StyleTable']['borderWidthPx']; - this.headerColor = jsonObj['StyleTable']['headerColor']; - this.columnWidthUnit = jsonObj['StyleTable']['columnWidthUnit']; - this.whitespaceMode = jsonObj['StyleTable']['whitespaceMode']; - } - } - return StyleTable; -}(Style)); -var ComponentText = (function (_super) { - __extends(ComponentText, _super); - function ComponentText(jsonStr) { - var _this = this; - _super.call(this, ComponentType.ComponentText); - this.render = function (appendToObject) { - var textNode = document.createTextNode(_this.text); - if (_this.style) { - var newSpan = document.createElement('span'); - if (_this.style.getFont()) - newSpan.style.font = _this.style.getFont(); - if (_this.style.getFontSize() != null) - newSpan.style.fontSize = _this.style.getFontSize() + "pt"; - if (_this.style.getUnderline() != null) - newSpan.style.textDecoration = 'underline'; - if (_this.style.getColor()) - newSpan.style.color = _this.style.getColor(); - if (_this.style.getMarginTop()) - newSpan.style.marginTop = _this.style.getMarginTop() + "px"; - if (_this.style.getMarginBottom()) - newSpan.style.marginBottom = _this.style.getMarginBottom() + "px"; - if (_this.style.getMarginLeft()) - newSpan.style.marginLeft = _this.style.getMarginLeft() + "px"; - if (_this.style.getMarginRight()) - newSpan.style.marginRight = _this.style.getMarginRight() + "px"; - if (_this.style.getWhitespacePre()) - newSpan.style.whiteSpace = 'pre'; - newSpan.appendChild(textNode); - appendToObject.append(newSpan); - } - else { - var newSpan = document.createElement('span'); - newSpan.appendChild(textNode); - appendToObject.append(newSpan); - } - }; - var json = JSON.parse(jsonStr); - if (!json["componentType"]) - json = json[ComponentType[ComponentType.ComponentText]]; - this.text = json['text']; - if (json['style']) - this.style = new StyleText(json['style']); - } - return ComponentText; -}(Component)); -var StyleText = (function (_super) { - __extends(StyleText, _super); - function StyleText(jsonObj) { - var _this = this; - _super.call(this, jsonObj['StyleText']); - this.getFont = function () { return _this.font; }; - this.getFontSize = function () { return _this.fontSize; }; - this.getUnderline = function () { return _this.underline; }; - this.getColor = function () { return _this.color; }; - this.getWhitespacePre = function () { return _this.whitespacePre; }; - var style = jsonObj['StyleText']; - if (style) { - this.font = style['font']; - this.fontSize = style['fontSize']; - this.underline = style['underline']; - this.color = style['color']; - this.whitespacePre = style['whitespacePre']; - } - } - return StyleText; -}(Style)); -//# sourceMappingURL=dl4j-ui.js.map \ No newline at end of file diff --git a/arbiter/arbiter-ui/src/main/resources/deeplearning4jUiAssets/dl4j-ui.js.map b/arbiter/arbiter-ui/src/main/resources/deeplearning4jUiAssets/dl4j-ui.js.map deleted file mode 100644 index 3545aed31..000000000 --- a/arbiter/arbiter-ui/src/main/resources/deeplearning4jUiAssets/dl4j-ui.js.map +++ /dev/null @@ -1 +0,0 @@ -{"version":3,"file":"dl4j-ui.js","sourceRoot":"","sources":["../../typescript/org/deeplearning4j/ui/api/Style.ts","../../typescript/org/deeplearning4j/ui/api/ComponentType.ts","../../typescript/org/deeplearning4j/ui/api/Component.ts","../../typescript/org/deeplearning4j/ui/api/Constants.ts","../../typescript/org/deeplearning4j/ui/api/Margin.ts","../../typescript/org/deeplearning4j/ui/api/Renderable.ts","../../typescript/org/deeplearning4j/ui/util/TSUtils.ts","../../typescript/org/deeplearning4j/ui/components/chart/Chart.ts","../../typescript/org/deeplearning4j/ui/components/chart/ChartHistogram.ts","../../typescript/org/deeplearning4j/ui/components/chart/ChartLine.ts","../../typescript/org/deeplearning4j/ui/components/chart/ChartScatter.ts","../../typescript/org/deeplearning4j/ui/components/chart/Legend.ts","../../typescript/org/deeplearning4j/ui/components/chart/ChartStackedArea.ts","../../typescript/org/deeplearning4j/ui/components/chart/ChartTimeline.ts","../../typescript/org/deeplearning4j/ui/components/chart/style/StyleChart.ts","../../typescript/org/deeplearning4j/ui/components/component/ComponentDiv.ts","../../typescript/org/deeplearning4j/ui/components/component/style/StyleDiv.ts","../../typescript/org/deeplearning4j/ui/components/decorator/DecoratorAccordion.ts","../../typescript/org/deeplearning4j/ui/components/decorator/style/StyleAccordion.ts","../../typescript/org/deeplearning4j/ui/components/table/ComponentTable.ts","../../typescript/org/deeplearning4j/ui/components/table/style/StyleTable.ts","../../typescript/org/deeplearning4j/ui/components/text/ComponentText.ts","../../typescript/org/deeplearning4j/ui/components/text/style/StyleText.ts"],"names":[],"mappings":";;;;;AAkBA;IAcI,eAAa,OAAY;QAd7B,iBAmDC;QAzBG,aAAQ,GAAG,cAAM,OAAA,KAAI,CAAC,KAAK,EAAV,CAAU,CAAC;QAC5B,cAAS,GAAG,cAAM,OAAA,KAAI,CAAC,MAAM,EAAX,CAAW,CAAC;QAC9B,iBAAY,GAAG,cAAM,OAAA,KAAI,CAAC,SAAS,EAAd,CAAc,CAAC;QACpC,kBAAa,GAAG,cAAM,OAAA,KAAI,CAAC,UAAU,EAAf,CAAe,CAAC;QACtC,iBAAY,GAAG,cAAM,OAAA,KAAI,CAAC,SAAS,EAAd,CAAc,CAAC;QACpC,oBAAe,GAAG,cAAM,OAAA,KAAI,CAAC,YAAY,EAAjB,CAAiB,CAAC;QAC1C,kBAAa,GAAG,cAAM,OAAA,KAAI,CAAC,UAAU,EAAf,CAAe,CAAC;QACtC,mBAAc,GAAG,cAAM,OAAA,KAAI,CAAC,WAAW,EAAhB,CAAgB,CAAC;QACxC,uBAAkB,GAAG,cAAM,OAAA,KAAI,CAAC,eAAe,EAApB,CAAoB,CAAC;QAnB5C,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,OAAO,CAAC,CAAC;QAC9B,IAAI,CAAC,MAAM,GAAG,OAAO,CAAC,QAAQ,CAAC,CAAC;QAChC,IAAI,CAAC,SAAS,GAAG,OAAO,CAAC,mBAAmB,CAAC,OAAO,CAAC,WAAW,CAAC,CAAC,CAAC;QACnE,IAAI,CAAC,UAAU,GAAG,OAAO,CAAC,mBAAmB,CAAC,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC;QACrE,IAAI,CAAC,SAAS,GAAG,OAAO,CAAC,WAAW,CAAC,CAAC;QACtC,IAAI,CAAC,YAAY,GAAG,OAAO,CAAC,cAAc,CAAC,CAAC;QAC5C,IAAI,CAAC,UAAU,GAAG,OAAO,CAAC,YAAY,CAAC,CAAC;QACxC,IAAI,CAAC,WAAW,GAAG,OAAO,CAAC,aAAa,CAAC,CAAC;QAC1C,IAAI,CAAC,eAAe,GAAG,OAAO,CAAC,iBAAiB,CAAC,CAAC;IACtD,CAAC;IAaM,gBAAU,GAAjB,UAAkB,CAAQ;QACtB,IAAI,IAAI,GAAW,CAAC,CAAC,GAAG,CAAC,CAAC,YAAY,EAAE,GAAG,CAAC,CAAC,CAAC;QAC9C,IAAI,OAAO,GAAW,CAAC,CAAC,GAAG,CAAC,CAAC,eAAe,EAAE,GAAG,CAAC,CAAC,CAAC;QACpD,IAAI,KAAK,GAAW,CAAC,CAAC,GAAG,CAAC,CAAC,aAAa,EAAE,GAAG,CAAC,CAAC,CAAC;QAChD,IAAI,MAAM,GAAW,CAAC,CAAC,GAAG,CAAC,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC,CAAC;QAGlD,MAAM,CAAC,EAAC,GAAG,EAAE,IAAI;YACb,KAAK,EAAE,MAAM;YACb,MAAM,EAAE,OAAO;YACf,IAAI,EAAE,KAAK;YACX,cAAc,EAAE,CAAC,CAAC,QAAQ,EAAE,GAAG,KAAK,GAAG,MAAM;YAC7C,eAAe,EAAE,CAAC,CAAC,SAAS,EAAE,GAAG,IAAI,GAAG,OAAO,EAAC,CAAC;IACzD,CAAC;IACL,YAAC;AAAD,CAAC,AAnDD,IAmDC;ACjDD,IAAK,aAWJ;AAXD,WAAK,aAAa;IACd,mEAAa,CAAA;IACb,qEAAc,CAAA;IACd,iEAAY,CAAA;IACZ,qEAAc,CAAA;IACd,6EAAkB,CAAA;IAClB,2DAAS,CAAA;IACT,iEAAY,CAAA;IACZ,yEAAgB,CAAA;IAChB,mEAAa,CAAA;IACb,6EAAkB,CAAA;AACtB,CAAC,EAXI,aAAa,KAAb,aAAa,QAWjB;ACTD;IAII,mBAAY,aAA4B;QACpC,IAAI,CAAC,aAAa,GAAG,aAAa,CAAC;IACvC,CAAC;IAEM,oCAAgB,GAAvB;QACI,MAAM,CAAC,IAAI,CAAC,aAAa,CAAC;IAC9B,CAAC;IAKa,sBAAY,GAA1B,UAA2B,OAAe;QAEtC,IAAI,IAAI,GAAQ,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QACpC,IAAI,GAAW,CAAC;QAChB,EAAE,CAAA,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,GAAG,GAAG,IAAI,CAAC,eAAe,CAAC,CAAC;QACtD,IAAI;YAAC,GAAG,GAAG,MAAM,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;QAIhC,MAAM,CAAA,CAAC,GAAG,CAAC,CAAA,CAAC;YACR,KAAK,aAAa,CAAC,aAAa,CAAC,aAAa,CAAC;gBAC3C,MAAM,CAAC,IAAI,aAAa,CAAC,OAAO,CAAC,CAAC;YAEtC,KAAK,aAAa,CAAC,aAAa,CAAC,cAAc,CAAC;gBAC5C,MAAM,CAAC,IAAI,cAAc,CAAC,OAAO,CAAC,CAAC;YAEvC,KAAK,aAAa,CAAC,aAAa,CAAC,cAAc,CAAC;gBAC5C,MAAM,CAAC,IAAI,cAAc,CAAC,OAAO,CAAC,CAAC;YAEvC,KAAK,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC;gBAChD,MAAM,IAAI,KAAK,CAAC,2CAA2C,CAAC,CAAC;YAEjE,KAAK,aAAa,CAAC,aAAa,CAAC,SAAS,CAAC;gBACvC,MAAM,CAAC,IAAI,SAAS,CAAC,OAAO,CAAC,CAAC;YAElC,KAAK,aAAa,CAAC,aAAa,CAAC,YAAY,CAAC;gBAC1C,MAAM,CAAC,IAAI,YAAY,CAAC,OAAO,CAAC,CAAC;YAErC,KAAK,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC;gBAC9C,MAAM,CAAC,IAAI,gBAAgB,CAAC,OAAO,CAAC,CAAC;YAEzC,KAAK,aAAa,CAAC,aAAa,CAAC,aAAa,CAAC;gBAC3C,MAAM,CAAC,IAAI,aAAa,CAAC,OAAO,CAAC,CAAC;YAEtC,KAAK,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC;gBAChD,MAAM,CAAC,IAAI,kBAAkB,CAAC,OAAO,CAAC,CAAC;YAE3C,KAAK,aAAa,CAAC,aAAa,CAAC,YAAY,CAAC;gBAC1C,MAAM,CAAC,IAAI,YAAY,CAAC,OAAO,CAAC,CAAC;YAErC;gBACI,MAAM,IAAI,KAAK,CAAC,2BAA2B,GAAG,GAAG,GAAG,wBAAwB,GAAG,OAAO,GAAG,IAAI,CAAC,CAAC;QACvG,CAAC;IACL,CAAC;IACL,gBAAC;AAAD,CAAC,AA3DD,IA2DC;AChED;IAAA;IAMA,CAAC;IAJU,yCAA0B,GAAG,GAAG,CAAC;IACjC,uCAAwB,GAAG,GAAG,CAAC;IAC/B,wCAAyB,GAAG,GAAG,CAAC;IAChC,kCAAmB,GAAG,SAAS,CAAC;IAC3C,qBAAC;AAAD,CAAC,AAND,IAMC;AGJD;IAAA;IA6CA,CAAC;IA1CU,WAAG,GAAV,UAAW,KAAiB;QACxB,IAAI,GAAG,GAAW,CAAC,MAAM,CAAC,SAAS,CAAC;QACpC,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACpC,GAAG,CAAA,CAAE,IAAI,CAAC,GAAC,CAAC,EAAE,CAAC,GAAC,KAAK,CAAC,CAAC,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBACnC,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,EAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACpC,CAAC;QACL,CAAC;QACD,MAAM,CAAC,GAAG,CAAC;IACf,CAAC;IAGM,WAAG,GAAV,UAAW,KAAiB;QACxB,IAAI,GAAG,GAAW,MAAM,CAAC,SAAS,CAAC;QACnC,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YACpC,GAAG,CAAA,CAAE,IAAI,CAAC,GAAC,CAAC,EAAE,CAAC,GAAC,KAAK,CAAC,CAAC,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBACnC,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,GAAG,EAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YACpC,CAAC;QACL,CAAC;QACD,MAAM,CAAC,GAAG,CAAC;IACf,CAAC;IAGM,2BAAmB,GAA1B,UAA2B,KAAa;QACpC,EAAE,CAAA,CAAC,KAAK,IAAI,IAAI,CAAC;YAAC,MAAM,CAAC,KAAK,CAAC;QAE/B,MAAM,CAAA,CAAC,KAAK,CAAC,WAAW,EAAE,CAAC,CAAA,CAAC;YACxB,KAAK,IAAI;gBACL,MAAM,CAAC,IAAI,CAAC;YAChB,KAAK,SAAS,CAAC;YACf,KAAK,GAAG;gBACJ,MAAM,CAAC,GAAG,CAAC;YACf,KAAK,IAAI;gBACL,MAAM,CAAC,IAAI,CAAC;YAChB,KAAK,IAAI;gBACL,MAAM,CAAC,IAAI,CAAC;YAChB,KAAK,IAAI;gBACL,MAAM,CAAC,IAAI,CAAC;YAChB;gBACI,MAAM,CAAC,KAAK,CAAC;QACrB,CAAC;IAEL,CAAC;IACL,cAAC;AAAD,CAAC,AA7CD,IA6CC;ACxCD;IAA6B,yBAAS;IAiBlC,eAAY,aAA4B,EAAE,OAAe;QACrD,kBAAM,aAAa,CAAC,CAAC;QAErB,IAAI,QAAQ,GAAQ,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QACxC,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,CAAC,CAAC;QAErE,IAAI,CAAC,sBAAsB,GAAG,IAAI,CAAC,wBAAwB,CAAC,CAAC;QAC7D,IAAI,CAAC,oBAAoB,GAAG,IAAI,CAAC,sBAAsB,CAAC,CAAC;QACzD,IAAI,CAAC,UAAU,GAAG,IAAI,CAAC,YAAY,CAAC,CAAC;QAErC,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC;QAC3B,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/B,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,SAAS,CAAC,CAAC;QAE/B,IAAI,CAAC,uBAAuB,GAAG,IAAI,CAAC,yBAAyB,CAAC,CAAC;QAC/D,IAAI,CAAC,yBAAyB,GAAG,IAAI,CAAC,2BAA2B,CAAC,CAAC;QAEnE,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAAC,IAAI,CAAC,KAAK,GAAG,IAAI,UAAU,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC;IACjE,CAAC;IAED,wBAAQ,GAAR;QACI,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC;IACtB,CAAC;IAEgB,iBAAW,GAA5B,UAA6B,GAAQ,EAAE,KAAa,EAAE,MAAc,EAAE,UAAqB;QACvF,IAAI,IAAI,GAAG,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC;aACxB,IAAI,CAAC,KAAK,CAAC;aACX,IAAI,CAAC,GAAG,EAAE,CAAC,MAAM,CAAC,cAAc,GAAG,CAAC,CAAC,CAAC;aACtC,IAAI,CAAC,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,CAAC,GAAG,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC;aACtC,IAAI,CAAC,aAAa,EAAE,QAAQ,CAAC,CAAC;QAEnC,EAAE,CAAA,CAAC,UAAU,CAAC,CAAA,CAAC;YACX,EAAE,CAAA,CAAC,UAAU,CAAC,OAAO,EAAE,CAAC;gBAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAC,UAAU,CAAC,OAAO,CAAC,CAAC;YACrE,EAAE,CAAA,CAAC,UAAU,CAAC,WAAW,EAAE,IAAI,IAAI,CAAC;gBAAC,IAAI,CAAC,IAAI,CAAC,WAAW,EAAC,UAAU,CAAC,WAAW,EAAE,GAAC,IAAI,CAAC,CAAC;YAC1F,EAAE,CAAA,CAAC,UAAU,CAAC,YAAY,EAAE,IAAI,IAAI,CAAC;gBAAC,IAAI,CAAC,KAAK,CAAC,iBAAiB,EAAE,WAAW,CAAC,CAAC;YACjF,EAAE,CAAA,CAAC,UAAU,CAAC,QAAQ,EAAE,CAAC;gBAAC,IAAI,CAAC,KAAK,CAAC,MAAM,EAAC,UAAU,CAAC,QAAQ,CAAC,CAAC;YACjE,IAAI;gBAAC,IAAI,CAAC,KAAK,CAAC,MAAM,EAAC,cAAc,CAAC,mBAAmB,CAAC,CAAC;QAC/D,CAAC;QAAC,IAAI,CAAC,CAAC;YACJ,IAAI,CAAC,KAAK,CAAC,iBAAiB,EAAE,WAAW,CAAC,CAAC;YAC3C,IAAI,CAAC,KAAK,CAAC,MAAM,EAAC,cAAc,CAAC,mBAAmB,CAAC,CAAC;QAC1D,CAAC;IACL,CAAC;IACL,YAAC;AAAD,CAAC,AA9DD,CAA6B,SAAS,GA8DrC;AChED;IAA6B,kCAAK;IAM9B,wBAAY,OAAe;QACvB,kBAAM,aAAa,CAAC,cAAc,EAAE,OAAO,CAAC,CAAC;QAYjD,WAAM,GAAG,UAAC,cAAsB;YAC5B,IAAI,CAAC,GAAe,IAAI,CAAC,QAAQ,EAAE,CAAC;YACpC,IAAI,MAAM,GAAW,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YAGzC,IAAI,IAAY,CAAC;YACjB,IAAI,IAAY,CAAC;YACjB,IAAI,IAAY,CAAC;YACjB,IAAI,IAAY,CAAC;YACjB,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YACrC,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,CAAC,WAAW,CAAC,GAAG,CAAC,CAAC,CAAC;YAC9D,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YACrC,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,WAAW,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,CAAC,WAAW,CAAC,GAAG,CAAC,CAAC,CAAC;YAC9D,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YACrC,IAAI;gBAAC,IAAI,GAAG,CAAC,CAAC;YACd,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YACrC,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,OAAO,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,CAAC,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC;YAGtD,IAAI,MAAM,GAAQ,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE;iBAC9B,MAAM,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC;iBACpB,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,cAAc,CAAC,CAAC,CAAC;YAEvC,IAAI,KAAK,GAAQ,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC;iBACvC,MAAM,CAAC,QAAQ,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAE/B,EAAE,CAAA,CAAC,IAAI,CAAC,uBAAuB,IAAI,IAAI,CAAC,uBAAuB,GAAG,CAAC,CAAC,CAAA,CAAC;gBACjE,KAAK,CAAC,aAAa,CAAC,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;YACjD,CAAC;YAED,IAAI,MAAM,GAAQ,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE;iBAC9B,MAAM,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC;iBACjB,KAAK,CAAC,CAAC,MAAM,CAAC,eAAe,EAAE,CAAC,CAAC,CAAC,CAAC;YACxC,IAAI,KAAK,GAAQ,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC;iBACvC,MAAM,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC7B,EAAE,CAAA,CAAC,IAAI,CAAC,yBAAyB,IAAI,IAAI,CAAC,yBAAyB,GAAG,CAAC,CAAC,CAAA,CAAC;gBACrE,KAAK,CAAC,aAAa,CAAC,CAAC,MAAM,CAAC,cAAc,CAAC,CAAC;YAChD,CAAC;YAID,EAAE,CAAA,CAAC,IAAI,CAAC,sBAAsB,KAAK,IAAI,CAAC;gBAAC,KAAK,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;YAE9D,EAAE,CAAA,CAAC,IAAI,CAAC,oBAAoB,KAAK,IAAI,CAAC;gBAAC,KAAK,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;YAG5D,IAAI,WAAW,GAAa,IAAI,CAAC,WAAW,CAAC;YAC7C,IAAI,WAAW,GAAa,IAAI,CAAC,WAAW,CAAC;YAC7C,IAAI,OAAO,GAAa,IAAI,CAAC,OAAO,CAAC;YAErC,IAAI,IAAI,GAAQ,WAAW,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE,CAAC;gBAC1C,MAAM,CAAC,EAAC,OAAO,EAAE,WAAW,CAAC,CAAC,CAAC,GAAG,WAAW,CAAC,CAAC,CAAC,EAAE,QAAQ,EAAE,OAAO,CAAC,CAAC,CAAC,EAAE,QAAQ,EAAE,WAAW,CAAC,CAAC,CAAC,EAAC,CAAC;YACtG,CAAC,CAAC,CAAC;YAGH,IAAI,GAAG,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,GAAG,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;iBAC/C,MAAM,CAAC,KAAK,CAAC;iBACb,KAAK,CAAC,MAAM,EAAE,MAAM,CAAC;iBACrB,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC,QAAQ,EAAE,CAAC;iBAC3B,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,SAAS,EAAE,CAAC;iBAC7B,IAAI,CAAC,SAAS,EAAE,MAAM,CAAC;iBACvB,MAAM,CAAC,GAAG,CAAC;iBACX,IAAI,CAAC,WAAW,EACb,YAAY,GAAG,MAAM,CAAC,IAAI,GAAG,GAAG,GAAG,MAAM,CAAC,GAAG,GAAG,GAAG,CAAC,CAAC;YAI7D,GAAG,CAAC,SAAS,CAAC,MAAM,CAAC;iBAChB,IAAI,CAAC,IAAI,CAAC;iBACV,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;iBACtB,IAAI,CAAC,OAAO,EAAE,KAAK,CAAC;iBACpB,KAAK,CAAC,MAAM,EAAC,WAAW,CAAC;iBACzB,IAAI,CAAC,GAAG,EAAE,UAAS,CAAM,IAAI,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;iBACxD,IAAI,CAAC,OAAO,EAAE,UAAS,CAAM,IAAI,MAAM,CAAC,MAAM,CAAC,IAAI,GAAG,CAAC,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;iBACtE,IAAI,CAAC,GAAG,EAAE,UAAS,CAAM,IAAI,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;iBACxD,IAAI,CAAC,QAAQ,EAAE,UAAS,CAAM,IAAI,MAAM,CAAC,MAAM,CAAC,eAAe,GAAG,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAG5F,IAAI,SAAS,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;iBACvB,IAAI,CAAC,WAAW,EAAE,cAAc,GAAG,MAAM,CAAC,eAAe,GAAG,GAAG,CAAC;iBAChE,KAAK,CAAC,QAAQ,EAAC,MAAM,CAAC;iBACtB,KAAK,CAAC,cAAc,EAAE,CAAC,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,kBAAkB,EAAE,IAAI,IAAI,GAAG,CAAC,CAAC,kBAAkB,EAAE,GAAG,cAAc,CAAC,yBAAyB,CAAC,CAAC;iBACxI,KAAK,CAAC,MAAM,EAAC,MAAM,CAAC;iBACpB,IAAI,CAAC,KAAK,CAAC,CAAC;YACjB,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAC,SAAS,CAAC,CAAC;YAE5E,EAAE,CAAA,CAAC,IAAI,CAAC,uBAAuB,IAAI,IAAI,CAAC;gBAAC,SAAS,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,KAAK,CAAC,EAAC,cAAc,EAAE,IAAI,CAAC,uBAAuB,EAAC,CAAC,CAAC;YAGjI,IAAI,SAAS,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;iBACvB,KAAK,CAAC,QAAQ,EAAC,MAAM,CAAC;iBACtB,KAAK,CAAC,cAAc,EAAE,CAAC,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,kBAAkB,EAAE,IAAI,IAAI,GAAG,CAAC,CAAC,kBAAkB,EAAE,GAAG,cAAc,CAAC,yBAAyB,CAAC,CAAC;iBACxI,KAAK,CAAC,MAAM,EAAC,MAAM,CAAC;iBACpB,IAAI,CAAC,KAAK,CAAC,CAAC;YACjB,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAC,SAAS,CAAC,CAAC;YAE5E,EAAE,CAAA,CAAC,IAAI,CAAC,yBAAyB,IAAI,IAAI,CAAC;gBAAC,SAAS,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,KAAK,CAAC,EAAC,cAAc,EAAE,IAAI,CAAC,yBAAyB,EAAC,CAAC,CAAC;YAGrI,EAAE,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;gBACb,IAAI,UAAqB,CAAC;gBAC1B,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC;oBAAC,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC;gBACvD,KAAK,CAAC,WAAW,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,EAAE,MAAM,EAAE,UAAU,CAAC,CAAC;YAC3D,CAAC;QACL,CAAC,CAAA;QApHG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,cAAc,CAAC,CAAC,CAAC;QAGpF,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,aAAa,CAAC,CAAC;QACvC,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,aAAa,CAAC,CAAC;QACvC,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,SAAS,CAAC,CAAC;IACnC,CAAC;IA8GL,qBAAC;AAAD,CAAC,AA9HD,CAA6B,KAAK,GA8HjC;AC9HD;IAAwB,6BAAK;IAMzB,mBAAY,OAAe;QACvB,kBAAM,aAAa,CAAC,SAAS,EAAE,OAAO,CAAC,CAAC;QAU5C,WAAM,GAAG,UAAC,cAAsB;YAE5B,IAAI,OAAO,GAAW,CAAC,CAAC,IAAI,CAAC,KAAK,GAAG,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;YAC5D,IAAI,CAAC,GAAe,IAAI,CAAC,QAAQ,EAAE,CAAC;YACpC,IAAI,MAAM,GAAW,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YAGzC,IAAI,MAAM,GAAmC,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,cAAc,CAAC,CAAC,CAAC;YACjG,IAAI,MAAM,GAAmC,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,eAAe,EAAE,CAAC,CAAC,CAAC,CAAC;YAGlG,IAAI,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC;iBAClC,MAAM,CAAC,QAAQ,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC/B,EAAE,CAAA,CAAC,IAAI,CAAC,uBAAuB,IAAI,IAAI,IAAI,IAAI,CAAC,uBAAuB,GAAG,CAAC,CAAC,CAAA,CAAC;gBACzE,KAAK,CAAC,aAAa,CAAC,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;YACjD,CAAC;YAGD,IAAI,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC;iBAClC,MAAM,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC7B,EAAE,CAAA,CAAC,IAAI,CAAC,yBAAyB,IAAI,IAAI,IAAI,IAAI,CAAC,yBAAyB,GAAG,CAAC,CAAC,CAAA,CAAC;gBAC7E,KAAK,CAAC,aAAa,CAAC,CAAC,MAAM,CAAC,cAAc,CAAC,CAAC;YAChD,CAAC;YAED,EAAE,CAAA,CAAC,IAAI,CAAC,sBAAsB,KAAK,IAAI,CAAC;gBAAC,KAAK,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;YAE9D,EAAE,CAAA,CAAC,IAAI,CAAC,oBAAoB,KAAK,IAAI,CAAC;gBAAC,KAAK,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;YAG5D,IAAI,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE;iBACxB,CAAC,CAAC,UAAU,CAAM;gBACf,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;YAC1B,CAAC,CAAC;iBACD,CAAC,CAAC,UAAU,CAAM;gBACf,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC;YAC1B,CAAC,CAAC,CAAC;YAIP,IAAI,GAAG,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,GAAG,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;iBAC/C,MAAM,CAAC,KAAK,CAAC;iBACb,KAAK,CAAC,cAAc,EAAE,CAAE,CAAC,IAAI,CAAC,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC,cAAc,EAAE,GAAG,cAAc,CAAC,0BAA0B,CAAC,CAAC;iBAClH,KAAK,CAAC,MAAM,EAAE,MAAM,CAAC;iBACrB,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC,QAAQ,EAAE,CAAC;iBAC3B,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,SAAS,EAAE,CAAC;iBAC7B,MAAM,CAAC,GAAG,CAAC;iBACX,IAAI,CAAC,WAAW,EAAE,YAAY,GAAG,MAAM,CAAC,IAAI,GAAG,GAAG,GAAG,MAAM,CAAC,GAAG,GAAG,GAAG,CAAC,CAAC;YAG5E,IAAI,IAAY,CAAC;YACjB,IAAI,IAAY,CAAC;YACjB,IAAI,IAAY,CAAC;YACjB,IAAI,IAAY,CAAC;YACjB,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,IAAI,IAAI,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YAC7C,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;YACvD,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,IAAI,IAAI,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YAC7C,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;YACvD,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,IAAI,IAAI,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YAC7C,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;YACvD,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,IAAI,IAAI,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YAC7C,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;YAEvD,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC;YAC5B,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC;YAG5B,IAAI,YAAY,GAA2B,EAAE,CAAC,KAAK,CAAC,UAAU,EAAE,CAAC;YACjE,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC/B,IAAI,KAAK,GAAa,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;gBACpC,IAAI,KAAK,GAAa,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;gBAEpC,IAAI,IAAI,GAAU,KAAK,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE,CAAC;oBACtC,MAAM,CAAC,EAAC,MAAM,EAAE,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,CAAC,CAAC,CAAC,EAAC,CAAC;gBAChD,CAAC,CAAC,CAAC;gBAEH,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC;qBACb,IAAI,CAAC,OAAO,EAAE,MAAM,CAAC;qBACrB,KAAK,CAAC,QAAQ,EAAE,CAAC,CAAC,IAAI,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,GAAG,YAAY,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;qBAC3F,IAAI,CAAC,GAAG,EAAE,SAAS,CAAC,IAAI,CAAC,CAAC,CAAC;YACpC,CAAC;YAGD,IAAI,SAAS,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;iBACvB,IAAI,CAAC,WAAW,EAAE,cAAc,GAAG,MAAM,CAAC,eAAe,GAAG,GAAG,CAAC;iBAChE,KAAK,CAAC,QAAQ,EAAC,MAAM,CAAC;iBACtB,KAAK,CAAC,cAAc,EAAE,CAAC,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,kBAAkB,EAAE,IAAI,IAAI,GAAG,CAAC,CAAC,kBAAkB,EAAE,GAAG,cAAc,CAAC,yBAAyB,CAAC,CAAC;iBACxI,KAAK,CAAC,MAAM,EAAC,MAAM,CAAC;iBACpB,IAAI,CAAC,KAAK,CAAC,CAAC;YACjB,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAC,SAAS,CAAC,CAAC;YAE5E,EAAE,CAAA,CAAC,IAAI,CAAC,uBAAuB,IAAI,IAAI,CAAC;gBAAC,SAAS,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,KAAK,CAAC,EAAC,cAAc,EAAE,IAAI,CAAC,uBAAuB,EAAC,CAAC,CAAC;YAGjI,IAAI,SAAS,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;iBACvB,KAAK,CAAC,QAAQ,EAAC,MAAM,CAAC;iBACtB,KAAK,CAAC,cAAc,EAAE,CAAC,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,kBAAkB,EAAE,IAAI,IAAI,GAAG,CAAC,CAAC,kBAAkB,EAAE,GAAG,cAAc,CAAC,yBAAyB,CAAC,CAAC;iBACxI,KAAK,CAAC,MAAM,EAAC,MAAM,CAAC;iBACpB,IAAI,CAAC,KAAK,CAAC,CAAC;YACjB,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAC,SAAS,CAAC,CAAC;YAE5E,EAAE,CAAA,CAAC,IAAI,CAAC,yBAAyB,IAAI,IAAI,CAAC;gBAAC,SAAS,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,KAAK,CAAC,EAAC,cAAc,EAAE,IAAI,CAAC,yBAAyB,EAAC,CAAC,CAAC;YAGrI,EAAE,CAAC,CAAC,IAAI,CAAC,WAAW,IAAI,IAAI,CAAC,UAAU,KAAK,IAAI,CAAC,CAAC,CAAC;gBAC/C,IAAI,WAAW,GAAG,MAAM,CAAC,cAAc,GAAG,CAAC,CAAC;gBAC5C,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,EAAE,CAAC,EAAE,EAAE,CAAC;oBAC/B,IAAI,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;oBAC3B,IAAI,OAAO,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;oBAC5B,IAAI,KAAK,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;oBACtC,IAAI,KAAK,GAAG,OAAO,CAAC,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;oBACxC,IAAI,SAAS,GAAG,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,CAAC;oBACpC,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC;yBACb,IAAI,CAAC,GAAG,EAAE,CAAC,WAAW,GAAG,CAAC,CAAC,GAAG,CAAC,GAAG,WAAW,CAAC;yBAC9C,IAAI,CAAC,GAAG,EAAE,MAAM,CAAC,eAAe,GAAG,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;yBAC3D,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;yBACvB,KAAK,CAAC,MAAM,EAAE,CAAC,CAAC,IAAI,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,GAAG,YAAY,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;yBACzF,IAAI,CAAC,SAAS,CAAC,CAAC;gBACzB,CAAC;YACL,CAAC;YAGD,EAAE,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;gBACb,IAAI,UAAqB,CAAC;gBAC1B,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC;oBAAC,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC;gBACvD,KAAK,CAAC,WAAW,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,EAAE,MAAM,EAAE,UAAU,CAAC,CAAC;YAC3D,CAAC;QACL,CAAC,CAAA;QAxIG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,SAAS,CAAC,CAAC,CAAC;QAE/E,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;QACvB,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;QACvB,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,aAAa,CAAC,CAAC;IAC3C,CAAC;IAmIL,gBAAC;AAAD,CAAC,AAlJD,CAAwB,KAAK,GAkJ5B;AClJD;IAA2B,gCAAK;IAM5B,sBAAY,OAAc;QACtB,kBAAM,aAAa,CAAC,YAAY,EAAE,OAAO,CAAC,CAAC;QAW/C,WAAM,GAAG,UAAC,cAAqB;YAE3B,IAAI,OAAO,GAAU,CAAC,CAAC,IAAI,CAAC,KAAK,GAAG,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;YAC3D,IAAI,CAAC,GAAc,IAAI,CAAC,QAAQ,EAAE,CAAC;YACnC,IAAI,MAAM,GAAU,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YAGxC,IAAI,MAAM,GAAkC,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,cAAc,CAAC,CAAC,CAAC;YAChG,IAAI,MAAM,GAAkC,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,eAAe,EAAE,CAAC,CAAC,CAAC,CAAC;YAGjG,IAAI,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC;iBAClC,aAAa,CAAC,CAAC,MAAM,CAAC,eAAe,CAAC;iBACtC,MAAM,CAAC,QAAQ,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC/B,IAAI,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC;iBAClC,aAAa,CAAC,CAAC,MAAM,CAAC,cAAc,CAAC;iBACrC,MAAM,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAE7B,EAAE,CAAC,CAAC,IAAI,CAAC,sBAAsB,KAAK,IAAI,CAAC;gBAAC,KAAK,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;YAE/D,EAAE,CAAC,CAAC,IAAI,CAAC,oBAAoB,KAAK,IAAI,CAAC;gBAAC,KAAK,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;YAI7D,IAAI,GAAG,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,GAAG,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;iBAC/C,MAAM,CAAC,KAAK,CAAC;iBACb,KAAK,CAAC,cAAc,EAAE,CAAE,CAAC,IAAI,CAAC,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC,CAAC;iBAC1E,KAAK,CAAC,MAAM,EAAE,MAAM,CAAC;iBACrB,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC,QAAQ,EAAE,CAAC;iBAC3B,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,SAAS,EAAE,CAAC;iBAC7B,IAAI,CAAC,SAAS,EAAE,MAAM,CAAC;iBACvB,MAAM,CAAC,GAAG,CAAC;iBACX,IAAI,CAAC,WAAW,EACb,YAAY,GAAG,MAAM,CAAC,IAAI,GAAG,GAAG,GAAG,MAAM,CAAC,GAAG,GAAG,GAAG,CAAC,CAAC;YAG7D,IAAI,IAAW,CAAC;YAChB,IAAI,IAAW,CAAC;YAChB,IAAI,IAAW,CAAC;YAChB,IAAI,IAAW,CAAC;YAChB,EAAE,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YACtC,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;YACvD,EAAE,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YACtC,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;YACvD,EAAE,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YACtC,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;YACvD,EAAE,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC;gBAAC,IAAI,GAAG,IAAI,CAAC,OAAO,CAAC;YACtC,IAAI;gBAAC,IAAI,GAAG,CAAC,IAAI,CAAC,KAAK,GAAG,OAAO,CAAC,GAAG,CAAC,IAAI,CAAC,KAAK,CAAC,GAAG,CAAC,CAAC,CAAC;YAEvD,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC;YAC5B,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,EAAE,IAAI,CAAC,CAAC,CAAC;YAG5B,IAAI,YAAY,GAA0B,EAAE,CAAC,KAAK,CAAC,UAAU,EAAE,CAAC;YAChE,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC/B,IAAI,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;gBAC1B,IAAI,KAAK,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;gBAE1B,IAAI,IAAI,GAAG,KAAK,CAAC,GAAG,CAAC,UAAU,CAAC,EAAE,CAAC;oBAC/B,MAAM,CAAC,EAAC,MAAM,EAAE,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,EAAE,KAAK,CAAC,CAAC,CAAC,EAAC,CAAC;gBAChD,CAAC,CAAC,CAAC;gBAEH,GAAG,CAAC,SAAS,CAAC,QAAQ,CAAC;qBAClB,IAAI,CAAC,IAAI,CAAC;qBACV,KAAK,EAAE;qBACP,MAAM,CAAC,QAAQ,CAAC;qBAChB,KAAK,CAAC,MAAM,EAAE,CAAC,CAAC,IAAI,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,GAAG,YAAY,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;qBACzF,IAAI,CAAC,GAAG,EAAE,CAAC,CAAC,IAAI,CAAC,CAAC,YAAY,EAAE,GAAG,CAAC,CAAC,YAAY,EAAE,GAAG,cAAc,CAAC,wBAAwB,CAAC,CAAC;qBAC/F,IAAI,CAAC,IAAI,EAAE,UAAU,CAAC;oBACnB,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC;gBAC7B,CAAC,CAAC;qBACD,IAAI,CAAC,IAAI,EAAE,UAAU,CAAC;oBACnB,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC;gBAC7B,CAAC,CAAC,CAAC;YACX,CAAC;YAGD,IAAI,SAAS,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;iBACvB,IAAI,CAAC,WAAW,EAAE,cAAc,GAAG,MAAM,CAAC,eAAe,GAAG,GAAG,CAAC;iBAChE,KAAK,CAAC,QAAQ,EAAE,MAAM,CAAC;iBACvB,KAAK,CAAC,cAAc,EAAE,CAAC,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,kBAAkB,EAAE,IAAI,IAAI,GAAG,CAAC,CAAC,kBAAkB,EAAE,GAAG,cAAc,CAAC,yBAAyB,CAAC,CAAC;iBACxI,KAAK,CAAC,MAAM,EAAE,MAAM,CAAC;iBACrB,IAAI,CAAC,KAAK,CAAC,CAAC;YACjB,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAC,SAAS,CAAC,CAAC;YAE5E,EAAE,CAAC,CAAC,IAAI,CAAC,uBAAuB,IAAI,IAAI,CAAC;gBAAC,SAAS,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,KAAK,CAAC,EAAC,cAAc,EAAE,IAAI,CAAC,uBAAuB,EAAC,CAAC,CAAC;YAGlI,IAAI,SAAS,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;iBACvB,KAAK,CAAC,QAAQ,EAAE,MAAM,CAAC;iBACvB,KAAK,CAAC,cAAc,EAAE,CAAC,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,kBAAkB,EAAE,IAAI,IAAI,GAAG,CAAC,CAAC,kBAAkB,EAAE,GAAG,cAAc,CAAC,yBAAyB,CAAC,CAAC;iBACxI,KAAK,CAAC,MAAM,EAAE,MAAM,CAAC;iBACrB,IAAI,CAAC,KAAK,CAAC,CAAC;YACjB,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAC,SAAS,CAAC,CAAC;YAE5E,EAAE,CAAC,CAAC,IAAI,CAAC,yBAAyB,IAAI,IAAI,CAAC;gBAAC,SAAS,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,KAAK,CAAC,EAAC,cAAc,EAAE,IAAI,CAAC,yBAAyB,EAAC,CAAC,CAAC;YAGtI,EAAE,CAAC,CAAC,IAAI,CAAC,WAAW,IAAI,IAAI,CAAC,UAAU,KAAK,IAAI,CAAC,CAAC,CAAC;gBAC/C,IAAI,WAAW,GAAG,MAAM,CAAC,cAAc,GAAG,CAAC,CAAC;gBAC5C,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,OAAO,EAAE,CAAC,EAAE,EAAE,CAAC;oBAC/B,IAAI,MAAM,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;oBAC3B,IAAI,OAAO,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;oBAC5B,IAAI,KAAK,GAAG,MAAM,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;oBACtC,IAAI,KAAK,GAAG,OAAO,CAAC,OAAO,CAAC,MAAM,GAAG,CAAC,CAAC,CAAC;oBACxC,IAAI,SAAS,CAAC;oBACd,EAAE,CAAC,CAAC,CAAC,KAAK,IAAI,CAAC,KAAK,CAAC;wBAAC,SAAS,GAAG,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,GAAG,YAAY,CAAC;oBACrE,IAAI;wBAAC,SAAS,GAAG,IAAI,CAAC,WAAW,CAAC,CAAC,CAAC,GAAG,IAAI,GAAG,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,GAAG,GAAG,GAAG,KAAK,CAAC,WAAW,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC;oBACtG,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC;yBACb,IAAI,CAAC,GAAG,EAAE,CAAC,WAAW,GAAG,CAAC,CAAC,GAAG,CAAC,GAAG,WAAW,CAAC;yBAC9C,IAAI,CAAC,GAAG,EAAE,MAAM,CAAC,eAAe,GAAG,CAAC,MAAM,CAAC,MAAM,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC;yBAC3D,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;yBACvB,KAAK,CAAC,MAAM,EAAE,CAAC,CAAC,IAAI,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,cAAc,CAAC,CAAC,CAAC,GAAG,YAAY,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;yBACzF,IAAI,CAAC,SAAS,CAAC,CAAC;gBACzB,CAAC;YACL,CAAC;YAGD,EAAE,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;gBACb,IAAI,UAAqB,CAAC;gBAC1B,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC;oBAAC,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC;gBACvD,KAAK,CAAC,WAAW,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,EAAE,MAAM,EAAE,UAAU,CAAC,CAAC;YAC3D,CAAC;QACL,CAAC,CAAA;QAtIG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,YAAY,CAAC,CAAC,CAAC;QAElF,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;QACvB,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;QACvB,IAAI,CAAC,WAAW,GAAG,IAAI,CAAC,aAAa,CAAC,CAAC;IAC3C,CAAC;IAiIL,mBAAC;AAAD,CAAC,AAhJD,CAA2B,KAAK,GAgJ/B;ACpJD;IAAA;IAiEA,CAAC;IA9DkB,cAAO,GAAW,EAAE,CAAC;IACrB,cAAO,GAAW,EAAE,CAAC;IACrB,cAAO,GAAW,CAAC,CAAC;IACpB,iBAAU,GAAW,EAAE,CAAC;IACxB,cAAO,GAAW,EAAE,CAAC;IACrB,gBAAS,GAAW,SAAS,CAAC;IAC9B,oBAAa,GAAW,IAAI,CAAC;IAC7B,wBAAiB,GAAW,SAAS,CAAC;IAG9C,eAAQ,GAAG,CAAC,UAAS,CAAM;QAE9B,IAAI,GAAG,GAAG,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,QAAQ,CAAC,wBAAwB,CAAC,CAAC,CAAC;QAC1D,IAAI,SAAS,GAAG,CAAC,CAAC,SAAS,CAAC,YAAY,CAAC,CAAC,IAAI,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC;QACvD,IAAI,WAAW,GAAG,CAAC,CAAC,SAAS,CAAC,gBAAgB,CAAC,CAAC,IAAI,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC;QAE7D,SAAS,CAAC,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,OAAO,EAAC,WAAW,CAAC,CAAC;QAC3D,WAAW,CAAC,KAAK,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,IAAI,CAAC,OAAO,EAAC,eAAe,CAAC,CAAC;QAE9D,IAAI,cAAc,GAAU,EAAE,CAAC;QAC/B,GAAG,CAAC,SAAS,CAAC,eAAe,CAAC,CAAC,IAAI,CAAC;YAChC,IAAI,OAAO,GAAG,EAAE,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC;YAC9B,cAAc,CAAC,IAAI,CAAC;gBAChB,KAAK,EAAE,OAAO,CAAC,IAAI,CAAC,aAAa,CAAC;gBAClC,KAAK,EAAE,OAAO,CAAC,KAAK,CAAC,MAAM,CAAC;aAC/B,CAAC,CAAC;QACP,CAAC,CAAC,CAAC;QAIH,WAAW,CAAC,SAAS,CAAC,MAAM,CAAC;aACxB,IAAI,CAAC,cAAc,EAAC,UAAS,CAAC,IAAI,MAAM,CAAC,CAAC,CAAC,KAAK,CAAA,CAAA,CAAC,CAAC;aAClD,IAAI,CAAC,UAAS,CAAC,IAAI,CAAC,CAAC,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC,CAAA,CAAA,CAAC,CAAC;aAC7C,IAAI,CAAC,UAAS,CAAC,IAAI,CAAC,CAAC,IAAI,EAAE,CAAC,MAAM,EAAE,CAAA,CAAA,CAAC,CAAC;aACtC,IAAI,CAAC,GAAG,EAAC,CAAC,CAAC;aACX,IAAI,CAAC,GAAG,EAAC,UAAS,CAAC,EAAC,CAAC,IAAI,MAAM,CAAC,CAAC,GAAC,MAAM,CAAC,UAAU,GAAC,MAAM,CAAC,OAAO,GAAC,IAAI,CAAA,CAAA,CAAC,CAAC;aACzE,IAAI,CAAC,OAAO,EAAC,MAAM,CAAC,OAAO,CAAC;aAC5B,IAAI,CAAC,QAAQ,EAAC,MAAM,CAAC,OAAO,CAAC;aAE7B,KAAK,CAAC,MAAM,EAAC,UAAS,CAAC,IAAI,MAAM,CAAC,CAAC,CAAC,KAAK,CAAA,CAAA,CAAC,CAAC,CAAC;QAGjD,WAAW,CAAC,SAAS,CAAC,MAAM,CAAC;aACxB,IAAI,CAAC,cAAc,EAAC,UAAS,CAAC,IAAI,MAAM,CAAC,CAAC,CAAC,KAAK,CAAA,CAAA,CAAC,CAAC;aAClD,IAAI,CAAC,UAAS,CAAC,IAAI,CAAC,CAAC,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC,CAAA,CAAA,CAAC,CAAC;aAC7C,IAAI,CAAC,UAAS,CAAC,IAAI,CAAC,CAAC,IAAI,EAAE,CAAC,MAAM,EAAE,CAAA,CAAA,CAAC,CAAC;aACtC,IAAI,CAAC,GAAG,EAAC,UAAS,CAAC,EAAC,CAAC,IAAI,MAAM,CAAC,CAAC,GAAC,MAAM,CAAC,UAAU,GAAG,IAAI,CAAA,CAAA,CAAC,CAAC;aAC5D,IAAI,CAAC,GAAG,EAAC,CAAC,MAAM,CAAC,OAAO,GAAG,MAAM,CAAC,OAAO,CAAC,GAAG,IAAI,CAAC;aAClD,IAAI,CAAC,UAAS,CAAC,IAAI,MAAM,CAAC,CAAC,CAAC,KAAK,CAAA,CAAA,CAAC,CAAC,CAAC;QAGzC,IAAI,iBAAiB,GAAQ,WAAW,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;QACzD,SAAS,CAAC,IAAI,CAAC,GAAG,EAAC,CAAC,iBAAiB,CAAC,CAAC,GAAC,MAAM,CAAC,OAAO,CAAC,CAAC;aACnD,IAAI,CAAC,GAAG,EAAC,CAAC,iBAAiB,CAAC,CAAC,GAAC,MAAM,CAAC,OAAO,CAAC,CAAC;aAC9C,IAAI,CAAC,QAAQ,EAAC,CAAC,iBAAiB,CAAC,MAAM,GAAC,CAAC,GAAC,MAAM,CAAC,OAAO,CAAC,CAAC;aAC1D,IAAI,CAAC,OAAO,EAAC,CAAC,iBAAiB,CAAC,KAAK,GAAC,CAAC,GAAC,MAAM,CAAC,OAAO,CAAC,CAAC;aACxD,KAAK,CAAC,MAAM,EAAC,MAAM,CAAC,SAAS,CAAC;aAC9B,KAAK,CAAC,QAAQ,EAAC,MAAM,CAAC,iBAAiB,CAAC;aACxC,KAAK,CAAC,SAAS,EAAC,MAAM,CAAC,aAAa,CAAC,CAAC;QAE3C,GAAG,CAAC,SAAS,CAAC,SAAS,CAAC,CAAC,IAAI,CAAC,WAAW,EAAC,YAAY,GAAG,MAAM,CAAC,OAAO,GAAG,GAAG,GAAG,MAAM,CAAC,OAAO,GAAG,GAAG,CAAC,CAAC;IAC1G,CAAC,CAAC,CAAC;IACP,aAAC;AAAD,CAAC,AAjED,IAiEC;AC1DD;IAA+B,oCAAK;IAKhC,0BAAY,OAAe;QACvB,kBAAM,aAAa,CAAC,gBAAgB,EAAE,OAAO,CAAC,CAAC;QAYnD,WAAM,GAAG,UAAC,cAAsB;YAE5B,IAAI,OAAO,GAAW,CAAC,CAAC,IAAI,CAAC,KAAK,GAAG,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC;YAC5D,IAAI,CAAC,GAAe,IAAI,CAAC,QAAQ,EAAE,CAAC;YACpC,IAAI,MAAM,GAAW,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YAGzC,IAAI,MAAM,GAAmC,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,MAAM,CAAC,cAAc,CAAC,CAAC,CAAC;YACjG,IAAI,MAAM,GAAmC,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,KAAK,CAAC,CAAC,MAAM,CAAC,eAAe,EAAE,CAAC,CAAC,CAAC,CAAC;YAGlG,IAAI,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC;iBAClC,MAAM,CAAC,QAAQ,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC/B,EAAE,CAAA,CAAC,IAAI,CAAC,uBAAuB,IAAI,IAAI,IAAI,IAAI,CAAC,uBAAuB,GAAG,CAAC,CAAC,CAAA,CAAC;gBACzE,KAAK,CAAC,aAAa,CAAC,CAAC,MAAM,CAAC,eAAe,CAAC,CAAC;YACjD,CAAC;YAGD,IAAI,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,CAAC,KAAK,CAAC,MAAM,CAAC;iBAClC,MAAM,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAC7B,EAAE,CAAA,CAAC,IAAI,CAAC,yBAAyB,IAAI,IAAI,IAAI,IAAI,CAAC,yBAAyB,GAAG,CAAC,CAAC,CAAA,CAAC;gBAC7E,KAAK,CAAC,aAAa,CAAC,CAAC,MAAM,CAAC,cAAc,CAAC,CAAC;YAChD,CAAC;YAED,EAAE,CAAA,CAAC,IAAI,CAAC,sBAAsB,KAAK,IAAI,CAAC;gBAAC,KAAK,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;YAE9D,EAAE,CAAA,CAAC,IAAI,CAAC,oBAAoB,KAAK,IAAI,CAAC;gBAAC,KAAK,CAAC,UAAU,CAAC,EAAE,CAAC,CAAC;YAE5D,IAAI,IAAI,GAAU,EAAE,CAAC;YACrB,GAAG,CAAA,CAAC,IAAI,CAAC,GAAC,CAAC,EAAE,CAAC,GAAC,IAAI,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBACpC,IAAI,GAAG,GAAG,EAAE,CAAC;gBACb,GAAG,CAAA,CAAE,IAAI,CAAC,GAAC,CAAC,EAAE,CAAC,GAAC,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;oBACtC,GAAG,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;oBACvC,GAAG,CAAC,QAAQ,CAAC,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;gBAClC,CAAC;gBACD,IAAI,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;YACnB,CAAC;YAED,IAAI,IAAI,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE;iBACnB,CAAC,CAAC,UAAS,CAAM,IAAI,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;iBAChD,EAAE,CAAC,UAAS,CAAM,IAAI,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC;iBAC7C,EAAE,CAAC,UAAS,CAAM,IAAI,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAEzD,IAAI,KAAK,GAAG,EAAE,CAAC,MAAM,CAAC,KAAK,EAAE;iBACxB,MAAM,CAAC,UAAS,CAAM,IAAI,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC;YAEnD,IAAI,GAAG,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,GAAG,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,MAAM,CAAC,KAAK,CAAC;iBAC7D,IAAI,CAAC,OAAO,EAAE,MAAM,CAAC,cAAc,GAAG,MAAM,CAAC,IAAI,GAAG,MAAM,CAAC,KAAK,CAAC;iBACjE,IAAI,CAAC,QAAQ,EAAE,MAAM,CAAC,eAAe,GAAG,MAAM,CAAC,GAAG,GAAG,MAAM,CAAC,MAAM,CAAC;iBACnE,MAAM,CAAC,GAAG,CAAC;iBACX,IAAI,CAAC,WAAW,EAAE,YAAY,GAAG,MAAM,CAAC,IAAI,GAAG,GAAG,GAAG,MAAM,CAAC,GAAG,GAAG,GAAG,CAAC,CAAC;YAE5E,IAAI,KAAK,GAAQ,EAAE,CAAC,KAAK,CAAC,UAAU,EAAE,CAAC;YACvC,KAAK,CAAC,MAAM,CAAC,EAAE,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,UAAU,GAAG;gBAC9C,MAAM,CAAC,GAAG,KAAK,QAAQ,CAAC;YAC5B,CAAC,CAAC,CAAC,CAAC;YAEJ,IAAI,QAAQ,GAAG,KAAK,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,GAAG,CAAC,UAAU,IAAI;gBAClD,MAAM,CAAC;oBACH,IAAI,EAAE,IAAI;oBACV,MAAM,EAAE,IAAI,CAAC,GAAG,CAAC,UAAU,CAAC;wBACxB,MAAM,CAAC,EAAC,MAAM,EAAE,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAC,CAAC;oBAC9C,CAAC,CAAC;iBACL,CAAC;YACN,CAAC,CAAC,CAAC,CAAC;YAGJ,IAAI,IAAI,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE,UAAU,CAAC;gBAC/B,IAAI,IAAI,GAAG,EAAE,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,UAAU,GAAG;oBACnC,MAAM,CAAC,GAAG,KAAK,QAAQ,GAAG,CAAC,CAAC,GAAG,CAAC,GAAG,CAAC,CAAA;gBACxC,CAAC,CAAC,CAAC;gBACH,MAAM,CAAC,EAAE,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;YACxB,CAAC,CAAC,CAAC;YAGH,MAAM,CAAC,MAAM,CAAC,EAAE,CAAC,MAAM,CAAC,IAAI,EAAE,UAAU,CAAC;gBACrC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC;YACpB,CAAC,CAAC,CAAC,CAAC;YAEJ,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,CAAC,CAAC;YAEzB,IAAI,OAAO,GAAG,GAAG,CAAC,SAAS,CAAC,UAAU,CAAC;iBAClC,IAAI,CAAC,QAAQ,CAAC;iBACd,KAAK,EAAE,CAAC,MAAM,CAAC,GAAG,CAAC;iBACnB,IAAI,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC;YAE9B,IAAI,UAAU,GAAG,IAAI,CAAC,MAAM,CAAC;YAE7B,IAAI,YAAY,GAA2B,EAAE,CAAC,KAAK,CAAC,UAAU,EAAE,CAAC;YACjE,OAAO,CAAC,MAAM,CAAC,MAAM,CAAC;iBACjB,IAAI,CAAC,OAAO,EAAE,MAAM,CAAC;iBACrB,IAAI,CAAC,aAAa,EAAC,UAAS,CAAM,IAAI,MAAM,CAAC,CAAC,CAAC,IAAI,CAAA,CAAA,CAAC,CAAC;iBACrD,IAAI,CAAC,GAAG,EAAE,UAAU,CAAM;gBACvB,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC;YAC1B,CAAC,CAAC;iBACD,KAAK,CAAC,MAAM,EAAE,UAAS,CAAM;gBAC1B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,CAAC,cAAc,CAAC,UAAU,CAAC,OAAO,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAA,CAAC;oBAClD,MAAM,CAAC,CAAC,CAAC,cAAc,CAAC,UAAU,CAAC,OAAO,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC;gBACxD,CAAC;gBAAC,IAAI,CAAA,CAAC;oBACH,MAAM,CAAC,YAAY,CAAC,MAAM,CAAC,UAAU,CAAC,OAAO,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAA;gBAC3D,CAAC;YACL,CAAC,CAAC;iBACD,KAAK,CAAC,EAAC,cAAc,EAAE,KAAK,EAAC,CAAC,CAAC;YAGpC,IAAI,SAAS,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;iBACvB,KAAK,CAAC,QAAQ,EAAC,MAAM,CAAC;iBACtB,KAAK,CAAC,cAAc,EAAE,CAAC,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,kBAAkB,EAAE,IAAI,IAAI,GAAG,CAAC,CAAC,kBAAkB,EAAE,GAAG,cAAc,CAAC,yBAAyB,CAAC,CAAC;iBACxI,KAAK,CAAC,MAAM,EAAC,MAAM,CAAC;iBACpB,IAAI,CAAC,WAAW,EAAE,cAAc,GAAG,MAAM,CAAC,eAAe,GAAG,GAAG,CAAC;iBAChE,IAAI,CAAC,KAAK,CAAC,CAAC;YACjB,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAC,SAAS,CAAC,CAAC;YAG5E,IAAI,SAAS,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,OAAO,EAAE,QAAQ,CAAC;iBACvB,KAAK,CAAC,QAAQ,EAAC,MAAM,CAAC;iBACtB,KAAK,CAAC,cAAc,EAAE,CAAC,CAAC,IAAI,IAAI,IAAI,CAAC,CAAC,kBAAkB,EAAE,IAAI,IAAI,GAAG,CAAC,CAAC,kBAAkB,EAAE,GAAG,cAAc,CAAC,yBAAyB,CAAC,CAAC;iBACxI,KAAK,CAAC,MAAM,EAAC,MAAM,CAAC;iBACpB,IAAI,CAAC,KAAK,CAAC,CAAC;YACjB,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC,KAAK,CAAC,MAAM,EAAC,SAAS,CAAC,CAAC;YAG5E,EAAE,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;gBACb,IAAI,UAAqB,CAAC;gBAC1B,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC;oBAAC,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC;gBACvD,KAAK,CAAC,WAAW,CAAC,GAAG,EAAE,IAAI,CAAC,KAAK,EAAE,MAAM,EAAE,UAAU,CAAC,CAAC;YAC3D,CAAC;YAGD,IAAI,MAAM,GAAQ,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC5B,IAAI,CAAC,OAAO,EAAC,QAAQ,CAAC;iBACtB,IAAI,CAAC,WAAW,EAAC,kBAAkB,CAAC;iBACpC,KAAK,CAAC,WAAW,EAAC,MAAM,CAAC;iBACzB,IAAI,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC;QAC/B,CAAC,CAAA;QAlJG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,gBAAgB,CAAC,CAAC,CAAC;QAGtF,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;QACvB,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC;QACvB,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;IACjC,CAAC;IA4IL,uBAAC;AAAD,CAAC,AA3JD,CAA+B,KAAK,GA2JnC;AC9JD;IAA4B,iCAAK;IAgC7B,uBAAY,OAAc;QACtB,kBAAM,aAAa,CAAC,aAAa,EAAE,OAAO,CAAC,CAAC;QAUhD,WAAM,GAAG,UAAC,cAAqB;YAC3B,IAAI,QAAQ,GAAG,IAAI,CAAC;YACpB,IAAI,CAAC,GAAc,IAAI,CAAC,QAAQ,EAAE,CAAC;YACnC,IAAI,MAAM,GAAU,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YAGxC,IAAI,CAAC,QAAQ,GAAG,EAAE,CAAC;YACnB,IAAI,KAAK,GAAG,CAAC,CAAC;YACd,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC5C,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;oBAC/C,IAAI,GAAG,GAAG,EAAE,CAAC;oBACb,GAAG,CAAC,OAAO,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,aAAa,CAAC,CAAC;oBAClD,GAAG,CAAC,KAAK,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,WAAW,CAAC,CAAC;oBAC9C,GAAG,CAAC,IAAI,CAAC,GAAG,KAAK,EAAE,CAAC;oBACpB,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;oBAChB,GAAG,CAAC,OAAO,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,CAAC,CAAC;oBAC5C,GAAG,CAAC,OAAO,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,YAAY,CAAC,CAAC;oBACjD,IAAI,CAAC,QAAQ,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;gBAC5B,CAAC;YACL,CAAC;YAED,IAAI,CAAC,KAAK,GAAG,EAAE,CAAC;YAChB,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,SAAS,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC7C,IAAI,GAAG,GAAG,EAAE,CAAC;gBACb,GAAG,CAAC,OAAO,CAAC,GAAG,IAAI,CAAC,SAAS,CAAC,CAAC,CAAC,CAAC;gBACjC,GAAG,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;gBACd,IAAI,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;YACzB,CAAC;YAID,IAAI,GAAG,GAAG,EAAE,CAAC,MAAM,CAAC,GAAG,GAAG,cAAc,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC;iBAC/C,MAAM,CAAC,KAAK,CAAC;iBACb,KAAK,CAAC,cAAc,EAAE,CAAE,CAAC,IAAI,CAAC,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC,cAAc,EAAE,GAAG,cAAc,CAAC,0BAA0B,CAAC,CAAC;iBAClH,KAAK,CAAC,MAAM,EAAE,MAAM,CAAC;iBACrB,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC,QAAQ,EAAE,CAAC;iBAC3B,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,SAAS,EAAE,CAAC;iBAC7B,MAAM,CAAC,GAAG,CAAC,CAAC;YAEjB,IAAI,eAAe,GAAG,CAAC,CAAC,SAAS,EAAE,GAAG,MAAM,CAAC,GAAG,GAAG,MAAM,CAAC,MAAM,CAAC;YACjE,IAAI,cAAc,GAAG,CAAC,CAAC,QAAQ,EAAE,GAAG,MAAM,CAAC,IAAI,GAAG,MAAM,CAAC,KAAK,CAAC;YAC/D,IAAI,UAAU,GAAG,IAAI,CAAC,SAAS,CAAC,MAAM,GAAG,aAAa,CAAC,mBAAmB,CAAC;YAC3E,IAAI,UAAU,GAAG,CAAC,CAAC,SAAS,EAAE,GAAG,UAAU,GAAG,MAAM,CAAC,GAAG,GAAG,MAAM,CAAC,MAAM,GAAG,EAAE,CAAC;YAE9E,IAAI,OAAO,GAAU,EAAE,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAK,IAAI,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;YACjF,IAAI,OAAO,GAAU,EAAE,CAAC,GAAG,CAAC,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAK,IAAI,MAAM,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC;YAC/E,IAAI,CAAC,CAAC,GAAG,EAAE,CAAC,IAAI,CAAC,KAAK,EAAE;iBACnB,MAAM,CAAC,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;iBAC1B,KAAK,CAAC,CAAC,CAAC,EAAE,cAAc,CAAC,CAAC,CAAC;YAChC,IAAI,CAAC,EAAE,GAAG,EAAE,CAAC,IAAI,CAAC,KAAK,EAAE,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,cAAc,CAAC,CAAC,CAAC;YAErD,IAAI,CAAC,EAAE,GAAG,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;YACtF,IAAI,CAAC,EAAE,GAAG,EAAE,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,MAAM,CAAC,CAAC,CAAC,EAAE,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,EAAE,UAAU,CAAC,CAAC,CAAC;YAGtF,IAAI,CAAC,IAAI,GAAG,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,MAAM,CAAC,UAAU,CAAC;iBAC5C,IAAI,CAAC,IAAI,EAAE,MAAM,CAAC;iBAClB,MAAM,CAAC,MAAM,CAAC;iBACd,IAAI,CAAC,OAAO,EAAE,cAAc,CAAC;iBAC7B,IAAI,CAAC,QAAQ,EAAE,CAAC,CAAC,SAAS,EAAE,GAAG,GAAG,CAAC,CAAC;YAEzC,IAAI,CAAC,QAAQ,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,WAAW,EAAE,YAAY,GAAG,MAAM,CAAC,IAAI,GAAG,GAAG,GAAG,MAAM,CAAC,GAAG,GAAG,GAAG,CAAC;iBACtE,IAAI,CAAC,OAAO,EAAE,cAAc,CAAC;iBAC7B,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC;iBAC1B,IAAI,CAAC,WAAW,EAAE,MAAM,CAAC;iBACzB,IAAI,CAAC,MAAM,EAAE,YAAY,CAAC,CAAC;YAEhC,IAAI,CAAC,QAAQ,GAAG,GAAG,CAAC,MAAM,CAAC,GAAG,CAAC;iBAC1B,IAAI,CAAC,WAAW,EAAE,YAAY,GAAG,MAAM,CAAC,IAAI,GAAG,GAAG,GAAG,CAAC,UAAU,GAAG,MAAM,CAAC,GAAG,GAAG,EAAE,CAAC,GAAG,GAAG,CAAC;iBAC1F,IAAI,CAAC,OAAO,EAAE,cAAc,CAAC;iBAC7B,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC;iBAC1B,IAAI,CAAC,WAAW,EAAE,MAAM,CAAC;iBACzB,IAAI,CAAC,MAAM,EAAE,YAAY,CAAC,CAAC;YAGhC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,SAAS,CAAC,YAAY,CAAC;iBAC5C,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC;iBAChB,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;iBACtB,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;iBACb,IAAI,CAAC,IAAI,EAAE,UAAU,CAAK;gBACvB,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,GAAG,CAAC;YAC7C,CAAC,CAAC;iBACD,IAAI,CAAC,IAAI,EAAE,cAAc,CAAC;iBAC1B,IAAI,CAAC,IAAI,EAAE,UAAU,CAAK;gBACvB,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,GAAG,CAAC;YAC7C,CAAC,CAAC;iBACD,IAAI,CAAC,QAAQ,EAAE,WAAW,CAAC;iBAC3B,IAAI,CAAC,cAAc,EAAE,CAAC,CAAC,CAAC;YAG7B,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,SAAS,CAAC,WAAW,CAAC;iBAC3C,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC;iBAChB,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;iBACtB,IAAI,CAAC,UAAU,CAAK;gBACjB,EAAE,CAAA,CAAC,CAAC,CAAC,KAAK,CAAC;oBAAC,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC;gBAC3B,MAAM,CAAC,EAAE,CAAC;YACd,CAAC,CAAC;iBACD,IAAI,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC;iBACd,IAAI,CAAC,GAAG,EAAE,UAAU,CAAK;gBACtB,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC;YAClC,CAAC,CAAC;iBACD,IAAI,CAAC,aAAa,EAAE,KAAK,CAAC;iBAC1B,IAAI,CAAC,MAAM,EAAC,gBAAgB,CAAC;iBAC7B,IAAI,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;YAG3B,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,SAAS,CAAC,YAAY,CAAC;iBAC5C,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC;iBAChB,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;iBACtB,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;iBACb,IAAI,CAAC,IAAI,EAAE,UAAU,CAAK,IAAI,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC;iBAC1E,IAAI,CAAC,IAAI,EAAE,cAAc,CAAC;iBAC1B,IAAI,CAAC,IAAI,EAAE,UAAU,CAAK,IAAI,MAAM,CAAC,EAAE,CAAC,KAAK,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC;iBAC1E,IAAI,CAAC,QAAQ,EAAE,MAAM,CAAC;iBACtB,IAAI,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC;YAG/B,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,SAAS,CAAC,WAAW,CAAC;iBAC3C,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC;iBAChB,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;iBACtB,IAAI,CAAC,UAAU,CAAK;gBACjB,EAAE,CAAA,CAAC,CAAC,CAAC,KAAK,CAAC;oBAAC,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC;gBAC3B,MAAM,CAAC,EAAE,CAAC;YACd,CAAC,CAAC;iBACD,IAAI,CAAC,GAAG,EAAE,CAAC,EAAE,CAAC;iBACd,IAAI,CAAC,GAAG,EAAE,UAAU,CAAK;gBACtB,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,EAAE,GAAG,EAAE,CAAC,CAAC;YAClC,CAAC,CAAC;iBACD,IAAI,CAAC,IAAI,EAAE,OAAO,CAAC;iBACnB,IAAI,CAAC,aAAa,EAAE,KAAK,CAAC;iBAC1B,IAAI,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;YAG3B,IAAI,CAAC,SAAS,GAAG,EAAE,CAAC,GAAG,CAAC,IAAI,EAAE;iBACzB,KAAK,CAAC,IAAI,CAAC,EAAE,CAAC;iBACd,MAAM,CAAC,QAAQ,CAAC;iBAChB,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC;iBACtB,UAAU,CAAC,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC;iBACnC,QAAQ,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC;YAGpB,IAAI,IAAI,GAAO,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,GAAG,CAAC;iBACnC,IAAI,CAAC,WAAW,EAAE,cAAc,GAAG,UAAU,GAAG,GAAG,CAAC;iBAEpD,IAAI,CAAC,OAAO,EAAE,UAAU,CAAC;iBACzB,IAAI,CAAC,MAAM,EAAE,OAAO,CAAC;iBACrB,KAAK,CAAC,QAAQ,EAAE,OAAO,CAAC,CAAC,KAAK,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC,KAAK,CAAC,MAAM,EAAE,OAAO,CAAC;iBAC1E,IAAI,CAAC,MAAM,EAAE,iBAAiB,CAAC;iBAC/B,IAAI,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;YAC1B,IAAI,CAAC,SAAS,CAAC,MAAM,CAAC,CAAC,KAAK,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC,IAAI,CAAC,cAAc,EAAE,GAAG,CAAC,CAAC;YAG5E,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,GAAG,CAAC;iBACrC,IAAI,CAAC,WAAW,EAAE,YAAY,CAAC,CAAC;YAGrC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC,SAAS,CAAC,WAAW,CAAC;iBAC3C,IAAI,CAAC,IAAI,CAAC,gBAAgB,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;iBAC1C,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;iBACtB,IAAI,CAAC,OAAO,EAAE,UAAU,CAAK;gBAC1B,MAAM,CAAC,WAAW,GAAG,CAAC,CAAC,KAAK,CAAC;YACjC,CAAC,CAAC;iBACD,IAAI,CAAC,GAAG,EAAE,UAAU,CAAK;gBACtB,MAAM,CAAC,CAAC,CAAC,IAAI,CAAC;YAClB,CAAC,CAAC;iBACD,IAAI,CAAC,QAAQ,EAAE,OAAO,CAAC;iBACvB,IAAI,CAAC,cAAc,EAAE,OAAO,CAAC,CAAC;YAGnC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,MAAM,CAAC;iBACvB,IAAI,CAAC,gBAAgB,EAAE,SAAS,CAAC;iBACjC,IAAI,CAAC,OAAO,EAAE,cAAc,CAAC;iBAC7B,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC;iBAC1B,IAAI,CAAC,YAAY,EAAE,QAAQ,CAAC;iBAC5B,EAAE,CAAC,SAAS,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;YACnC,IAAI,CAAC,KAAK,GAAG,EAAE,CAAC,GAAG,CAAC,KAAK,EAAE;iBACtB,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC;iBACT,MAAM,CAAC,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;iBAC1B,EAAE,CAAC,OAAO,EAAE,IAAI,CAAC,WAAW,CAAC,CAAC;YACnC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,GAAG,CAAC;iBACpB,IAAI,CAAC,OAAO,EAAE,SAAS,CAAC;iBACxB,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC;iBAChB,SAAS,CAAC,MAAM,CAAC;iBACjB,IAAI,CAAC,GAAG,EAAE,CAAC,CAAC;iBACZ,IAAI,CAAC,QAAQ,EAAE,UAAU,GAAG,CAAC,CAAC;iBAC9B,KAAK,CAAC,MAAM,EAAC,MAAM,CAAC;iBACpB,KAAK,CAAC,cAAc,EAAC,KAAK,CAAC;iBAC3B,KAAK,CAAC,QAAQ,EAAC,eAAe,CAAC;iBAC/B,KAAK,CAAC,cAAc,EAAC,CAAC,CAAC,CAAC;YAG7B,IAAI,CAAC,QAAQ,CAAC,SAAS,CAAC,iBAAiB,CAAC,CAAC,MAAM,EAAE,CAAC;YACpD,IAAI,CAAC,WAAW,EAAE,CAAC;YAGnB,EAAE,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC,CAAC,CAAC;gBACb,IAAI,UAAoB,CAAC;gBACzB,EAAE,CAAC,CAAC,IAAI,CAAC,KAAK,CAAC;oBAAC,UAAU,GAAG,IAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC;gBACxD,IAAI,IAAI,GAAG,GAAG,CAAC,MAAM,CAAC,MAAM,CAAC;qBACxB,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC;qBAChB,IAAI,CAAC,GAAG,EAAE,CAAC,CAAC,CAAC,QAAQ,EAAE,GAAG,CAAC,CAAC,CAAC;qBAC7B,IAAI,CAAC,GAAG,EAAE,CAAC,CAAC,MAAM,CAAC,GAAG,GAAG,EAAE,CAAC,GAAG,CAAC,CAAC,CAAC;qBAClC,IAAI,CAAC,aAAa,EAAE,QAAQ,CAAC,CAAC;gBAEnC,EAAE,CAAC,CAAC,UAAU,CAAC,CAAC,CAAC;oBACb,EAAE,CAAC,CAAC,UAAU,CAAC,OAAO,EAAE,CAAC;wBAAC,IAAI,CAAC,IAAI,CAAC,aAAa,EAAE,UAAU,CAAC,OAAO,CAAC,CAAC;oBACvE,EAAE,CAAC,CAAC,UAAU,CAAC,WAAW,EAAE,IAAI,IAAI,CAAC;wBAAC,IAAI,CAAC,IAAI,CAAC,WAAW,EAAE,UAAU,CAAC,WAAW,EAAE,GAAG,IAAI,CAAC,CAAC;oBAC9F,EAAE,CAAC,CAAC,UAAU,CAAC,YAAY,EAAE,IAAI,IAAI,CAAC;wBAAC,IAAI,CAAC,KAAK,CAAC,iBAAiB,EAAE,WAAW,CAAC,CAAC;oBAClF,EAAE,CAAC,CAAC,UAAU,CAAC,QAAQ,EAAE,CAAC;wBAAC,IAAI,CAAC,KAAK,CAAC,MAAM,EAAE,UAAU,CAAC,QAAQ,CAAC,CAAC;oBACnE,IAAI;wBAAC,IAAI,CAAC,KAAK,CAAC,MAAM,EAAE,cAAc,CAAC,mBAAmB,CAAC,CAAC;gBAChE,CAAC;gBAAC,IAAI,CAAC,CAAC;oBACJ,IAAI,CAAC,KAAK,CAAC,iBAAiB,EAAE,WAAW,CAAC,CAAC;oBAC3C,IAAI,CAAC,KAAK,CAAC,MAAM,EAAE,cAAc,CAAC,mBAAmB,CAAC,CAAC;gBAC3D,CAAC;YACL,CAAC;QACL,CAAC,CAAC;QAGF,gBAAW,GAAG;YACV,IAAI,QAAQ,GAAO,IAAI,CAAC;YAExB,IAAI,MAAM,GAAY,IAAI,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC;YAC1C,IAAI,SAAS,GAAU,MAAM,CAAC,CAAC,CAAC,CAAC;YACjC,IAAI,SAAS,GAAU,MAAM,CAAC,CAAC,CAAC,CAAC;YAEjC,IAAI,YAAY,GAAO,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,UAAU,CAAC;gBACnD,MAAM,CAAC,CAAC,CAAC,KAAK,GAAG,SAAS,IAAI,CAAC,CAAC,GAAG,GAAG,SAAS,CAAA;YACnD,CAAC,CAAC,CAAC;YAEH,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC,CAAC;YAE/E,IAAI,CAAC,EAAE,CAAC,MAAM,CAAC,CAAC,SAAS,EAAE,SAAS,CAAC,CAAC,CAAC;YAGvC,IAAI,KAAK,GAAG,SAAS,GAAG,SAAS,CAAC;YAClC,EAAE,CAAC,CAAC,KAAK,GAAG,CAAC,GAAG,aAAa,CAAC,iBAAiB,CAAC,CAAC,CAAC;gBAC9C,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,UAAU,CAAC,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC;YACjF,CAAC;YAAC,IAAI,CAAC,EAAE,CAAC,CAAC,KAAK,GAAG,CAAC,GAAG,aAAa,CAAC,gBAAgB,CAAC,CAAC,CAAC;gBACpD,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,IAAI,EAAE,CAAC,CAAC,CAAC,UAAU,CAAC,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC;YAC9E,CAAC;YAAC,IAAI,CAAC,EAAE,CAAC,CAAC,KAAK,GAAG,CAAC,GAAG,aAAa,CAAC,iBAAiB,CAAC,CAAC,CAAC;gBACrD,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,KAAK,EAAE,CAAC,CAAC,CAAC,UAAU,CAAC,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC;YAC/E,CAAC;YAAC,IAAI,CAAC,EAAE,CAAC,CAAC,KAAK,GAAG,CAAC,GAAG,aAAa,CAAC,mBAAmB,CAAC,CAAC,CAAC;gBACvD,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,UAAU,CAAC,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC,CAAC;YACjF,CAAC;YAAC,IAAI,CAAC,EAAE,CAAC,CAAC,KAAK,IAAI,KAAK,CAAC,CAAC,CAAC;gBACxB,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,OAAO,EAAE,EAAE,CAAC,CAAC,UAAU,CAAC,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC;YACrF,CAAC;YAAC,IAAI,CAAC,CAAC;gBACJ,IAAI,CAAC,SAAS,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,OAAO,EAAE,CAAC,CAAC,CAAC,UAAU,CAAC,EAAE,CAAC,IAAI,CAAC,MAAM,CAAC,UAAU,CAAC,CAAC,CAAC;YACpF,CAAC;YAGD,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,WAAW,CAAC,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;YAGvD,IAAI,KAAK,GAAO,IAAI,CAAC,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC;iBACvC,IAAI,CAAC,YAAY,EAAE,UAAU,CAAC,IAAI,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC;iBACjD,IAAI,CAAC,GAAG,EAAE,UAAU,CAAC,IAAI,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;iBACxD,IAAI,CAAC,OAAO,EAAE,UAAU,CAAC,IAAI,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;YAG3F,KAAK,CAAC,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;iBACvB,IAAI,CAAC,GAAG,EAAE,UAAU,CAAC,IAAI,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;iBACxD,IAAI,CAAC,GAAG,EAAE,UAAU,CAAC,IAAI,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,aAAa,CAAC,iCAAiC,GAAG,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC;iBAChI,IAAI,CAAC,OAAO,EAAE,UAAU,CAAC,IAAI,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC,CAAC;iBACjF,IAAI,CAAC,QAAQ,EAAE,UAAU,CAAC,IAAI,MAAM,CAAC,aAAa,CAAC,gCAAgC,GAAG,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;iBACxG,IAAI,CAAC,QAAQ,EAAE,OAAO,CAAC;iBACvB,IAAI,CAAC,MAAM,EAAE,UAAS,CAAC;gBACpB,EAAE,CAAA,CAAC,CAAC,CAAC,KAAK,CAAC;oBAAC,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC;gBAC3B,MAAM,CAAC,aAAa,CAAC,aAAa,CAAC;YACvC,CAAC,CAAC;iBACD,IAAI,CAAC,cAAc,EAAE,CAAC,CAAC,CAAC;YAC7B,KAAK,CAAC,IAAI,EAAE,CAAC,MAAM,EAAE,CAAC;YAGtB,IAAI,MAAM,GAAO,IAAI,CAAC,SAAS,CAAC,SAAS,CAAC,MAAM,CAAC;iBAC5C,IAAI,CAAC,YAAY,EAAE,UAAU,CAAC;gBAC3B,MAAM,CAAC,CAAC,CAAC,EAAE,CAAC;YAChB,CAAC,CAAC;iBACD,IAAI,CAAC,GAAG,EAAE,UAAU,CAAC;gBAClB,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,EAAE,SAAS,CAAC,CAAC,GAAG,CAAC,CAAC;YACzD,CAAC,CAAC;iBACD,IAAI,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;YAE3B,MAAM,CAAC,KAAK,EAAE,CAAC,MAAM,CAAC,MAAM,CAAC;iBACxB,IAAI,CAAC,UAAU,CAAC;gBACb,EAAE,CAAA,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,CAAC,GAAG,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,IAAI,EAAE,CAAC;oBAAC,MAAM,CAAC,EAAE,CAAC;gBAC9D,EAAE,CAAA,CAAC,CAAC,CAAC,KAAK,CAAC;oBAAC,MAAM,CAAC,CAAC,CAAC,KAAK,CAAC;gBAC3B,MAAM,CAAC,EAAE,CAAC;YACd,CAAC,CAAC;iBACD,IAAI,CAAC,GAAG,EAAE,UAAU,CAAC;gBAClB,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,KAAK,EAAE,SAAS,CAAC,CAAC,GAAG,CAAC,CAAC;YACzD,CAAC,CAAC;iBACD,IAAI,CAAC,GAAG,EAAE,UAAU,CAAC;gBAClB,MAAM,CAAC,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,EAAE,GAAG,QAAQ,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC;YAC3D,CAAC,CAAC;iBACD,IAAI,CAAC,aAAa,EAAE,OAAO,CAAC;iBAC5B,IAAI,CAAC,OAAO,EAAE,WAAW,CAAC;iBAC1B,IAAI,CAAC,MAAM,EAAE,OAAO,CAAC,CAAC;YAE3B,MAAM,CAAC,IAAI,EAAE,CAAC,MAAM,EAAE,CAAC;QAC3B,CAAC,CAAC;QAEF,cAAS,GAAG;YACR,IAAI,MAAM,GAAO,EAAE,CAAC,KAAK,CAAC,IAAI,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;YACxC,IAAI,IAAI,GAAQ,IAAI,CAAC,CAAC,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC;YACnD,IAAI,UAAU,GAAW,CAAC,IAAI,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,GAAG,IAAI,CAAC,KAAK,CAAC,MAAM,EAAE,CAAC,CAAC,CAAC,CAAC,OAAO,EAAE,CAAC,GAAG,CAAC,CAAC;YAEnG,IAAI,CAAC,KAAK,CAAC,MAAM,CAAC,CAAC,IAAI,IAAI,CAAC,IAAI,GAAG,UAAU,CAAC,EAAE,IAAI,IAAI,CAAC,IAAI,GAAG,UAAU,CAAC,CAAC,CAAC,CAAC;YAC9E,IAAI,CAAC,WAAW,EAAE,CAAC;QACvB,CAAC,CAAC;QAEF,qBAAgB,GAAG,UAAC,KAAS;YACzB,IAAI,KAAK,GAAG,EAAE,EAAE,CAAC,EAAE,MAAM,GAAG,EAAE,GAAG,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,EAAE,MAAM,GAAG,EAAE,CAAC;YAC/D,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBACpC,CAAC,GAAG,KAAK,CAAC,CAAC,CAAC,CAAC;gBACb,EAAE,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC;oBAAC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,GAAG,EAAE,CAAC;gBACzC,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,IAAI,CAAC,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,MAAM,CAAC,EAAE,GAAG,EAAE,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC;YACvG,CAAC;YAED,GAAG,CAAC,CAAC,IAAI,SAAS,IAAI,KAAK,CAAC,CAAC,CAAC;gBAC1B,MAAM,CAAC,IAAI,CAAC,EAAC,KAAK,EAAE,SAAS,EAAE,IAAI,EAAE,KAAK,CAAC,SAAS,CAAC,EAAC,CAAC,CAAC;YAC5D,CAAC;YACD,MAAM,CAAC,MAAM,CAAC;QAClB,CAAC,CAAA;QA3UG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,aAAa,CAAC,CAAC,CAAC;QAEpF,IAAI,CAAC,SAAS,GAAG,IAAI,CAAC,WAAW,CAAC,CAAC;QACnC,IAAI,CAAC,QAAQ,GAAG,IAAI,CAAC,UAAU,CAAC,CAAC;IACrC,CAAC;IApBc,iCAAmB,GAAG,EAAE,CAAC;IACzB,+CAAiC,GAAU,IAAI,CAAC;IAChD,8CAAgC,GAAU,IAAI,CAAC;IAE/C,iCAAmB,GAAU,EAAE,GAAG,IAAI,CAAC;IACvC,+BAAiB,GAAU,EAAE,GAAG,aAAa,CAAC,mBAAmB,CAAC;IAClE,8BAAgB,GAAU,EAAE,GAAG,aAAa,CAAC,iBAAiB,CAAC;IAC/D,+BAAiB,GAAU,CAAC,GAAG,aAAa,CAAC,gBAAgB,CAAC;IAE9D,2BAAa,GAAG,WAAW,CAAC;IAkV/C,oBAAC;AAAD,CAAC,AA/WD,CAA4B,KAAK,GA+WhC;ACpXD;IAAyB,8BAAK;IAQ1B,oBAAa,OAAY;QAR7B,iBAgCC;QAvBO,kBAAM,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC;QAYjC,mBAAc,GAAG,cAAM,OAAA,KAAI,CAAC,WAAW,EAAhB,CAAgB,CAAC;QACxC,iBAAY,GAAG,cAAM,OAAA,KAAI,CAAC,SAAS,EAAd,CAAc,CAAC;QACpC,oBAAe,GAAG,cAAM,OAAA,KAAI,CAAC,YAAY,EAAjB,CAAiB,CAAC;QAE1C,mBAAc,GAAG,UAAC,GAAW;YACzB,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,YAAY,IAAI,GAAG,GAAG,CAAC,IAAI,GAAG,GAAG,IAAI,CAAC,YAAY,CAAC,MAAM,CAAC;gBAAC,MAAM,CAAC,IAAI,CAAC;YAChF,MAAM,CAAC,KAAI,CAAC,YAAY,CAAC,GAAG,CAAC,CAAC;QAClC,CAAC,CAAC;QAEF,uBAAkB,GAAG,cAAM,OAAA,KAAI,CAAC,eAAe,EAApB,CAAoB,CAAC;QAChD,kBAAa,GAAG,cAAM,OAAA,KAAI,CAAC,UAAU,EAAf,CAAe,CAAC;QApBlC,IAAI,KAAK,GAAQ,OAAO,CAAC,YAAY,CAAC,CAAC;QAEvC,EAAE,CAAA,CAAC,KAAK,CAAC,CAAA,CAAC;YACN,IAAI,CAAC,WAAW,GAAG,KAAK,CAAC,aAAa,CAAC,CAAC;YACxC,IAAI,CAAC,SAAS,GAAG,KAAK,CAAC,WAAW,CAAC,CAAC;YACpC,IAAI,CAAC,YAAY,GAAG,KAAK,CAAC,cAAc,CAAC,CAAC;YAC1C,EAAE,CAAA,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC;gBAAC,IAAI,CAAC,UAAU,GAAG,IAAI,SAAS,CAAC,KAAK,CAAC,YAAY,CAAC,CAAC,CAAC;QACjF,CAAC;IACL,CAAC;IAaL,iBAAC;AAAD,CAAC,AAhCD,CAAyB,KAAK,GAgC7B;AC/BD;IAA2B,gCAAS;IAKhC,sBAAY,OAAe;QACvB,kBAAM,aAAa,CAAC,YAAY,CAAC,CAAC;QAoBtC,WAAM,GAAG,UAAC,cAAsB;YAE5B,IAAI,MAAM,GAAW,CAAC,CAAC,aAAa,CAAC,CAAC;YACtC,MAAM,CAAC,QAAQ,EAAE,CAAC;YAElB,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC,CAAA,CAAC;gBAEX,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAC,CAAA,CAAC;oBACtB,IAAI,IAAI,GAAW,IAAI,CAAC,KAAK,CAAC,YAAY,EAAE,CAAC;oBAC7C,MAAM,CAAC,KAAK,CAAC,IAAI,CAAC,KAAK,CAAC,QAAQ,EAAE,GAAG,CAAC,IAAI,GAAG,IAAI,GAAG,EAAE,CAAC,CAAC,CAAC;gBAC7D,CAAC;gBACD,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC,SAAS,EAAE,CAAC,CAAA,CAAC;oBACvB,IAAI,IAAI,GAAW,IAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC;oBAC9C,MAAM,CAAC,MAAM,CAAC,IAAI,CAAC,KAAK,CAAC,SAAS,EAAE,GAAG,CAAC,IAAI,GAAG,IAAI,GAAG,EAAE,CAAC,CAAC,CAAC;gBAC/D,CAAC;gBACD,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC,kBAAkB,EAAE,CAAC;oBAAC,MAAM,CAAC,GAAG,CAAC,kBAAkB,EAAC,IAAI,CAAC,KAAK,CAAC,kBAAkB,EAAE,CAAC,CAAC;gBACnG,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC;oBAAC,MAAM,CAAC,GAAG,CAAC,OAAO,EAAE,IAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC,CAAC;YACnF,CAAC;YAGD,cAAc,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;YAG9B,EAAE,CAAA,CAAC,IAAI,CAAC,UAAU,CAAC,CAAA,CAAC;gBAChB,GAAG,CAAA,CAAE,IAAI,CAAC,GAAC,CAAC,EAAE,CAAC,GAAC,IAAI,CAAC,UAAU,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;oBAC1C,IAAI,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC;gBACtC,CAAC;YACL,CAAC;QACL,CAAC,CAAA;QA9CG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,YAAY,CAAC,CAAC,CAAC;QAElF,IAAI,UAAU,GAAU,IAAI,CAAC,YAAY,CAAC,CAAC;QAE3C,EAAE,CAAA,CAAC,UAAU,CAAC,CAAA,CAAC;YACX,IAAI,CAAC,UAAU,GAAG,EAAE,CAAC;YACrB,GAAG,CAAA,CAAE,IAAI,CAAC,GAAC,CAAC,EAAE,CAAC,GAAC,UAAU,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBACrC,IAAI,KAAK,GAAW,IAAI,CAAC,SAAS,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC,CAAC;gBAClD,IAAI,CAAC,UAAU,CAAC,IAAI,CAAC,SAAS,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC;YACxD,CAAC;QACL,CAAC;QAED,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAAC,IAAI,CAAC,KAAK,GAAG,IAAI,QAAQ,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC;IAG/D,CAAC;IAgCL,mBAAC;AAAD,CAAC,AAxDD,CAA2B,SAAS,GAwDnC;ACxDD;IAAuB,4BAAK;IAIxB,kBAAa,OAAY;QAJ7B,iBAcC;QATO,kBAAM,OAAO,CAAC,UAAU,CAAC,CAAC,CAAC;QAM/B,kBAAa,GAAG,cAAM,OAAA,KAAI,CAAC,UAAU,EAAf,CAAe,CAAC;QAJlC,EAAE,CAAA,CAAC,OAAO,IAAI,OAAO,CAAC,UAAU,CAAC,CAAC;YAAC,IAAI,CAAC,UAAU,GAAG,OAAO,CAAC,UAAU,CAAC,CAAC,YAAY,CAAC,CAAC;IAE3F,CAAC;IAKL,eAAC;AAAD,CAAC,AAdD,CAAuB,KAAK,GAc3B;ACRD;IAAiC,sCAAS;IAOtC,4BAAY,OAAe;QACvB,kBAAM,aAAa,CAAC,kBAAkB,CAAC,CAAC;QAqB5C,WAAM,GAAG,UAAC,cAAsB;YAE5B,IAAI,CAAC,GAAkB,IAAI,CAAC,KAAK,CAAC;YAElC,IAAI,QAAQ,GAAW,CAAC,CAAC,aAAa,CAAC,CAAC;YACxC,QAAQ,CAAC,QAAQ,EAAE,CAAC;YAEpB,IAAI,QAAgB,CAAC;YACrB,EAAE,CAAA,CAAC,IAAI,CAAC,KAAK,CAAC;gBAAC,QAAQ,GAAG,CAAC,CAAC,OAAO,GAAG,IAAI,CAAC,KAAK,GAAG,QAAQ,CAAC,CAAC;YAC7D,IAAI;gBAAC,QAAQ,GAAG,CAAC,CAAC,aAAa,CAAC,CAAC;YACjC,QAAQ,CAAC,QAAQ,EAAE,CAAC;YACpB,QAAQ,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC;YAE1B,IAAI,QAAQ,GAAW,CAAC,CAAC,aAAa,CAAC,CAAC;YACxC,QAAQ,CAAC,QAAQ,EAAE,CAAC;YACpB,QAAQ,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC;YAG1B,EAAE,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC,CAAC;gBACvB,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,eAAe,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;oBAEnD,IAAI,CAAC,eAAe,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC;gBAC7C,CAAC;YACL,CAAC;YAED,cAAc,CAAC,MAAM,CAAC,QAAQ,CAAC,CAAC;YAEhC,EAAE,CAAA,CAAC,IAAI,CAAC,gBAAgB,CAAC;gBAAC,QAAQ,CAAC,SAAS,CAAC,EAAC,WAAW,EAAE,IAAI,EAAE,WAAW,EAAE,SAAS,EAAE,MAAM,EAAE,KAAK,EAAC,CAAC,CAAC;YACzG,IAAI;gBAAC,QAAQ,CAAC,SAAS,CAAC,EAAC,WAAW,EAAE,IAAI,EAAE,WAAW,EAAE,SAAS,EAAC,CAAC,CAAC;QASzE,CAAC,CAAA;QAxDG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,kBAAkB,CAAC,CAAC,CAAC;QAExF,IAAI,CAAC,KAAK,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC;QAC3B,IAAI,CAAC,gBAAgB,GAAG,IAAI,CAAC,kBAAkB,CAAC,CAAC;QAEjD,IAAI,OAAO,GAAU,IAAI,CAAC,iBAAiB,CAAC,CAAC;QAE7C,EAAE,CAAA,CAAC,OAAO,CAAC,CAAA,CAAC;YACR,IAAI,CAAC,eAAe,GAAG,EAAE,CAAC;YAC1B,GAAG,CAAA,CAAE,IAAI,CAAC,GAAC,CAAC,EAAE,CAAC,GAAC,OAAO,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBAClC,IAAI,KAAK,GAAW,IAAI,CAAC,SAAS,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC;gBAC/C,IAAI,CAAC,eAAe,CAAC,IAAI,CAAC,SAAS,CAAC,YAAY,CAAC,KAAK,CAAC,CAAC,CAAC;YAC7D,CAAC;QACL,CAAC;QAED,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAAC,IAAI,CAAC,KAAK,GAAG,IAAI,cAAc,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC;IACrE,CAAC;IA0CL,yBAAC;AAAD,CAAC,AArED,CAAiC,SAAS,GAqEzC;AC3ED;IAA6B,kCAAK;IAE9B,wBAAa,OAAY;QACrB,kBAAM,OAAO,CAAC,gBAAgB,CAAC,CAAC,CAAC;IAGrC,CAAC;IAEL,qBAAC;AAAD,CAAC,AARD,CAA6B,KAAK,GAQjC;ACLD;IAA6B,kCAAS;IAOlC,wBAAY,OAAe;QACvB,kBAAM,aAAa,CAAC,cAAc,CAAC,CAAC;QAUxC,WAAM,GAAG,UAAC,cAAsB;YAE5B,IAAI,CAAC,GAAe,IAAI,CAAC,KAAK,CAAC;YAC/B,IAAI,MAAM,GAAW,KAAK,CAAC,UAAU,CAAC,CAAC,CAAC,CAAC;YAEzC,IAAI,GAAG,GAAG,QAAQ,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC;YAE1C,GAAG,CAAC,KAAK,CAAC,KAAK,GAAG,MAAM,CAAC;YACzB,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,CAAC,gBAAgB,EAAE,IAAI,IAAK,CAAC;gBAAC,GAAG,CAAC,YAAY,CAAC,QAAQ,EAAE,MAAM,CAAC,CAAC,CAAC,gBAAgB,EAAE,CAAC,CAAC,CAAC;YAChG,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,CAAC,kBAAkB,EAAE,CAAC;gBAAC,GAAG,CAAC,KAAK,CAAC,eAAe,GAAG,CAAC,CAAC,kBAAkB,EAAE,CAAC;YACnF,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,CAAC,iBAAiB,EAAE,CAAC;gBAAC,GAAG,CAAC,KAAK,CAAC,UAAU,GAAG,CAAC,CAAC,iBAAiB,EAAE,CAAC;YAE5E,EAAE,CAAC,CAAC,CAAC,IAAI,CAAC,CAAC,eAAe,EAAE,CAAC,CAAC,CAAC;gBAE3B,IAAI,SAAS,GAAa,CAAC,CAAC,eAAe,EAAE,CAAC;gBAC9C,IAAI,IAAI,GAAW,OAAO,CAAC,mBAAmB,CAAC,CAAC,CAAC,kBAAkB,EAAE,CAAC,CAAC;gBACvE,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,SAAS,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;oBACxC,IAAI,GAAG,GAAG,QAAQ,CAAC,aAAa,CAAC,KAAK,CAAC,CAAC;oBACxC,GAAG,CAAC,YAAY,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC;oBAC/C,GAAG,CAAC,WAAW,CAAC,GAAG,CAAC,CAAC;gBACzB,CAAC;YACL,CAAC;YAGD,IAAI,MAAM,GAAG,CAAC,CAAC;YACf,IAAI,QAAQ,GAAG,CAAC,CAAC;YACjB,IAAI,SAAS,GAAG,CAAC,CAAC;YAClB,IAAI,OAAO,GAAG,CAAC,CAAC;YAEhB,EAAE,CAAC,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC;gBACd,IAAI,OAAO,GAAG,QAAQ,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC;gBAC9C,IAAI,SAAS,GAAG,QAAQ,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;gBAE7C,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,CAAC,cAAc,EAAE,CAAC;oBAAC,SAAS,CAAC,KAAK,CAAC,eAAe,GAAG,CAAC,CAAC,cAAc,EAAE,CAAC;gBAEjF,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,MAAM,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;oBAC1C,IAAI,OAAO,GAAG,QAAQ,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;oBAC3C,OAAO,CAAC,KAAK,CAAC,OAAO,GAAG,MAAM,GAAG,KAAK,GAAG,QAAQ,GAAG,KAAK,GAAG,SAAS,GAAG,KAAK,GAAG,OAAO,GAAG,IAAI,CAAC;oBAC/F,OAAO,CAAC,WAAW,CAAC,QAAQ,CAAC,cAAc,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;oBAC7D,SAAS,CAAC,WAAW,CAAC,OAAO,CAAC,CAAC;gBACnC,CAAC;gBACD,GAAG,CAAC,WAAW,CAAC,SAAS,CAAC,CAAC;YAC/B,CAAC;YAGD,EAAE,CAAC,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC;gBAEf,IAAI,IAAI,GAAG,QAAQ,CAAC,aAAa,CAAC,OAAO,CAAC,CAAC;gBAC3C,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;oBAC3C,IAAI,EAAE,GAAG,QAAQ,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;oBAEtC,GAAG,CAAC,CAAC,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;wBAC9C,IAAI,EAAE,GAAG,QAAQ,CAAC,aAAa,CAAC,IAAI,CAAC,CAAC;wBACtC,EAAE,CAAC,KAAK,CAAC,OAAO,GAAG,MAAM,GAAG,KAAK,GAAG,QAAQ,GAAG,KAAK,GAAG,SAAS,GAAG,KAAK,GAAG,OAAO,GAAG,IAAI,CAAC;wBAC1F,EAAE,CAAC,WAAW,CAAC,QAAQ,CAAC,cAAc,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;wBAC5D,EAAE,CAAC,WAAW,CAAC,EAAE,CAAC,CAAC;oBACvB,CAAC;oBAED,IAAI,CAAC,WAAW,CAAC,EAAE,CAAC,CAAC;gBACzB,CAAC;gBACD,GAAG,CAAC,WAAW,CAAC,IAAI,CAAC,CAAC;YAC1B,CAAC;YAED,cAAc,CAAC,MAAM,CAAC,GAAG,CAAC,CAAC;QAC/B,CAAC,CAAA;QAxEG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,cAAc,CAAC,CAAC,CAAC;QAEpF,IAAI,CAAC,MAAM,GAAG,IAAI,CAAC,QAAQ,CAAC,CAAC;QAC7B,IAAI,CAAC,OAAO,GAAG,IAAI,CAAC,SAAS,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAAC,IAAI,CAAC,KAAK,GAAG,IAAI,UAAU,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC;IACjE,CAAC;IAqEL,qBAAC;AAAD,CAAC,AArFD,CAA6B,SAAS,GAqFrC;ACzFD;IAAyB,8BAAK;IAQ1B,oBAAa,OAAY;QAR7B,iBA0BC;QAjBO,kBAAM,OAAO,CAAC,YAAY,CAAC,CAAC,CAAC;QAYjC,oBAAe,GAAG,cAAM,OAAA,KAAI,CAAC,YAAY,EAAjB,CAAiB,CAAC;QAC1C,uBAAkB,GAAG,cAAM,OAAA,KAAI,CAAC,eAAe,EAApB,CAAoB,CAAC;QAChD,qBAAgB,GAAG,cAAM,OAAA,KAAI,CAAC,aAAa,EAAlB,CAAkB,CAAC;QAC5C,mBAAc,GAAG,cAAM,OAAA,KAAI,CAAC,WAAW,EAAhB,CAAgB,CAAC;QACxC,sBAAiB,GAAG,cAAM,OAAA,KAAI,CAAC,cAAc,EAAnB,CAAmB,CAAC;QAd1C,IAAI,KAAK,GAAQ,OAAO,CAAC,YAAY,CAAC,CAAC;QACvC,EAAE,CAAA,CAAC,KAAK,CAAC,CAAA,CAAC;YACN,IAAI,CAAC,YAAY,GAAG,OAAO,CAAC,YAAY,CAAC,CAAC,cAAc,CAAC,CAAC;YAC1D,IAAI,CAAC,aAAa,GAAG,OAAO,CAAC,YAAY,CAAC,CAAC,eAAe,CAAC,CAAC;YAC5D,IAAI,CAAC,WAAW,GAAG,OAAO,CAAC,YAAY,CAAC,CAAC,aAAa,CAAC,CAAC;YACxD,IAAI,CAAC,eAAe,GAAG,OAAO,CAAC,YAAY,CAAC,CAAC,iBAAiB,CAAC,CAAC;YAChE,IAAI,CAAC,cAAc,GAAG,OAAO,CAAC,YAAY,CAAC,CAAC,gBAAgB,CAAC,CAAC;QAClE,CAAC;IACL,CAAC;IAOL,iBAAC;AAAD,CAAC,AA1BD,CAAyB,KAAK,GA0B7B;ACtBD;IAA4B,iCAAS;IAKjC,uBAAY,OAAe;QAL/B,iBAwCC;QAlCO,kBAAM,aAAa,CAAC,aAAa,CAAC,CAAC;QASvC,WAAM,GAAG,UAAC,cAAsB;YAE5B,IAAI,QAAQ,GAAS,QAAQ,CAAC,cAAc,CAAC,KAAI,CAAC,IAAI,CAAC,CAAC;YACxD,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,CAAA,CAAC;gBACX,IAAI,OAAO,GAAoB,QAAQ,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;gBAC9D,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,OAAO,EAAE,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,IAAI,GAAG,KAAI,CAAC,KAAK,CAAC,OAAO,EAAE,CAAC;gBACnE,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,WAAW,EAAE,IAAI,IAAI,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,QAAQ,GAAG,KAAI,CAAC,KAAK,CAAC,WAAW,EAAE,GAAG,IAAI,CAAC;gBAC9F,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,YAAY,EAAE,IAAI,IAAI,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,cAAc,GAAC,WAAW,CAAC;gBAC/E,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,KAAK,GAAG,KAAI,CAAC,KAAK,CAAC,QAAQ,EAAE,CAAC;gBACtE,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,YAAY,EAAE,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,SAAS,GAAG,KAAI,CAAC,KAAK,CAAC,YAAY,EAAE,GAAG,IAAI,CAAC;gBACzF,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,eAAe,EAAE,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,YAAY,GAAG,KAAI,CAAC,KAAK,CAAC,eAAe,EAAE,GAAG,IAAI,CAAC;gBAClG,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,aAAa,EAAE,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,UAAU,GAAG,KAAI,CAAC,KAAK,CAAC,aAAa,EAAE,GAAG,IAAI,CAAC;gBAC5F,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,cAAc,EAAE,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,WAAW,GAAG,KAAI,CAAC,KAAK,CAAC,cAAc,EAAE,GAAG,IAAI,CAAC;gBAC/F,EAAE,CAAA,CAAC,KAAI,CAAC,KAAK,CAAC,gBAAgB,EAAE,CAAC;oBAAC,OAAO,CAAC,KAAK,CAAC,UAAU,GAAG,KAAK,CAAC;gBAEnE,OAAO,CAAC,WAAW,CAAC,QAAQ,CAAC,CAAC;gBAC9B,cAAc,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC;YACnC,CAAC;YAAC,IAAI,CAAC,CAAC;gBACJ,IAAI,OAAO,GAAoB,QAAQ,CAAC,aAAa,CAAC,MAAM,CAAC,CAAC;gBAE9D,OAAO,CAAC,WAAW,CAAC,QAAQ,CAAC,CAAC;gBAC9B,cAAc,CAAC,MAAM,CAAC,OAAO,CAAC,CAAC;YACnC,CAAC;QACL,CAAC,CAAA;QA/BG,IAAI,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,OAAO,CAAC,CAAC;QAC/B,EAAE,CAAA,CAAC,CAAC,IAAI,CAAC,eAAe,CAAC,CAAC;YAAC,IAAI,GAAG,IAAI,CAAC,aAAa,CAAC,aAAa,CAAC,aAAa,CAAC,CAAC,CAAC;QAEnF,IAAI,CAAC,IAAI,GAAG,IAAI,CAAC,MAAM,CAAC,CAAC;QAEzB,EAAE,CAAA,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;YAAC,IAAI,CAAC,KAAK,GAAG,IAAI,SAAS,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC,CAAC;IAChE,CAAC;IA2BL,oBAAC;AAAD,CAAC,AAxCD,CAA4B,SAAS,GAwCpC;AC5CD;IAAwB,6BAAK;IAQzB,mBAAa,OAAY;QAR7B,iBA0BC;QAjBO,kBAAM,OAAO,CAAC,WAAW,CAAC,CAAC,CAAC;QAYhC,YAAO,GAAG,cAAM,OAAA,KAAI,CAAC,IAAI,EAAT,CAAS,CAAC;QAC1B,gBAAW,GAAG,cAAM,OAAA,KAAI,CAAC,QAAQ,EAAb,CAAa,CAAC;QAClC,iBAAY,GAAG,cAAM,OAAA,KAAI,CAAC,SAAS,EAAd,CAAc,CAAC;QACpC,aAAQ,GAAG,cAAM,OAAA,KAAI,CAAC,KAAK,EAAV,CAAU,CAAC;QAC5B,qBAAgB,GAAG,cAAM,OAAA,KAAI,CAAC,aAAa,EAAlB,CAAkB,CAAC;QAdxC,IAAI,KAAK,GAAQ,OAAO,CAAC,WAAW,CAAC,CAAC;QACtC,EAAE,CAAA,CAAC,KAAK,CAAC,CAAA,CAAC;YACN,IAAI,CAAC,IAAI,GAAG,KAAK,CAAC,MAAM,CAAC,CAAC;YAC1B,IAAI,CAAC,QAAQ,GAAG,KAAK,CAAC,UAAU,CAAC,CAAC;YAClC,IAAI,CAAC,SAAS,GAAG,KAAK,CAAC,WAAW,CAAC,CAAC;YACpC,IAAI,CAAC,KAAK,GAAG,KAAK,CAAC,OAAO,CAAC,CAAC;YAC5B,IAAI,CAAC,aAAa,GAAG,KAAK,CAAC,eAAe,CAAC,CAAC;QAChD,CAAC;IACL,CAAC;IAOL,gBAAC;AAAD,CAAC,AA1BD,CAAwB,KAAK,GA0B5B"} \ No newline at end of file diff --git a/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html b/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html deleted file mode 100644 index a1b4e92a0..000000000 --- a/arbiter/arbiter-ui/src/main/resources/templates/ArbiterUI.html +++ /dev/null @@ -1,638 +0,0 @@ - - - - - - DL4J - Arbiter UI - - - - - - - - - - - - - -

OriginalRegexBoneConeTone
'  4.25 ''  4.25 '^\\s+|\\s+$'4.25'
- - - - - - -
-
Deeplearning4J - Arbiter UI
- -
- - -
-
-
-

Summary

-
-
-
-
- -
-
-

Optimization Settings

-
-
-
- - -
-
Results
-
- - - - - - -
-
-
- -
-
-

Selected Result

-
-
-
-
- - \ No newline at end of file diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java deleted file mode 100644 index fcf3066e2..000000000 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,50 +0,0 @@ -/* ****************************************************************************** - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ -package org.deeplearning4j.arbiter.optimize; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.deeplearning4j.BaseDL4JTest; - -import java.util.*; - -/** - * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) - * extends BaseDl4jTest - either directly or indirectly. - * Other than a small set of exceptions, all tests must extend this - * - * @author Alex Black - */ - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.deeplearning4j.arbiter.optimize"; - } - - @Override - protected Class getBaseClass() { - return BaseDL4JTest.class; - } -} diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java deleted file mode 100644 index b7502c84b..000000000 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ /dev/null @@ -1,791 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.arbiter.optimize; - -import io.netty.handler.codec.http.HttpResponseStatus; -import lombok.NonNull; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.core.storage.StatsStorage; -import org.deeplearning4j.arbiter.ComputationGraphSpace; -import org.deeplearning4j.arbiter.MultiLayerSpace; -import org.deeplearning4j.arbiter.conf.updater.SgdSpace; -import org.deeplearning4j.arbiter.layers.ConvolutionLayerSpace; -import org.deeplearning4j.arbiter.layers.DenseLayerSpace; -import org.deeplearning4j.arbiter.layers.OutputLayerSpace; -import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; -import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; -import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; -import org.deeplearning4j.arbiter.optimize.api.data.DataSource; -import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; -import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; -import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; -import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; -import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.discrete.DiscreteParameterSpace; -import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; -import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; -import org.deeplearning4j.arbiter.optimize.runner.listener.StatusListener; -import org.deeplearning4j.arbiter.saver.local.FileModelSaver; -import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction; -import org.deeplearning4j.arbiter.scoring.impl.TestSetLossScoreFunction; -import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator; -import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; -import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener; -import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.ui.api.UIServer; -import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; - -import org.junit.jupiter.api.Test; -import org.nd4j.common.function.Function; -import org.nd4j.evaluation.classification.Evaluation; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.LossFunctions; - -import java.io.File; -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.net.HttpURLConnection; -import java.net.URL; -import java.net.URLEncoder; -import java.util.*; -import java.util.concurrent.TimeUnit; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -/** - * Created by Alex on 19/07/2017. - */ -@Slf4j -public class TestBasic extends BaseDL4JTest { - - @Override - public long getTimeoutMilliseconds() { - return 3600_000L; - } - - @Test - //@Ignore - public void testBasicUiOnly() throws Exception { - - UIServer.getInstance(); - - Thread.sleep(1000_000); - } - - @Test - //@Ignore - public void testBasicMnist() throws Exception { - Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - - MultiLayerSpace mls = getMultiLayerSpaceMnist(); - Map commands = new HashMap<>(); -// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new MnistDataSetProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - - StatsStorage ss = new InMemoryStatsStorage(); - StatusListener sl = new ArbiterStatusListener(ss); - runner.addListeners(sl); - - UIServer.getInstance().attach(ss); - - runner.execute(); - Thread.sleep(1000_000); - } - - private static MultiLayerSpace getMultiLayerSpaceMnist() { - return new MultiLayerSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .addLayer( - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 30)) - .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, - new int[]{4, 4}, new int[]{5, 5})) - .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, - new int[]{2, 2})) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build()) - .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - } - - @Test - //@Ignore - public void testBasicMnistDataSource() throws InterruptedException { - ParameterSpace learningRateHyperparam = new ContinuousParameterSpace(0.0001, 0.1); - ParameterSpace layerSizeHyperparam = new IntegerParameterSpace(16, 256); - - MultiLayerSpace hyperparameterSpace = new MultiLayerSpace.Builder() - .weightInit(WeightInit.XAVIER) - .l2(0.0001) - .updater(new SgdSpace(learningRateHyperparam)) - .addLayer(new DenseLayerSpace.Builder() - .nIn(784) - .activation(Activation.LEAKYRELU) - .nOut(layerSizeHyperparam) - .build()) - .addLayer(new OutputLayerSpace.Builder() - .nOut(10) - .activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT) - .build()) - .build(); - CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperparameterSpace, null); - ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY); - TerminationCondition[] terminationConditions = { - new MaxTimeCondition(5, TimeUnit.MINUTES), - new MaxCandidatesCondition(2)}; - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - Class ds = MnistDataSource.class; - Properties dsp = new Properties(); - dsp.setProperty("minibatch", "8"); - OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataSource(ds, dsp) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(scoreFunction) - .terminationConditions(terminationConditions) - .build(); - - IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - - StatsStorage ss = new InMemoryStatsStorage(); - StatusListener sl = new ArbiterStatusListener(ss); - runner.addListeners(sl); - - UIServer.getInstance().attach(ss); - - runner.execute(); - Thread.sleep(90000); - } - - - @Test - //@Ignore - public void testBasicMnistCompGraph() throws Exception { - - ComputationGraphSpace cgs = new ComputationGraphSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .addInputs("in") - .addLayer("0", - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 30)) - .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, - new int[]{4, 4}, new int[]{5, 5})) - .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, - new int[]{2, 2})) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build(), "in") - .addLayer("1", new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), "0") - .addLayer("out", new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "1") - .setOutputs("out") - .setInputTypes(InputType.convolutionalFlat(28, 28, 1)) - .build(); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs); - DataProvider dataProvider = new MnistDataSetProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnistCG\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator()); - - StatsStorage ss = new InMemoryStatsStorage(); - StatusListener sl = new ArbiterStatusListener(ss); - runner.addListeners(sl); - - UIServer.getInstance().attach(ss); - - runner.execute(); - Thread.sleep(100000); - } - - - @Test - //@Ignore - public void testCandidateGenerationExceptionsMnist() throws Exception { - - //Idea: Create a configuration that is not physically realizable, which should throw an exception - // during the candidate generation phase - //This exception should be visible in UI, but training should continue otherwise - - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .dropOut(new ContinuousParameterSpace(0.2, 0.7)) - .addLayer( - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 5)) - .kernelSize(new DiscreteParameterSpace<>(new int[]{14, 14}, new int[]{30, 30})) - .stride(2, 2) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build()) - .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - Map commands = new HashMap<>(); -// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new MnistDataSetProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - - StatsStorage ss = new InMemoryStatsStorage(); - StatusListener sl = new ArbiterStatusListener(ss); - runner.addListeners(sl); - - UIServer.getInstance().attach(ss); - - runner.execute(); - Thread.sleep(1000_000); - } - - - @Test - //@Ignore - public void testCandidateExecutionExceptionsMnist() throws Exception { - //Idea: Create a configuration that will throw an exception in the *execution* stage - // How? let's set wrong nOut - //This exception should be visible in UI, but training should continue otherwise - - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .dropOut(new ContinuousParameterSpace(0.2, 0.7)) - .addLayer( - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 5)) - .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, new int[]{4, 4})) - .stride(2, 2) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build()) - .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 64)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers - .addLayer(new OutputLayerSpace.Builder().nOut(99).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - Map commands = new HashMap<>(); -// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - DataProvider dataProvider = new MnistDataSetProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - - StatsStorage ss = new InMemoryStatsStorage(); - StatusListener sl = new ArbiterStatusListener(ss); - runner.addListeners(sl); - - UIServer.getInstance().attach(ss); - - runner.execute(); - Thread.sleep(1000_000); - } - - - @Test - //@Ignore - public void testExecutionExceptionMnistCompGraph() throws Exception { - - //Idea: Create a configuration that will throw an exception in the *execution* stage - // How? let's set wrong nOut - //This exception should be visible in UI, but training should continue otherwise - - ComputationGraphSpace cgs = new ComputationGraphSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .dropOut(new ContinuousParameterSpace(0.2, 0.7)) - .addInputs("in") - .addLayer("0", - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 30)) - .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, - new int[]{4, 4}, new int[]{5, 5})) - .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, - new int[]{2, 2})) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build(), "in") - .addLayer("1", new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 64)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), "0") - .addLayer("out", new OutputLayerSpace.Builder().nIn(99).nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "1") - .setOutputs("out") - .setInputTypes(InputType.convolutionalFlat(28, 28, 1)) - .build(); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(cgs); - DataProvider dataProvider = new MnistDataSetProvider(); - - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnistCG\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataProvider(dataProvider) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(120, TimeUnit.MINUTES), - new MaxCandidatesCondition(100)) - .build(); - - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator()); - - StatsStorage ss = new InMemoryStatsStorage(); - StatusListener sl = new ArbiterStatusListener(ss); - runner.addListeners(sl); - - UIServer.getInstance().attach(ss); - - runner.execute(); - Thread.sleep(1000_000); - } - - - /** - * Visualize multiple optimization sessions run one after another on single-session mode UI - * @throws InterruptedException if current thread has been interrupted - */ - @Test - //@Ignore - public void testBasicMnistMultipleSessions() throws InterruptedException { - - MultiLayerSpace mls = new MultiLayerSpace.Builder() - .updater(new SgdSpace(new ContinuousParameterSpace(0.0001, 0.2))) - .l2(new ContinuousParameterSpace(0.0001, 0.05)) - .dropOut(new ContinuousParameterSpace(0.2, 0.7)) - .addLayer( - new ConvolutionLayerSpace.Builder().nIn(1) - .nOut(new IntegerParameterSpace(5, 30)) - .kernelSize(new DiscreteParameterSpace<>(new int[]{3, 3}, - new int[]{4, 4}, new int[]{5, 5})) - .stride(new DiscreteParameterSpace<>(new int[]{1, 1}, - new int[]{2, 2})) - .activation(new DiscreteParameterSpace<>(Activation.RELU, - Activation.SOFTPLUS, Activation.LEAKYRELU)) - .build()) - .addLayer(new DenseLayerSpace.Builder().nOut(new IntegerParameterSpace(32, 128)) - .activation(new DiscreteParameterSpace<>(Activation.RELU, Activation.TANH)) - .build(), new IntegerParameterSpace(0, 1), true) //0 to 1 layers - .addLayer(new OutputLayerSpace.Builder().nOut(10).activation(Activation.SOFTMAX) - .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) - .build(); - Map commands = new HashMap<>(); -// commands.put(DataSetIteratorFactoryProvider.FACTORY_KEY, TestDataFactoryProviderMnist.class.getCanonicalName()); - - //Define configuration: - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls, commands); - - Class ds = MnistDataSource.class; - Properties dsp = new Properties(); - dsp.setProperty("minibatch", "8"); - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\").getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataSource(ds, dsp) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), - new MaxCandidatesCondition(3)) - .build(); - - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - - StatsStorage ss = new InMemoryStatsStorage(); - - - StatusListener sl = new ArbiterStatusListener(ss); - runner.addListeners(sl); - - UIServer.getInstance().attach(ss); - runner.execute(); - - - candidateGenerator = new RandomSearchGenerator(mls, commands); - configuration = new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataSource(ds, dsp) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(1, TimeUnit.MINUTES), - new MaxCandidatesCondition(3)) - .build(); - - runner = new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - sl = new ArbiterStatusListener(ss); - runner.addListeners(sl); - - UIServer.getInstance().attach(ss); - - runner.execute(); - - Thread.sleep(1000_000); - } - - /** - * Auto-attach multiple optimization sessions to multi-session mode UI - * @throws IOException if could not connect to the server - */ - @Test - public void testUiMultiSessionAutoAttach() throws IOException { - - //Define configuration: - MultiLayerSpace mls = getMultiLayerSpaceMnist(); - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); - - Class ds = MnistDataSource.class; - Properties dsp = new Properties(); - dsp.setProperty("minibatch", "8"); - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestMultiSessionAutoAttach\\") - .getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataSource(ds, dsp) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(10, TimeUnit.SECONDS), - new MaxCandidatesCondition(1)) - .build(); - - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - - // add 3 different sessions to the same execution - HashMap statsStorageForSession = new HashMap<>(); - for (int i = 0; i < 3; i++) { - StatsStorage ss = new InMemoryStatsStorage(); - @NonNull String sessionId = "sid" + i; - statsStorageForSession.put(sessionId, ss); - StatusListener sl = new ArbiterStatusListener(sessionId, ss); - runner.addListeners(sl); - } - - Function statsStorageProvider = statsStorageForSession::get; - UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); - String serverAddress = uIServer.getAddress(); - - runner.execute(); - - for (String sessionId : statsStorageForSession.keySet()) { - /* - * Visiting /arbiter/:sessionId to auto-attach StatsStorage - */ - String sessionUrl = sessionUrl(uIServer.getAddress(), sessionId); - HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); - conn.connect(); - - log.info("Checking auto-attaching Arbiter session at {}", sessionUrl(serverAddress, sessionId)); - assertEquals(HttpResponseStatus.OK.code(), conn.getResponseCode()); - assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId))); - } - } - - /** - * Attach multiple optimization sessions to multi-session mode UI by manually visiting session URL - * @throws Exception if an error occurred - */ - @Test - //@Ignore - public void testUiMultiSessionManualAttach() throws Exception { - Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - - //Define configuration: - MultiLayerSpace mls = getMultiLayerSpaceMnist(); - CandidateGenerator candidateGenerator = new RandomSearchGenerator(mls); - - Class ds = MnistDataSource.class; - Properties dsp = new Properties(); - dsp.setProperty("minibatch", "8"); - - String modelSavePath = new File(System.getProperty("java.io.tmpdir"), "ArbiterUiTestBasicMnist\\") - .getAbsolutePath(); - - File f = new File(modelSavePath); - if (f.exists()) - f.delete(); - f.mkdir(); - if (!f.exists()) - throw new RuntimeException(); - - OptimizationConfiguration configuration = - new OptimizationConfiguration.Builder() - .candidateGenerator(candidateGenerator).dataSource(ds, dsp) - .modelSaver(new FileModelSaver(modelSavePath)) - .scoreFunction(new TestSetLossScoreFunction(true)) - .terminationConditions(new MaxTimeCondition(10, TimeUnit.MINUTES), - new MaxCandidatesCondition(10)) - .build(); - - - // parallel execution of multiple optimization sessions - HashMap statsStorageForSession = new HashMap<>(); - for (int i = 0; i < 3; i++) { - String sessionId = "sid" + i; - IOptimizationRunner runner = - new LocalOptimizationRunner(configuration, new MultiLayerNetworkTaskCreator()); - StatsStorage ss = new InMemoryStatsStorage(); - statsStorageForSession.put(sessionId, ss); - StatusListener sl = new ArbiterStatusListener(sessionId, ss); - runner.addListeners(sl); - // Asynchronous execution - new Thread(runner::execute).start(); - } - - Function statsStorageProvider = statsStorageForSession::get; - UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); - String serverAddress = uIServer.getAddress(); - - for (String sessionId : statsStorageForSession.keySet()) { - log.info("Arbiter session can be attached at {}", sessionUrl(serverAddress, sessionId)); - } - - Thread.sleep(1000_000); - } - - - /** - * Get URL for arbiter session on given server address - * @param serverAddress server address, e.g.: http://localhost:9000 - * @param sessionId session ID (will be URL-encoded) - * @return URL - * @throws UnsupportedEncodingException if the character encoding is not supported - */ - private static String sessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException { - return String.format("%s/arbiter/%s", serverAddress, URLEncoder.encode(sessionId, "UTF-8")); - } - - private static class MnistDataSetProvider implements DataProvider { - - @Override - public DataSetIterator trainData(Map dataParameters) { - try { - if (dataParameters == null || dataParameters.isEmpty()) { - return new MnistDataSetIterator(64, 10000, false, true, true, 123); - } - if (dataParameters.containsKey("batchsize")) { - int b = (Integer) dataParameters.get("batchsize"); - return new MnistDataSetIterator(b, 10000, false, true, true, 123); - } - return new MnistDataSetIterator(64, 10000, false, true, true, 123); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public DataSetIterator testData(Map dataParameters) { - return trainData(dataParameters); - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - - @Override - public String toString() { - return "MnistDataSetProvider()"; - } - } - - public static class MnistDataSource implements DataSource { - private int minibatch; - - public MnistDataSource() { - - } - - @Override - public void configure(Properties properties) { - this.minibatch = Integer.parseInt(properties.getProperty("minibatch", "16")); - } - - @Override - public Object trainData() { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(minibatch, true, 12345), 3); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Object testData() { - try { - return new EarlyTerminationDataSetIterator(new MnistDataSetIterator(minibatch, true, 12345), 3); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Class getDataType() { - return DataSetIterator.class; - } - } - -} diff --git a/arbiter/arbiter-ui/src/test/resources/logback.xml b/arbiter/arbiter-ui/src/test/resources/logback.xml deleted file mode 100644 index 410bdaae9..000000000 --- a/arbiter/arbiter-ui/src/test/resources/logback.xml +++ /dev/null @@ -1,51 +0,0 @@ - - - - - - logs/application.log - - %date - [%level] - from %logger in %thread - %n%message%n%xException%n - - - - - - %logger{15} - %message%n%xException{5} - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/arbiter/buildmultiplescalaversions.sh b/arbiter/buildmultiplescalaversions.sh deleted file mode 100644 index e04610a02..000000000 --- a/arbiter/buildmultiplescalaversions.sh +++ /dev/null @@ -1,53 +0,0 @@ -#! /bin/bash -################################################################################ -# Copyright (c) 2015-2018 Skymind, Inc. -# -# This program and the accompanying materials are made available under the -# terms of the Apache License, Version 2.0 which is available at -# https://www.apache.org/licenses/LICENSE-2.0. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -BASEDIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -function echoError() { - (>&2 echo "$1") -} - -function scalaError() { - echoError "Changing Scala major version to 2.10 in the build did not change the state of your working copy, is Scala 2.11 still the default ?" - exit 2 -} - -function whatchanged() { - cd "$BASEDIR" - for i in $(git status -s --porcelain -- $(find ./ -mindepth 2 -name pom.xml)|awk '{print $2}'); do - echo "$(dirname $i)" - cd "$BASEDIR" - done -} - -set -eu -./change-scala-versions.sh 2.11 # should be idempotent, this is the default -mvn "$@" -./change-scala-versions.sh 2.10 -if [ -z "$(whatchanged)" ]; then - scalaError; -else - if [[ "${@#-pl}" = "$@" ]]; then - mvn -Dmaven.clean.skip=true -pl $(whatchanged| tr '\n' ',') -amd "$@" - else - # the arguments already tweak the project list ! don't tweak them more - # as this can lead to conflicts (excluding a project that's not part of - # the reactor) - mvn "$@" - fi -fi -./change-scala-versions.sh 2.11 # back to the default diff --git a/arbiter/contrib/formatter.xml b/arbiter/contrib/formatter.xml deleted file mode 100644 index d6cc96bf6..000000000 --- a/arbiter/contrib/formatter.xml +++ /dev/null @@ -1,353 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/arbiter/pom.xml b/arbiter/pom.xml deleted file mode 100644 index a3321c0ab..000000000 --- a/arbiter/pom.xml +++ /dev/null @@ -1,182 +0,0 @@ - - - - - - - - net.brutex.ai - deeplearning4j - 1.0.0-SNAPSHOT - - - 4.0.0 - - net.brutex.ai - arbiter - pom - - Arbiter - Model Evaluation and Testing - - - - Apache License, Version 2.0 - http://www.apache.org/licenses/LICENSE-2.0.txt - repo - - - - - arbiter-deeplearning4j - arbiter-core - arbiter-server - arbiter-ui - - - - - - org.apache.maven.plugins - maven-javadoc-plugin - - - generate-javadoc - prepare-package - - javadoc - - - - - - - - - - - net.alchim31.maven - scala-maven-plugin - ${maven-scala-plugin.version} - - - -deprecation - -explaintypes - -nobootcp - - - - - scala-compile-first - process-resources - - add-source - compile - - - - scala-test-compile - process-test-resources - - add-source - testCompile - - - - - - - - - - - test-nd4j-native - - - net.brutex.ai - nd4j-native - ${project.version} - test - - - net.brutex.ai - dl4j-test-resources - ${dl4j-test-resources.version} - test - - - - - - test-nd4j-cuda-10.2 - - - net.brutex.ai - nd4j-cuda-${cuda.version} - ${project.version} - test - - - net.brutex.ai - dl4j-test-resources - ${dl4j-test-resources.version} - test - - - - - only-eclipse - - - m2e.version - - - - - - - org.eclipse.m2e - lifecycle-mapping - 1.0.0 - - - - - - com.lewisd - lint-maven-plugin - [0.0.11,) - - check - - - - - - - - - - - - - - - - diff --git a/cavis-datavec/cavis-datavec-api/build.gradle b/cavis-datavec/cavis-datavec-api/build.gradle index ecad3eda0..c7f07f255 100644 --- a/cavis-datavec/cavis-datavec-api/build.gradle +++ b/cavis-datavec/cavis-datavec-api/build.gradle @@ -2,55 +2,10 @@ plugins { id 'java-library' id 'maven-publish' id 'signing' - id 'idea' } -ext { - buildTarget = rootProject.ext.buildTarget -} -idea { - module { - downloadJavadoc = true // defaults to false - downloadSources = true - } -} - -apply from: "../../chooseBackend.gradle" - -chipList.each { thisChip -> - configurations.register("${thisChip}TestImplementation") { - it.extendsFrom configurations.testImplementation - it.extendsFrom configurations.implementation - } - configurations.register("${thisChip}TestRuntime") { - it.extendsFrom configurations.testRuntimeOnly - it.extendsFrom configurations.api - it.extendsFrom configurations.implementation - it.extendsFrom configurations.testImplementation - } - - tasks.register("${thisChip}Test", Test) { - it.testClassesDirs = sourceSets.test.output.classesDirs - it.useJUnitPlatform() - it.classpath = configurations.getByName("${thisChip}TestRuntime") - it.classpath += sourceSets.test.output.classesDirs - it.classpath += sourceSets.main.output.classesDirs - it.ignoreFailures = true - it.testLogging { - events "PASSED", "SKIPPED", "FAILED", "STANDARD_OUT", "STANDARD_ERROR" - } - //it.jvmArgs("-Dorg.bytedeco.javacpp.logger.debug=true") - - // it.debug = true - } - - tasks.test.dependsOn "${thisChip}Test" -} - -test { - enabled = false -} +apply from: "${rootProject.projectDir.path}/createTestBackends.gradle" dependencies { testImplementation 'org.junit.jupiter:junit-jupiter-params' @@ -59,32 +14,6 @@ dependencies { testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT" - if(withCuda()) { - cudaTestRuntime platform(project(":cavis-common-platform")) - cudaTestRuntime project(":cavis-native:cavis-native-jcublas") - cudaTestRuntime group: "org.bytedeco", name: "openblas" - cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget - cudaTestRuntime "org.bytedeco:cuda" - cudaTestRuntime (project(":cavis-native:cavis-native-lib")) { - capabilities{ - it.requireCapabilities "net.brutex.cavis-native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT" - } - } - } - - if(withCpu()) { - cpuTestRuntime project(":cavis-native:cavis-native-cpu") - cpuTestRuntime group: "org.bytedeco", name: "openblas" - cpuTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget - cpuTestRuntime (project(":cavis-native:cavis-native-lib")) { - capabilities{ - it.requireCapabilities "net.brutex.cavis-native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT" - } - } - } - - - implementation platform(project(':cavis-common-platform')) implementation project(':cavis-dnn:cavis-dnn-common') @@ -113,7 +42,6 @@ dependencies { testImplementation 'org.hamcrest:hamcrest-core:1.3' - implementation 'org.bytedeco:javacpp' } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/package-info.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/package-info.java new file mode 100644 index 000000000..fe3929a3d --- /dev/null +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/package-info.java @@ -0,0 +1,6 @@ +/** + * This is the core data vectorisation API (cavis-datavec-api). The main concept is to have + * {@link org.datavec.api.Writable} forming a {@link org.datavec.api.Record}. + */ + +package org.datavec.api; \ No newline at end of file diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java index 332358f74..a124ade3e 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/RecordReader.java @@ -35,129 +35,130 @@ import java.util.List; public interface RecordReader extends AutoCloseable, Serializable, Configurable { - String NAME_SPACE = RecordReader.class.getName(); + String NAME_SPACE = RecordReader.class.getName(); - String APPEND_LABEL = NAME_SPACE + ".appendlabel"; - String LABELS = NAME_SPACE + ".labels"; + String APPEND_LABEL = NAME_SPACE + ".appendlabel"; + String LABELS = NAME_SPACE + ".labels"; - /** - * Called once at initialization. - * - * @param split the split that defines the range of records to read - * @throws IOException - * @throws InterruptedException - */ - void initialize(InputSplit split) throws IOException, InterruptedException; + /** + * Called once at initialization. + * + * @param split the split that defines the range of records to read + * @throws IOException + * @throws InterruptedException + */ + void initialize(InputSplit split) throws IOException, InterruptedException; - /** - * Called once at initialization. - * - * @param conf a configuration for initialization - * @param split the split that defines the range of records to read - * @throws IOException - * @throws InterruptedException - */ - void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException; + /** + * Called once at initialization. + * + * @param conf a configuration for initialization + * @param split the split that defines the range of records to read + * @throws IOException + * @throws InterruptedException + */ + void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException; - /** - * This method returns true, if next(int) signature is supported by this RecordReader implementation. - * - * @return - */ - boolean batchesSupported(); + /** + * This method returns true, if next(int) signature is supported by this RecordReader + * implementation. + * + * @return + */ + boolean batchesSupported(); - /** - * This method will be used, if batchesSupported() returns true. - * - * @param num - * @return - */ - List> next(int num); + /** + * This method will be used, if batchesSupported() returns true. + * + * @param num + * @return + */ + List> next(int num); - /** - * Get the next record - * - * @return - */ - List next(); + /** + * Get the next record + * + * @return + */ + List next(); + /** + * Whether there are anymore records + * + * @return + */ + boolean hasNext(); - /** - * Whether there are anymore records - * - * @return - */ - boolean hasNext(); + /** + * List of label strings + * + * @return + */ + List getLabels(); - /** - * List of label strings - * - * @return - */ - List getLabels(); + /** + * Reset record reader iterator + */ + void reset(); - /** - * Reset record reader iterator - */ - void reset(); + /** + * @return True if the record reader can be reset, false otherwise. Note that some record readers + * cannot be reset - for example, if they are backed by a non-resettable input split (such as + * certain types of streams) + */ + boolean resetSupported(); - /** - * @return True if the record reader can be reset, false otherwise. Note that some record readers cannot be reset - - * for example, if they are backed by a non-resettable input split (such as certain types of streams) - */ - boolean resetSupported(); - - /** - * Load the record from the given DataInputStream - * Unlike {@link #next()} the internal state of the RecordReader is not modified - * Implementations of this method should not close the DataInputStream - * - * @throws IOException if error occurs during reading from the input stream - */ - List record(URI uri, DataInputStream dataInputStream) throws IOException; + /** + * Load the record from the given DataInputStream Unlike {@link #next()} the internal state of the + * RecordReader is not modified Implementations of this method should not close the + * DataInputStream + * + * @throws IOException if error occurs during reading from the input stream + */ + List record(URI uri, DataInputStream dataInputStream) throws IOException; - /** - * Similar to {@link #next()}, but returns a {@link Record} object, that may include metadata such as the source - * of the data - * - * @return next record - */ - Record nextRecord(); + /** + * Similar to {@link #next()}, but returns a {@link Record} object, that may include metadata such + * as the source of the data + * + * @return next record + */ + Record nextRecord(); - /** - * Load a single record from the given {@link RecordMetaData} instance
- * Note: that for data that isn't splittable (i.e., text data that needs to be scanned/split), it is more efficient to - * load multiple records at once using {@link #loadFromMetaData(List)} - * - * @param recordMetaData Metadata for the record that we want to load from - * @return Single record for the given RecordMetaData instance - * @throws IOException If I/O error occurs during loading - */ - Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException; + /** + * Load a single record from the given {@link RecordMetaData} instance
Note: that for data + * that isn't splittable (i.e., text data that needs to be scanned/split), it is more efficient to + * load multiple records at once using {@link #loadFromMetaData(List)} + * + * @param recordMetaData Metadata for the record that we want to load from + * @return Single record for the given RecordMetaData instance + * @throws IOException If I/O error occurs during loading + */ + Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException; - /** - * Load multiple records from the given a list of {@link RecordMetaData} instances
- * - * @param recordMetaDatas Metadata for the records that we want to load from - * @return Multiple records for the given RecordMetaData instances - * @throws IOException If I/O error occurs during loading - */ - List loadFromMetaData(List recordMetaDatas) throws IOException; + /** + * Load multiple records from the given a list of {@link RecordMetaData} instances
+ * + * @param recordMetaDatas Metadata for the records that we want to load from + * @return Multiple records for the given RecordMetaData instances + * @throws IOException If I/O error occurs during loading + */ + List loadFromMetaData(List recordMetaDatas) throws IOException; - /** - * Get the record listeners for this record reader. - */ - List getListeners(); + /** + * Get the record listeners for this record reader. + */ + List getListeners(); - /** - * Set the record listeners for this record reader. - */ - void setListeners(RecordListener... listeners); + /** + * Set the record listeners for this record reader. + */ + void setListeners(RecordListener... listeners); - /** - * Set the record listeners for this record reader. - */ - void setListeners(Collection listeners); + /** + * Set the record listeners for this record reader. + */ + void setListeners(Collection listeners); } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java index fb010f258..33a79b0c2 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/records/reader/impl/FileRecordReader.java @@ -40,202 +40,206 @@ import java.util.*; /** * File reader/writer - * - * @author Adam Gibson */ public class FileRecordReader extends BaseRecordReader { - protected Iterator locationsIterator; - protected Configuration conf; - protected URI currentUri; - protected List labels; - protected boolean appendLabel = false; - @Getter @Setter - protected String charset = StandardCharsets.UTF_8.name(); //Using String as StandardCharsets.UTF_8 is not serializable + protected Iterator locationsIterator; + protected Configuration conf; + protected URI currentUri; + protected List labels; + protected boolean appendLabel = false; + @Getter + @Setter + protected String charset = StandardCharsets.UTF_8.name(); //Using String as StandardCharsets.UTF_8 is not serializable - public FileRecordReader() {} + public FileRecordReader() { + } - @Override - public void initialize(InputSplit split) throws IOException, InterruptedException { - super.initialize(split); - doInitialize(split); - } + @Override + public void initialize(InputSplit split) throws IOException, InterruptedException { + super.initialize(split); + doInitialize(split); + } - protected void doInitialize(InputSplit split) { + protected void doInitialize(InputSplit split) { - if (labels == null && appendLabel) { - URI[] locations = split.locations(); - if (locations.length > 0) { - Set labels = new HashSet<>(); - for(URI u : locations){ - String[] pathSplit = u.toString().split("[/\\\\]"); - labels.add(pathSplit[pathSplit.length-2]); - } - this.labels = new ArrayList<>(labels); - Collections.sort(this.labels); - } + if (labels == null && appendLabel) { + URI[] locations = split.locations(); + if (locations.length > 0) { + Set labels = new HashSet<>(); + for (URI u : locations) { + String[] pathSplit = u.toString().split("[/\\\\]"); + labels.add(pathSplit[pathSplit.length - 2]); } - locationsIterator = split.locationsIterator(); + this.labels = new ArrayList<>(labels); + Collections.sort(this.labels); + } + } + locationsIterator = split.locationsIterator(); + } + + @Override + public void initialize(Configuration conf, InputSplit split) + throws IOException, InterruptedException { + appendLabel = conf.getBoolean(APPEND_LABEL, true); + doInitialize(split); + this.inputSplit = split; + this.conf = conf; + } + + @Override + public List next() { + return nextRecord().getRecord(); + } + + private List loadFromStream(URI uri, InputStream next, Charset charset) { + List ret = new ArrayList<>(); + try { + if (!(next instanceof BufferedInputStream)) { + next = new BufferedInputStream(next); + } + String s = org.apache.commons.io.IOUtils.toString(next, charset); + ret.add(new Text(s)); + if (appendLabel) { + int idx = getLabel(uri); + ret.add(new IntWritable(idx)); + } + } catch (IOException e) { + throw new IllegalStateException("Error reading from input stream: " + uri); + } + return ret; + } + + /** + * Return the current label. The index of the current file's parent directory in the label list + * + * @return The index of the current file's parent directory + */ + public int getCurrentLabel() { + return getLabel(currentUri); + } + + public int getLabel(URI uri) { + String s = uri.toString(); + int lastIdx = Math.max(s.lastIndexOf('/'), + s.lastIndexOf('\\')); //Note: if neither are found, -1 is fine here + String sub = s.substring(0, lastIdx); + int secondLastIdx = Math.max(sub.lastIndexOf('/'), sub.lastIndexOf('\\')); + String name = s.substring(secondLastIdx + 1, lastIdx); + return labels.indexOf(name); + } + + public List getLabels() { + return labels; + } + + public void setLabels(List labels) { + this.labels = labels; + } + + @Override + public boolean hasNext() { + return locationsIterator.hasNext(); + } + + @Override + public void close() throws IOException { + + } + + @Override + public void setConf(Configuration conf) { + this.conf = conf; + } + + @Override + public Configuration getConf() { + return conf; + } + + @Override + public List> next(int num) { + List> ret = new ArrayList<>(num); + int numBatches = 0; + while (hasNext() && numBatches < num) { + ret.add(next()); } - @Override - public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { - appendLabel = conf.getBoolean(APPEND_LABEL, true); - doInitialize(split); - this.inputSplit = split; - this.conf = conf; + return ret; + } + + @Override + public void reset() { + if (inputSplit == null) { + throw new UnsupportedOperationException("Cannot reset without first initializing"); + } + try { + doInitialize(inputSplit); + } catch (Exception e) { + throw new RuntimeException("Error during LineRecordReader reset", e); + } + } + + @Override + public boolean resetSupported() { + if (inputSplit != null) { + return inputSplit.resetSupported(); + } + return false; //reset() throws exception on reset() if inputSplit is null + } + + @Override + public List record(URI uri, DataInputStream dataInputStream) throws IOException { + invokeListeners(uri); + //Here: reading the entire file to a Text writable + BufferedReader br = new BufferedReader(new InputStreamReader(dataInputStream)); + StringBuilder sb = new StringBuilder(); + String line; + while ((line = br.readLine()) != null) { + sb.append(line).append("\n"); + } + return Collections.singletonList(new Text(sb.toString())); + } + + @Override + public Record nextRecord() { + URI next = locationsIterator.next(); + invokeListeners(next); + + List ret; + try (InputStream s = streamCreatorFn.apply(next)) { + ret = loadFromStream(next, s, Charset.forName(charset)); + } catch (IOException e) { + throw new RuntimeException("Error reading from stream for URI: " + next); } - @Override - public List next() { - return nextRecord().getRecord(); + return new org.datavec.api.records.impl.Record(ret, + new RecordMetaDataURI(next, FileRecordReader.class)); + } + + @Override + public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { + return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0); + } + + @Override + public List loadFromMetaData(List recordMetaDatas) throws IOException { + List out = new ArrayList<>(); + + for (RecordMetaData meta : recordMetaDatas) { + URI uri = meta.getURI(); + + List list; + try (InputStream s = streamCreatorFn.apply(uri)) { + list = loadFromStream(uri, s, Charset.forName(charset)); + } catch (IOException e) { + throw new RuntimeException("Error reading from stream for URI: " + uri); + } + + out.add(new org.datavec.api.records.impl.Record(list, meta)); } - private List loadFromStream(URI uri, InputStream next, Charset charset) { - List ret = new ArrayList<>(); - try { - if(!(next instanceof BufferedInputStream)){ - next = new BufferedInputStream(next); - } - String s = org.apache.commons.io.IOUtils.toString(next, charset); - ret.add(new Text(s)); - if (appendLabel) { - int idx = getLabel(uri); - ret.add(new IntWritable(idx)); - } - } catch (IOException e) { - throw new IllegalStateException("Error reading from input stream: " + uri); - } - return ret; - } - - /** - * Return the current label. - * The index of the current file's parent directory - * in the label list - * @return The index of the current file's parent directory - */ - public int getCurrentLabel() { - return getLabel(currentUri); - } - - public int getLabel(URI uri){ - String s = uri.toString(); - int lastIdx = Math.max(s.lastIndexOf('/'), s.lastIndexOf('\\')); //Note: if neither are found, -1 is fine here - String sub = s.substring(0, lastIdx); - int secondLastIdx = Math.max(sub.lastIndexOf('/'), sub.lastIndexOf('\\')); - String name = s.substring(secondLastIdx+1, lastIdx); - return labels.indexOf(name); - } - - public List getLabels() { - return labels; - } - - public void setLabels(List labels) { - this.labels = labels; - } - - @Override - public boolean hasNext() { - return locationsIterator.hasNext(); - } - - @Override - public void close() throws IOException { - - } - - @Override - public void setConf(Configuration conf) { - this.conf = conf; - } - - @Override - public Configuration getConf() { - return conf; - } - - @Override - public List> next(int num) { - List> ret = new ArrayList<>(num); - int numBatches = 0; - while (hasNext() && numBatches < num) { - ret.add(next()); - } - - return ret; - } - @Override - public void reset() { - if (inputSplit == null) - throw new UnsupportedOperationException("Cannot reset without first initializing"); - try { - doInitialize(inputSplit); - } catch (Exception e) { - throw new RuntimeException("Error during LineRecordReader reset", e); - } - } - - @Override - public boolean resetSupported() { - if(inputSplit != null){ - return inputSplit.resetSupported(); - } - return false; //reset() throws exception on reset() if inputSplit is null - } - - @Override - public List record(URI uri, DataInputStream dataInputStream) throws IOException { - invokeListeners(uri); - //Here: reading the entire file to a Text writable - BufferedReader br = new BufferedReader(new InputStreamReader(dataInputStream)); - StringBuilder sb = new StringBuilder(); - String line; - while ((line = br.readLine()) != null) { - sb.append(line).append("\n"); - } - return Collections.singletonList(new Text(sb.toString())); - } - - @Override - public Record nextRecord() { - URI next = locationsIterator.next(); - invokeListeners(next); - - List ret; - try(InputStream s = streamCreatorFn.apply(next)) { - ret = loadFromStream(next, s, Charset.forName(charset)); - } catch (IOException e){ - throw new RuntimeException("Error reading from stream for URI: " + next); - } - - return new org.datavec.api.records.impl.Record(ret,new RecordMetaDataURI(next, FileRecordReader.class)); - } - - @Override - public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { - return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0); - } - - @Override - public List loadFromMetaData(List recordMetaDatas) throws IOException { - List out = new ArrayList<>(); - - for (RecordMetaData meta : recordMetaDatas) { - URI uri = meta.getURI(); - - List list; - try(InputStream s = streamCreatorFn.apply(uri)) { - list = loadFromStream(uri, s, Charset.forName(charset)); - } catch (IOException e){ - throw new RuntimeException("Error reading from stream for URI: " + uri); - } - - out.add(new org.datavec.api.records.impl.Record(list, meta)); - } - - return out; - } + return out; + } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableFactory.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableFactory.java index 2b7d419ea..eba29e63c 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableFactory.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableFactory.java @@ -32,95 +32,99 @@ import org.datavec.api.Writable; public class WritableFactory { - private static final WritableFactory INSTANCE = new WritableFactory(); + private static final WritableFactory INSTANCE = new WritableFactory(); - private final Map> map = new ConcurrentHashMap<>(); - private final Map> constructorMap = new ConcurrentHashMap<>(); + private final Map> map = new ConcurrentHashMap<>(); + private final Map> constructorMap = new ConcurrentHashMap<>(); - private WritableFactory() { - for (WritableType wt : WritableType.values()) { - if (wt.isCoreWritable()) { - registerWritableType((short) wt.ordinal(), wt.getWritableClass()); - } - } + private WritableFactory() { + for (WritableType wt : WritableType.values()) { + if (wt.isCoreWritable()) { + registerWritableType((short) wt.ordinal(), wt.getWritableClass()); + } + } + } + + /** + * @return Singleton WritableFactory instance + */ + public static WritableFactory getInstance() { + return INSTANCE; + } + + /** + * Register a writable class with a specific key (as a short). Note that key values must be unique + * for each type of Writable, as they are used as type information in certain types of + * serialisation. Consequently, an exception will be thrown If the key value is not unique or is + * already assigned.
Note that in general, this method needs to only be used for custom + * Writable types; Care should be taken to ensure that the given key does not change once + * assigned. + * + * @param writableTypeKey Key for the Writable + * @param writableClass Class for the given key. Must have a no-arg constructor + */ + public void registerWritableType(short writableTypeKey, + @NonNull Class writableClass) { + if (map.containsKey(writableTypeKey)) { + throw new UnsupportedOperationException( + "Key " + writableTypeKey + " is already registered to type " + + map.get(writableTypeKey) + " and cannot be registered to " + writableClass); } - /** - * @return Singleton WritableFactory instance - */ - public static WritableFactory getInstance() { - return INSTANCE; + Constructor c; + try { + c = writableClass.getDeclaredConstructor(); + } catch (NoSuchMethodException e) { + throw new RuntimeException("Cannot find no-arg constructor for class " + writableClass); } - /** - * Register a writable class with a specific key (as a short). Note that key values must be unique for each type of - * Writable, as they are used as type information in certain types of serialisation. Consequently, an exception will - * be thrown If the key value is not unique or is already assigned.
- * Note that in general, this method needs to only be used for custom Writable types; Care should be taken to ensure - * that the given key does not change once assigned. - * - * @param writableTypeKey Key for the Writable - * @param writableClass Class for the given key. Must have a no-arg constructor - */ - public void registerWritableType(short writableTypeKey, @NonNull Class writableClass) { - if (map.containsKey(writableTypeKey)) { - throw new UnsupportedOperationException("Key " + writableTypeKey + " is already registered to type " - + map.get(writableTypeKey) + " and cannot be registered to " + writableClass); - } + map.put(writableTypeKey, writableClass); + constructorMap.put(writableTypeKey, c); + } - Constructor c; - try { - c = writableClass.getDeclaredConstructor(); - } catch (NoSuchMethodException e) { - throw new RuntimeException("Cannot find no-arg constructor for class " + writableClass); - } - - map.put(writableTypeKey, writableClass); - constructorMap.put(writableTypeKey, c); + /** + * Create a new writable instance (using reflection) given the specified key + * + * @param writableTypeKey Key to create a new writable instance for + * @return A new (empty/default) Writable instance + */ + public Writable newWritable(short writableTypeKey) { + Constructor c = constructorMap.get(writableTypeKey); + if (c == null) { + throw new IllegalStateException("Unknown writable key: " + writableTypeKey); } - - /** - * Create a new writable instance (using reflection) given the specified key - * - * @param writableTypeKey Key to create a new writable instance for - * @return A new (empty/default) Writable instance - */ - public Writable newWritable(short writableTypeKey) { - Constructor c = constructorMap.get(writableTypeKey); - if (c == null) { - throw new IllegalStateException("Unknown writable key: " + writableTypeKey); - } - try { - return c.newInstance(); - } catch (Exception e) { - throw new RuntimeException("Could not create new Writable instance"); - } + try { + return c.newInstance(); + } catch (Exception e) { + throw new RuntimeException("Could not create new Writable instance"); } + } - /** - * A convenience method for writing a given Writable object to a DataOutput. The key is 1st written (a single short) - * followed by the value from writable. - * - * @param w Writable value - * @param dataOutput DataOutput to write both key and value to - * @throws IOException If an error occurs during writing to the DataOutput - */ - public void writeWithType(Writable w, DataOutput dataOutput) throws IOException { - w.writeType(dataOutput); - w.write(dataOutput); - } + /** + * A convenience method for writing a given Writable object to a DataOutput. The key is 1st + * written (a single short) followed by the value from writable. + * + * @param w Writable value + * @param dataOutput DataOutput to write both key and value to + * @throws IOException If an error occurs during writing to the DataOutput + */ + public void writeWithType(Writable w, DataOutput dataOutput) throws IOException { + w.writeType(dataOutput); + w.write(dataOutput); + } - /** - * Read a Writable From the DataInput, where the Writable was previously written using {@link #writeWithType(Writable, DataOutput)} - * - * @param dataInput DataInput to read the Writable from - * @return Writable from the DataInput - * @throws IOException In an error occurs during reading - */ - public Writable readWithType(DataInput dataInput) throws IOException { - Writable w = newWritable(dataInput.readShort()); - w.readFields(dataInput); - return w; - } + /** + * Read a Writable From the DataInput, where the Writable was previously written using + * {@link #writeWithType(Writable, DataOutput)} + * + * @param dataInput DataInput to read the Writable from + * @return Writable from the DataInput + * @throws IOException In an error occurs during reading + */ + public Writable readWithType(DataInput dataInput) throws IOException { + Writable w = newWritable(dataInput.readShort()); + w.readFields(dataInput); + return w; + } } diff --git a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableType.java b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableType.java index 3afbe4088..a889dfd39 100644 --- a/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableType.java +++ b/cavis-datavec/cavis-datavec-api/src/main/java/org/datavec/api/writable/WritableType.java @@ -23,72 +23,71 @@ package org.datavec.api.writable; import org.datavec.api.Writable; public enum WritableType { - Boolean, Byte, Double, Float, Int, Long, Null, Text, NDArray, Image,Arrow,Bytes; + Boolean, Byte, Double, Float, Int, Long, Null, Text, NDArray, Image, Arrow, Bytes; - //NOTE TO DEVELOPERS: - //In the current implementation, the order (ordinal idx) for the WritableType values matters. - //New writables can be added to the end of the list, but not between exiting types, as this will change the - //ordinal value for all writable types that follow, which will mess up serialization in some cases (like Spark - // sequence and map files) - //Alternatively, modify WritableType.typeIdx() to ensure backward compatibility + //NOTE TO DEVELOPERS: + //In the current implementation, the order (ordinal idx) for the WritableType values matters. + //New writables can be added to the end of the list, but not between exiting types, as this will change the + //ordinal value for all writable types that follow, which will mess up serialization in some cases (like Spark + // sequence and map files) + //Alternatively, modify WritableType.typeIdx() to ensure backward compatibility - /** - * - * @return True if Writable is defined in datavec-api, false otherwise - */ - public boolean isCoreWritable() { - switch (this) { - case Image: - case Arrow: - return false; - default: - return true; - } + /** + * @return True if Writable is defined in datavec-api, false otherwise + */ + public boolean isCoreWritable() { + switch (this) { + case Image: + case Arrow: + return false; + default: + return true; } + } - /** - * Return a unique type index for the given writable - * - * @return Type index for the writable - */ - public short typeIdx() { - return (short) this.ordinal(); - } + /** + * Return a unique type index for the given writable + * + * @return Type index for the writable + */ + public short typeIdx() { + return (short) this.ordinal(); + } - /** - * Return the class of the implementation corresponding to each WritableType. - * Note that if {@link #isCoreWritable()} returns false, null will be returned by this method. - * - * @return Class for the given WritableType - */ - public Class getWritableClass() { - switch (this) { - case Boolean: - return BooleanWritable.class; - case Byte: - return ByteWritable.class; - case Double: - return DoubleWritable.class; - case Float: - return FloatWritable.class; - case Int: - return IntWritable.class; - case Long: - return LongWritable.class; - case Null: - return NullWritable.class; - case Text: - return Text.class; - case NDArray: - return NDArrayWritable.class; - case Bytes: - return ByteWritable.class; - case Image: - case Arrow: - default: - return null; - } + /** + * Return the class of the implementation corresponding to each WritableType. Note that if + * {@link #isCoreWritable()} returns false, null will be returned by this method. + * + * @return Class for the given WritableType + */ + public Class getWritableClass() { + switch (this) { + case Boolean: + return BooleanWritable.class; + case Byte: + return ByteWritable.class; + case Double: + return DoubleWritable.class; + case Float: + return FloatWritable.class; + case Int: + return IntWritable.class; + case Long: + return LongWritable.class; + case Null: + return NullWritable.class; + case Text: + return Text.class; + case NDArray: + return NDArrayWritable.class; + case Bytes: + return ByteWritable.class; + case Image: + case Arrow: + default: + return null; } + } } diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 03cf9e177..774c600f5 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -106,6 +106,17 @@ publishing { */ } } + + task printDeps { + doLast { + configurations.api.dependencies.each { dep -> + println "${dep.group} - ${dep.name} - ${dep.version}" + dep.artifacts.each { art -> + println " ${art.extension} - ${art.classifier}" + } + } + } + } } diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 36fa1f765..0439110ce 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -424,6 +424,7 @@ tasks.withType(Javadoc) { options.addStringOption('Xdoclint:none', '-quiet') } +/* jar { manifest { attributes 'Class-Path': configurations.runtimeClasspath.collect { it.getName() }.join(' '), @@ -436,7 +437,7 @@ jar { } //archiveClassifier = "${javacppPlatform}${javacppPlatformExtension}-${chip}" } - +*/ javadoc { dependsOn "javacppPomProperties" failOnError = false @@ -457,6 +458,13 @@ javadoc { enabled = true } +artifacts { + chipList.each { thisChip -> + implementation(tasks.getByName("${thisChip}SupportJar")) + } +} + +/* artifacts { archives jar chipList.each { thisChip -> @@ -464,6 +472,8 @@ artifacts { } } + */ + publishing { publications { mavenJava(MavenPublication) { @@ -499,6 +509,19 @@ if( osdetector.os.startsWith("windows")) { } } + +task printDeps { + doLast { + configurations.apiElements.dependencies.each { dep -> + println "${dep.group} - ${dep.name} - ${dep.version}" + dep.artifacts.each { art -> + println " ${art.extension} - ${art.classifier}" + } + } + } +} + + /* def pomClosure = { name = 'Brutex AI - Native Components' diff --git a/chooseBackend.gradle b/chooseBackend.gradle index 258650a05..9e34b7552 100644 --- a/chooseBackend.gradle +++ b/chooseBackend.gradle @@ -19,13 +19,22 @@ * */ ext { - chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() + chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA + testChip = (properties.CAVIS_TEST_CHIP ?: " ").toLowerCase() //the default is without specific backend chipList = chip.split(",") + testChipList = testChip.split(",") + /* just for usability */ withCuda = { -> return chip.contains("cuda") } withCpu = { -> return chip.contains("cpu") } + withCudaTest = { -> + return testChip.contains("cuda") + } + withCpuTest = { -> + return testChip.contains("cpu") + } } diff --git a/createTestBackends.gradle b/createTestBackends.gradle index cbe536802..a0cef6c24 100644 --- a/createTestBackends.gradle +++ b/createTestBackends.gradle @@ -24,7 +24,7 @@ ext { buildTarget = rootProject.ext.buildTarget apply from: new File("${project.rootProject.projectDir}/chooseBackend.gradle") - chipList.each { thisChip -> + testChipList.each { thisChip -> configurations.register("${thisChip}TestImplementation") { it.extendsFrom configurations.testImplementation, configurations.implementation @@ -79,7 +79,7 @@ ext { dependencies { - if (withCuda()) { + if (withCudaTest()) { cudaTestRuntime platform(projects.cavisCommonPlatform) cudaTestRuntime projects.cavisNative.cavisNativeJcublas cudaTestRuntime group: "org.bytedeco", name: "openblas" @@ -89,12 +89,12 @@ ext { cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist" cudaTestRuntime(project(":cavis-native:cavis-native-lib")) { capabilities { - it.requireCapabilities "net.brutex.cavis-native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT" + it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT" } } } - if (withCpu()) { + if (withCpuTest()) { cpuTestRuntime platform(projects.cavisCommonPlatform) cpuTestRuntime projects.cavisNative.cavisNativeCpu cpuTestRuntime group: "org.bytedeco", name: "openblas" @@ -103,7 +103,7 @@ ext { cpuTestRuntime group: "org.bytedeco", name: "opencv", classifier: buildTarget cpuTestRuntime(project(":cavis-native:cavis-native-lib")) { capabilities { - it.requireCapabilities "net.brutex.cavis-native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT" + it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT" } } } From 2bde6f0975c0cbaa66d507bf9cfb352de74de8b5 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 13:48:16 +0200 Subject: [PATCH 100/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 1 + 1 file changed, 1 insertion(+) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 774c600f5..8b4d96604 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -28,6 +28,7 @@ dependencies { api sproj } } + api project(path: ":cavis-native:cavis-native-lib", configuration: "runtimeElements") /* api(projects.cavisNative.cavisNativeLib) { capabilities { From b2acaf80620940cb64d6254002437e8f4b2c96ba Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 19:22:58 +0200 Subject: [PATCH 101/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- build.gradle | 2 +- cavis-full/build.gradle | 29 +++++++--------------- cavis-native/cavis-native-lib/build.gradle | 9 ++++--- chooseBackend.gradle | 2 ++ 4 files changed, 17 insertions(+), 25 deletions(-) diff --git a/build.gradle b/build.gradle index fc9167f30..479821e9d 100644 --- a/build.gradle +++ b/build.gradle @@ -114,7 +114,7 @@ allprojects { Project proj -> /* Need to verify the property exists, as some modules may not declare it (i.e. the java-platform plugin) */ - if (components.hasProperty("java") && !proj.name.equals("cavis-native-lib")) { + if (components.hasProperty("java") ) { from components.java } } diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 8b4d96604..c124dbac8 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -21,14 +21,14 @@ dependencies { && !sproj.name.equals("Cavis") && !sproj.name.equals("cavis-datavec") && !sproj.name.equals("cavis-dnn") - && !sproj.name.equals("cavis-native") + && !sproj.name.equals("cavis-native") && !sproj.name.equals("cavis-native-lib") && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { api sproj } } - api project(path: ":cavis-native:cavis-native-lib", configuration: "runtimeElements") + api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements") /* api(projects.cavisNative.cavisNativeLib) { capabilities { @@ -76,17 +76,17 @@ tasks.getByName("jar") { } } -/* - -/* -artifacts { - archives customFatJar -} */ + +artifacts { + archives shadowJar +} + shadowJar { - enabled false; + enabled true; zip64 true //need this to support jars with more than 65535 entries + classifier null } publishing { @@ -107,17 +107,6 @@ publishing { */ } } - - task printDeps { - doLast { - configurations.api.dependencies.each { dep -> - println "${dep.group} - ${dep.name} - ${dep.version}" - dep.artifacts.each { art -> - println " ${art.extension} - ${art.classifier}" - } - } - } - } } diff --git a/cavis-native/cavis-native-lib/build.gradle b/cavis-native/cavis-native-lib/build.gradle index 0439110ce..8ecbeafc1 100644 --- a/cavis-native/cavis-native-lib/build.gradle +++ b/cavis-native/cavis-native-lib/build.gradle @@ -459,6 +459,7 @@ javadoc { } artifacts { + //implementation(jar) chipList.each { thisChip -> implementation(tasks.getByName("${thisChip}SupportJar")) } @@ -473,7 +474,7 @@ artifacts { } */ - +/* publishing { publications { mavenJava(MavenPublication) { @@ -484,8 +485,8 @@ publishing { } } } - - +*/ +/* if( osdetector.os.startsWith("windows")) { @@ -508,7 +509,7 @@ if( osdetector.os.startsWith("windows")) { } } } - +*/ task printDeps { doLast { diff --git a/chooseBackend.gradle b/chooseBackend.gradle index 9e34b7552..9a25c6caf 100644 --- a/chooseBackend.gradle +++ b/chooseBackend.gradle @@ -21,6 +21,8 @@ ext { chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA testChip = (properties.CAVIS_TEST_CHIP ?: " ").toLowerCase() //the default is without specific backend + logger.quiet("Building for chips ${chip} and running tests with backends for ${testChip}") + chipList = chip.split(",") testChipList = testChip.split(",") From badbc19eae69b62b7d346d08a6f2878daa9771e3 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 19:25:23 +0200 Subject: [PATCH 102/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index c124dbac8..f77c1dbd5 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -28,7 +28,7 @@ dependencies { api sproj } } - api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements") + api project(path: ":cavis-native:cavis-native-lib", configuration: "runtimeElements") /* api(projects.cavisNative.cavisNativeLib) { capabilities { From 565bc17cfb62aaef307386034eac82763a2a8b13 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 19:38:10 +0200 Subject: [PATCH 103/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index f77c1dbd5..1d956c91f 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -84,9 +84,8 @@ artifacts { } shadowJar { - enabled true; + enabled false; zip64 true //need this to support jars with more than 65535 entries - classifier null } publishing { From bc99c932cfe93be0daff8530210d276edad15f76 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 26 Oct 2022 19:43:30 +0200 Subject: [PATCH 104/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 1d956c91f..ce704c79e 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -84,7 +84,7 @@ artifacts { } shadowJar { - enabled false; + enabled true; zip64 true //need this to support jars with more than 65535 entries } From 49fdbff24c98547bf02c2f6a42eee8ba6d65f00c Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 27 Oct 2022 14:32:48 +0200 Subject: [PATCH 105/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index ce704c79e..6fedf169f 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -28,7 +28,8 @@ dependencies { api sproj } } - api project(path: ":cavis-native:cavis-native-lib", configuration: "runtimeElements") + if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuRuntimeElements") + if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaRuntimeElements") /* api(projects.cavisNative.cavisNativeLib) { capabilities { From fa3c9a3a4fb46227360c1dd192dc4625bcd3ab62 Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 27 Oct 2022 14:36:31 +0200 Subject: [PATCH 106/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 6fedf169f..292def927 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -28,8 +28,8 @@ dependencies { api sproj } } - if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuRuntimeElements") - if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaRuntimeElements") + if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportRuntimeElements") + if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements") /* api(projects.cavisNative.cavisNativeLib) { capabilities { From 80aad3087ab3db63ece530a67e5bc5af34e9b4c7 Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 27 Oct 2022 18:34:57 +0200 Subject: [PATCH 107/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- build.gradle | 11 +++++++---- cavis-full/build.gradle | 22 ++++++++++++++++------ chooseBackend.gradle | 2 +- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/build.gradle b/build.gradle index 479821e9d..20e45b528 100644 --- a/build.gradle +++ b/build.gradle @@ -108,14 +108,17 @@ allprojects { Project proj -> } plugins.withType(MavenPublishPlugin) { + publishing { publications { - mavenJava(MavenPublication) { - /* Need to verify the property exists, as some + if(! proj.name.contains("cavis-full")) { + mavenJava(MavenPublication) { + /* Need to verify the property exists, as some modules may not declare it (i.e. the java-platform plugin) */ - if (components.hasProperty("java") ) { - from components.java + if (components.hasProperty("java")) { + from components.java + } } } } diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 292def927..0d8bfbdc7 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -28,16 +28,22 @@ dependencies { api sproj } } - if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportRuntimeElements") - if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements") -/* + // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements") + // if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements") + api(projects.cavisNative.cavisNativeLib) { capabilities { - if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) + //if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } } -*/ + api(projects.cavisNative.cavisNativeLib) { + capabilities { + if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) + //if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) + } + } + //if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") @@ -87,12 +93,13 @@ artifacts { shadowJar { enabled true; zip64 true //need this to support jars with more than 65535 entries + archiveClassifier.set('') } publishing { publications { mavenJava(MavenPublication) { - // artifact customFatJar + //artifact customFatJar // from components.java /* pom.withXml { def dependenciesNode = asNode().dependencies @@ -106,6 +113,9 @@ publishing { } */ } + shadow(MavenPublication) { publication -> + project.shadow.component(publication) + } } } diff --git a/chooseBackend.gradle b/chooseBackend.gradle index 9a25c6caf..7a3159f59 100644 --- a/chooseBackend.gradle +++ b/chooseBackend.gradle @@ -21,7 +21,7 @@ ext { chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA testChip = (properties.CAVIS_TEST_CHIP ?: " ").toLowerCase() //the default is without specific backend - logger.quiet("Building for chips ${chip} and running tests with backends for ${testChip}") + logger.debug("Building for chips ${chip} and running tests with backends for ${testChip}") chipList = chip.split(",") testChipList = testChip.split(",") From d91d9db28344e80a1e47080763ec5e3af9e1e5b3 Mon Sep 17 00:00:00 2001 From: brian Date: Sun, 30 Oct 2022 05:28:44 +0100 Subject: [PATCH 108/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 68 +++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 0d8bfbdc7..cef1c944d 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -7,48 +7,49 @@ plugins { apply from: rootProject.projectDir.path+"/chooseBackend.gradle" dependencies { - //Todo clean this - api platform(project(":cavis-common-platform")) - //api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise - //api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" - //api 'org.slf4j:slf4j-simple:2.0.3' - //api 'org.slf4j:slf4j-api:2.0.3' - //TODO for the two below.. either platform specific uber jars or a single big one with all platforms - //api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" + afterEvaluate { + //Todo clean this + api platform(project(":cavis-common-platform")) + //api "org.bytedeco:javacpp:1.5.7" //for some reason we needed to apply version numbers here, they do not end up in POM otherwise + //api "com.fasterxml.jackson.datatype:jackson-datatype-joda:2.10.5" + //api 'org.slf4j:slf4j-simple:2.0.3' + //api 'org.slf4j:slf4j-api:2.0.3' + //TODO for the two below.. either platform specific uber jars or a single big one with all platforms + //api group: "org.bytedeco", name: "javacpp", version: "1.5.7", classifier: "linux-x86_64" - rootProject.getAllprojects().each { Project sproj -> - if (!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") - && !sproj.name.equals("Cavis") - && !sproj.name.equals("cavis-datavec") - && !sproj.name.equals("cavis-dnn") - && !sproj.name.equals("cavis-native") && !sproj.name.equals("cavis-native-lib") - && !sproj.name.equals("cavis-nd4j") - && !sproj.name.equals("cavis-ui") - && !sproj.name.equals("cavis-zoo")) { - api sproj + rootProject.getAllprojects().each { Project sproj -> + if (!sproj.name.equals(name) && !sproj.name.equals("cavis-common-platform") + && !sproj.name.equals("Cavis") + && !sproj.name.equals("cavis-datavec") + && !sproj.name.equals("cavis-dnn") + && !sproj.name.equals("cavis-native") && !sproj.name.equals("cavis-native-lib") + && !sproj.name.equals("cavis-nd4j") + && !sproj.name.equals("cavis-ui") + && !sproj.name.equals("cavis-zoo")) { + api sproj + } } - } - // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements") - // if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements") + // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements") + // if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements") - api(projects.cavisNative.cavisNativeLib) { + api(projects.cavisNative.cavisNativeLib) { capabilities { //if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) - if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) + if (withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) + } + } + api(projects.cavisNative.cavisNativeLib) { + capabilities { + if (withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) + //if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } - } - api(projects.cavisNative.cavisNativeLib) { - capabilities { - if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) - //if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } - } - //if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation") - //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation") - //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") + //if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation") + //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation") + //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") - /* + /* api (project(':cavis-native:cavis-native-lib')) { capabilities { if(withCpu()) requireCapability("net.brutex.cavis.cavis-native:cavis-native-lib-cpu-support") @@ -56,6 +57,7 @@ dependencies { } } */ + } } From f4bd8c7400384bc6e2ef0b93f32b65374e3368cd Mon Sep 17 00:00:00 2001 From: brian Date: Sun, 30 Oct 2022 05:39:15 +0100 Subject: [PATCH 109/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index cef1c944d..395654e1c 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -26,7 +26,7 @@ dependencies { && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { - api sproj + api project(path: sproj.path, configuration: 'apiElements') } } // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements") @@ -100,7 +100,7 @@ shadowJar { publishing { publications { - mavenJava(MavenPublication) { + /*mavenJava(MavenPublication) { //artifact customFatJar // from components.java /* pom.withXml { @@ -113,8 +113,9 @@ publishing { //dependencyNode.appendNode('classifier', 'linux-x86_64-avx2-cpu') //dependencyNode.appendNode('scope', 'compile') } - */ + } + */ shadow(MavenPublication) { publication -> project.shadow.component(publication) } From d02c5d78624b93710aba2624b3f45838beb9d5cb Mon Sep 17 00:00:00 2001 From: brian Date: Sun, 30 Oct 2022 05:42:02 +0100 Subject: [PATCH 110/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 395654e1c..a2724ffd4 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -26,19 +26,19 @@ dependencies { && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { - api project(path: sproj.path, configuration: 'apiElements') + implementation project(path: sproj.path, configuration: 'apiElements') } } // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements") // if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements") - api(projects.cavisNative.cavisNativeLib) { + implementation(projects.cavisNative.cavisNativeLib) { capabilities { //if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) if (withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } } - api(projects.cavisNative.cavisNativeLib) { + implementation(projects.cavisNative.cavisNativeLib) { capabilities { if (withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) //if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) From 9bc4009d10e99947fd451ac5235f0f1e54dfaa5c Mon Sep 17 00:00:00 2001 From: brian Date: Sun, 30 Oct 2022 05:45:35 +0100 Subject: [PATCH 111/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index a2724ffd4..649ed8354 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -26,7 +26,7 @@ dependencies { && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { - implementation project(path: sproj.path, configuration: 'apiElements') + implementation project(path: sproj.path, configuration: 'runtimeElements') } } // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements") From e61068da55832c50eba3ff31e5ca8cf9d268e243 Mon Sep 17 00:00:00 2001 From: brian Date: Sun, 30 Oct 2022 05:52:19 +0100 Subject: [PATCH 112/126] Add jenkinsfile for pipeline build and dockerfile for build Signed-off-by: brian --- cavis-full/build.gradle | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index 649ed8354..c18c258ad 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -26,19 +26,19 @@ dependencies { && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { - implementation project(path: sproj.path, configuration: 'runtimeElements') + api project(path: sproj.path, configuration: 'runtimeElements') } } // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements") // if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements") - implementation(projects.cavisNative.cavisNativeLib) { + api(projects.cavisNative.cavisNativeLib) { capabilities { //if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) if (withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } } - implementation(projects.cavisNative.cavisNativeLib) { + api(projects.cavisNative.cavisNativeLib) { capabilities { if (withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) //if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) From bf10564be7785ad23bd593db4ecd5448a79ba656 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 10 Mar 2023 11:20:32 +0100 Subject: [PATCH 113/126] Adding cuDNN support Signed-off-by: brian --- brutex-extended-tests/build.gradle | 23 +- .../src/test/java/net/brutex/gan/App.java | 279 +++++++ .../src/test/java/net/brutex/gan/GAN.java | 411 ++++++++++ .../net/brutex/gan/GANVisualizationUtils.java | 73 ++ .../net/brutex/gan/MnistDCGANExample.java | 193 +++++ .../java/net/brutex/gan/MnistSimpleGAN.java | 146 ++++ .../brutex/spark/BaseSparkSessionTest.java | 107 ++- .../test/java/net/brutex/spark/BrianTest.java | 396 ++++----- cavis-common-platform/build.gradle | 4 +- ...FieldInterface.java => AbstractField.java} | 61 +- .../java/net/brutex/cavis/dvec/api/Field.java | 61 +- .../common/config/DL4JClassLoading.java | 6 +- cavis-dnn/cavis-dnn-cudnn/build.gradle | 23 + .../deeplearning4j/cuda/BaseCudnnHelper.java | 252 ++++++ .../convolution/CudnnConvolutionHelper.java | 758 ++++++++++++++++++ .../subsampling/CudnnSubsamplingHelper.java | 308 +++++++ .../cuda/dropout/CudnnDropoutHelper.java | 245 ++++++ .../CudnnBatchNormalizationHelper.java | 384 +++++++++ ...CudnnLocalResponseNormalizationHelper.java | 240 ++++++ .../cuda/recurrent/CudnnLSTMHelper.java | 659 +++++++++++++++ .../KerasFlattenRnnPreprocessor.java | 2 + cavis-full/build.gradle | 6 +- cavis-native/cavis-native-lib/CMakeLists.txt | 2 +- chooseBackend.gradle | 10 +- createTestBackends.gradle | 17 +- settings.gradle | 4 + 26 files changed, 4361 insertions(+), 309 deletions(-) create mode 100644 brutex-extended-tests/src/test/java/net/brutex/gan/App.java create mode 100644 brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java create mode 100644 brutex-extended-tests/src/test/java/net/brutex/gan/GANVisualizationUtils.java create mode 100644 brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java create mode 100644 brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java rename cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/{FieldInterface.java => AbstractField.java} (51%) create mode 100644 cavis-dnn/cavis-dnn-cudnn/build.gradle create mode 100644 cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/BaseCudnnHelper.java create mode 100644 cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/convolution/CudnnConvolutionHelper.java create mode 100644 cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/convolution/subsampling/CudnnSubsamplingHelper.java create mode 100644 cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/dropout/CudnnDropoutHelper.java create mode 100644 cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnBatchNormalizationHelper.java create mode 100644 cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnLocalResponseNormalizationHelper.java create mode 100644 cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java diff --git a/brutex-extended-tests/build.gradle b/brutex-extended-tests/build.gradle index bd53f61bd..c15f6d325 100644 --- a/brutex-extended-tests/build.gradle +++ b/brutex-extended-tests/build.gradle @@ -19,8 +19,12 @@ * */ -apply plugin: 'java' -apply plugin: 'maven-publish' +plugins { + id 'java-library' + id 'maven-publish' + id 'com.github.johnrengelman.shadow' version '7.1.2' +} + apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" @@ -54,6 +58,7 @@ dependencies { implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkParameterserver implementation projects.cavisDnn.cavisDnnNnParent.cavisDnnNnCore implementation projects.cavisDnn.cavisDnnNn + implementation projects.cavisUi.cavisUiCommon implementation projects.cavisUi.cavisUiVertx implementation projects.cavisUi.cavisUiModel @@ -66,11 +71,21 @@ dependencies { implementation projects.cavisDnn.cavisDnnParallelwrapper implementation projects.cavisZoo.cavisZooModels - testRuntimeOnly "net.brutex.ai:dl4j-test-resources:1.0.1-SNAPSHOT" } + test { - dependsOn jar + enabled true + dependsOn shadowJar } + +shadowJar { + enabled true; + zip64 true //need this to support jars with more than 65535 entries + archiveClassifier.set('all') + from sourceSets.test.output +} + + diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java new file mode 100644 index 000000000..f4feb6fdf --- /dev/null +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -0,0 +1,279 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.gan; + +import org.apache.commons.lang3.ArrayUtils; +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.PerformanceListener; +import org.junit.jupiter.api.Test; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import javax.swing.*; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.io.File; +import java.util.Arrays; + +public class App { + private static final double LEARNING_RATE = 0.0002; + private static final double GRADIENT_THRESHOLD = 100.0; + private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build(); + + private static JFrame frame; + private static JPanel panel; + + private static Layer[] genLayers() { + return new Layer[] { + new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(), + new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + new DenseLayer.Builder().nIn(256).nOut(512).build(), + new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + new DenseLayer.Builder().nIn(512).nOut(1024).build(), + new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build() + }; + } + + /** + * Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image. + * + * @return config + */ + private static MultiLayerConfiguration generator() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .seed(42) + .updater(UPDATER) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(GRADIENT_THRESHOLD) + .weightInit(WeightInit.XAVIER) + .activation(Activation.IDENTITY) + .list(genLayers()) + .build(); + + return conf; + } + + private static Layer[] disLayers() { + return new Layer[]{ + new DenseLayer.Builder().nIn(784).nOut(1024).build(), + new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + new DropoutLayer.Builder(1 - 0.5).build(), + new DenseLayer.Builder().nIn(1024).nOut(512).build(), + new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + new DropoutLayer.Builder(1 - 0.5).build(), + new DenseLayer.Builder().nIn(512).nOut(256).build(), + new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + new DropoutLayer.Builder(1 - 0.5).build(), + new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build() + }; + } + + private static MultiLayerConfiguration discriminator() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .seed(42) + .updater(UPDATER) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(GRADIENT_THRESHOLD) + .weightInit(WeightInit.XAVIER) + .activation(Activation.IDENTITY) + .list(disLayers()) + .build(); + + return conf; + } + + private static MultiLayerConfiguration gan() { + Layer[] genLayers = genLayers(); + Layer[] disLayers = Arrays.stream(disLayers()) + .map((layer) -> { + if (layer instanceof DenseLayer || layer instanceof OutputLayer) { + return new FrozenLayerWithBackprop(layer); + } else { + return layer; + } + }).toArray(Layer[]::new); + Layer[] layers = ArrayUtils.addAll(genLayers, disLayers); + + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .seed(42) + .updater(UPDATER) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(GRADIENT_THRESHOLD) + .weightInit(WeightInit.XAVIER) + .activation(Activation.IDENTITY) + .list(layers) + .build(); + + return conf; + } + + + @Test + public void runTest() throws Exception { + main(); + } + + public static void main(String... args) throws Exception { + Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); + + MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 42); + + MultiLayerNetwork gen = new MultiLayerNetwork(generator()); + MultiLayerNetwork dis = new MultiLayerNetwork(discriminator()); + MultiLayerNetwork gan = new MultiLayerNetwork(gan()); + gen.init(); + dis.init(); + gan.init(); + + copyParams(gen, dis, gan); + + gen.setListeners(new PerformanceListener(10, true)); + dis.setListeners(new PerformanceListener(10, true)); + gan.setListeners(new PerformanceListener(10, true)); + + trainData.reset(); + + int j = 0; + for (int i = 0; i < 10; i++) { + while (trainData.hasNext()) { + j++; + + // generate data + INDArray real = trainData.next().getFeatures().muli(2).subi(1); + int batchSize = (int) real.shape()[0]; + + INDArray fakeIn = Nd4j.rand(batchSize, 100); + INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); + + DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1)); + DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1)); + + DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet)); + + dis.fit(data); + dis.fit(data); + + // Update the discriminator in the GAN network + updateGan(gen, dis, gan); + + gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1))); + + + if (j % 10 == 1) { + System.out.println("Iteration " + j + " Visualizing..."); + INDArray[] samples = new INDArray[9]; + DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1)); + + for (int k = 0; k < 9; k++) { + INDArray input = fakeSet2.get(k).getFeatures(); + //samples[k] = gen.output(input, false); + samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input); + + } + visualize(samples); + } + } + trainData.reset(); + } + + // Copy the GANs generator to gen. + updateGen(gen, gan); + + gen.save(new File("mnist-mlp-generator.dlj")); + } + + private static void copyParams(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { + int genLayerCount = gen.getLayers().length; + for (int i = 0; i < gan.getLayers().length; i++) { + if (i < genLayerCount) { + gen.getLayer(i).setParams(gan.getLayer(i).params()); + } else { + dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params()); + } + } + } + + private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) { + for (int i = 0; i < gen.getLayers().length; i++) { + gen.getLayer(i).setParams(gan.getLayer(i).params()); + } + } + + private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { + int genLayerCount = gen.getLayers().length; + for (int i = genLayerCount; i < gan.getLayers().length; i++) { + gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params()); + } + } + + private static void visualize(INDArray[] samples) { + if (frame == null) { + frame = new JFrame(); + frame.setTitle("Viz"); + frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); + frame.setLayout(new BorderLayout()); + + panel = new JPanel(); + + panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8)); + frame.add(panel, BorderLayout.CENTER); + frame.setVisible(true); + } + + panel.removeAll(); + + for (INDArray sample : samples) { + panel.add(getImage(sample)); + } + + frame.revalidate(); + frame.pack(); + } + + private static JLabel getImage(INDArray tensor) { + BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); + for (int i = 0; i < 784; i++) { + int pixel = (int)(((tensor.getDouble(i) + 1) * 2) * 255); + bi.getRaster().setSample(i % 28, i / 28, 0, pixel); + } + ImageIcon orig = new ImageIcon(bi); + Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE); + + ImageIcon scaled = new ImageIcon(imageScaled); + + return new JLabel(scaled); + } +} \ No newline at end of file diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java new file mode 100644 index 000000000..25473fc9e --- /dev/null +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java @@ -0,0 +1,411 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.gan; + +import org.apache.commons.lang3.ArrayUtils; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.optimize.api.BaseTrainingListener; +import org.nd4j.evaluation.classification.Evaluation; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.config.Sgd; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + + +/** + * Implementation of vanilla Generative Adversarial Networks as introduced in https://arxiv.org/pdf/1406.2661.pdf. + *

+ * A DL4J GAN is initialized from two networks: a generator and a discriminator and will build a third network, + * the GAN network, from the first two. + * + * @author Max Pumperla + */ +public class GAN { + private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build(); + + public interface DiscriminatorProvider { + MultiLayerNetwork provide(IUpdater updater); + } + + protected Supplier generatorSupplier; + protected DiscriminatorProvider discriminatorSupplier; + + protected MultiLayerNetwork generator; + protected MultiLayerNetwork discriminator; + protected MultiLayerNetwork gan; + protected int latentDim; + + protected IUpdater updater; + protected IUpdater biasUpdater; + protected OptimizationAlgorithm optimizer; + protected GradientNormalization gradientNormalizer; + protected double gradientNormalizationThreshold; + protected WorkspaceMode trainingWorkSpaceMode; + protected WorkspaceMode inferenceWorkspaceMode; + protected CacheMode cacheMode; + protected long seed; + + private Double[] discriminatorLearningRates; + + + public GAN(Builder builder) { + this.generatorSupplier = builder.generator; + this.discriminatorSupplier = builder.discriminator; + this.latentDim = builder.latentDimension; + this.updater = builder.iUpdater; + this.biasUpdater = builder.biasUpdater; + this.optimizer = builder.optimizationAlgo; + this.gradientNormalizer = builder.gradientNormalization; + this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold; + this.trainingWorkSpaceMode = builder.trainingWorkspaceMode; + this.inferenceWorkspaceMode = builder.inferenceWorkspaceMode; + this.cacheMode = builder.cacheMode; + this.seed = builder.seed; + + defineGan(); + } + + public MultiLayerNetwork getGenerator() { + return generator; + } + + public MultiLayerNetwork getDiscriminator() { + return discriminator; + } + + public Evaluation evaluateGan(DataSetIterator data) { + return gan.evaluate(data); + } + + public Evaluation evaluateGan(DataSetIterator data, List labelsList) { + return gan.evaluate(data, labelsList); + } + + + public void setGeneratorListeners(BaseTrainingListener[] listeners) { + generator.setListeners(listeners); + } + + public void setDiscriminatorListeners(BaseTrainingListener[] listeners) { + discriminator.setListeners(listeners); + } + + public void setGanListeners(BaseTrainingListener[] listeners) { + gan.setListeners(listeners); + } + + public void fit(DataSetIterator realData, int numEpochs) { + for (int i = 0; i < numEpochs; i++) { + while (realData.hasNext()) { + // Get real images as features + DataSet next = realData.next(); + fit(next); + } + realData.reset(); + } + } + + public void fit(DataSet next) { + int batchSize; + INDArray realImages = next.getFeatures().muli(2).subi(1); + batchSize = (int) realImages.shape()[0]; + + // Sample from latent space and let the generate create fake images. + INDArray randomLatentData = Nd4j.rand(new int[]{batchSize, latentDim}); + INDArray fakeImages = generator.output(randomLatentData); + + // Real images are marked as "0", fake images at "1". + DataSet realSet = new DataSet(realImages, Nd4j.zeros(batchSize, 1)); + DataSet fakeSet = new DataSet(fakeImages, Nd4j.ones(batchSize, 1)); + + // Fit the discriminator on a combined batch of real and fake images. + DataSet combined = DataSet.merge(Arrays.asList(realSet, fakeSet)); + + /*for (int i = 0; i < discriminator.getLayers().length; i++) { + if (discriminatorLearningRates[i] != null) { + discriminator.setLearningRate(i, discriminatorLearningRates[i]); + } + }*/ + + discriminator.fit(combined); + //discriminator.fit(combined); + + // Update the discriminator in the GAN network + updateGanWithDiscriminator(); + + // Generate a new set of adversarial examples and try to mislead the discriminator. + // by labeling the fake images as real images we reward the generator when it's output + // tricks the discriminator. + INDArray adversarialExamples = Nd4j.rand(new int[]{batchSize, latentDim}); + INDArray misleadingLabels = Nd4j.zeros(batchSize, 1); + DataSet adversarialSet = new DataSet(adversarialExamples, misleadingLabels); + + // Set learning rate of discriminator part of gan to zero. + /*for (int i = generator.getLayers().length; i < gan.getLayers().length; i++) { + gan.setLearningRate(i, 0.0); + }*/ + + // Fit the GAN on the adversarial set, trying to fool the discriminator by generating + // better fake images. + gan.fit(adversarialSet); + + // Copy the GANs generator part to "generator". + updateGeneratorFromGan(); + } + + private void defineGan() { + generator = generatorSupplier.get(); + generator.init(); + + Layer[] genLayers = generator.getLayers(); + int numGenLayers = genLayers.length; + + discriminator = discriminatorSupplier.provide(updater); + discriminator.init(); + + MultiLayerNetwork ganDiscriminator = discriminatorSupplier.provide(UPDATER_ZERO); + ganDiscriminator.init(); + + Layer[] disLayers = ganDiscriminator.getLayers(); + Layer[] layers = ArrayUtils.addAll(genLayers, disLayers); + MultiLayerConfiguration genConf = generator.getLayerWiseConfigurations(); + MultiLayerConfiguration disConf = ganDiscriminator.getLayerWiseConfigurations(); + org.deeplearning4j.nn.conf.layers.Layer[] confLayers = new org.deeplearning4j.nn.conf.layers.Layer[layers.length]; + + Map preProcessors = new HashMap<>(); + for (int i = 0; i < layers.length; i++) { + confLayers[i] = layers[i].conf().getLayer(); + if (i < numGenLayers) { + preProcessors.put(i, genConf.getInputPreProcess(i)); + } else { + preProcessors.put(i, disConf.getInputPreProcess(i - numGenLayers)); + } + } + + MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder() + .seed(seed) + .updater(updater) + .biasUpdater(biasUpdater) + .optimizationAlgo(optimizer) + .gradientNormalization(gradientNormalizer) + .gradientNormalizationThreshold(gradientNormalizationThreshold) + .activation(Activation.IDENTITY) + .trainingWorkspaceMode(trainingWorkSpaceMode) + .inferenceWorkspaceMode(inferenceWorkspaceMode) + .cacheMode(cacheMode) + .list(confLayers) + .inputPreProcessors(preProcessors) + .build(); + gan = new MultiLayerNetwork(ganConf); + gan.init(); + + // we lose proper init here, need to copy weights after + copyParamsToGan(); + } + + private void copyParamsToGan() { + int genLayerCount = generator.getLayers().length; + for (int i = 0; i < gan.getLayers().length; i++) { + if (i < genLayerCount) { + generator.getLayer(i).setParams(gan.getLayer(i).params()); + } else { + discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params()); + } + } + } + + /** + * After the GAN has been trained on misleading images, we update the generator the + * new weights (we don't have to update the discriminator, as it is frozen in the GAN). + */ + private void updateGeneratorFromGan() { + for (int i = 0; i < generator.getLayers().length; i++) { + generator.getLayer(i).setParams(gan.getLayer(i).params()); + } + } + + /** + * After the discriminator has been trained, we update the respective parts of the GAN network + * as well. + */ + private void updateGanWithDiscriminator() { + int genLayerCount = generator.getLayers().length; + for (int i = genLayerCount; i < gan.getLayers().length; i++) { + gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).params()); + } + } + + /** + * GAN builder, used as a starting point for creating a MultiLayerConfiguration or + * ComputationGraphConfiguration.
+ */ + public static class Builder implements Cloneable { + protected Supplier generator; + protected DiscriminatorProvider discriminator; + protected int latentDimension; + + protected IUpdater iUpdater = new Sgd(); + protected IUpdater biasUpdater = null; + protected long seed = System.currentTimeMillis(); + protected OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; + protected GradientNormalization gradientNormalization = GradientNormalization.None; + protected double gradientNormalizationThreshold = 1.0; + + protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; + protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; + protected CacheMode cacheMode = CacheMode.NONE; + + + public Builder() { + } + + + /** + * Set the (fake) image generator of the GAN. + * + * @param generator MultilayerNetwork + * @return Builder + */ + public GAN.Builder generator(Supplier generator) { + this.generator = generator; + return this; + } + + /** + * Set the image discriminator of the GAN. + * + * @param discriminator MultilayerNetwork + * @return Builder + */ + public GAN.Builder discriminator(DiscriminatorProvider discriminator) { + this.discriminator = discriminator; + return this; + } + + /** + * Set the latent dimension, i.e. the input vector space dimension of the generator. + * + * @param latentDimension latent space input dimension. + * @return Builder + */ + public GAN.Builder latentDimension(int latentDimension) { + this.latentDimension = latentDimension; + return this; + } + + + /** + * Random number generator seed. Used for reproducibility between runs + */ + public GAN.Builder seed(long seed) { + this.seed = seed; + Nd4j.getRandom().setSeed(seed); + return this; + } + + /** + * Optimization algorithm to use. Most common: OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT + * + * @param optimizationAlgo Optimization algorithm to use when training + */ + public GAN.Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) { + this.optimizationAlgo = optimizationAlgo; + return this; + } + + + /** + * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} + * or {@link org.nd4j.linalg.learning.config.Nesterovs}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param updater Updater to use + */ + public GAN.Builder updater(IUpdater updater) { + this.iUpdater = updater; + return this; + } + + /** + * Gradient updater configuration, for the biases only. If not set, biases will use the updater as + * set by {@link #updater(IUpdater)}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param updater Updater to use for bias parameters + */ + public GAN.Builder biasUpdater(IUpdater updater) { + this.biasUpdater = updater; + return this; + } + + /** + * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. + * See {@link GradientNormalization} for details
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param gradientNormalization Type of normalization to use. Defaults to None. + * @see GradientNormalization + */ + public GAN.Builder gradientNormalization(GradientNormalization gradientNormalization) { + this.gradientNormalization = gradientNormalization; + return this; + } + + /** + * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, + * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
+ * Not used otherwise.
+ * L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + */ + public GAN.Builder gradientNormalizationThreshold(double threshold) { + this.gradientNormalizationThreshold = threshold; + return this; + } + + public GAN build() { + return new GAN(this); + } + + } + +} \ No newline at end of file diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/GANVisualizationUtils.java b/brutex-extended-tests/src/test/java/net/brutex/gan/GANVisualizationUtils.java new file mode 100644 index 000000000..b88a6ae8f --- /dev/null +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/GANVisualizationUtils.java @@ -0,0 +1,73 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.gan; + +import org.nd4j.linalg.api.ndarray.INDArray; + +import javax.swing.*; +import java.awt.*; +import java.awt.image.BufferedImage; + +public class GANVisualizationUtils { + + public static JFrame initFrame() { + JFrame frame = new JFrame(); + frame.setTitle("Viz"); + frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); + frame.setLayout(new BorderLayout()); + return frame; + } + + public static JPanel initPanel(JFrame frame, int numSamples) { + JPanel panel = new JPanel(); + + panel.setLayout(new GridLayout(numSamples / 3, 1, 8, 8)); + frame.add(panel, BorderLayout.CENTER); + frame.setVisible(true); + return panel; + } + + public static void visualize(INDArray[] samples, JFrame frame, JPanel panel) { + panel.removeAll(); + + for (int i = 0; i < samples.length; i++) { + panel.add(getImage(samples[i])); + } + + frame.revalidate(); + frame.pack(); + } + + private static JLabel getImage(INDArray tensor) { + BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); + for (int i = 0; i < 784; i++) { + int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255); + bi.getRaster().setSample(i % 28, i / 28, 0, pixel); + } + ImageIcon orig = new ImageIcon(bi); + Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE); + + ImageIcon scaled = new ImageIcon(imageScaled); + + return new JLabel(scaled); + } +} \ No newline at end of file diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java new file mode 100644 index 000000000..d0e5bb73d --- /dev/null +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java @@ -0,0 +1,193 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.gan; + +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; +import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.optimize.listeners.PerformanceListener; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.RmsProp; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import javax.swing.*; +import java.awt.*; +import java.awt.image.BufferedImage; +import java.util.function.Supplier; + + +/** + * Training and visualizing a deep convolutional generative adversarial network (DCGAN) on handwritten digits. + * + * @author Max Pumperla, wmeddie + */ +public class MnistDCGANExample { + + private static JFrame frame; + private static JPanel panel; + + private static final int latentDim = 100; + private static final int height = 28; + private static final int width = 28; + private static final int channels = 1; + + + private static void visualize(INDArray[] samples) { + if (frame == null) { + frame = new JFrame(); + frame.setTitle("Viz"); + frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); + frame.setLayout(new BorderLayout()); + + panel = new JPanel(); + + panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8)); + frame.add(panel, BorderLayout.CENTER); + frame.setVisible(true); + } + + panel.removeAll(); + + for (int i = 0; i < samples.length; i++) { + panel.add(getImage(samples[i])); + } + + frame.revalidate(); + frame.pack(); + } + + private static JLabel getImage(INDArray tensor) { + BufferedImage bi = new BufferedImage(28, 28, BufferedImage.TYPE_BYTE_GRAY); + for (int i = 0; i < 784; i++) { + int pixel = (int) (((tensor.getDouble(i) + 1) * 2) * 255); + bi.getRaster().setSample(i % 28, i / 28, 0, pixel); + } + ImageIcon orig = new ImageIcon(bi); + Image imageScaled = orig.getImage().getScaledInstance((8 * 28), (8 * 28), Image.SCALE_REPLICATE); + + ImageIcon scaled = new ImageIcon(imageScaled); + + return new JLabel(scaled); + } + + public static void main(String[] args) throws Exception { + Supplier genSupplier = () -> { + return new MultiLayerNetwork(new NeuralNetConfiguration.Builder().list() + .layer(0, new DenseLayer.Builder().nIn(latentDim).nOut(width / 2 * height / 2 * 128) + .activation(Activation.LEAKYRELU).weightInit(WeightInit.NORMAL).build()) + .layer(1, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5) + .convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build()) + // Up-sampling to 28x28x256 + .layer(2, new Deconvolution2D.Builder().nIn(128).nOut(128).stride(2, 2) + .kernelSize(5, 5).convolutionMode(ConvolutionMode.Same) + .activation(Activation.LEAKYRELU).build()) + .layer(3, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5) + .convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build()) + .layer(4, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5) + .convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build()) + .layer(5, new Convolution2D.Builder().nIn(128).nOut(channels).kernelSize(7, 7) + .convolutionMode(ConvolutionMode.Same).activation(Activation.LEAKYRELU).build()) + .layer(6, new ActivationLayer.Builder().activation(Activation.TANH).build()) + .inputPreProcessor(1, + new FeedForwardToCnnPreProcessor(height / 2, width / 2, 128)) + .inputPreProcessor(6, new CnnToFeedForwardPreProcessor(height, width, channels)) + .setInputType(InputType.feedForward(latentDim)) + .build()); + }; + + GAN.DiscriminatorProvider discriminatorProvider = (updater) -> { + return new MultiLayerNetwork(new NeuralNetConfiguration.Builder() + .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build()) + //.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + //.gradientNormalizationThreshold(100.0) + .list() + .layer(0, new Convolution2D.Builder().nIn(channels).nOut(64).kernelSize(3, 3) + .activation(Activation.LEAKYRELU).build()) + .layer(1, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2) + .activation(Activation.LEAKYRELU).build()) + .layer(2, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2) + .activation(Activation.LEAKYRELU).build()) + .layer(3, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2) + .activation(Activation.LEAKYRELU).build()) + .layer(4, new DropoutLayer.Builder().dropOut(0.5).build()) + .layer(5, new DenseLayer.Builder().nIn(64 * 2 * 2).nOut(1).activation(Activation.SIGMOID).build()) + .layer(6, new LossLayer.Builder().lossFunction(LossFunctions.LossFunction.XENT).build()) + .inputPreProcessor(0, new FeedForwardToCnnPreProcessor(height, width, channels)) + .inputPreProcessor(4, new CnnToFeedForwardPreProcessor(2, 2, 64)) + .setInputType(InputType.convolutionalFlat(height, width, channels)) + .build()); + }; + + GAN gan = new GAN.Builder() + .generator(genSupplier) + .discriminator(discriminatorProvider) + .latentDimension(latentDim) + //.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) + //.gradientNormalizationThreshold(1.0) + .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build()) + .build(); + + gan.getGenerator().setListeners(new PerformanceListener(1, true)); + gan.getDiscriminator().setListeners(new PerformanceListener(1, true)); + + Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); + + int batchSize = 64; + MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42); + + for (int i = 0; i < 10; i++) { + //gan.fit(trainData, 1); + + System.out.println("Starting epoch: " + (i + 1)); + + trainData.reset(); + int j = 0; + while (trainData.hasNext()) { + DataSet next = trainData.next(); + gan.fit(next); + + if (j % 1 == 0) { + System.out.println("Epoch " + (i + 1) + " iteration " + j + " Visualizing..."); + INDArray fakeIn = Nd4j.rand(new int[]{batchSize, latentDim}); + + INDArray[] samples = new INDArray[9]; + for (int k = 0; k < 9; k++) { + samples[k] = gan.getGenerator().output(fakeIn.getRow(k), false); + } + visualize(samples); + } + j++; + } + + System.out.println("Finished epoch: " + (i + 1)); + } + } +} diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java new file mode 100644 index 000000000..037a0be9d --- /dev/null +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java @@ -0,0 +1,146 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.gan; + +import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.DropoutLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.config.Sgd; +import org.nd4j.linalg.lossfunctions.LossFunctions; + +import javax.swing.*; + + +/** + * Relatively small GAN example using only Dense layers with dropout to generate handwritten + * digits from MNIST data. + */ +public class MnistSimpleGAN { + + private static final int LATENT_DIM = 100; + + private static final double LEARNING_RATE = 0.0002; + private static final IUpdater UPDATER_ZERO = Sgd.builder().learningRate(0.0).build(); + private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build(); + + + public static MultiLayerNetwork getGenerator() { + MultiLayerConfiguration genConf = new NeuralNetConfiguration.Builder() + .weightInit(WeightInit.XAVIER) + .activation(Activation.IDENTITY) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(100) + .list() + .layer(new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DenseLayer.Builder().nIn(256).nOut(512).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DenseLayer.Builder().nIn(512).nOut(1024).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build()) + .build(); + return new MultiLayerNetwork(genConf); + } + + + public static MultiLayerNetwork getDiscriminator(IUpdater updater) { + MultiLayerConfiguration discConf = new NeuralNetConfiguration.Builder() + .seed(42) + .updater(updater) + .weightInit(WeightInit.XAVIER) + .activation(Activation.IDENTITY) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(100) + .list() + .layer(new DenseLayer.Builder().nIn(784).nOut(1024).updater(updater).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DropoutLayer.Builder(1 - 0.5).build()) + .layer(new DenseLayer.Builder().nIn(1024).nOut(512).updater(updater).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DropoutLayer.Builder(1 - 0.5).build()) + .layer(new DenseLayer.Builder().nIn(512).nOut(256).updater(updater).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DropoutLayer.Builder(1 - 0.5).build()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1) + .activation(Activation.SIGMOID).updater(updater).build()) + .build(); + + return new MultiLayerNetwork(discConf); + } + + public static void main(String[] args) throws Exception { + GAN gan = new GAN.Builder() + .generator(MnistSimpleGAN::getGenerator) + .discriminator(MnistSimpleGAN::getDiscriminator) + .latentDimension(LATENT_DIM) + .seed(42) + .updater(UPDATER) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(100) + .build(); + + Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); + + int batchSize = 128; + MnistDataSetIterator trainData = new MnistDataSetIterator(batchSize, true, 42); + + + // Sample from latent space once to visualize progress on image generation. + int numSamples = 9; + JFrame frame = GANVisualizationUtils.initFrame(); + JPanel panel = GANVisualizationUtils.initPanel(frame, numSamples); + + for (int i = 0; i < 100; i++) { + trainData.reset(); + int j = 0; + while (trainData.hasNext()) { + gan.fit(trainData.next()); + //gan.fit(trainData, 1); + + if (j % 10 == 0) { + INDArray fakeIn = Nd4j.rand(new int[]{batchSize, LATENT_DIM}); + System.out.println("Epoch " + (i + 1) + " Iteration " + j + " Visualizing..."); + INDArray[] samples = new INDArray[numSamples]; + for (int k = 0; k < numSamples; k++) { + INDArray input = fakeIn.getRow(k); + samples[k] = gan.getGenerator().output(input, false); + } + GANVisualizationUtils.visualize(samples, frame, panel); + } + j++; + } + } + } +} \ No newline at end of file diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BaseSparkSessionTest.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BaseSparkSessionTest.java index 5f81489e0..3b1b36c72 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BaseSparkSessionTest.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BaseSparkSessionTest.java @@ -20,48 +20,103 @@ package net.brutex.spark; +import java.io.File; +import java.io.IOException; +import java.net.URI; +import java.util.EnumSet; import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.CreateFlag; +import org.apache.hadoop.fs.FSDataOutputStream; +import org.apache.hadoop.fs.FileContext; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Options.CreateOpts; import org.apache.spark.SparkConf; +import org.apache.spark.SparkContext; import org.apache.spark.sql.SparkSession; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import java.io.Serializable; +import org.junit.jupiter.api.Test; @Slf4j public abstract class BaseSparkSessionTest implements Serializable { - private static SparkSession spark; - public static SparkSession getSession() { - SparkConf sparkConf = new SparkConf() - .setMaster("spark://10.5.5.200:7077") - .setAppName(BaseSparkSessionTest.class.getSimpleName()) - .set("spark.driver.bindAddress", "10.5.5.145") - .set("spark.network.timeout", "240000") - .set("spark.driver.host", "10.5.5.145") - .set("spark.deploy.mode", "client") - .set("spark.executor.memory", "4g") - .set("spark.cores.max", "4") - .set("spark.worker.cleanup.enabled", "true") - .set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml") - .set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml") - .set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000"); + private static SparkSession spark; - spark = SparkSession.builder() - .config(sparkConf) - .getOrCreate(); + public static SparkSession getSession() { + final String jarPath = uploadToHdfs("./build/libs/brutex-extended-tests-1.0.0-SNAPSHOT-all.jar"); - return spark; + SparkConf sparkConf = new SparkConf() + .setMaster("spark://10.5.5.200:7077") + .setAppName(BaseSparkSessionTest.class.getSimpleName()) + .set("spark.driver.bindAddress", "10.5.5.145") + .set("spark.blockManager.port", "65001") + //.set("spark.driver.bindAddress", "0.0.0.0") + .set("spark.network.timeout", "240000") + .set("spark.driver.host", "10.5.5.145") + .set("spark.deploy.mode", "cluster") + .set("spark.executor.memory", "4g") + .set("spark.cores.max", "4") + .set("spark.worker.cleanup.enabled", "true") + .set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml") + .set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml") + .set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000") + //.set("spark.jars", jarPath) + ; + spark = SparkSession.builder() + .config(sparkConf) + .getOrCreate(); + + spark.sparkContext().addJar(jarPath); + return spark; + } + public static String uploadToHdfs(String jarFile) { + File f = new File(jarFile); + if(!f.exists() && !f.isFile()) throw new RuntimeException("File to upload does not exist."); + final String base = "hdfs://10.5.5.200:9000/"; + String targetPath = "/user/brian/" + f.getName(); + try { + Configuration conf = new Configuration(); + + //FileContext hdfs = FileContext.getFileContext(URI.create(base), conf); + org.apache.hadoop.fs.FileSystem hdfs = FileSystem.get(URI.create(base), conf); + //String file = SparkFiles.get("phpMawTba"); + + org.apache.hadoop.fs.Path target = new org.apache.hadoop.fs.Path(targetPath); + + try { + hdfs.delete(target, false); + } catch (Exception e) {}; + + FileUtil.copy(f, hdfs, target, false, conf); + //Apache Commons + //FileUtils.copyFile(f, fTarget); + } catch(IOException ioe) { + ioe.printStackTrace(); + } + return base + targetPath; } - @BeforeAll - public static void beforeAll() { - } + @BeforeAll + public static void beforeAll() { - @AfterAll - public static synchronized void afterAll() { - getSession().close(); + } - } + @AfterAll + public static synchronized void afterAll() { + getSession().close(); + + } + + @Test + public void testSessionCreation() { + SparkSession session = getSession(); + log.info("Spark {} session id: {}", session.version(), session.sessionUUID()); + + } } diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java index cc88a0914..efb54aa29 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java @@ -20,22 +20,34 @@ */ package net.brutex.spark; -import com.fasterxml.jackson.core.Version; +import java.io.IOException; import lombok.extern.slf4j.Slf4j; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.ForeachFunction; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StringType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.filter.FilterInvalidValues; import org.datavec.api.transform.schema.Schema; import org.datavec.api.Writable; +import org.datavec.spark.transform.Normalization; import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.misc.StringToWritablesFunction; +import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator; +import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator.Set; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -47,7 +59,6 @@ import org.deeplearning4j.spark.api.TrainingMaster; import org.deeplearning4j.spark.datavec.DataVecDataSetFunction; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; -import org.deeplearning4j.ui.api.UIServer; import org.junit.jupiter.api.*; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; @@ -56,7 +67,6 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.nio.file.Paths; import java.util.Arrays; import java.util.Iterator; import java.util.List; @@ -70,23 +80,76 @@ import java.util.Random; @Slf4j @TestInstance(TestInstance.Lifecycle.PER_CLASS) @Tag("integration") -public class BrianTest /*extends BaseDL4JTest*/ { - static { - String OS = System.getProperty("os.name").toLowerCase(); +public class BrianTest extends BaseSparkSessionTest { +/* + static { + String OS = System.getProperty("os.name").toLowerCase(); - if (OS.contains("win")) { - System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString()); - } else { - System.setProperty("hadoop.home.dir", "/"); - } + if (OS.contains("win")) { + System.setProperty("hadoop.home.dir", + Paths.get("c:\\java\\winutils").toAbsolutePath().toString()); + } else { + System.setProperty("hadoop.home.dir", "/"); } + } +*/ + private JavaSparkContext sc; + private JavaRDD rdd; - public long getTimeoutMilliseconds() { - return 400000L; - } - private JavaSparkContext sc; - private JavaRDD rdd; + @Test + public void wrapEmnitDataset() throws IOException, InterruptedException { + SparkSession sc = getSession(); + EmnistDataSetIterator dataset = new EmnistDataSetIterator(Set.BALANCED, 128, true); + DataSet ds = dataset.next(); + System.out.println( "Number of features " + ds.numInputs()); + System.out.println( "Number of samples " + ds.numExamples()); + System.out.println( "Outcomes " + ds.numOutcomes()); + final String oppsFile = uploadToHdfs("c:/temp/opps.csv"); + + //System.out.println( "Reading file from " + oppsFile); + + JavaRDD rdd = sc.sparkContext().textFile(oppsFile, 1) + .toJavaRDD(); + System.out.println("Count " + rdd.count()); + //while(true) Thread.sleep(1000); + + //rdd.foreach( s -> { + // System.out.println("* "+s); + // }); + + + //JavaRDD rdd2 = rdd.flatMap( s -> Arrays.asList( s.split(";")).iterator() ); + //rdd2.collect().forEach( a -> System.out.print("# " + a + " ") ); + + StructType struct = new StructType(Arrays.asList( + StructField.apply("stage", DataTypes.StringType, false, Metadata.empty()), + StructField.apply("period", DataTypes.StringType, false, Metadata.empty()), + StructField.apply("portfolio", DataTypes.StringType, false, Metadata.empty()), + StructField.apply("country", DataTypes.StringType, false, Metadata.empty()), + StructField.apply("lfr", DataTypes.StringType, false, Metadata.empty()), + StructField.apply("saas", DataTypes.StringType, false, Metadata.empty()) + ).toArray(new StructField[]{}) + ); + JavaRDD rdd3 = rdd.map( attributes -> RowFactory.create(attributes.split(";"))); + + Dataset frame = sc.createDataFrame(rdd3, struct); + Dataset frame2 = frame.select(frame.col("lfr").cast(DataTypes.FloatType)); + frame.show(200); + + // frame.collect().map(row -> System.out.println(row.fieldIndex("stage") + row.fieldIndex("country"))); + + + + //frame.agg( frame.col("stage"), frame.col("lfr")); + frame.foreach((ForeachFunction) s -> System.out.println(s)); + + //sc.read().csv(rdd2); + //Normalization normalization = Normalization.zeromeanUnitVariance() + //sc. + + } + /* @BeforeAll @@ -109,120 +172,53 @@ public class BrianTest /*extends BaseDL4JTest*/ { } */ - @BeforeAll - public void setUp() throws Exception { - log.info("Running @BeforeEach scope"); - System.setProperty("hadoop.home.dir", Paths.get("c:\\java\\winutils").toAbsolutePath().toString()); - Version version = com.fasterxml.jackson.databind.cfg.PackageVersion.VERSION; - System.out.println("Jackson version found: " + version); - SparkConf sparkConf = new SparkConf() - .setMaster("spark://10.5.5.200:7077") - .setAppName("Brian3") - .set("spark.driver.bindAddress", "10.5.5.145") - .set("spark.network.timeout", "240000") - .set("spark.driver.host", "10.5.5.145") - .set("spark.driver.bindAddress", "10.5.5.145") - .set("spark.deploy.mode", "cluster") - .set("spark.executor.memory", "2g") - .set("spark.executor.cores", "2") - .set("spark.cores.max", "4") - .set("spark.worker.cleanup.enabled", "false") - .set("spark.driver.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml") - .set("spark.executor.extraJavaOptions", "-Dlog4j.configurationFile=log4j2.xml") - .set("spark.driver.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar") - .set("spark.executor.extraClassPath", "brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar;brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar") - .set("spark.hadoop.fs.defaultFS", "hdfs://10.5.5.200:9000"); - //.set("spark.driver.cores", "2") - //.set("spark.driver.memory", "8g") - //.set("spark.driver.host", "10.5.5.145") - //.setExecutorEnv("spark.executor.cores", "2") - //.setExecutorEnv("spark.executor.memory", "2g") - //.set("spark.submit.deployMode", "client") -/* - SparkSession spark = SparkSession - .builder() - .master("spark://10.5.5.200:7077") - .config("spark.driver.bindAddress", "10.5.5.145") - .config("spark.driver.host", "10.5.5.145") - //.config("spark.driver.memory", "5g") - .appName("BrianTest2") - .getOrCreate(); -*/ - sc = new JavaSparkContext(sparkConf); + @Test + ////@Ignore("AB 2019/05/21 - Failing - Issue #7657") + public void testStringsTokenization1() throws Exception { - // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\deeplearning4j\\deeplearning4j-scaleout\\spark\\dl4j-spark-nlp-java8\\target\\dl4j-spark-nlp-java8_2.12-1.0.0-SNAPSHOT-tests.jar"); - // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\datavec\\datavec-api\\target\\datavec-api-1.0.0-SNAPSHOT.jar"); - // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\nd4j\\nd4j-uberjar\\target\\nd4j-uberjar-1.0.0-SNAPSHOT.jar"); - // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\nd4j\\nd4j-common\\target\\nd4j-common-1.0.0-SNAPSHOT.jar"); - // sc.addJar("C:\\Users\\brian\\_projects\\deeplearning4j\\datavec\\datavec-spark\\target\\datavec-spark_2.12-1.0.0-SNAPSHOT.jar"); - sc.addJar("C:\\Users\\brian\\_projects\\Brian-Spark-DL4J-Tests\\target\\brian-spark-dl4j-tests-1.0.1-SNAPSHOT-bin.jar"); - sc.addJar("C:\\Users\\brian\\_projects\\Brian-Spark-DL4J-Tests\\target\\brian-spark-dl4j-tests-1.0.1-SNAPSHOT-tests.jar"); + //shrink for Test + //List list = Arrays.asList(new String[]{"asdsad", "asdasdasd", "asdasdasd", "3easdasd"}); + //JavaRDD rdd = sc.parallelize(list); + // rdd = rdd.sample(true, 1.0, 1); + log.info("Datenmenge: " + rdd.count()); + log.info("Sample: " + rdd.top(3)); - rdd = sc.textFile("hdfs://10.5.5.200:9000/user/zeppelin/cities_full.csv.gz"); + Assertions.assertEquals(146889, rdd.count()); + } + @Test + public void testSchemaCreation() throws Exception { + rdd.cache(); + JavaRDD cities = rdd.map((Function) line -> { + return line.split(",")[1]; + }).cache(); - } + JavaRDD stateCodeList = rdd.map((Function) line -> { + return line.split(",")[2]; + }).cache(); - @AfterAll - public void tearDown() throws Exception { - sc.close(); - sc.stop(); - UIServer.stopInstance(); + JavaRDD countryCodeList = rdd.map((Function) line -> { + return line.split(",")[3]; + }).cache(); - } + CSVRecordReader recordReader = new CSVRecordReader(0, ','); + JavaRDD> convertedRDD = rdd.map((Function>) s -> { + return new StringToWritablesFunction(recordReader).call(s); + }); - @Test - ////@Ignore("AB 2019/05/21 - Failing - Issue #7657") - public void testStringsTokenization1() throws Exception { + //Source Schema + Schema inputSchema = new Schema.Builder() + .addColumnLong("city_id") + .addColumnsString("city_name", "state_code", "country_code") + .addColumnsString("country_full") + .addColumnsDouble("lat", "lon") + .build(); - //shrink for Test - //List list = Arrays.asList(new String[]{"asdsad", "asdasdasd", "asdasdasd", "3easdasd"}); - //JavaRDD rdd = sc.parallelize(list); - - // rdd = rdd.sample(true, 1.0, 1); - log.info("Datenmenge: " + rdd.count()); - log.info("Sample: " + rdd.top(3)); - - Assertions.assertEquals(146889, rdd.count()); - } - - @Test - public void testSchemaCreation() throws Exception { - - - rdd.cache(); - - JavaRDD cities = rdd.map( (Function) line -> { - return line.split(",")[1]; - }).cache(); - - JavaRDD stateCodeList = rdd.map( (Function) line -> { - return line.split(",")[2]; - }).cache(); - - JavaRDD countryCodeList = rdd.map( (Function) line -> { - return line.split(",")[3]; - }).cache(); - - - CSVRecordReader recordReader = new CSVRecordReader(0, ','); - JavaRDD> convertedRDD = rdd.map((Function>) s -> { - return new StringToWritablesFunction( recordReader).call(s); - }); - - //Source Schema - Schema inputSchema = new Schema.Builder() - .addColumnLong("city_id") - .addColumnsString("city_name", "state_code", "country_code") - .addColumnsString("country_full") - .addColumnsDouble("lat", "lon") - .build(); - - //Running Transformation + //Running Transformation /* TransformProcess tp = new TransformProcess.Builder(inputSchema) .removeColumns("country_full", "lat", "lon") @@ -236,38 +232,40 @@ public class BrianTest /*extends BaseDL4JTest*/ { .categoricalToOneHot("country_code") .build(); */ - TransformProcess tp = new TransformProcess.Builder(inputSchema) - .removeAllColumnsExceptFor("country_code", "lat", "lon") - .stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH")) - .filter(new FilterInvalidValues()) - .categoricalToOneHot("country_code") - .build(); + TransformProcess tp = new TransformProcess.Builder(inputSchema) + .removeAllColumnsExceptFor("country_code", "lat", "lon") + .stringToCategorical("country_code", Arrays.asList("GR", "FR", "DE", "CH")) + .filter(new FilterInvalidValues()) + .categoricalToOneHot("country_code") + .build(); - //log.info("Final Schema: " +tp.getFinalSchema().toString()); - //Execute Transformation Process - convertedRDD.repartition(8); - convertedRDD.cache(); - JavaRDD> processedData = SparkTransformExecutor.execute(convertedRDD, tp); - processedData.repartition(8); - processedData.cache(); - //log.info("Datenmenge nach processing: " + processedData.count()); + //log.info("Final Schema: " +tp.getFinalSchema().toString()); + //Execute Transformation Process + convertedRDD.repartition(8); + convertedRDD.cache(); + JavaRDD> processedData = SparkTransformExecutor.execute(convertedRDD, tp); + processedData.repartition(8); + processedData.cache(); + //log.info("Datenmenge nach processing: " + processedData.count()); + //Vectorisieren + int labelIndex = 0; //in welcher Spalte ist das Label + int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte - //Vectorisieren - int labelIndex = 0; //in welcher Spalte ist das Label - int numLabels = 4; //Anzahl der Klassen 0-236 = 237 Werte + DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels, + false); + JavaRDD rddDataSet = processedData.map(datavecFunction); + log.info("rddDataset: " + rddDataSet.toDebugString()); + Random rand = new Random(); + rddDataSet.sortBy((Function) s -> { + return rand.nextDouble(); + }, true, 8); - DataVecDataSetFunction datavecFunction = new DataVecDataSetFunction(labelIndex, numLabels, false); - JavaRDD rddDataSet = processedData.map(datavecFunction); - log.info("rddDataset: " + rddDataSet.toDebugString()); - Random rand = new Random(); - rddDataSet.sortBy( (Function) s -> {return rand.nextDouble(); }, true, 8); + //og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect()); - //og.info("Sample: " + rddDataSet.sample(false, 0.005, 0).collect()); - - /* Skip, this will save each record one by one to hdfs - */ - //Now save this hard work + /* Skip, this will save each record one by one to hdfs + */ + //Now save this hard work /* int miniBatchSize = 1; //Minibatch size of the saved DataSet objects final String exportPath = "hdfs://10.5.5.200:9000/user/brian/data"; @@ -278,63 +276,67 @@ public class BrianTest /*extends BaseDL4JTest*/ { paths.collect(); */ - //Create Trainingmaster + //Create Trainingmaster - TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4) - .rddTrainingApproach(RDDTrainingApproach.Direct) //when "export", tries to save everything first - .batchSizePerWorker(1000) - .collectTrainingStats(true) - .build(); + TrainingMaster trainingMaster = new ParameterAveragingTrainingMaster.Builder(4) + .rddTrainingApproach( + RDDTrainingApproach.Direct) //when "export", tries to save everything first + .batchSizePerWorker(1000) + .collectTrainingStats(true) + .build(); - //Define Network + //Define Network - MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() - .seed(123) - .updater(new Nesterovs(0.1, 0.9)) - .list() - .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) - //.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build()) - .build(); + MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() + .seed(123) + .updater(new Nesterovs(0.1, 0.9)) + .list() + .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER) + .activation(Activation.RELU).l2(0.001).build()) + .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER) + .activation(Activation.RELU).build()) + //.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) + .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4) + .weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build()) + .build(); - //Define SparkNet - SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration, trainingMaster); + //Define SparkNet + SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, multiLayerConfiguration, + trainingMaster); + JavaRDD[] split = rddDataSet.randomSplit(new double[]{0.9, 0.1}, 123); + //JavaRDD trainingData = split[0]; + JavaRDD trainingData = rddDataSet; + JavaRDD testData = split[1]; - JavaRDD[] split = rddDataSet.randomSplit(new double[] {0.9, 0.1}, 123); - //JavaRDD trainingData = split[0]; - JavaRDD trainingData = rddDataSet; - JavaRDD testData = split[1]; - - //Run Training on subset - for(int i =0; i<20; i++) { - sparkNet.fit(trainingData); - } - - //Evaluieren - MultiLayerNetwork finalNet = sparkNet.getNetwork(); - - //Speichern - Configuration conf = sc.hadoopConfiguration(); - conf.set("hadoop.tmp.dir", "/user/brian/tmp"); - FileSystem fs = FileSystem.get(conf); - Path p = new Path("hdfs://10.5.5.200:9000/user/brian/model"); - //fs.mkdirs(p); - //ModelSerializer.writeModel(finalNet, fs.create(p), true ); - - Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes - Iterator iter = testData.toLocalIterator(); - log.info("testData has " + testData.count() + " DataSets"); - while(iter.hasNext()){ - DataSet next = iter.next(); - //log.info("getFeatures " + next.getFeatures() ); - INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction - //log.info("output "+ output.toStringFull()); - eval.eval(next.getLabels(), output); //check the prediction against the true class - //log.info("Predict " + finalNet.predict(next)); - } - log.info("Evaluation stats: " + eval.stats()); + //Run Training on subset + for (int i = 0; i < 20; i++) { + sparkNet.fit(trainingData); } + //Evaluieren + MultiLayerNetwork finalNet = sparkNet.getNetwork(); + + //Speichern + Configuration conf = sc.hadoopConfiguration(); + conf.set("hadoop.tmp.dir", "/user/brian/tmp"); + FileSystem fs = FileSystem.get(conf); + Path p = new Path("hdfs://10.5.5.200:9000/user/brian/model"); + //fs.mkdirs(p); + //ModelSerializer.writeModel(finalNet, fs.create(p), true ); + + Evaluation eval = new Evaluation(4); // outputNum = 10: number of output classes + Iterator iter = testData.toLocalIterator(); + log.info("testData has " + testData.count() + " DataSets"); + while (iter.hasNext()) { + DataSet next = iter.next(); + //log.info("getFeatures " + next.getFeatures() ); + INDArray output = finalNet.output(next.getFeatures()); //get the networks prediction + //log.info("output "+ output.toStringFull()); + eval.eval(next.getLabels(), output); //check the prediction against the true class + //log.info("Predict " + finalNet.predict(next)); + } + log.info("Evaluation stats: " + eval.stats()); + } + } diff --git a/cavis-common-platform/build.gradle b/cavis-common-platform/build.gradle index aaf070d84..a1a728508 100644 --- a/cavis-common-platform/build.gradle +++ b/cavis-common-platform/build.gradle @@ -25,8 +25,8 @@ ext { def flatbuffers = [version: "1.10.0"] - def spark = [version: "3.1.2"] - def scala = [version:"2.12.10"] //[version:"2.13.5"] + def spark = [version: "3.2.2"] + def scala = [version:"2.12.15"] //[version:"2.13.5"] def netty = [version: "4.1.68.Final"] diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/AbstractField.java similarity index 51% rename from cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java rename to cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/AbstractField.java index 92705bea8..87aa389ce 100644 --- a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/FieldInterface.java +++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/AbstractField.java @@ -21,59 +21,44 @@ package net.brutex.cavis.dvec.api; -import java.io.Serializable; import java.nio.Buffer; -import java.nio.LongBuffer; -import java.util.List; +import java.nio.ByteBuffer; import net.brutex.cavis.dvec.api.exceptions.DVecException; /** - * A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple - * entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api, - * other (i.e. Image or Arrow) require dvec extensions accordingly. + * Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage + * in memory and adds basic error handling. * * @author Brian Rosenberger * @since 1.0 */ -public interface FieldInterface extends Serializable { +public abstract class AbstractField implements Field { /** - * Get a reference to the metadata for this Field. - * - * @return the {@link FieldMetadata} - */ - FieldMetadata getFieldMetadata(); - - /** - * Get the 1st field as Buffer. This deserializes the data from the underlying storage. - * - * @return T underlying Buffer - */ - default T read() throws DVecException { - return read(0, 1); - } - - /** - * Get a range of fields as a {@code Buffer} + * {@inheritDoc} * * @param start Index of starting position, zero based * @param length how many fields to read - * @return the buffers + * @return the list of Buffer */ - T read(long start, long length) throws DVecException; - - /** - * Write the data into the underlying storage. - */ - default void write(T buffer) { - write(0, buffer); + @Override + public T read(long start, long length) throws DVecException { + if (start<0 || start>internalStorage.capacity()-1 ) { + throw new DVecException("Read on Field start position is out of bounds."); + } + if (start+length> internalStorage.capacity()) { + throw new DVecException("Read on Field exceeds field length"); + } + return null; } - /** - * Write the data into the underyling storage starting at a position - * - * @param pos the position to start - */ - void write(long pos, T buffer); + @Override + public void write(long pos, T buffer) { + + } + + private ByteBuffer internalStorage = null; + + } diff --git a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java index a3be6313f..ace9be2a1 100644 --- a/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java +++ b/cavis-datavec/dvec-api/src/main/java/net/brutex/cavis/dvec/api/Field.java @@ -21,46 +21,57 @@ package net.brutex.cavis.dvec.api; +import java.io.Serializable; import java.nio.Buffer; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; import net.brutex.cavis.dvec.api.exceptions.DVecException; /** - * Abtract implementation of the Field interface {@see FieldInterface}, that handles all data storage - * in memory and adds basic error handling. + * A Field can be considered a "column" in a {@code Record}, as such a Field may refer to multiple + * entries of that "column". Fields are typed as Buffers. Some of them defined in the dvec core api, + * other (i.e. Image or Arrow) require dvec extensions accordingly. * * @author Brian Rosenberger * @since 1.0 */ -public abstract class Field implements FieldInterface { +public interface Field extends Serializable { /** - * {@inheritDoc} + * Get a reference to the metadata for this Field. + * + * @return the {@link FieldMetadata} + */ + FieldMetadata getFieldMetadata(); + + /** + * Get the 1st field as Buffer. This deserializes the data from the underlying storage. + * + * @return T underlying Buffer + */ + default T read() throws DVecException { + return read(0, 1); + } + + /** + * Get a range of fields as a {@code Buffer} * * @param start Index of starting position, zero based * @param length how many fields to read - * @return the list of Buffer + * @return the buffers */ - @Override - public T read(long start, long length) throws DVecException { - if (start<0 || start>internalStorage.capacity()-1 ) { - throw new DVecException("Read on Field start position is out of bounds."); - } - if (start+length> internalStorage.capacity()) { - throw new DVecException("Read on Field exceeds field length"); - } - return null; + T read(long start, long length) throws DVecException; + + /** + * Write the data into the underlying storage. + */ + default void write(T buffer) { + write(0, buffer); } - @Override - public void write(long pos, T buffer) { - - } - - private ByteBuffer internalStorage = null; - - + /** + * Write the data into the underyling storage starting at a position + * + * @param pos the position to start + */ + void write(long pos, T buffer); } diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java index 83cba9988..16f6f134a 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java @@ -99,9 +99,13 @@ public class DL4JClassLoading { .asSubclass(superclass) .getDeclaredConstructor(parameterTypes) .newInstance(args); - } catch (InstantiationException | IllegalAccessException | InvocationTargetException + } catch (InstantiationException | IllegalAccessException | NoSuchMethodException instantiationException) { log.error(String.format("Cannot create instance of class '%s'.", className), instantiationException); + + throw new RuntimeException(instantiationException); + } catch (InvocationTargetException instantiationException) { + log.error(String.format("InvocationTargetException was '%s'.", instantiationException.getTargetException().getMessage()), instantiationException); throw new RuntimeException(instantiationException); } } diff --git a/cavis-dnn/cavis-dnn-cudnn/build.gradle b/cavis-dnn/cavis-dnn-cudnn/build.gradle new file mode 100644 index 000000000..725ca1f85 --- /dev/null +++ b/cavis-dnn/cavis-dnn-cudnn/build.gradle @@ -0,0 +1,23 @@ + +apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" + +ext { + buildTarget = rootProject.ext.buildTarget +} + +dependencies { + implementation platform(projects.cavisCommonPlatform) + implementation projects.cavisNative.cavisNativeJcublas + implementation projects.cavisDnn.cavisDnnApi + implementation projects.cavisDnn.cavisDnnNn + + implementation group: "org.bytedeco", name: "cuda" + implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget + implementation group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist" + + implementation group: "org.bytedeco", name: "javacpp" + implementation group: "org.bytedeco", name: "javacpp", classifier: buildTarget + + implementation 'com.jakewharton.byteunits:byteunits:0.9.1' + +} \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/BaseCudnnHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/BaseCudnnHelper.java new file mode 100644 index 000000000..5465f6224 --- /dev/null +++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/BaseCudnnHelper.java @@ -0,0 +1,252 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.cuda; + +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.*; +import org.nd4j.jita.allocator.impl.AtomicAllocator; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; + +import org.bytedeco.cuda.cudnn.*; +import static org.bytedeco.cuda.global.cudart.*; +import static org.bytedeco.cuda.global.cudnn.*; + +/** + * Functionality shared by all cuDNN-based helpers. + * + * @author saudet + */ +@Slf4j +public abstract class BaseCudnnHelper { + + /* public BaseCudnnHelper() { + + } + */ + + protected static void checkCuda(int error) { + if (error != cudaSuccess) { + throw new RuntimeException("CUDA error = " + error + ": " + cudaGetErrorString(error).getString()); + } + } + + protected static void checkCudnn(int status) { + if (status != CUDNN_STATUS_SUCCESS) { + throw new RuntimeException("cuDNN status = " + status + ": " + cudnnGetErrorString(status).getString()); + } + } + + protected static class CudnnContext extends cudnnContext { + + protected static class Deallocator extends CudnnContext implements Pointer.Deallocator { + Deallocator(CudnnContext c) { + super(c); + } + + @Override + public void deallocate() { + destroyHandles(); + } + } + + public CudnnContext() { + // insure that cuDNN initializes on the same device as ND4J for this thread + Nd4j.create(1); + AtomicAllocator.getInstance(); + // This needs to be called in subclasses: + // createHandles(); + // deallocator(new Deallocator(this)); + } + + public CudnnContext(CudnnContext c) { + super(c); + } + + protected void createHandles() { + checkCudnn(cudnnCreate(this)); + } + + protected void destroyHandles() { + checkCudnn(cudnnDestroy(this)); + } + } + + protected static class DataCache extends Pointer { + + static class Deallocator extends DataCache implements Pointer.Deallocator { + Deallocator(DataCache c) { + super(c); + } + + @Override + public void deallocate() { + checkCuda(cudaFree(this)); + setNull(); + } + } + + static class HostDeallocator extends DataCache implements Pointer.Deallocator { + HostDeallocator(DataCache c) { + super(c); + } + + @Override + public void deallocate() { + checkCuda(cudaFreeHost(this)); + setNull(); + } + } + + public DataCache() {} + + public DataCache(long size) { + position = 0; + limit = capacity = size; + int error = cudaMalloc(this, size); + if (error != cudaSuccess) { + log.warn("Cannot allocate " + size + " bytes of device memory (CUDA error = " + error + + "), proceeding with host memory"); + checkCuda(cudaMallocHost(this, size)); + deallocator(new HostDeallocator(this)); + } else { + deallocator(new Deallocator(this)); + } + } + + public DataCache(DataCache c) { + super(c); + } + } + + protected static class TensorArray extends PointerPointer { + + static class Deallocator extends TensorArray implements Pointer.Deallocator { + Pointer owner; + + Deallocator(TensorArray a, Pointer owner) { + this.address = a.address; + this.capacity = a.capacity; + this.owner = owner; + } + + @Override + public void deallocate() { + for (int i = 0; !isNull() && i < capacity; i++) { + cudnnTensorStruct t = this.get(cudnnTensorStruct.class, i); + checkCudnn(cudnnDestroyTensorDescriptor(t)); + } + if (owner != null) { + owner.deallocate(); + owner = null; + } + setNull(); + } + } + + public TensorArray() {} + + public TensorArray(long size) { + PointerPointer p = new PointerPointer(size); + p.deallocate(false); + this.address = p.address(); + this.limit = p.limit(); + this.capacity = p.capacity(); + + cudnnTensorStruct t = new cudnnTensorStruct(); + for (int i = 0; i < capacity; i++) { + checkCudnn(cudnnCreateTensorDescriptor(t)); + this.put(i, t); + } + deallocator(new Deallocator(this, p)); + } + + public TensorArray(TensorArray a) { + super(a); + } + } + + protected final DataType nd4jDataType; + protected final int dataType; + protected final int dataTypeSize; + // both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta + protected final Pointer alpha; + protected final Pointer beta; + protected SizeTPointer sizeInBytes = new SizeTPointer(1); + + public BaseCudnnHelper(@NonNull DataType dataType){ + this.nd4jDataType = dataType; + this.dataType = dataType == DataType.DOUBLE ? CUDNN_DATA_DOUBLE + : dataType == DataType.FLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF; + this.dataTypeSize = dataType == DataType.DOUBLE ? 8 : dataType == DataType.FLOAT ? 4 : 2; + // both CUDNN_DATA_HALF and CUDNN_DATA_FLOAT need a float value for alpha and beta + this.alpha = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(1.0) : new FloatPointer(1.0f); + this.beta = this.dataType == CUDNN_DATA_DOUBLE ? new DoublePointer(0.0) : new FloatPointer(0.0f); + } + + public static int toCudnnDataType(DataType type){ + switch (type){ + case DOUBLE: + return CUDNN_DATA_DOUBLE; + case FLOAT: + return CUDNN_DATA_FLOAT; + case INT: + return CUDNN_DATA_INT32; + case HALF: + return CUDNN_DATA_HALF; + default: + throw new RuntimeException("Cannot convert type: " + type); + } + } + + public boolean checkSupported() { + // add general checks here, if any + return true; + } + + + /** + * From CuDNN documentation - + * "Tensors are restricted to having at least 4 dimensions... When working with lower dimensional data, it is + * recommended that the user create a 4Dtensor, and set the size along unused dimensions to 1." + * + * This method implements that - basically appends 1s to the end (shape or stride) to make it length 4, + * or leaves it unmodified if the length is already 4 or more. + * This method can be used for both shape and strides + * + * @param shapeOrStrides + * @return + */ + protected static int[] adaptForTensorDescr(int[] shapeOrStrides){ + if(shapeOrStrides.length >= 4) + return shapeOrStrides; + int[] out = new int[4]; + int i=0; + for(; i backpropGradient(INDArray input, INDArray weights, INDArray bias, INDArray delta, int[] kernel, + int[] strides, int[] pad, INDArray biasGradView, INDArray weightGradView, IActivation afn, + AlgoMode mode, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, + ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + + //AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working + // correctly on NHWC data, even after updating all descriptors, tensor format, etc. + //Therefore: all computation here is done in NCHW format only + //As of a future (next?) release we'll likely switch to C++ for cuDNN support + boolean origNHWC = false; + if(format == CNN2DFormat.NHWC){ + input = input.permute(0,3,1,2); //NHWC to NCHW + delta = delta.permute(0,3,1,2); + origNHWC = true; + } + + int TENSOR_FORMAT = CUDNN_TENSOR_NCHW; + + int code; + + val miniBatch = input.size(0); + val outDepth = weights.size(0); + val inDepth = weights.size(1); + val kH = weights.size(2); + val kW = weights.size(3); + + CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above + input = args.getInput(); + val inH = input.size(2); + val inW = input.size(3); + val srcStride = input.stride(); + val outSize = args.getOutSize(); + val outH = outSize[0]; + val outW = outSize[1]; + + if (!Shape.strideDescendingCAscendingF(delta)) { + // apparently not supported by cuDNN + delta = delta.dup(); + } + + val deltaStride = delta.stride(); + int[] algo1 = new int[1]; + int[] algo2 = new int[1]; + + + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth,(int) inH, (int) inW, + (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]); + checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + code = cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) outDepth, (int) outH, (int) outW, + (int) deltaStride[0], (int) deltaStride[1], (int) deltaStride[2], (int) deltaStride[3]); + checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0], + dilation[1], CUDNN_CROSS_CORRELATION, dataType); + checkCudnn(false, "cudnnSetConvolution2dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW); + checkCudnn(false, "cudnnSetFilter4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + if (mode == AlgoMode.USER_SPECIFIED && bwdFilterAlgo != null && bwdDataAlgo != null) { + switch (bwdFilterAlgo) { + case ALGO_0: + algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; + break; + case ALGO_1: + algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + break; + case FFT: + algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT; + break; + case ALGO_3: + algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3; + break; + case WINOGRAD: + algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD; + break; + case WINOGRAD_NONFUSED: + algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED; + break; + case FFT_TILING: + algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING; + break; + case COUNT: + algo1[0] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT; + break; + default: + throw new IllegalArgumentException("Unknown BwdFilterAlgo: " + bwdFilterAlgo); + } + + switch (bwdDataAlgo) { + case ALGO_0: + algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; + break; + case ALGO_1: + algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + break; + case FFT: + algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT; + break; + case FFT_TILING: + algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING; + break; + case WINOGRAD: + algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD; + break; + case WINOGRAD_NONFUSED: + algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED; + break; + case COUNT: + algo2[0] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + break; + default: + throw new IllegalArgumentException("Unknown BwdDataAlgo: " + bwdDataAlgo); + } + } else { + /* + code = cudnnGetConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, + mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE + : CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST, + 0, algo1); + */ + val fa = new cudnnConvolutionBwdFilterAlgoPerf_t(); + val counts = new int[1]; + code = cudnnFindConvolutionBackwardFilterAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, 1, counts, fa); + algo1[0] = fa.algo(); + + checkCudnn(false, "cudnnGetConvolutionBackwardFilterAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + /* + code = cudnnGetConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc, + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, + mode == AlgoMode.NO_WORKSPACE ? CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE + : CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST, + 0, algo2); + */ + + val da = new cudnnConvolutionBwdDataAlgoPerf_t(); + code = cudnnFindConvolutionBackwardDataAlgorithm(cudnnContext, cudnnContext.filterDesc, + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.srcTensorDesc, 1, counts, da); + + algo2[0] = da.algo(); + checkCudnn(false, "cudnnGetConvolutionBackwardDataAlgorithm", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + } + + if(log.isTraceEnabled()){ + BwdFilterAlgo fa = BwdFilterAlgo.values()[algo1[0]]; + BwdDataAlgo da = BwdDataAlgo.values()[algo2[0]]; + log.trace("CudnnConvolutionHelper backward algorithm selection: mode {}, filter algorithm {}, data algorithm {}", mode, fa, da); + } + + INDArray epsNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, weights.dataType(), new long[] {(int) miniBatch,(int) inDepth, (int) inH, (int) inW}, 'c'); + + val dstStride = epsNext.stride(); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, weights, weightGradView, + biasGradView, delta, epsNext); + Pointer srcData = allocator.getPointer(input, context); + Pointer filterData = allocator.getPointer(weights, context); + Pointer filterGradData = allocator.getPointer(weightGradView, context); + Pointer biasGradData = allocator.getPointer(biasGradView, context); + Pointer deltaData = allocator.getPointer(delta, context); + Pointer dstData = allocator.getPointer(epsNext, context); + + code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())); + checkCudnn(false, "cudnnSetStream", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, + (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]); + checkCudnn(false, "cudnnSetTensor4dDescriptorEx", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + code = cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.filterDesc, algo1[0], + sizeInBytes); + checkCudnn(false, "cudnnGetConvolutionBackwardFilterWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + long sizeInBytes1 = sizeInBytes.get(0); + code = cudnnGetConvolutionBackwardDataWorkspaceSize(cudnnContext, cudnnContext.filterDesc, + cudnnContext.deltaTensorDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo2[0], + sizeInBytes); + checkCudnn(false, "cudnnGetConvolutionBackwardDataWorkspaceSize", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); + long sizeInBytes2 = sizeInBytes.get(0); + if (workSpace == null || sizeInBytes1 > workSpace.capacity() || sizeInBytes2 > workSpace.capacity()) { + long newSize = Math.max(sizeInBytes1, sizeInBytes2); + if(log.isTraceEnabled()){ + if(workSpace == null){ + log.trace("CudnnConvolutionHelper backpropGradient: Allocating initial workspace of size {} ({})", newSize, + BinaryByteUnit.format(newSize, "#.00")); + } else { + log.trace("CudnnConvolutionHelper backpropGradient: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", + workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"), + newSize, BinaryByteUnit.format(newSize, "#.00")); + } + } + if(workSpace != null) + workSpace.deallocate(); + workSpace = new DataCache(newSize); + workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace); + } + + code = cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, TENSOR_FORMAT, dataType, 1, (int) outDepth, 1, 1); + checkCudnn(false, "cudnnSetTensor4dDescriptor", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + code = cudnnConvolutionBackwardBias(cudnnContext, alpha, cudnnContext.deltaTensorDesc, deltaData, beta, + cudnnContext.biasTensorDesc, biasGradData); + checkCudnn(false, "cudnnConvolutionBackwardBias", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + code = cudnnConvolutionBackwardFilter(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, + cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo1[0], workSpace, + workSpace.capacity(), beta, cudnnContext.filterDesc, filterGradData); + checkCudnn(false, "cudnnConvolutionBackwardFilter", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + code = cudnnConvolutionBackwardData(cudnnContext, alpha, cudnnContext.filterDesc, filterData, + cudnnContext.deltaTensorDesc, deltaData, cudnnContext.convDesc, algo2[0], workSpace, + workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); + checkCudnn(false, "cudnnConvolutionBackwardData", code, input, weights, null, delta, kernel, strides, pad, mode, null, bwdFilterAlgo, bwdDataAlgo, convolutionMode, dilation); + + allocator.getFlowController().registerActionAllWrite(context, input, weights, weightGradView, biasGradView, + delta, epsNext); + + Gradient retGradient = new DefaultGradient(); + retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView); + retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c'); + + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + + //Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon + // we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input. + if(args.isManualPadBottom() || args.isManualPadRight()) { + epsNext = epsNext.get(all(), all(), + interval(0, epsNext.size(2) - (args.isManualPadBottom() ? 1 : 0)), + interval(0, epsNext.size(3) - (args.isManualPadRight() ? 1 : 0))); + } + + if(origNHWC){ + epsNext = epsNext.permute(0,2,3,1); //NCHW to NHWC + } + + return new Pair<>(retGradient, epsNext); + } + + @Override + public INDArray preOutput(INDArray input, INDArray weights, INDArray bias, int[] kernel, int[] strides, int[] pad, + AlgoMode mode, FwdAlgo fwdAlgo, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, + LayerWorkspaceMgr workspaceMgr) { + + //AB 2020/04/21 - cuDNN does have NHWC support (with limitations) however I have been unable to get it working + // correctly on NHWC data, even after updating all descriptors, tensor format, etc. + //Therefore: all computation here is done in NCHW format only + //As of a future (next?) release we'll likely switch to C++ for cuDNN support + boolean origNHWC = false; + if(format == CNN2DFormat.NHWC){ + input = input.permute(0,3,1,2); //NHWC to NCHW + origNHWC = true; + } + + int TENSOR_FORMAT = CUDNN_TENSOR_NCHW; + + int code; + + val miniBatch = input.size(0); + val outDepth = weights.size(0); + val inDepth = weights.size(1); + val kH = weights.size(2); + val kW = weights.size(3); + + CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, null, CNN2DFormat.NCHW); //Note hardcoded NCHW due to above + input = args.getInput(); + val inH = input.size(2); + val inW = input.size(3); + val srcStride = input.stride(); + val outSize = args.getOutSize(); + + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + INDArray z = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, weights.dataType(), new long[] {(int) miniBatch, (int) outDepth, outSize[0], outSize[1]}); + + code = cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, + (int) srcStride[0], (int) srcStride[1], (int) srcStride[2], (int) srcStride[3]); + checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + code = cudnnSetFilter4dDescriptor(cudnnContext.filterDesc, dataType, TENSOR_FORMAT, (int) outDepth, (int) inDepth, (int) kH, (int) kW); + checkCudnn(true, "cudnnSetFilter4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + code = cudnnSetConvolution2dDescriptor(cudnnContext.convDesc, pad[0], pad[1], strides[0], strides[1], dilation[0], + dilation[1], CUDNN_CROSS_CORRELATION, dataType); + checkCudnn(true, "cudnnSetConvolution2dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + + // find dimension of convolution output + // checkCudnn(cudnnGetConvolution2dForwardOutputDim(cudnnContext.convDesc, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, n, c, h, w)); + // INDArray z = Nd4j.createUninitialized(new int[]{n[0],c[0],h[0],w[0]},'c'); + + + int[] algo = new int[1]; + val dstStride = z.stride(); + code = cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) outDepth, (int) outSize[0], + (int) outSize[1], (int) dstStride[0], (int) dstStride[1], (int) dstStride[2], (int) dstStride[3]); + checkCudnn(true, "cudnnSetTensor4dDescriptorEx", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + if (mode == AlgoMode.USER_SPECIFIED && fwdAlgo != null) { + switch (fwdAlgo) { + case IMPLICIT_GEMM: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; + break; + case IMPLICIT_PRECOMP_GEMM: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + break; + case GEMM: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_GEMM; + break; + case DIRECT: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_DIRECT; + break; + case FFT: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT; + break; + case FFT_TILING: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING; + break; + case WINOGRAD: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD; + break; + case WINOGRAD_NONFUSED: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED; + break; + case COUNT: + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + break; + default: + throw new IllegalArgumentException("Unknown FwdAlgo: " + fwdAlgo); + } + } else { + /* + code = cudnnGetConvolutionForwardAlgorithm_v7(cudnnContext, cudnnContext.srcTensorDesc, + cudnnContext.filterDesc, cudnnContext.convDesc, + cudnnContext.dstTensorDesc, mode == AlgoMode.NO_WORKSPACE + ? CUDNN_CONVOLUTION_FWD_ : CUDNN_CONVOLUTION_FWD_PREFER_FASTEST, + 0, algo); + */ + + val cdf = new cudnnConvolutionFwdAlgoPerf_t(); + val count = new int[1]; + code = cudnnFindConvolutionForwardAlgorithm(cudnnContext, cudnnContext.srcTensorDesc, cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, 1, count, cdf); + + if(code != CUDNN_STATUS_SUCCESS){ + //If CuDNN can't infer algorithm - try IMPLICIT_GEMM + //Why this specifically? According to the docs, it seems to have the least number of restrictions + // to things like dilation + + OneTimeLogger.warn(log, "Error getting CuDNN forward algorithm - falling back on IMPLICIT_GEMM"); + mode = AlgoMode.USER_SPECIFIED; + fwdAlgo = FwdAlgo.IMPLICIT_GEMM; + algo[0] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; + } + + algo[0] = cdf.algo(); + } + + if(log.isTraceEnabled()){ + FwdAlgo a = FwdAlgo.values()[algo[0]]; + log.trace("CudnnConvolutionHelper forward algorithm selection: mode {}, algorithm {}", mode, a); + } + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareAction(z, input, weights, bias); + Pointer srcData = allocator.getPointer(input, context); + Pointer filterData = allocator.getPointer(weights, context); + Pointer biasData = allocator.getPointer(bias, context); + Pointer dstData = allocator.getPointer(z, context); + + code = cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream())); + checkCudnn(true, "cudnnSetStream", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + code = cudnnGetConvolutionForwardWorkspaceSize(cudnnContext, cudnnContext.srcTensorDesc, + cudnnContext.filterDesc, cudnnContext.convDesc, cudnnContext.dstTensorDesc, algo[0], + sizeInBytes); + checkCudnn(true, "cudnnGetConvolutionForwardWorkspaceSize", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); + if (workSpace == null || sizeInBytes.get(0) > workSpace.capacity()) { + if(log.isTraceEnabled()){ + if(workSpace == null){ + log.trace("CudnnConvolutionHelper preOutput: allocating initial workspace of size {} ({})", + sizeInBytes.get(), BinaryByteUnit.format(sizeInBytes.get(), "#.00")); + } else { + log.trace("CudnnConvolutionHelper preOutput: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", + workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"), + sizeInBytes.get(), BinaryByteUnit.format(sizeInBytes.get(), "#.00")); + } + } + if(workSpace != null) + workSpace.deallocate(); + workSpace = new DataCache(sizeInBytes.get(0)); + workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace); + } + code = cudnnConvolutionForward(cudnnContext, alpha, cudnnContext.srcTensorDesc, srcData, + cudnnContext.filterDesc, filterData, cudnnContext.convDesc, algo[0], workSpace, + workSpace.capacity(), beta, cudnnContext.dstTensorDesc, dstData); + checkCudnn(true, "cudnnConvolutionForward", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + + code = cudnnSetTensor4dDescriptor(cudnnContext.biasTensorDesc, TENSOR_FORMAT, dataType, 1, (int) outDepth, 1, 1); + checkCudnn(true, "cudnnSetTensor4dDescriptor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + code = cudnnAddTensor(cudnnContext, alpha, cudnnContext.biasTensorDesc, biasData, alpha, + cudnnContext.dstTensorDesc, dstData); + checkCudnn(true, "cudnnAddTensor", code, input, weights, bias, null, kernel, strides, pad, mode, fwdAlgo, null, null, convolutionMode, dilation); + + allocator.registerAction(context, z, input, weights, bias); + + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + + if(origNHWC){ + z = z.permute(0,2,3,1); //NCHW to NHWC + } + + return z; + } + + private void checkCudnn(boolean forward, String step, int code, INDArray input, INDArray weights, INDArray bias, INDArray delta, + int[] kernel, int[] strides, int[] pad, + AlgoMode mode, FwdAlgo fwdAlgo, BwdFilterAlgo bwdFilterAlgo, BwdDataAlgo bwdDataAlgo, ConvolutionMode convolutionMode, int[] dilation) { + + if (code != CUDNN_STATUS_SUCCESS) { + StringBuilder sb = new StringBuilder(); + sb.append("CuDNN error = ").append(code).append(": ").append(cudnnGetErrorString(code).getString()) + .append(" during ") + .append(forward ? "forward pass" : "backward pass") + .append(" - step ").append(step) + .append(": inputShape=").append(Arrays.toString(input.shape())) + .append(", weightsShape=").append(Arrays.toString(weights.shape())) + .append(", biasShape=").append(bias == null ? null : Arrays.toString(bias.shape())); + if (!forward) { + sb.append(", gradientShape=").append(Arrays.toString(delta.shape())); + } + sb.append(", kernel=").append(Arrays.toString(kernel)) + .append(", stride=").append(Arrays.toString(strides)) + .append(", padding=").append(Arrays.toString(pad)) + .append(", dilation=").append(Arrays.toString(dilation)) + .append(", AlgoMode=").append(mode); + if (forward) { + sb.append(", fwdAlgo=").append(fwdAlgo); + } else { + sb.append(", bwdFilterAlgo=").append(bwdFilterAlgo) + .append(", bwdDataAlgo=").append(bwdDataAlgo); + } + sb.append(", convolutionMode=").append(convolutionMode); + + throw new RuntimeException(sb.toString()); + } + } + + @Override + public INDArray activate(INDArray z, IActivation afn, boolean training) { + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + INDArray activation = z; + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareAction(z); + Pointer dstData = allocator.getPointer(z, context); + + checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); + switch (afn.toString()) { + case "identity": + break; + case "sigmoid": + checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_SIGMOID, + CUDNN_PROPAGATE_NAN, 0)); + checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + break; + case "relu": + checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_RELU, + CUDNN_PROPAGATE_NAN, 0)); + checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + break; + case "tanh": + checkCudnn(cudnnSetActivationDescriptor(cudnnContext.activationDesc, CUDNN_ACTIVATION_TANH, + CUDNN_PROPAGATE_NAN, 0)); + checkCudnn(cudnnActivationForward(cudnnContext, cudnnContext.activationDesc, alpha, + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + break; + case "softmax": + checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + break; + case "logsoftmax": + checkCudnn(cudnnSoftmaxForward(cudnnContext, CUDNN_SOFTMAX_LOG, CUDNN_SOFTMAX_MODE_CHANNEL, alpha, + cudnnContext.dstTensorDesc, dstData, beta, cudnnContext.dstTensorDesc, dstData)); + break; + default: + activation = null; + } + + allocator.registerAction(context, activation); + + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + + return activation; + } + + /** + * @param poolingType Used when preparing data for subsampling layers ONLY. Null for convolution layers + * @return + */ + public static CudnnForwardArgs getCudnnForwardArgs(INDArray input, int[] kernel, int[] strides, int[] padding, int[] dilation, + ConvolutionMode convolutionMode, PoolingType poolingType, CNN2DFormat format){ + INDArray origInput = input; + + //Check if we need to dup the input: views, non-contiguous, etc. CuDNN also seems to have has issues if strides + // are non-default for C order - even if they *should* be OK otherwise + if(input.isView() || !Shape.hasDefaultStridesForShape(input)){ + input = input.dup('c'); + } + + boolean nchw = format == CNN2DFormat.NCHW; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + + val inH = input.size(hIdx); + val inW = input.size(wIdx); + + boolean manualPadBottom = false; + boolean manualPadRight = false; + + int[] outSize; + if (convolutionMode == ConvolutionMode.Same) { + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation + padding = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); + int[] padBottomRight = ConvolutionUtils.getSameModeBottomRightPadding(outSize, new int[] {(int) inH, (int) inW}, kernel, strides, dilation); + if(!Arrays.equals(padding, padBottomRight)){ + /* + CuDNN - even as of 7.1 (CUDA 9.1) still doesn't have support for proper SAME mode padding (i.e., asymmetric + padding) - padding can *only* be specified as the same amount for both the top/bottom, and for left/right. + In SAME mode padding, sometimes these are the same - but often they are not. + Note that when they differ, the bottom or right padding will be exactly 1 more than the top or left padding. + As per TF, we'll manually pad here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/conv_ops.cc#L571-L607 + */ + manualPadBottom = (padding[0] != padBottomRight[0]); + manualPadRight = (padding[1] != padBottomRight[1]); + + //NCHW format + long[] newShape; + if(nchw){ + newShape = new long[]{input.size(0), input.size(1), + input.size(2) + (manualPadBottom ? 1 : 0), + input.size(3) + (manualPadRight ? 1 : 0)}; + } else { + newShape = new long[]{input.size(0), + input.size(1) + (manualPadBottom ? 1 : 0), + input.size(2) + (manualPadRight ? 1 : 0), + input.size(3)}; + } + INDArray newInput; + if(poolingType == null || poolingType != PoolingType.MAX){ + newInput = Nd4j.create(input.dataType(), newShape); + } else { + //For max pooling, we don't want to include the padding in the maximum values. But, CuDNN doesn't knowm + // that these values are padding and hence should be excluded. Instead: We'll use -infinity so that, + // if the 'real' (non-padding) values are all < 0, we take the real value, not the padding value + newInput = Nd4j.valueArrayOf(newShape, Double.NEGATIVE_INFINITY, input.dataType()); + } + + if(nchw){ + newInput.put(new INDArrayIndex[]{all(), all(), interval(0,input.size(2)), + interval(0, input.size(3))}, input); + } else { + newInput.put(new INDArrayIndex[]{all(), interval(0,input.size(1)), + interval(0, input.size(2)), all()}, input); + } + + input = newInput; + //Now: we've manually applied the "extra" bottom/right padding only - if required. Consequently, we + // now have the same amount of padding required for top/bottom, and left/right - which we'll let + // CuDNN handle + } + } else { + outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, padding, convolutionMode, dilation, format); //Also performs validation + } + + return new CudnnForwardArgs(manualPadBottom, manualPadRight, input, origInput, padding, outSize); + } + + + @AllArgsConstructor + @Data + public static class CudnnForwardArgs { + private boolean manualPadBottom; + private boolean manualPadRight; + private INDArray input; + private INDArray origInput; + private int[] padding; + private int[] outSize; + } + + @Override + public Map helperMemoryUse() { + //No memory use other than shared, and the structs (which are small) + return Collections.emptyMap(); + } + +} \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/convolution/subsampling/CudnnSubsamplingHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/convolution/subsampling/CudnnSubsamplingHelper.java new file mode 100644 index 000000000..b92810959 --- /dev/null +++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/convolution/subsampling/CudnnSubsamplingHelper.java @@ -0,0 +1,308 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.cuda.convolution.subsampling; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.layers.PoolingType; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.cuda.BaseCudnnHelper; +import org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper; +import org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingHelper; +import org.nd4j.jita.allocator.Allocator; +import org.nd4j.jita.allocator.impl.AtomicAllocator; +import org.nd4j.jita.conf.CudaEnvironment; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.executioner.GridExecutioner; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.context.CudaContext; +import org.nd4j.common.primitives.Pair; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.nn.workspace.ArrayType; + +import java.util.Collections; +import java.util.Map; + +import org.bytedeco.cuda.cudart.*; +import org.bytedeco.cuda.cudnn.*; + +import static org.bytedeco.cuda.global.cudnn.*; +import static org.deeplearning4j.cuda.convolution.CudnnConvolutionHelper.getCudnnForwardArgs; +import static org.nd4j.linalg.indexing.NDArrayIndex.all; +import static org.nd4j.linalg.indexing.NDArrayIndex.interval; + +/** + * cuDNN-based helper for the subsampling layer. + * + * @author saudet + */ +@Slf4j +public class CudnnSubsamplingHelper extends BaseCudnnHelper implements SubsamplingHelper { + + public CudnnSubsamplingHelper(DataType dataType) { + super(dataType); + } + + private static class CudnnSubsamplingContext extends CudnnContext { + + private static class Deallocator extends CudnnSubsamplingContext implements Pointer.Deallocator { + Deallocator(CudnnSubsamplingContext c) { + super(c); + } + + @Override + public void deallocate() { + destroyHandles(); + } + } + + private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(), + deltaTensorDesc = new cudnnTensorStruct(); + private cudnnPoolingStruct poolingDesc = new cudnnPoolingStruct(); + + public CudnnSubsamplingContext() { + createHandles(); + deallocator(new Deallocator(this)); + } + + public CudnnSubsamplingContext(CudnnSubsamplingContext c) { + super(c); + srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc); + dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc); + deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc); + poolingDesc = new cudnnPoolingStruct(c.poolingDesc); + } + + @Override + protected void createHandles() { + super.createHandles(); + checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc)); + checkCudnn(cudnnCreatePoolingDescriptor(poolingDesc)); + } + + @Override + protected void destroyHandles() { + checkCudnn(cudnnDestroyPoolingDescriptor(poolingDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc)); + super.destroyHandles(); + } + } + + private CudnnSubsamplingContext cudnnContext = new CudnnSubsamplingContext(); + + @Override + public Pair backpropGradient(INDArray input, INDArray epsilon, int[] kernel, int[] strides, + int[] pad, PoolingType poolingType, ConvolutionMode convolutionMode, + int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + if(dilation[0] != 1 || dilation[1] != 1){ + //CuDNN doesn't support dilated subsampling + return null; + } + + boolean nchw = format == CNN2DFormat.NCHW; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + + //We require the output as one of the arguments for backprop here + //TODO we could add cache mode support here somehow... + INDArray reduced = activate(input, true, kernel, strides, pad, poolingType, convolutionMode, dilation, format, workspaceMgr); + + val miniBatch = input.size(0); + val depth = input.size(chIdx); + + CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format); + input = args.getInput(); + val inH = input.size(hIdx); + val inW = input.size(wIdx); + val srcStride = input.stride(); + int[] outSize = args.getOutSize(); + int outH = outSize[0]; + int outW = outSize[1]; + + //subsampling doesn't have weights and thus gradients are not calculated for this layer + //only scale and reshape epsilon + Gradient retGradient = new DefaultGradient(); + + //Epsilons in shape: [miniBatch, channels, outH, outW] + //Epsilons out shape: [miniBatch, channels, inH, inW] + + int poolingMode; + switch (poolingType) { + case AVG: + poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + break; + case MAX: + poolingMode = CUDNN_POOLING_MAX; + break; + default: + return null; + } + + if (!Shape.hasDefaultStridesForShape(epsilon) || epsilon.isView()) { + // apparently not supported by cuDNN + epsilon = epsilon.dup('c'); + } + + input = input.dup(); + + val deltaStride = epsilon.stride(); + + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) outH, (int) outW, + (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx])); + checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], + kernel[1], pad[0], pad[1], strides[0], strides[1])); + + long[] outEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth}; + INDArray outEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), outEpsShape, 'c'); + + val dstStride = outEpsilon.stride(); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, + (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx])); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareAction(input, epsilon, reduced, outEpsilon); + Pointer srcData = allocator.getPointer(input, context); + Pointer epsData = allocator.getPointer(epsilon, context); + Pointer zData = allocator.getPointer(reduced, context); + Pointer dstData = allocator.getPointer(outEpsilon, context); + + checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); + checkCudnn(cudnnPoolingBackward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.deltaTensorDesc, + zData, cudnnContext.deltaTensorDesc, epsData, cudnnContext.srcTensorDesc, srcData, beta, + cudnnContext.dstTensorDesc, dstData)); + + allocator.registerAction(context, outEpsilon, input, epsilon, reduced); + + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + + //Note that: if we had to manually pad for SAME mode, we have to 'undo' this manual padding for the epsilon + // we return. The returned epsilon (i.e., dL/dIn array) has to be the same shape as the *original* input. + if(args.isManualPadBottom() || args.isManualPadRight()) { + if(nchw){ + outEpsilon = outEpsilon.get(all(), all(), + interval(0, outEpsilon.size(2) - (args.isManualPadBottom() ? 1 : 0)), + interval(0, outEpsilon.size(3) - (args.isManualPadRight() ? 1 : 0))); + } else { + outEpsilon = outEpsilon.get(all(), + interval(0, outEpsilon.size(1) - (args.isManualPadBottom() ? 1 : 0)), + interval(0, outEpsilon.size(2) - (args.isManualPadRight() ? 1 : 0)), + all()); + } + } + + return new Pair<>(retGradient, outEpsilon); + } + + + @Override + public INDArray activate(INDArray input, boolean training, int[] kernel, int[] strides, int[] pad, + PoolingType poolingType, ConvolutionMode convolutionMode, int[] dilation, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + if(dilation[0] != 1 || dilation[1] != 1){ + //CuDNN doesn't support dilated subsampling + return null; + } + + boolean nchw = format == CNN2DFormat.NCHW; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + + val miniBatch = input.size(0); + val inDepth = input.size(nchw ? 1 : 3); + + CudnnConvolutionHelper.CudnnForwardArgs args = getCudnnForwardArgs(input, kernel, strides, pad, dilation, convolutionMode, poolingType, format); + input = args.getInput(); + val inH = input.size(nchw ? 2 : 1); + val inW = input.size(nchw ? 3 : 2); + val srcStride = input.stride(); + val outSize = args.getOutSize(); + int outH = outSize[0]; + int outW = outSize[1]; + + + int poolingMode; + switch (poolingType) { + case AVG: + poolingMode = CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; + break; + case MAX: + poolingMode = CUDNN_POOLING_MAX; + break; + default: + return null; + } + + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + checkCudnn(cudnnSetPooling2dDescriptor(cudnnContext.poolingDesc, poolingMode, CUDNN_PROPAGATE_NAN, kernel[0], + kernel[1], pad[0], pad[1], strides[0], strides[1])); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) inH, (int) inW, + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); + + long[] outShape = nchw ? new long[] {miniBatch, inDepth, outH, outW} : new long[] {miniBatch, outH, outW, inDepth}; + INDArray reduced = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c'); + + val dstStride = reduced.stride(); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, (int) miniBatch, (int) inDepth, (int) outH, (int) outW, + (int) dstStride[0], (int) dstStride[chIdx], (int) dstStride[hIdx], (int) dstStride[wIdx])); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareAction(input, reduced); + Pointer srcData = allocator.getPointer(input, context); + Pointer dstData = allocator.getPointer(reduced, context); + + checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); + checkCudnn(cudnnPoolingForward(cudnnContext, cudnnContext.poolingDesc, alpha, cudnnContext.srcTensorDesc, + srcData, beta, cudnnContext.dstTensorDesc, dstData)); + + allocator.registerAction(context, reduced, input); + + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + + return reduced; + } + + @Override + public Map helperMemoryUse() { + //No persistent memory use other than the structs (which are small) + return Collections.emptyMap(); + } + +} diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/dropout/CudnnDropoutHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/dropout/CudnnDropoutHelper.java new file mode 100644 index 000000000..83fd9c7f0 --- /dev/null +++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/dropout/CudnnDropoutHelper.java @@ -0,0 +1,245 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.cuda.dropout; + +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import com.jakewharton.byteunits.BinaryByteUnit; +import org.bytedeco.javacpp.*; +import org.deeplearning4j.nn.conf.dropout.DropoutHelper; +import org.deeplearning4j.cuda.BaseCudnnHelper; +import org.nd4j.jita.allocator.Allocator; +import org.nd4j.jita.allocator.impl.AtomicAllocator; +import org.nd4j.jita.conf.CudaEnvironment; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.context.CudaContext; +import org.nd4j.common.util.ArrayUtil; + +import org.bytedeco.cuda.cudart.*; +import org.bytedeco.cuda.cudnn.*; + +import java.util.Collections; +import java.util.Map; + +import static org.bytedeco.cuda.global.cudnn.*; + +/** + * CuDNN dropout helper + * + * Note that for repeatability between calls (for example, for gradient checks), we need to do two things: + * (a) set the ND4J RNG seed + * (b) clear the rngStates field + * + * @author Alex Black + */ +@Data +@Slf4j +public class CudnnDropoutHelper extends BaseCudnnHelper implements DropoutHelper { + + private static class CudnnDropoutContext extends CudnnContext { + + private static class Deallocator extends CudnnDropoutContext implements Pointer.Deallocator { + Deallocator(CudnnDropoutContext c) { + super(c); + } + + @Override + public void deallocate() { + destroyHandles(); + } + } + + private cudnnTensorStruct xTensorDesc = new cudnnTensorStruct(); //Input + private cudnnTensorStruct dxTensorDesc = new cudnnTensorStruct(); //Grad at input + private cudnnTensorStruct yTensorDesc = new cudnnTensorStruct(); //Output + private cudnnTensorStruct dyTensorDesc = new cudnnTensorStruct(); //Grad at output + private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct(); + + public CudnnDropoutContext() { + createHandles(); + deallocator(new Deallocator(this)); + } + + public CudnnDropoutContext(CudnnDropoutContext c) { + super(c); + xTensorDesc = new cudnnTensorStruct(c.xTensorDesc); + dxTensorDesc = new cudnnTensorStruct(c.dxTensorDesc); + yTensorDesc = new cudnnTensorStruct(c.yTensorDesc); + dyTensorDesc = new cudnnTensorStruct(c.dyTensorDesc); + dropoutDesc = new cudnnDropoutStruct(c.dropoutDesc); + } + + @Override + protected void createHandles() { + super.createHandles(); + checkCudnn(cudnnCreateTensorDescriptor(xTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dxTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(yTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dyTensorDesc)); + checkCudnn(cudnnCreateDropoutDescriptor(dropoutDesc)); + } + + @Override + protected void destroyHandles() { + checkCudnn(cudnnDestroyTensorDescriptor(xTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dxTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(yTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dyTensorDesc)); + checkCudnn(cudnnDestroyDropoutDescriptor(dropoutDesc)); + super.destroyHandles(); + } + } + + private CudnnDropoutContext cudnnContext = new CudnnDropoutContext(); + private boolean initializedDescriptor = false; + private DataCache rngStates; //"Pointer to user-allocated GPU memory that will hold random number generator states." + private DataCache mask; //Mask: persistence between forward and backward + private SizeTPointer stateSizeBytesPtr; + private SizeTPointer reserveSizeBytesPtr; + private float lastInitializedP; + + public CudnnDropoutHelper(DataType dataType){ + super(dataType); + } + + //@Override + public Map helperMemoryUse() { + return Collections.emptyMap(); + } + + @Override + public boolean checkSupported() { + return true; + } + + @Override + public void applyDropout(INDArray input, INDArray resultArray, double dropoutInputRetainProb) { + float p = (float)(1.0 - dropoutInputRetainProb); //CuDNN uses p = probability of setting to 0. We use p = probability of retaining + + //TODO int cast + int[] inShape = adaptForTensorDescr(ArrayUtil.toInts(input.shape())); + int[] inStride = adaptForTensorDescr(ArrayUtil.toInts(input.stride())); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.xTensorDesc, dataType, inShape.length, inShape, inStride)); + + int[] outShape = adaptForTensorDescr(ArrayUtil.toInts(resultArray.shape())); + int[] outStride = adaptForTensorDescr(ArrayUtil.toInts(resultArray.stride())); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.yTensorDesc, dataType, outShape.length, outShape, outStride)); + + + if(stateSizeBytesPtr == null){ + stateSizeBytesPtr = new SizeTPointer(1); + reserveSizeBytesPtr = new SizeTPointer(1); + } + checkCudnn(cudnnDropoutGetStatesSize(cudnnContext, stateSizeBytesPtr)); + long rngStateSizeBytes = stateSizeBytesPtr.get(); + checkCudnn(cudnnDropoutGetReserveSpaceSize(cudnnContext.xTensorDesc, reserveSizeBytesPtr)); + long maskReserveSizeBytes = reserveSizeBytesPtr.get(); + + if(rngStates == null || rngStates.capacity() < rngStateSizeBytes){ + if(log.isTraceEnabled()){ + if(rngStates == null){ + log.trace("CudnnDropoutHelper: Allocating intial RNG states workspace of size {} ({})", rngStateSizeBytes, + BinaryByteUnit.format(rngStateSizeBytes, "#.00")); + } else { + log.trace("CudnnDropoutHelper: Deallocating RNG states of size {} ({}), allocating new workspace of size {} ({})", + rngStates.capacity(), BinaryByteUnit.format(rngStates.capacity(), "#.00"), + rngStateSizeBytes, BinaryByteUnit.format(rngStateSizeBytes, "#.00")); + } + } + + if(rngStates != null) + rngStates.deallocate(); + //states = "Pointer to user-allocated GPU memory that will hold random number generator states." + rngStates = new DataCache(rngStateSizeBytes); + initializedDescriptor = false; + } + if(mask == null || mask.capacity() < maskReserveSizeBytes){ + if(log.isTraceEnabled()){ + if(mask == null){ + log.trace("CudnnDropoutHelper: Allocating intial mask array of size {} ({})", maskReserveSizeBytes, + BinaryByteUnit.format(maskReserveSizeBytes, "#.00")); + } else { + log.trace("CudnnDropoutHelper: Deallocating mask array of size {} ({}), allocating new mask array of size {} ({})", + mask.capacity(), BinaryByteUnit.format(mask.capacity(), "#.00"), + maskReserveSizeBytes, BinaryByteUnit.format(maskReserveSizeBytes, "#.00")); + } + } + + if(mask != null) + mask.deallocate(); + //mask = "Pointer to user-allocated GPU memory used by this function. It is expected + //that contents of reserveSpace doe not change between cudnnDropoutForward and + //cudnnDropoutBackward calls." + mask = new DataCache(maskReserveSizeBytes); + } + + //Dropout descriptor: (re)initialize if required + if(!initializedDescriptor || p != lastInitializedP) { + if(log.isTraceEnabled()){ + log.trace("CudnnDropoutHelper: (re)initializing dropout descriptor"); + } + //NOTE: cudnnSetDropoutDescriptor has some internal computation/initialization, and hence is expensive to + // call - so we want to call this as infrequently as possible, and cache the result + long seed = Nd4j.getRandom().nextLong(); + lastInitializedP = p; + checkCudnn(cudnnSetDropoutDescriptor(cudnnContext.dropoutDesc, cudnnContext, p, rngStates, rngStates.capacity(), seed)); + initializedDescriptor = true; + } + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareAction(input, resultArray); + Pointer xPtr = allocator.getPointer(input, context); + Pointer yPtr = allocator.getPointer(resultArray, context); + + checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); + checkCudnn(cudnnDropoutForward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.xTensorDesc, xPtr, + cudnnContext.yTensorDesc, yPtr, mask, mask.capacity())); + + allocator.registerAction(context, input, resultArray); + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + } + + @Override + public void backprop(INDArray gradAtOutput, INDArray gradAtInput) { + int[] gradAtOutShape = adaptForTensorDescr(ArrayUtil.toInts(gradAtOutput.shape())); + int[] gradAtOutStride = adaptForTensorDescr(ArrayUtil.toInts(gradAtOutput.stride())); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dyTensorDesc, dataType, gradAtOutShape.length, gradAtOutShape, gradAtOutStride)); + + int[] gradAtInShape = adaptForTensorDescr(ArrayUtil.toInts(gradAtInput.shape())); + int[] gradAtInStride = adaptForTensorDescr(ArrayUtil.toInts(gradAtInput.stride())); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dxTensorDesc, dataType, gradAtInShape.length, gradAtInShape, gradAtInStride)); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareAction(gradAtOutput, gradAtInput); + Pointer dyPtr = allocator.getPointer(gradAtOutput, context); + Pointer dxPtr = allocator.getPointer(gradAtInput, context); + + checkCudnn(cudnnDropoutBackward(cudnnContext, cudnnContext.dropoutDesc, cudnnContext.dyTensorDesc, dyPtr, + cudnnContext.dxTensorDesc, dxPtr, mask, mask.capacity())); + + allocator.registerAction(context, gradAtOutput, gradAtInput); + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + } +} diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnBatchNormalizationHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnBatchNormalizationHelper.java new file mode 100644 index 000000000..fea813aa0 --- /dev/null +++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnBatchNormalizationHelper.java @@ -0,0 +1,384 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.cuda.normalization; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.cuda.BaseCudnnHelper; +import org.deeplearning4j.nn.layers.normalization.BatchNormalizationHelper; +import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.nd4j.jita.allocator.Allocator; +import org.nd4j.jita.allocator.impl.AtomicAllocator; +import org.nd4j.jita.conf.CudaEnvironment; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.executioner.GridExecutioner; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.context.CudaContext; +import org.nd4j.common.primitives.Pair; +import org.nd4j.common.util.ArrayUtil; + +import java.util.HashMap; +import java.util.Map; + +import org.bytedeco.cuda.cudart.*; +import org.bytedeco.cuda.cudnn.*; + +import static org.bytedeco.cuda.global.cudnn.*; + +/** + * cuDNN-based helper for the batch normalization layer. + * + * @author saudet + */ +@Slf4j +public class CudnnBatchNormalizationHelper extends BaseCudnnHelper implements BatchNormalizationHelper { + + public CudnnBatchNormalizationHelper(DataType dataType) { + super(dataType); + } + + private static class CudnnBatchNormalizationContext extends CudnnContext { + + private static class Deallocator extends CudnnBatchNormalizationContext implements Pointer.Deallocator { + Deallocator(CudnnBatchNormalizationContext c) { + super(c); + } + + @Override + public void deallocate() { + destroyHandles(); + } + } + + private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(), + deltaTensorDesc = new cudnnTensorStruct(), gammaBetaTensorDesc = new cudnnTensorStruct(); + + public CudnnBatchNormalizationContext() { + createHandles(); + deallocator(new Deallocator(this)); + } + + public CudnnBatchNormalizationContext(CudnnBatchNormalizationContext c) { + super(c); + srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc); + dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc); + deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc); + gammaBetaTensorDesc = new cudnnTensorStruct(c.gammaBetaTensorDesc); + } + + @Override + protected void createHandles() { + super.createHandles(); + checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(gammaBetaTensorDesc)); + } + + @Override + protected void destroyHandles() { + checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(gammaBetaTensorDesc)); + super.destroyHandles(); + } + } + + protected final int batchNormMode = CUDNN_BATCHNORM_SPATIAL; // would need to increase rank of gamma and beta for CUDNN_BATCHNORM_PER_ACTIVATION + + private CudnnBatchNormalizationContext cudnnContext = new CudnnBatchNormalizationContext(); + private INDArray meanCache; + private INDArray varCache; + private double eps; + + public boolean checkSupported(double eps, boolean isFixedGammaBeta) { + boolean supported = checkSupported(); + if (eps < CUDNN_BN_MIN_EPSILON) { + supported = false; + log.warn("Not supported: eps < CUDNN_BN_MIN_EPSILON (" + eps + " < " + CUDNN_BN_MIN_EPSILON + ")"); + } + return supported; + } + + @Override + public Pair backpropGradient(INDArray input, INDArray epsilon, long[] shape, INDArray gamma, INDArray beta, + INDArray dGammaView, INDArray dBetaView, double eps, CNN2DFormat format, LayerWorkspaceMgr layerWorkspaceMgr) { + + boolean nchw = format == CNN2DFormat.NCHW; + + this.eps = eps; + + int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + + val miniBatch = (int) input.size(0); + val depth = (int) input.size(chIdx); + val inH = (int) input.size(hIdx); + val inW = (int) input.size(wIdx); + + final boolean isHalf = (input.dataType() == DataType.HALF); + INDArray gammaOrig = null; + INDArray dGammaViewOrig = null; + INDArray dBetaViewOrig = null; + if(isHalf) { //Convert FP16 to FP32 if required (CuDNN BN doesn't support FP16 for these params, only for input/output) + gammaOrig = gamma; + dGammaViewOrig = dGammaView; + dBetaViewOrig = dBetaView; + /* + From CuDNN docs: bnScale, resultBnScaleDiff, resultBnBiasDiff, savedMean, savedInvVariance + "Note: The data type of this tensor descriptor must be 'float' for FP16 and FP32 input tensors, and 'double' + for FP64 input tensors." + >> Last 2 are the meanCache and varCache; first 3 are below + */ + gamma = gamma.castTo(DataType.FLOAT); + dGammaView = dGammaView.castTo(DataType.FLOAT); + dBetaView = dBetaView.castTo(DataType.FLOAT); + } + + Gradient retGradient = new DefaultGradient(); + + if (!Shape.hasDefaultStridesForShape(epsilon)) { + // apparently not supported by cuDNN + epsilon = epsilon.dup('c'); + } + + val srcStride = ArrayUtil.toInts(input.stride()); + val deltaStride = ArrayUtil.toInts(epsilon.stride()); + + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, + (int) srcStride[0], (int) srcStride[chIdx], (int) srcStride[hIdx], (int) srcStride[wIdx])); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, (int) miniBatch, (int) depth, (int) inH, (int) inW, + (int) deltaStride[0], (int) deltaStride[chIdx], (int) deltaStride[hIdx], (int) deltaStride[wIdx])); + + long[] nextEpsShape = nchw ? new long[] {miniBatch, depth, inH, inW} : new long[] {miniBatch, inH, inW, depth}; + INDArray nextEpsilon = layerWorkspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), nextEpsShape, 'c'); + val dstStride = ArrayUtil.toInts(nextEpsilon.stride()); + + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, + dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx])); + checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(gamma.data().dataType()), (int)shape[0], + (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, epsilon, nextEpsilon, gamma, + dGammaView, dBetaView); + Pointer srcData = allocator.getPointer(input, context); + Pointer epsData = allocator.getPointer(epsilon, context); + Pointer dstData = allocator.getPointer(nextEpsilon, context); + Pointer gammaData = allocator.getPointer(gamma, context); + Pointer dGammaData = allocator.getPointer(dGammaView, context); + Pointer dBetaData = allocator.getPointer(dBetaView, context); + Pointer meanCacheData = allocator.getPointer(meanCache, context); + Pointer varCacheData = allocator.getPointer(varCache, context); + + checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); + checkCudnn(cudnnBatchNormalizationBackward(cudnnContext, batchNormMode, alpha, this.beta, alpha, alpha, + cudnnContext.srcTensorDesc, srcData, cudnnContext.deltaTensorDesc, epsData, + cudnnContext.dstTensorDesc, dstData, cudnnContext.gammaBetaTensorDesc, gammaData, dGammaData, + dBetaData, eps, meanCacheData, varCacheData)); + + allocator.getFlowController().registerActionAllWrite(context, input, epsilon, nextEpsilon, gamma, dGammaView, + dBetaView); + + retGradient.setGradientFor(BatchNormalizationParamInitializer.GAMMA, dGammaView); + retGradient.setGradientFor(BatchNormalizationParamInitializer.BETA, dBetaView); + + context.syncOldStream(); + + //Convert back and assign, if required: + if(isHalf){ + gammaOrig.assign(gamma.castTo(DataType.HALF)); + dGammaViewOrig.assign(dGammaView.castTo(DataType.HALF)); + dBetaViewOrig.assign(dBetaView.castTo(DataType.HALF)); + } + + return new Pair<>(retGradient, nextEpsilon); + } + + + @Override + public INDArray preOutput(INDArray x, boolean training, long[] shape, INDArray gamma, INDArray beta, INDArray mean, + INDArray var, double decay, double eps, CNN2DFormat format, LayerWorkspaceMgr workspaceMgr) { + boolean nchw = format == CNN2DFormat.NCHW; + int cudnnTensorFormat = nchw ? CUDNN_TENSOR_NCHW : CUDNN_TENSOR_NHWC; + int chIdx = nchw ? 1 : 3; + int hIdx = nchw ? 2 : 1; + int wIdx = nchw ? 3 : 2; + + this.eps = eps; + final boolean isHalf = (x.dataType() == DataType.FLOAT16); + INDArray origGamma = gamma; + INDArray origBeta = beta; + INDArray origMean = mean; + INDArray origVar = var; + if(isHalf) { + gamma = gamma.castTo(DataType.FLOAT); + beta = beta.castTo(DataType.FLOAT); + mean = mean.castTo(DataType.FLOAT); + var = var.castTo(DataType.FLOAT); + } + + //Notation difference between CuDNN and our implementation: + //Us: runningMean = (1-decay) * batchMean + decay * runningMean + //CuDNN: runningMean = decay * batchMean + (1-decay) * runningMean + //i.e., "decay" has a different meaning... + //Disable in-place updating of running mean/variance, so that all parameter changes are done via the update/gradient + // vector. This is necessary for BatchNormalization to be safe to use in distributed gradient sharing settings + decay = 0.0; //From cudnn docs: runningMean = newMean*factor + runningMean*(1-factor). -> 0 = "in-place modification of running mean disabled" + + val miniBatch = (int) x.size(0); + val inDepth = (int) x.size(chIdx); + val inH = (int) x.size(hIdx); + val inW = (int) x.size(wIdx); + + val srcStride = ArrayUtil.toInts(x.stride()); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, + srcStride[0], srcStride[chIdx], srcStride[hIdx], srcStride[wIdx])); + + long[] actShape = nchw ? new long[] {miniBatch, inDepth, inH, inW} : new long[] {miniBatch, inH, inW, inDepth}; + INDArray activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, x.dataType(), actShape, 'c'); + + val dstStride = ArrayUtil.toInts(activations.stride()); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, + dstStride[0], dstStride[chIdx], dstStride[hIdx], dstStride[wIdx])); + + checkCudnn(cudnnSetTensor4dDescriptor(cudnnContext.gammaBetaTensorDesc, cudnnTensorFormat, toCudnnDataType(mean.data().dataType()), (int)shape[0], + (int)shape[1], shape.length > 2 ? (int)shape[2] : 1, shape.length > 3 ? (int)shape[3] : 1)); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = + allocator.getFlowController().prepareActionAllWrite(x, activations, gamma, beta, mean, var); + Pointer srcData = allocator.getPointer(x, context); + Pointer dstData = allocator.getPointer(activations, context); + Pointer gammaData = allocator.getPointer(gamma, context); + Pointer betaData = allocator.getPointer(beta, context); + Pointer meanData = allocator.getPointer(mean, context); + Pointer varData = allocator.getPointer(var, context); + + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); + if (training) { + if(meanCache == null || meanCache.length() < mean.length()){ + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + meanCache = Nd4j.createUninitialized(x.dataType(), mean.length()); + } + if(x.dataType() == DataType.HALF){ + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + meanCache = meanCache.castTo(DataType.FLOAT); + } + } + } + if(varCache == null || varCache.length() < mean.length()){ + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + varCache = Nd4j.createUninitialized(x.dataType(), mean.length()); + } + if(nd4jDataType == DataType.HALF){ + try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + varCache = varCache.castTo(DataType.FLOAT); + } + } + } + Pointer meanCacheData = allocator.getPointer(meanCache, context); + Pointer varCacheData = allocator.getPointer(varCache, context); + + checkCudnn(cudnnBatchNormalizationForwardTraining(cudnnContext, batchNormMode, this.alpha, this.beta, + cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData, + cudnnContext.gammaBetaTensorDesc, gammaData, betaData, decay, meanData, varData, eps, + meanCacheData, varCacheData)); + } else { + checkCudnn(cudnnBatchNormalizationForwardInference(cudnnContext, batchNormMode, this.alpha, this.beta, + cudnnContext.srcTensorDesc, srcData, cudnnContext.dstTensorDesc, dstData, + cudnnContext.gammaBetaTensorDesc, gammaData, betaData, meanData, varData, eps)); + } + + allocator.getFlowController().registerActionAllWrite(context, x, activations, gamma, beta, mean, var); + + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + + context.syncOldStream(); + if(training) { + AtomicAllocator.getInstance().getAllocationPoint(meanCache).tickDeviceWrite(); + AtomicAllocator.getInstance().getAllocationPoint(varCache).tickDeviceWrite(); + } + + if(training && isHalf){ + //Update the running mean and variance arrays; also gamma/beta + origMean.assign(mean.castTo(DataType.HALF)); + origVar.assign(var.castTo(DataType.HALF)); + origGamma.assign(gamma.castTo(DataType.HALF)); + origBeta.assign(beta.castTo(DataType.HALF)); + } + + return activations; + } + + @Override + public INDArray getMeanCache(DataType dataType) { + if(dataType == DataType.HALF){ + //Buffer is FP32 + return meanCache.castTo(DataType.HALF); + } + return meanCache; + } + + @Override + public INDArray getVarCache(DataType dataType) { + INDArray ret; + if(dataType == DataType.HALF){ + INDArray vc = varCache.castTo(DataType.HALF); + ret = vc.mul(vc).rdivi(1.0).subi(eps); + } else { + ret = varCache.mul(varCache).rdivi(1.0).subi(eps); + } + if(dataType == DataType.HALF){ + //Buffer is FP32 + return ret.castTo(DataType.HALF); + } + return ret; + } + + + @Override + public Map helperMemoryUse() { + Map memUse = new HashMap<>(); + memUse.put("meanCache", meanCache == null ? 0 : meanCache.length() * meanCache.data().getElementSize()); + memUse.put("varCache", varCache == null ? 0 : varCache.length() * varCache.data().getElementSize()); + return memUse; + } +} diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnLocalResponseNormalizationHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnLocalResponseNormalizationHelper.java new file mode 100644 index 000000000..e0257a3ec --- /dev/null +++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/normalization/CudnnLocalResponseNormalizationHelper.java @@ -0,0 +1,240 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.cuda.normalization; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.cuda.BaseCudnnHelper; +import org.deeplearning4j.nn.layers.normalization.LocalResponseNormalizationHelper; +import org.nd4j.jita.allocator.Allocator; +import org.nd4j.jita.allocator.impl.AtomicAllocator; +import org.nd4j.jita.conf.CudaEnvironment; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.executioner.GridExecutioner; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.context.CudaContext; +import org.nd4j.common.primitives.Pair; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.nd4j.common.util.ArrayUtil; + +import java.util.Collections; +import java.util.Map; + +import org.bytedeco.cuda.cudart.*; +import org.bytedeco.cuda.cudnn.*; + +import static org.bytedeco.cuda.global.cudnn.*; + +/** + * cuDNN-based helper for the local response normalization layer. + * + * @author saudet + */ +@Slf4j +public class CudnnLocalResponseNormalizationHelper extends BaseCudnnHelper implements LocalResponseNormalizationHelper { + + public CudnnLocalResponseNormalizationHelper(DataType dataType) { + super(dataType); + } + + private static class CudnnLocalResponseNormalizationContext extends CudnnContext { + + private static class Deallocator extends CudnnLocalResponseNormalizationContext implements Pointer.Deallocator { + Deallocator(CudnnLocalResponseNormalizationContext c) { + super(c); + } + + @Override + public void deallocate() { + destroyHandles(); + } + } + + private cudnnTensorStruct srcTensorDesc = new cudnnTensorStruct(), dstTensorDesc = new cudnnTensorStruct(), + deltaTensorDesc = new cudnnTensorStruct(); + private cudnnLRNStruct lrnDesc = new cudnnLRNStruct(); + + public CudnnLocalResponseNormalizationContext() { + createHandles(); + deallocator(new Deallocator(this)); + } + + public CudnnLocalResponseNormalizationContext(CudnnLocalResponseNormalizationContext c) { + super(c); + srcTensorDesc = new cudnnTensorStruct(c.srcTensorDesc); + dstTensorDesc = new cudnnTensorStruct(c.dstTensorDesc); + deltaTensorDesc = new cudnnTensorStruct(c.deltaTensorDesc); + lrnDesc = new cudnnLRNStruct(c.lrnDesc); + } + + @Override + protected void createHandles() { + super.createHandles(); + checkCudnn(cudnnCreateTensorDescriptor(srcTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dstTensorDesc)); + checkCudnn(cudnnCreateTensorDescriptor(deltaTensorDesc)); + checkCudnn(cudnnCreateLRNDescriptor(lrnDesc)); + } + + @Override + protected void destroyHandles() { + checkCudnn(cudnnDestroyLRNDescriptor(lrnDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(srcTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dstTensorDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(deltaTensorDesc)); + super.destroyHandles(); + } + } + + private CudnnLocalResponseNormalizationContext cudnnContext = new CudnnLocalResponseNormalizationContext(); + private INDArray activations = null; + + public boolean checkSupported(double k, double n, double alpha, double beta) { + boolean supported = checkSupported(); + if (n < CUDNN_LRN_MIN_N) { + supported = false; + log.warn("Not supported: n < CUDNN_LRN_MIN_N (" + n + " < " + CUDNN_LRN_MIN_N + ")"); + } + if (n > CUDNN_LRN_MAX_N) { + supported = false; + log.warn("Not supported: n > CUDNN_LRN_MAX_N (" + n + " > " + CUDNN_LRN_MAX_N + ")"); + } + if (k < CUDNN_LRN_MIN_K) { + supported = false; + log.warn("Not supported: k < CUDNN_LRN_MIN_K (" + k + " < " + CUDNN_LRN_MIN_K + ")"); + } + if (beta < CUDNN_LRN_MIN_BETA) { + supported = false; + log.warn("Not supported: beta < CUDNN_LRN_MIN_BETA (" + beta + " < " + CUDNN_LRN_MIN_BETA + ")"); + } + return supported; + } + + @Override + public Pair backpropGradient(INDArray input, INDArray epsilon, double k, double n, double alpha, + double beta, LayerWorkspaceMgr workspaceMgr) { + val miniBatch = (int) input.size(0); + val depth = (int) input.size(1); + val inH = (int) input.size(2); + val inW = (int) input.size(3); + + Gradient retGradient = new DefaultGradient(); + + if (!Shape.hasDefaultStridesForShape(epsilon)) { + // apparently not supported by cuDNN + epsilon = epsilon.dup('c'); + } + + val srcStride = ArrayUtil.toInts(input.stride()); + val deltaStride = ArrayUtil.toInts(epsilon.stride()); + + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, depth, inH, inW, + srcStride[0], srcStride[1], srcStride[2], srcStride[3])); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.deltaTensorDesc, dataType, miniBatch, depth, inH, inW, + deltaStride[0], deltaStride[1], deltaStride[2], deltaStride[3])); + checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k)); + + INDArray nextEpsilon = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[] {miniBatch, depth, inH, inW}, 'c'); + + val dstStride = ArrayUtil.toInts(nextEpsilon.stride()); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, depth, inH, inW, + dstStride[0], dstStride[1], dstStride[2], dstStride[3])); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = + allocator.getFlowController().prepareActionAllWrite(input, epsilon, activations, nextEpsilon); + Pointer srcData = allocator.getPointer(input, context); + Pointer epsData = allocator.getPointer(epsilon, context); + Pointer zData = allocator.getPointer(activations, context); + Pointer dstData = allocator.getPointer(nextEpsilon, context); + + checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); + checkCudnn(cudnnLRNCrossChannelBackward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1, + this.alpha, cudnnContext.deltaTensorDesc, zData, cudnnContext.deltaTensorDesc, epsData, + cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, dstData)); + + allocator.getFlowController().registerActionAllWrite(context, input, epsilon, activations, nextEpsilon); + + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + + return new Pair<>(retGradient, nextEpsilon); + } + + + @Override + public INDArray activate(INDArray input, boolean training, double k, double n, double alpha, double beta, LayerWorkspaceMgr workspaceMgr) { + val miniBatch = (int) input.size(0); + val inDepth = (int) input.size(1); + val inH = (int) input.size(2); + val inW = (int) input.size(3); + + if(!Shape.hasDefaultStridesForShape(input)){ + input = input.dup('c'); + } + + val srcStride = ArrayUtil.toInts(input.stride()); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.srcTensorDesc, dataType, miniBatch, inDepth, inH, inW, + srcStride[0], srcStride[1], srcStride[2], srcStride[3])); + + activations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, input.dataType(), new long[] {miniBatch, inDepth, inH, inW}, 'c'); + + val dstStride = ArrayUtil.toInts(activations.stride()); + checkCudnn(cudnnSetTensor4dDescriptorEx(cudnnContext.dstTensorDesc, dataType, miniBatch, inDepth, inH, inW, + dstStride[0], dstStride[1], dstStride[2], dstStride[3])); + checkCudnn(cudnnSetLRNDescriptor(cudnnContext.lrnDesc, (int) n, alpha, beta, k)); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareActionAllWrite(input, activations); + Pointer srcData = allocator.getPointer(input, context); + Pointer dstData = allocator.getPointer(activations, context); + + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueue(); + + checkCudnn(cudnnSetStream(cudnnContext, new CUstream_st(context.getCublasStream()))); + checkCudnn(cudnnLRNCrossChannelForward(cudnnContext, cudnnContext.lrnDesc, CUDNN_LRN_CROSS_CHANNEL_DIM1, + this.alpha, cudnnContext.srcTensorDesc, srcData, this.beta, cudnnContext.dstTensorDesc, + dstData)); + + allocator.getFlowController().registerActionAllWrite(context, input, activations); + + if (CudaEnvironment.getInstance().getConfiguration().isDebug()) + context.syncOldStream(); + + return activations; + } + + @Override + public Map helperMemoryUse() { + //No persistent memory use other than the structs (which are small) + return Collections.emptyMap(); + } +} diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java new file mode 100644 index 000000000..120078d07 --- /dev/null +++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java @@ -0,0 +1,659 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.cuda.recurrent; + +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import com.jakewharton.byteunits.BinaryByteUnit; +import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.DefaultGradient; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.cuda.BaseCudnnHelper; +import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn; +import org.deeplearning4j.nn.layers.recurrent.LSTMHelper; +import org.nd4j.jita.allocator.Allocator; +import org.nd4j.jita.allocator.impl.AtomicAllocator; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.activations.impl.ActivationTanH; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.jcublas.context.CudaContext; +import org.nd4j.common.primitives.Pair; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.nn.workspace.ArrayType; + +import java.util.HashMap; +import java.util.Map; + +import org.bytedeco.cuda.cudart.*; +import org.bytedeco.cuda.cudnn.*; +import static org.bytedeco.cuda.global.cudart.*; +import static org.bytedeco.cuda.global.cudnn.*; + +/** + * cuDNN-based helper for the recurrent LSTM layer (no peephole connections). + * + * @author saudet + */ +@Slf4j +public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper { + + public CudnnLSTMHelper(DataType dataType) { + super(dataType); + } + + private static class CudnnLSTMContext extends CudnnContext { + + private static class Deallocator extends CudnnLSTMContext implements Pointer.Deallocator { + Deallocator(CudnnLSTMContext c) { + super(c); + } + + @Override + public void deallocate() { + destroyHandles(); + } + } + + private cudnnTensorStruct hxDesc = new cudnnTensorStruct(), cxDesc = new cudnnTensorStruct(); + private cudnnTensorStruct hyDesc = new cudnnTensorStruct(), cyDesc = new cudnnTensorStruct(); + private cudnnTensorStruct dhxDesc = new cudnnTensorStruct(), dcxDesc = new cudnnTensorStruct(); + private cudnnTensorStruct dhyDesc = new cudnnTensorStruct(), dcyDesc = new cudnnTensorStruct(); + + private cudnnFilterStruct wDesc = new cudnnFilterStruct(), dwDesc = new cudnnFilterStruct(); + private cudnnFilterStruct linLayerMatDesc = new cudnnFilterStruct(), linLayerBiasDesc = new cudnnFilterStruct(); + + private cudnnRNNStruct rnnDesc = new cudnnRNNStruct(); + private cudnnDropoutStruct dropoutDesc = new cudnnDropoutStruct(); + private cudnnActivationStruct activationDesc = new cudnnActivationStruct(); + + public CudnnLSTMContext() { + createHandles(); + deallocator(new Deallocator(this)); + } + + public CudnnLSTMContext(CudnnLSTMContext c) { + super(c); + hxDesc = new cudnnTensorStruct(c.hxDesc); + cxDesc = new cudnnTensorStruct(c.cxDesc); + hyDesc = new cudnnTensorStruct(c.hyDesc); + cyDesc = new cudnnTensorStruct(c.cyDesc); + dhxDesc = new cudnnTensorStruct(c.dhxDesc); + dcxDesc = new cudnnTensorStruct(c.dcxDesc); + dhyDesc = new cudnnTensorStruct(c.dhyDesc); + dcyDesc = new cudnnTensorStruct(c.dcyDesc); + + wDesc = new cudnnFilterStruct(c.wDesc); + dwDesc = new cudnnFilterStruct(c.dwDesc); + linLayerMatDesc = new cudnnFilterStruct(c.linLayerMatDesc); + linLayerBiasDesc = new cudnnFilterStruct(c.linLayerBiasDesc); + + rnnDesc = new cudnnRNNStruct(c.rnnDesc); + dropoutDesc = new cudnnDropoutStruct(c.dropoutDesc); + activationDesc = new cudnnActivationStruct(c.activationDesc); + } + + @Override + protected void createHandles() { + super.createHandles(); + + checkCudnn(cudnnCreateTensorDescriptor(hxDesc)); + checkCudnn(cudnnCreateTensorDescriptor(cxDesc)); + checkCudnn(cudnnCreateTensorDescriptor(hyDesc)); + checkCudnn(cudnnCreateTensorDescriptor(cyDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dhxDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dcxDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dhyDesc)); + checkCudnn(cudnnCreateTensorDescriptor(dcyDesc)); + + checkCudnn(cudnnCreateFilterDescriptor(wDesc)); + checkCudnn(cudnnCreateFilterDescriptor(dwDesc)); + checkCudnn(cudnnCreateFilterDescriptor(linLayerMatDesc)); + checkCudnn(cudnnCreateFilterDescriptor(linLayerBiasDesc)); + + checkCudnn(cudnnCreateRNNDescriptor(rnnDesc)); + checkCudnn(cudnnCreateDropoutDescriptor(dropoutDesc)); + checkCudnn(cudnnCreateActivationDescriptor(activationDesc)); + } + + @Override + protected void destroyHandles() { + checkCudnn(cudnnDestroyActivationDescriptor(activationDesc)); + checkCudnn(cudnnDestroyDropoutDescriptor(dropoutDesc)); + checkCudnn(cudnnDestroyRNNDescriptor(rnnDesc)); + + checkCudnn(cudnnDestroyFilterDescriptor(wDesc)); + checkCudnn(cudnnDestroyFilterDescriptor(dwDesc)); + checkCudnn(cudnnDestroyFilterDescriptor(linLayerMatDesc)); + checkCudnn(cudnnDestroyFilterDescriptor(linLayerBiasDesc)); + + checkCudnn(cudnnDestroyTensorDescriptor(hxDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(cxDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(hyDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(cyDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dhxDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dcxDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dhyDesc)); + checkCudnn(cudnnDestroyTensorDescriptor(dcyDesc)); + + super.destroyHandles(); + } + } + + // These constants might eventually become variable parameters... + protected static final int NUM_LAYERS = 1; + protected static final float DROPOUT = 0; + protected static final boolean BIDIRECTIONAL = false; + protected static final int RNN_MODE = CUDNN_LSTM; + protected static final int NUM_LINEAR_LAYERS = 8; // CUDNN_LSTM + + private CudnnLSTMContext cudnnContext = new CudnnLSTMContext(); + private TensorArray xDesc = new TensorArray(); + private TensorArray yDesc = new TensorArray(); + private TensorArray dxDesc = new TensorArray(); + private TensorArray dyDesc = new TensorArray(); + private DataCache stateSpace = new DataCache(); + private DataCache reserveSpace = new DataCache(); + private DataCache weightsSpace = new DataCache(); + + private boolean initializedDropoutDescriptor = false; + + private static INDArray toCOrder(INDArray arr) { + if (arr.isView() || arr.ordering() != 'c' || !Shape.strideDescendingCAscendingF(arr)) { + arr = arr.dup('c'); + } + return arr; + } + + @Override + public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, + boolean hasPeepholeConnections) { + boolean supported = checkSupported(); + if (!(gateActivationFn instanceof ActivationSigmoid)) { + supported = false; + log.warn("Not supported: Gate activation functions != ActivationSigmoid"); + } + if (!(activationFn instanceof ActivationTanH)) { + supported = false; + log.warn("Not supported: Layer activation functions != ActivationTanH"); + } + if (hasPeepholeConnections) { + supported = false; + log.warn("Not supported: LSTM layers with peephole connections"); + } + return supported; + } + + @Override + public Pair backpropGradient(final NeuralNetConfiguration conf, + final IActivation gateActivationFn, final INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG] + final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg] + final INDArray epsilon, final boolean truncatedBPTT, final int tbpttBackwardLength, + final FwdPassReturn fwdPass, final boolean forwards, final String inputWeightKey, + final String recurrentWeightKey, final String biasWeightKey, + final Map gradientViews, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length + final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM + final LayerWorkspaceMgr workspaceMgr) { + + //Expect errors to have shape: [miniBatchSize,n^(L+1),timeSeriesLength] + val hiddenLayerSize = recurrentWeights.size(0); //i.e., n^L + val prevLayerSize = inputWeights.size(0); //n^(L-1) + val inputLayerSize = input.size(1); + val miniBatchSize = epsilon.size(0); + boolean is2dInput = epsilon.rank() < 3; //Edge case: T=1 may have shape [miniBatchSize,n^(L+1)], equiv. to [miniBatchSize,n^(L+1),1] + long timeSeriesLength = (is2dInput ? 1 : epsilon.size(2)); + + INDArray x = toCOrder(input.permute(2, 0, 1)); + INDArray dy = toCOrder(epsilon.permute(2, 0, 1)); + INDArray dx = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, prevLayerSize}, 'c'); + + INDArray iwGradientsOut = gradientViews.get(inputWeightKey); + INDArray rwGradientsOut = gradientViews.get(recurrentWeightKey); //Order: {I,F,O,G} + INDArray bGradientsOut = gradientViews.get(biasWeightKey); + + INDArray outputActivations = toCOrder(fwdPass.fwdPassOutput.permute(2, 0, 1)); + INDArray prevStepMemCellState = toCOrder(fwdPass.prevMemCell); + INDArray prevStepActivations = toCOrder(fwdPass.prevAct); + + Nd4j.getExecutioner().commit(); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareActionAllWrite(x, dy, dx, outputActivations, + prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut); + Pointer xData = allocator.getPointer(x, context); + Pointer dyData = allocator.getPointer(dy, context); + Pointer dxData = allocator.getPointer(dx, context); + Pointer outputActivationsData = allocator.getPointer(outputActivations, context); + Pointer prevMemCellStateData = allocator.getPointer(prevStepMemCellState, context); + Pointer prevStepActivationsData = allocator.getPointer(prevStepActivations, context); + Pointer iwGradientsOutData = allocator.getPointer(iwGradientsOut, context); + Pointer rwGradientsOutData = allocator.getPointer(rwGradientsOut, context); + Pointer bGradientsOutData = allocator.getPointer(bGradientsOut, context); + + CUstream_st stream = new CUstream_st(context.getCublasStream()); + checkCudnn(cudnnSetStream(cudnnContext, stream)); + + if (truncatedBPTT) { + val endIdx = Math.max(0, timeSeriesLength - tbpttBackwardLength) * miniBatchSize * hiddenLayerSize; + xData.position(endIdx * dataTypeSize); + dyData.position(endIdx * (BIDIRECTIONAL ? 2 : 1) * dataTypeSize); + outputActivationsData.position(endIdx * (BIDIRECTIONAL ? 2 : 1) * dataTypeSize); + timeSeriesLength = (int) Math.min(timeSeriesLength, tbpttBackwardLength); + } + + cudnnTensorStruct xDesc0 = xDesc.get(cudnnTensorStruct.class, 0); + + DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); + checkCudnn(cudnnRNNBackwardData(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, yDesc, + outputActivationsData, dyDesc, dyData, cudnnContext.dhyDesc, null, cudnnContext.dcyDesc, null, + cudnnContext.wDesc, weightsSpace, cudnnContext.hxDesc, prevStepActivationsData, //hx: initial hidden state of RNN + cudnnContext.cxDesc, prevMemCellStateData, //cx: initial cell state of RNN + dxDesc, dxData, //dx: gradient at input of each time step + cudnnContext.dhxDesc, null, //dhx: gradient at initial hidden state of RNN + cudnnContext.dcxDesc, null, //dcx: Gradient at initial cell state + workSpace, workSpace.limit(), reserveSpace, reserveSpace.limit())); + + // cudnnRNNBackwardWeights adds to the data in dW. + checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream)); + + checkCudnn(cudnnRNNBackwardWeights(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData, //Input data + cudnnContext.hxDesc, prevStepActivationsData, //Initial hidden state + yDesc, outputActivationsData, //Output data + workSpace, workSpace.limit(), cudnnContext.dwDesc, weightsSpace, reserveSpace, + reserveSpace.limit())); + + int[] dataType = new int[1]; + int[] format = new int[1]; + int[] nbDims = new int[1]; + int[] filterDimA = new int[3]; + Pointer linLayerMat = new Pointer(); + Pointer linLayerBias = new Pointer(); + + for (int layer = 0; layer < NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1); layer++) { + for (int linLayerID = 0; linLayerID < NUM_LINEAR_LAYERS; linLayerID++) { + checkCudnn(cudnnGetRNNLinLayerMatrixParams(cudnnContext, cudnnContext.rnnDesc, layer, xDesc0, + cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerMatDesc, + linLayerMat)); + + checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerMatDesc, 3, dataType, format, nbDims, + filterDimA)); + + checkCudnn(cudnnGetRNNLinLayerBiasParams(cudnnContext, cudnnContext.rnnDesc, layer, xDesc0, + cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerBiasDesc, + linLayerBias)); + + checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerBiasDesc, 3, dataType, format, nbDims, + filterDimA)); + + // our data is in "new, forget, output, and input gates" order (aka IFOG), each kind of weight packed together + int position = 0; + long size = 0; + Pointer data = null; + switch (linLayerID) { + case 0: + data = iwGradientsOutData; + position = 3; + size = inputLayerSize; + break; // input gate + case 1: + data = iwGradientsOutData; + position = 1; + size = inputLayerSize; + break; // forget gate + case 2: + data = iwGradientsOutData; + position = 0; + size = inputLayerSize; + break; // new gate (input modulation gate) + case 3: + data = iwGradientsOutData; + position = 2; + size = inputLayerSize; + break; // output gate + case 4: + data = rwGradientsOutData; + position = 3; + size = hiddenLayerSize; + break; // input gate + case 5: + data = rwGradientsOutData; + position = 1; + size = hiddenLayerSize; + break; // forget gate + case 6: + data = rwGradientsOutData; + position = 0; + size = hiddenLayerSize; + break; // new gate (input modulation gate) + case 7: + data = rwGradientsOutData; + position = 2; + size = hiddenLayerSize; + break; // output gate + default: + throw new RuntimeException(); + } + checkCuda(cudaMemcpyAsync(data.position(position * size * hiddenLayerSize * dataTypeSize), linLayerMat, + size * hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream)); + if (linLayerID < 4) { + checkCuda(cudaMemcpyAsync(bGradientsOutData.position(position * hiddenLayerSize * dataTypeSize), + linLayerBias, hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream)); + } + } + } + + allocator.getFlowController().registerActionAllWrite(context, x, dy, dx, outputActivations, + prevStepMemCellState, prevStepActivations, iwGradientsOut, rwGradientsOut, bGradientsOut); + + Gradient retGradient = new DefaultGradient(); + retGradient.gradientForVariable().put(inputWeightKey, iwGradientsOut); + retGradient.gradientForVariable().put(recurrentWeightKey, rwGradientsOut); + retGradient.gradientForVariable().put(biasWeightKey, bGradientsOut); + + INDArray epsilonNext = dx.permute(1, 2, 0); //i.e., what would be W^L*(delta^L)^T. Shape: [m,n^(L-1),T] + + return new Pair<>(retGradient, epsilonNext); + } + + @Override + public FwdPassReturn activate(final Layer layer, final NeuralNetConfiguration conf, + final IActivation gateActivationFn, //Activation function for the gates - sigmoid or hard sigmoid (must be found in range 0 to 1) + INDArray input, final INDArray recurrentWeights, //Shape: [hiddenLayerSize,4*hiddenLayerSize+3]; order: [wI,wF,wO,wG,wFF,wOO,wGG] + final INDArray inputWeights, //Shape: [n^(L-1),4*hiddenLayerSize]; order: [wi,wf,wo,wg] + final INDArray biases, //Shape: [4,hiddenLayerSize]; order: [bi,bf,bo,bg]^T + final boolean training, final INDArray prevOutputActivations, final INDArray prevMemCellState, + boolean forBackprop, boolean forwards, final String inputWeightKey, INDArray maskArray, //Input mask: should only be used with bidirectional RNNs + variable length + final boolean hasPeepholeConnections, //True for GravesLSTM, false for LSTM + final LayerWorkspaceMgr workspaceMgr) { + + boolean is2dInput = input.rank() < 3; //Edge case of T=1, may have shape [m,nIn], equiv. to [m,nIn,1] + val timeSeriesLength = (is2dInput ? 1 : input.size(2)); + val hiddenLayerSize = recurrentWeights.size(0); + val miniBatchSize = input.size(0); + val inputLayerSize = input.size(1); + + INDArray x = toCOrder(input.permute(2, 0, 1)); + INDArray linInputWeights = inputWeights; + INDArray linRecurrentWeights = recurrentWeights; + INDArray linBiases = biases; + + INDArray prevAct = toCOrder(prevOutputActivations); + INDArray prevMemCell = toCOrder(prevMemCellState); + + INDArray outputActivations = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, + inputWeights.dataType(), new long[] {timeSeriesLength, miniBatchSize, hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1)}, 'c'); + INDArray finalMemCellState = Nd4j.createUninitialized( inputWeights.dataType(), + new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c'); + INDArray finalStepActivations = Nd4j.createUninitialized( inputWeights.dataType(), + new long[] {/*numLayers * (bidirectional ? 2 : 1),*/ miniBatchSize, hiddenLayerSize}, 'c'); + + FwdPassReturn toReturn = new FwdPassReturn(); + toReturn.prevAct = prevAct; + toReturn.prevMemCell = prevMemCell; + + Nd4j.getExecutioner().commit(); + + + + if (timeSeriesLength > xDesc.capacity()) { + xDesc.deallocate(); + xDesc = new TensorArray(timeSeriesLength); + } + if (timeSeriesLength > yDesc.capacity()) { + yDesc.deallocate(); + yDesc = new TensorArray(timeSeriesLength); + } + if (timeSeriesLength > dxDesc.capacity()) { + dxDesc.deallocate(); + dxDesc = new TensorArray(timeSeriesLength); + } + if (timeSeriesLength > dyDesc.capacity()) { + dyDesc.deallocate(); + dyDesc = new TensorArray(timeSeriesLength); + } + + for (int i = 0; i < timeSeriesLength; i++) { + int[] dimA = {(int) miniBatchSize, (int) inputLayerSize, 1}; + int[] strideA = {(int) dimA[2] * dimA[1], dimA[2], 1}; + + checkCudnn(cudnnSetTensorNdDescriptor(xDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimA, strideA)); + checkCudnn(cudnnSetTensorNdDescriptor(dxDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimA, strideA)); + + int[] dimB = {(int) miniBatchSize, (int) hiddenLayerSize * (BIDIRECTIONAL ? 2 : 1), 1}; + int[] strideB = {dimB[2] * dimB[1], dimB[2], 1}; + + checkCudnn(cudnnSetTensorNdDescriptor(yDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimB, strideB)); + checkCudnn(cudnnSetTensorNdDescriptor(dyDesc.get(cudnnTensorStruct.class, i), dataType, 3, dimB, strideB)); + } + + int[] dimC = {NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1), (int) miniBatchSize, (int) hiddenLayerSize}; + int[] strideC = {dimC[2] * dimC[1], dimC[2], 1}; + + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.hxDesc, dataType, 3, dimC, strideC)); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.cxDesc, dataType, 3, dimC, strideC)); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.hyDesc, dataType, 3, dimC, strideC)); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.cyDesc, dataType, 3, dimC, strideC)); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dhxDesc, dataType, 3, dimC, strideC)); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dcxDesc, dataType, 3, dimC, strideC)); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dhyDesc, dataType, 3, dimC, strideC)); + checkCudnn(cudnnSetTensorNdDescriptor(cudnnContext.dcyDesc, dataType, 3, dimC, strideC)); + + checkCudnn(cudnnDropoutGetStatesSize(cudnnContext, sizeInBytes)); + long stateSize = sizeInBytes.get(0); + if (stateSize > stateSpace.capacity()) { + stateSpace.deallocate(); + stateSpace = new DataCache(stateSize); + } + stateSpace.limit(stateSize); + + if(!initializedDropoutDescriptor) { + checkCudnn(cudnnSetDropoutDescriptor(cudnnContext.dropoutDesc, cudnnContext, DROPOUT, stateSpace, stateSize, + Nd4j.getRandom().getSeed())); + } + + checkCudnn(cudnnSetRNNDescriptor_v6(cudnnContext, cudnnContext.rnnDesc, (int) hiddenLayerSize, NUM_LAYERS, cudnnContext.dropoutDesc, + CUDNN_LINEAR_INPUT, BIDIRECTIONAL ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL, RNN_MODE, + CUDNN_RNN_ALGO_STANDARD, dataType)); + + cudnnTensorStruct xDesc0 = xDesc.get(cudnnTensorStruct.class, 0); + checkCudnn(cudnnGetRNNParamsSize(cudnnContext, cudnnContext.rnnDesc, xDesc0, sizeInBytes, dataType)); + long weightsSize = sizeInBytes.get(0); + if (weightsSize > weightsSpace.capacity()) { + weightsSpace.deallocate(); + weightsSpace = new DataCache(weightsSize); + } + weightsSpace.limit(weightsSize); + + int[] dimW = {(int) weightsSize / dataTypeSize, 1, 1}; + + checkCudnn(cudnnSetFilterNdDescriptor(cudnnContext.wDesc, dataType, CUDNN_TENSOR_NCHW, 3, dimW)); + checkCudnn(cudnnSetFilterNdDescriptor(cudnnContext.dwDesc, dataType, CUDNN_TENSOR_NCHW, 3, dimW)); + + checkCudnn(cudnnGetRNNWorkspaceSize(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, sizeInBytes)); + long workSize = sizeInBytes.get(0); + DataCache workSpace = workspaceMgr.getHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY); + if (workSpace == null || workSize > workSpace.capacity()) { + if(log.isTraceEnabled()){ + if(workSpace == null){ + log.trace("CudnnLSTMHelper activate: Allocating initial workspace of size {} ({})", workSize, + BinaryByteUnit.format(workSize, "#.00")); + } else { + log.trace("CudnnLSTMHelper activate: Deallocating workspace of size {} ({}), allocating new workspace of size {} ({})", + workSpace.capacity(), BinaryByteUnit.format(workSpace.capacity(), "#.00"), + workSize, BinaryByteUnit.format(workSize, "#.00")); + } + } + if(workSpace != null) + workSpace.deallocate(); + workSpace = new DataCache(workSize); + workspaceMgr.setHelperWorkspace(LayerWorkspaceMgr.CUDNN_WORKSPACE_KEY, workSpace); + } + workSpace.limit(workSize); + + checkCudnn(cudnnGetRNNTrainingReserveSize(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, + sizeInBytes)); + long reserveSize = sizeInBytes.get(0); + if (reserveSize > reserveSpace.capacity()) { + reserveSpace.deallocate(); + reserveSpace = new DataCache(reserveSize); + } + reserveSpace.limit(reserveSize); + + Allocator allocator = AtomicAllocator.getInstance(); + CudaContext context = allocator.getFlowController().prepareActionAllWrite(x, linInputWeights, + linRecurrentWeights, linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState, + finalStepActivations); + Pointer xData = allocator.getPointer(x, context); + Pointer linInputWeightsData = allocator.getPointer(linInputWeights, context); + Pointer linRecurrentWeightsData = allocator.getPointer(linRecurrentWeights, context); + Pointer linBiasesData = allocator.getPointer(linBiases, context); + Pointer prevActData = allocator.getPointer(prevAct, context); + Pointer prevMemCellData = allocator.getPointer(prevMemCell, context); + Pointer outputActivationsData = allocator.getPointer(outputActivations, context); + Pointer finalMemCellStateData = allocator.getPointer(finalMemCellState, context); + Pointer finalTimeStepActivationsData = allocator.getPointer(finalStepActivations, context); + + CUstream_st stream = new CUstream_st(context.getCublasStream()); + checkCudnn(cudnnSetStream(cudnnContext, stream)); + + checkCuda(cudaMemsetAsync(weightsSpace, 0, weightsSpace.limit(), stream)); + + int[] dataType = new int[1]; + int[] format = new int[1]; + int[] nbDims = new int[1]; + int[] filterDimA = new int[3]; + Pointer linLayerMat = new Pointer(); + Pointer linLayerBias = new Pointer(); + + for (int layerNum = 0; layerNum < NUM_LAYERS * (BIDIRECTIONAL ? 2 : 1); layerNum++) { + for (int linLayerID = 0; linLayerID < NUM_LINEAR_LAYERS; linLayerID++) { + checkCudnn(cudnnGetRNNLinLayerMatrixParams(cudnnContext, cudnnContext.rnnDesc, layerNum, xDesc0, + cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerMatDesc, + linLayerMat)); + + checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerMatDesc, 3, dataType, format, nbDims, + filterDimA)); + + checkCudnn(cudnnGetRNNLinLayerBiasParams(cudnnContext, cudnnContext.rnnDesc, layerNum, xDesc0, + cudnnContext.wDesc, weightsSpace, linLayerID, cudnnContext.linLayerBiasDesc, + linLayerBias)); + + checkCudnn(cudnnGetFilterNdDescriptor(cudnnContext.linLayerBiasDesc, 3, dataType, format, nbDims, + filterDimA)); + + // our data is in "new, forget, output, and input gates" order (aka IFOG), each kind of weight packed together + int position = 0; + long size = 0; + Pointer data = null; + switch (linLayerID) { + case 0: + data = linInputWeightsData; + position = 3; + size = inputLayerSize; + break; // input gate + case 1: + data = linInputWeightsData; + position = 1; + size = inputLayerSize; + break; // forget gate + case 2: + data = linInputWeightsData; + position = 0; + size = inputLayerSize; + break; // new gate + case 3: + data = linInputWeightsData; + position = 2; + size = inputLayerSize; + break; // output gate + case 4: + data = linRecurrentWeightsData; + position = 3; + size = hiddenLayerSize; + break; // input gate + case 5: + data = linRecurrentWeightsData; + position = 1; + size = hiddenLayerSize; + break; // forget gate + case 6: + data = linRecurrentWeightsData; + position = 0; + size = hiddenLayerSize; + break; // new gate + case 7: + data = linRecurrentWeightsData; + position = 2; + size = hiddenLayerSize; + break; // output gate + default: + throw new RuntimeException(); + } + checkCuda(cudaMemcpyAsync(linLayerMat, data.position(position * size * hiddenLayerSize * dataTypeSize), + size * hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream)); + if (linLayerID < 4) { + checkCuda(cudaMemcpyAsync(linLayerBias, + linBiasesData.position(position * hiddenLayerSize * dataTypeSize), + hiddenLayerSize * dataTypeSize, cudaMemcpyDeviceToDevice, stream)); + } + } + } + + if (training) { + checkCudnn(cudnnRNNForwardTraining(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData, + cudnnContext.hxDesc, prevActData, cudnnContext.cxDesc, prevMemCellData, cudnnContext.wDesc, + weightsSpace, yDesc, outputActivationsData, cudnnContext.hyDesc, + finalTimeStepActivationsData, cudnnContext.cyDesc, finalMemCellStateData, workSpace, + workSpace.limit(), reserveSpace, reserveSpace.limit())); + } else { + checkCudnn(cudnnRNNForwardInference(cudnnContext, cudnnContext.rnnDesc, (int) timeSeriesLength, xDesc, xData, + cudnnContext.hxDesc, prevActData, cudnnContext.cxDesc, prevMemCellData, cudnnContext.wDesc, + weightsSpace, yDesc, outputActivationsData, cudnnContext.hyDesc, + finalTimeStepActivationsData, cudnnContext.cyDesc, finalMemCellStateData, workSpace, + workSpace.limit())); + } + + allocator.getFlowController().registerActionAllWrite(context, x, linInputWeights, linRecurrentWeights, + linBiases, prevAct, prevMemCell, outputActivations, finalMemCellState, finalStepActivations); + + toReturn.fwdPassOutput = outputActivations.permute(1, 2, 0); + toReturn.lastAct = finalStepActivations; + toReturn.lastMemCell = finalMemCellState; + toReturn.prevAct = prevAct; + toReturn.prevMemCell = prevMemCell; + + return toReturn; + } + + @Override + public Map helperMemoryUse() { + Map memUse = new HashMap<>(); + memUse.put("stateStace", stateSpace.capacity()); + memUse.put("reserveSpace", reserveSpace.capacity()); + memUse.put("weightsSpace", weightsSpace.capacity()); + return memUse; + } +} diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java index 37df1f31b..92ad3c679 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/KerasFlattenRnnPreprocessor.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.modelimport.keras.preprocessors; import lombok.Data; +import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; @@ -32,6 +33,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; @Slf4j @Data +@EqualsAndHashCode(callSuper=false) public class KerasFlattenRnnPreprocessor extends BaseInputPreProcessor { private long tsLength; diff --git a/cavis-full/build.gradle b/cavis-full/build.gradle index c18c258ad..558d36365 100644 --- a/cavis-full/build.gradle +++ b/cavis-full/build.gradle @@ -22,7 +22,7 @@ dependencies { && !sproj.name.equals("Cavis") && !sproj.name.equals("cavis-datavec") && !sproj.name.equals("cavis-dnn") - && !sproj.name.equals("cavis-native") && !sproj.name.equals("cavis-native-lib") + && !sproj.name.equals("cavis-native") && !sproj.name.equals("cavis-nd4j") && !sproj.name.equals("cavis-ui") && !sproj.name.equals("cavis-zoo")) { @@ -31,7 +31,7 @@ dependencies { } // if(withCpu) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportApiElements") // if(withCuda) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportApiElements") - +/* api(projects.cavisNative.cavisNativeLib) { capabilities { //if(withCuda()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cuda-support", version: project.version) @@ -44,7 +44,7 @@ dependencies { //if(withCpu()) it.requireCapability(group: "net.brutex.cavis.cavis-native", name: "cavis-native-lib-cpu-support", version: project.version) } } - +*/ //if(withCpu()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportImplementation") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportImplementation") //if(withCuda()) api project(path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportCompileClasspath") diff --git a/cavis-native/cavis-native-lib/CMakeLists.txt b/cavis-native/cavis-native-lib/CMakeLists.txt index 3795e7bd0..24360e856 100644 --- a/cavis-native/cavis-native-lib/CMakeLists.txt +++ b/cavis-native/cavis-native-lib/CMakeLists.txt @@ -121,7 +121,7 @@ endfunction() if (SD_CUDA) #enable_language(CUDA) - find_package(CUDAToolkit 11.2 REQUIRED) + find_package(CUDAToolkit 11.4 REQUIRED) message(STATUS "CUDAToolkit_VERSION: ${CUDAToolkit_VERSION}") message(STATUS "CUDAToolkit_VERSION_MAJOR: ${CUDAToolkit_VERSION_MAJOR}") message(STATUS "CUDAToolkit_VERSION_MINOR: ${CUDAToolkit_VERSION_MINOR}") diff --git a/chooseBackend.gradle b/chooseBackend.gradle index 7a3159f59..d1a33af9e 100644 --- a/chooseBackend.gradle +++ b/chooseBackend.gradle @@ -20,11 +20,9 @@ */ ext { chip = (properties.CAVIS_CHIP ?: "cuda,cpu").toLowerCase() //the default is to build for CPU and CUDA - testChip = (properties.CAVIS_TEST_CHIP ?: " ").toLowerCase() //the default is without specific backend - logger.debug("Building for chips ${chip} and running tests with backends for ${testChip}") + logger.debug("Building for chips ${chip} and running tests with backends for ${chip}") chipList = chip.split(",") - testChipList = testChip.split(",") /* just for usability */ withCuda = { -> @@ -33,10 +31,4 @@ ext { withCpu = { -> return chip.contains("cpu") } - withCudaTest = { -> - return testChip.contains("cuda") - } - withCpuTest = { -> - return testChip.contains("cpu") - } } diff --git a/createTestBackends.gradle b/createTestBackends.gradle index a0cef6c24..638e511e2 100644 --- a/createTestBackends.gradle +++ b/createTestBackends.gradle @@ -24,7 +24,7 @@ ext { buildTarget = rootProject.ext.buildTarget apply from: new File("${project.rootProject.projectDir}/chooseBackend.gradle") - testChipList.each { thisChip -> + chipList.each { thisChip -> configurations.register("${thisChip}TestImplementation") { it.extendsFrom configurations.testImplementation, configurations.implementation @@ -79,33 +79,44 @@ ext { dependencies { - if (withCudaTest()) { + if (withCuda()) { cudaTestRuntime platform(projects.cavisCommonPlatform) cudaTestRuntime projects.cavisNative.cavisNativeJcublas + cudaTestRuntime projects.cavisDnn.cavisDnnCudnn cudaTestRuntime group: "org.bytedeco", name: "openblas" cudaTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget cudaTestRuntime group: "org.bytedeco", name: "cuda" cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: buildTarget cudaTestRuntime group: "org.bytedeco", name: "cuda", classifier: "${buildTarget}-redist" + cudaTestRuntime (project( path: ":cavis-native:cavis-native-lib", configuration: "cudaSupportRuntimeElements")) + /* cudaTestRuntime(project(":cavis-native:cavis-native-lib")) { + capabilities { it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cuda-support:1.0.0-SNAPSHOT" } } + + */ } - if (withCpuTest()) { + if (withCpu()) { cpuTestRuntime platform(projects.cavisCommonPlatform) cpuTestRuntime projects.cavisNative.cavisNativeCpu cpuTestRuntime group: "org.bytedeco", name: "openblas" cpuTestRuntime group: "org.bytedeco", name: "openblas", classifier: buildTarget cpuTestRuntime group: "org.bytedeco", name: "opencv" cpuTestRuntime group: "org.bytedeco", name: "opencv", classifier: buildTarget + cpuTestRuntime project( path: ":cavis-native:cavis-native-lib", configuration: "cpuSupportRuntimeElements") + /* cpuTestRuntime(project(":cavis-native:cavis-native-lib")) { + capabilities { it.requireCapabilities "net.brutex.cavis.native:cavis-native-lib-cpu-support:1.0.0-SNAPSHOT" } } + + */ } } } \ No newline at end of file diff --git a/settings.gradle b/settings.gradle index 17d2ee1b9..efde26230 100644 --- a/settings.gradle +++ b/settings.gradle @@ -89,6 +89,7 @@ include ':cavis-native:cavis-native-lib' include ':cavis-native:cavis-native-common' include ':cavis-dnn' include ':cavis-dnn:cavis-dnn-api' +include ':cavis-dnn:cavis-dnn-cudnn' include ':cavis-dnn:cavis-dnn-common' include ':cavis-dnn:cavis-dnn-common-tests' include ':cavis-dnn:cavis-dnn-core' @@ -151,3 +152,6 @@ include ':cavis-zoo' include ':cavis-zoo:cavis-zoo-models' include ':brutex-extended-tests' include ':cavis-full' +include 'cavis-dnn:cavis-dnn-cudnn' +findProject(':cavis-dnn:cavis-dnn-cudnn')?.name = 'cavis-dnn-cudnn' + From 8d31caffbe30bb0b3c34b342e22fa7817c7980ad Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 10 Mar 2023 16:32:41 +0100 Subject: [PATCH 114/126] Adding cuDNN support Signed-off-by: brian --- brutex-extended-tests/src/test/java/net/brutex/gan/App.java | 2 +- .../src/main/java/org/nd4j/linalg/factory/Nd4j.java | 6 +++--- .../org/deeplearning4j/common/config/DL4JClassLoading.java | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index f4feb6fdf..5d4704f2c 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -167,7 +167,7 @@ public class App { trainData.reset(); int j = 0; - for (int i = 0; i < 10; i++) { + for (int i = 0; i < 20; i++) { while (trainData.hasNext()) { j++; diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index f542e3cce..359f30b02 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -4877,7 +4877,7 @@ public class Nd4j { * Create an ndarray of zeros * * @param shape the shape of the array - * @return an ndarray with ones filled in + * @return an ndarray with zeros filled in */ public static INDArray zeros(int[] shape, char order) { checkShapeValues(shape); @@ -4896,7 +4896,7 @@ public class Nd4j { * Create an ndarray of zeros * * @param shape the shape of the array - * @return an ndarray with ones filled in + * @return an ndarray with zeros filled in */ public static INDArray zeros(@NonNull int... shape) { return Nd4j.create(shape); @@ -4907,7 +4907,7 @@ public class Nd4j { * Create an ndarray of zeros * * @param shape the shape of the array - * @return an ndarray with ones filled in + * @return an ndarray with zeros filled in */ public static INDArray zeros(@NonNull long... shape) { return Nd4j.create(shape); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java index 16f6f134a..55481d875 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/deeplearning4j/common/config/DL4JClassLoading.java @@ -105,7 +105,8 @@ public class DL4JClassLoading { throw new RuntimeException(instantiationException); } catch (InvocationTargetException instantiationException) { - log.error(String.format("InvocationTargetException was '%s'.", instantiationException.getTargetException().getMessage()), instantiationException); + log.error(String.format("---------- ----------- ---------- \nInvocationTargetException was '%s'.", instantiationException.getTargetException().getMessage()), instantiationException); + log.error(String.format("java.library.path was '%s'\n---------- ---------- ----------", System.getProperty("java.library.path"))); throw new RuntimeException(instantiationException); } } From e59e14da351e00b8d98b5cdae21bfb8ac822b694 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 10 Mar 2023 16:40:19 +0100 Subject: [PATCH 115/126] Adding cuDNN support Signed-off-by: brian --- settings.gradle | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/settings.gradle b/settings.gradle index efde26230..86efa281d 100644 --- a/settings.gradle +++ b/settings.gradle @@ -89,7 +89,7 @@ include ':cavis-native:cavis-native-lib' include ':cavis-native:cavis-native-common' include ':cavis-dnn' include ':cavis-dnn:cavis-dnn-api' -include ':cavis-dnn:cavis-dnn-cudnn' +if(withCuda()) { include ':cavis-dnn:cavis-dnn-cudnn' } include ':cavis-dnn:cavis-dnn-common' include ':cavis-dnn:cavis-dnn-common-tests' include ':cavis-dnn:cavis-dnn-core' From 67e32adb5605894791f421530e8a2750c13a324e Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 10 Mar 2023 17:17:19 +0100 Subject: [PATCH 116/126] Adding cuDNN support Signed-off-by: brian --- cavis-dnn/cavis-dnn-cudnn/build.gradle | 2 +- cavis-native/cavis-native-jcublas/build.gradle | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cavis-dnn/cavis-dnn-cudnn/build.gradle b/cavis-dnn/cavis-dnn-cudnn/build.gradle index 725ca1f85..8f40cf470 100644 --- a/cavis-dnn/cavis-dnn-cudnn/build.gradle +++ b/cavis-dnn/cavis-dnn-cudnn/build.gradle @@ -7,7 +7,7 @@ ext { dependencies { implementation platform(projects.cavisCommonPlatform) - implementation projects.cavisNative.cavisNativeJcublas + implementation project(":cavis-native:cavis-native-jcublas") implementation projects.cavisDnn.cavisDnnApi implementation projects.cavisDnn.cavisDnnNn diff --git a/cavis-native/cavis-native-jcublas/build.gradle b/cavis-native/cavis-native-jcublas/build.gradle index 0e0a9dd22..cf3372c5c 100644 --- a/cavis-native/cavis-native-jcublas/build.gradle +++ b/cavis-native/cavis-native-jcublas/build.gradle @@ -11,7 +11,8 @@ ext { dependencies { implementation platform(projects.cavisCommonPlatform) - implementation project(":cavis-native:cavis-native-blas") + //implementation project(":cavis-native:cavis-native-blas") + implementation projects.cavisNative.cavisNativeBlas implementation group: "org.bytedeco", name: "cuda" implementation group: "org.bytedeco", name: "cuda", classifier: buildTarget From 98665032b1fe8e170df0ed8053db4189db6f7da3 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 10 Mar 2023 17:25:32 +0100 Subject: [PATCH 117/126] Adding cuDNN support Signed-off-by: brian --- settings.gradle | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/settings.gradle b/settings.gradle index 86efa281d..0002d667d 100644 --- a/settings.gradle +++ b/settings.gradle @@ -152,6 +152,5 @@ include ':cavis-zoo' include ':cavis-zoo:cavis-zoo-models' include ':brutex-extended-tests' include ':cavis-full' -include 'cavis-dnn:cavis-dnn-cudnn' -findProject(':cavis-dnn:cavis-dnn-cudnn')?.name = 'cavis-dnn-cudnn' + From 4665c5a10a4c7eabd59a880c5f521d1eee9e0be3 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 22 Mar 2023 17:34:43 +0100 Subject: [PATCH 118/126] Playing with GAN Signed-off-by: brian --- brutex-extended-tests/build.gradle | 12 ++ .../src/test/java/net/brutex/gan/App.java | 200 ++++++++++++++---- .../test/resources/simplelogger.properties | 49 +++++ cavis-common-platform/build.gradle | 7 + .../transform/ColorConversionTransform.java | 16 +- .../image/transform/ShowImageTransform.java | 1 + .../java/org/nd4j/linalg/factory/Nd4j.java | 18 +- .../nd4j/common/config/ND4JClassLoading.java | 8 +- cavis-dnn/cavis-dnn-nn/build.gradle | 7 + .../deeplearning4j/nn/conf/layers/Layer.java | 2 +- .../layers/feedforward/dense/DenseLayer.java | 39 ++-- .../listeners/ScoreToChartListener.java | 62 ++++++ .../nativeblas/NativeOpsGPUInfoProvider.java | 41 ++-- .../jita/allocator/impl/AtomicAllocator.java | 3 +- .../jcublas/buffer/BaseCudaDataBuffer.java | 4 + .../buffer/BaseCudaDataBufferTest.java | 55 +++++ 16 files changed, 421 insertions(+), 103 deletions(-) create mode 100644 brutex-extended-tests/src/test/resources/simplelogger.properties create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreToChartListener.java create mode 100644 cavis-native/cavis-native-jcublas/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java diff --git a/brutex-extended-tests/build.gradle b/brutex-extended-tests/build.gradle index c15f6d325..8115c10bd 100644 --- a/brutex-extended-tests/build.gradle +++ b/brutex-extended-tests/build.gradle @@ -34,6 +34,8 @@ ext { } dependencies { + implementation platform(projects.cavisCommonPlatform) + implementation "com.fasterxml.jackson.core:jackson-databind" implementation "com.google.guava:guava" implementation projects.cavisDnn.cavisDnnCore @@ -52,6 +54,16 @@ dependencies { testImplementation "org.apache.spark:spark-sql_${scalaVersion}" testCompileOnly "org.scala-lang:scala-library" + //Rest Client + // define any required OkHttp artifacts without version + implementation("com.squareup.okhttp3:okhttp") + implementation("com.squareup.okhttp3:logging-interceptor") + + + implementation "org.bytedeco:javacv" + implementation "org.bytedeco:opencv" + implementation group: "org.bytedeco", name: "opencv", classifier: buildTarget + implementation "it.unimi.dsi:fastutil-core:8.5.8" implementation projects.cavisDnn.cavisDnnSpark.cavisDnnSparkCore diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index 5d4704f2c..bf4783145 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -21,49 +21,90 @@ package net.brutex.gan; +import java.util.Random; +import javax.ws.rs.client.ClientBuilder; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; import org.apache.commons.lang3.ArrayUtils; +import org.datavec.api.Writable; +import org.datavec.api.records.reader.RecordReader; +import org.datavec.api.split.FileSplit; +import org.datavec.image.loader.NativeImageLoader; +import org.datavec.image.recordreader.ImageRecordReader; +import org.datavec.image.transform.ColorConversionTransform; +import org.datavec.image.transform.ImageTransform; +import org.datavec.image.transform.PipelineImageTransform; +import org.datavec.image.transform.ResizeImageTransform; +import org.datavec.image.transform.ScaleImageTransform; +import org.datavec.image.transform.ShowImageTransform; +import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.PerformanceListener; +import org.deeplearning4j.optimize.listeners.ScoreToChartListener; +import org.glassfish.jersey.client.JerseyClient; +import org.glassfish.jersey.client.JerseyClientBuilder; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; +import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; + import javax.swing.*; import java.awt.*; import java.awt.image.BufferedImage; import java.io.File; import java.util.Arrays; +import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; +@Slf4j public class App { - private static final double LEARNING_RATE = 0.0002; + private static final double LEARNING_RATE = 0.000002; private static final double GRADIENT_THRESHOLD = 100.0; + + private static final int X_DIM = 28; + private static final int Y_DIM = 28; + private static final int CHANNELS = 1; + private static final int batchSize = 9; + private static final int INPUT = 128; + + private static final int OUTPUT_PER_PANEL = 4; + + private static final int ARRAY_SIZE_PER_SAMPLE = X_DIM*Y_DIM*CHANNELS; private static final IUpdater UPDATER = Adam.builder().learningRate(LEARNING_RATE).beta1(0.5).build(); - private static JFrame frame; + private static JFrame frame; + private static JFrame frame2; private static JPanel panel; + private static JPanel panel2; private static Layer[] genLayers() { return new Layer[] { - new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build(), + new DenseLayer.Builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), - new DenseLayer.Builder().nIn(256).nOut(512).build(), + new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), - new DenseLayer.Builder().nIn(512).nOut(1024).build(), + new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), - new DenseLayer.Builder().nIn(1024).nOut(784).activation(Activation.TANH).build() + new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) + .build() }; } @@ -81,6 +122,7 @@ public class App { .weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY) .list(genLayers()) + .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .build(); return conf; @@ -88,16 +130,19 @@ public class App { private static Layer[] disLayers() { return new Layer[]{ - new DenseLayer.Builder().nIn(784).nOut(1024).build(), + new DenseLayer.Builder().nOut(X_DIM*Y_DIM*CHANNELS*2).build(), //input is set by setInputType on the network new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), new DropoutLayer.Builder(1 - 0.5).build(), - new DenseLayer.Builder().nIn(1024).nOut(512).build(), + new DenseLayer.Builder().nIn(X_DIM * Y_DIM*CHANNELS*2).nOut(X_DIM*Y_DIM*CHANNELS*4).build(), //HxBxC new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), new DropoutLayer.Builder(1 - 0.5).build(), - new DenseLayer.Builder().nIn(512).nOut(256).build(), + new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS*4).nOut(X_DIM*Y_DIM*CHANNELS).build(), new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), new DropoutLayer.Builder(1 - 0.5).build(), - new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(256).nOut(1).activation(Activation.SIGMOID).build() + new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), + new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + new DropoutLayer.Builder(1 - 0.5).build(), + new OutputLayer.Builder(LossFunction.XENT).nIn(X_DIM*Y_DIM).nOut(1).activation(Activation.SIGMOID).build() }; } @@ -110,6 +155,7 @@ public class App { .weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY) .list(disLayers()) + .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .build(); return conf; @@ -135,6 +181,7 @@ public class App { .weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY) .list(layers) + .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .build(); return conf; @@ -149,7 +196,25 @@ public class App { public static void main(String... args) throws Exception { Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); - MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 42); +// MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45); + // FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/flowers"), NativeImageLoader.getALLOWED_FORMATS()); + FileSplit fileSplit = new FileSplit(new File("c:/users/brian/downloads/humans"), NativeImageLoader.getALLOWED_FORMATS()); + + + ImageTransform transform = new ColorConversionTransform(new Random(42), 7 ); + + ImageTransform transform2 = new ShowImageTransform("Tester", 30); + ImageTransform transform3 = new ResizeImageTransform(X_DIM, Y_DIM); + + ImageTransform tr = new PipelineImageTransform.Builder() + .addImageTransform(transform) //convert to GREY SCALE + .addImageTransform(transform3) + //.addImageTransform(transform2) + .build(); + + ImageRecordReader imageRecordReader = new ImageRecordReader(X_DIM, Y_DIM, CHANNELS); + imageRecordReader.initialize(fileSplit, tr); + DataSetIterator trainData = new RecordReaderDataSetIterator(imageRecordReader, batchSize ); MultiLayerNetwork gen = new MultiLayerNetwork(generator()); MultiLayerNetwork dis = new MultiLayerNetwork(discriminator()); @@ -160,27 +225,50 @@ public class App { copyParams(gen, dis, gan); - gen.setListeners(new PerformanceListener(10, true)); - dis.setListeners(new PerformanceListener(10, true)); - gan.setListeners(new PerformanceListener(10, true)); + //gen.setListeners(new PerformanceListener(10, true)); + //dis.setListeners(new PerformanceListener(10, true)); + //gan.setListeners(new PerformanceListener(10, true)); + gan.setListeners(new ScoreToChartListener("gan")); + //dis.setListeners(new ScoreToChartListener("dis")); - trainData.reset(); + gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)); + + //gan.fit(new DataSet(trainData.next().getFeatures(), Nd4j.zeros(batchSize, 1))); + //trainData.reset(); int j = 0; - for (int i = 0; i < 20; i++) { + for (int i = 0; i < 201; i++) { //epoch while (trainData.hasNext()) { j++; + DataSet next = trainData.next(); // generate data - INDArray real = trainData.next().getFeatures().muli(2).subi(1); - int batchSize = (int) real.shape()[0]; + INDArray real = next.getFeatures();//.div(255f); - INDArray fakeIn = Nd4j.rand(batchSize, 100); + //start next round if there are not enough images left to have a full batchsize dataset + if(real.length() < ARRAY_SIZE_PER_SAMPLE*batchSize) { + log.warn("Your total number of input images is not a multiple of {}, " + + "thus skipping {} images to make it fit", batchSize, real.length()/ARRAY_SIZE_PER_SAMPLE); + break; + } + + if(i%20 == 0) { + // frame2 = visualize(new INDArray[]{real}, batchSize, + // frame2 == null ? new JFrame() : frame2, true); //real has batchsize number of images + } + real.divi(255f); + +// int batchSize = (int) real.shape()[0]; + + INDArray fakeIn = Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM); INDArray fake = gan.activateSelectedLayers(0, gen.getLayers().length - 1, fakeIn); + fake = fake.reshape(batchSize, CHANNELS, X_DIM, Y_DIM); + //log.info("real has {} items.", real.length()); DataSet realSet = new DataSet(real, Nd4j.zeros(batchSize, 1)); DataSet fakeSet = new DataSet(fake, Nd4j.ones(batchSize, 1)); + DataSet data = DataSet.merge(Arrays.asList(realSet, fakeSet)); dis.fit(data); @@ -189,21 +277,29 @@ public class App { // Update the discriminator in the GAN network updateGan(gen, dis, gan); - gan.fit(new DataSet(Nd4j.rand(batchSize, 100), Nd4j.zeros(batchSize, 1))); + //gan.fit(new DataSet(Nd4j.rand(batchSize, INPUT), Nd4j.zeros(batchSize, 1))); + gan.fit(new DataSet(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1))); if (j % 10 == 1) { System.out.println("Iteration " + j + " Visualizing..."); - INDArray[] samples = new INDArray[9]; - DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1)); + INDArray[] samples = batchSize > OUTPUT_PER_PANEL ? new INDArray[OUTPUT_PER_PANEL] : new INDArray[batchSize]; - for (int k = 0; k < 9; k++) { + + for (int k = 0; k < samples.length; k++) { + //INDArray input = fakeSet2.get(k).getFeatures(); + DataSet fakeSet2 = new DataSet(fakeIn, Nd4j.ones(batchSize, 1)); INDArray input = fakeSet2.get(k).getFeatures(); + input = input.reshape(1,CHANNELS, X_DIM, Y_DIM); //batch size will be 1 here + //samples[k] = gen.output(input, false); samples[k] = gan.activateSelectedLayers(0, gen.getLayers().length - 1, input); + samples[k] = samples[k].reshape(1, CHANNELS, X_DIM, Y_DIM); + //samples[k] = + samples[k].addi(1f).divi(2f).muli(255f); } - visualize(samples); + frame = visualize(samples, 1, frame == null ? new JFrame() : frame, false); //each samples only has 1 image, thus batchElements=1 } } trainData.reset(); @@ -239,41 +335,57 @@ public class App { } } - private static void visualize(INDArray[] samples) { - if (frame == null) { - frame = new JFrame(); - frame.setTitle("Viz"); - frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); - frame.setLayout(new BorderLayout()); - - panel = new JPanel(); - - panel.setLayout(new GridLayout(samples.length / 3, 1, 8, 8)); - frame.add(panel, BorderLayout.CENTER); - frame.setVisible(true); + private static JFrame visualize(INDArray[] samples, int batchElements, JFrame frame, boolean isOrig) { + if (isOrig) { + frame.setTitle("Viz Original"); + } else { + frame.setTitle("Generated"); } - panel.removeAll(); + frame.setDefaultCloseOperation(WindowConstants.DISPOSE_ON_CLOSE); + frame.setLayout(new BorderLayout()); + JPanel panelx = new JPanel(); + + panelx.setLayout(new GridLayout(4, 4, 8, 8)); for (INDArray sample : samples) { - panel.add(getImage(sample)); + for(int i = 0; i distributionFactoryClazz = ND4JClassLoading.loadClassByName(clazzName); - memoryManager = memoryManagerClazz.newInstance(); - constantHandler = constantProviderClazz.newInstance(); - shapeInfoProvider = shapeInfoProviderClazz.newInstance(); - workspaceManager = workspaceManagerClazz.newInstance(); + memoryManager = memoryManagerClazz.getDeclaredConstructor().newInstance(); + constantHandler = constantProviderClazz.getDeclaredConstructor().newInstance(); + shapeInfoProvider = shapeInfoProviderClazz.getDeclaredConstructor().newInstance(); + workspaceManager = workspaceManagerClazz.getDeclaredConstructor().newInstance(); Class opExecutionerClazz = ND4JClassLoading .loadClassByName(pp.toString(OP_EXECUTIONER, DefaultOpExecutioner.class.getName())); - OP_EXECUTIONER_INSTANCE = opExecutionerClazz.newInstance(); + OP_EXECUTIONER_INSTANCE = opExecutionerClazz.getDeclaredConstructor().newInstance(); Constructor c2 = ndArrayFactoryClazz.getConstructor(DataType.class, char.class); INSTANCE = (NDArrayFactory) c2.newInstance(dtype, ORDER); - CONVOLUTION_INSTANCE = convolutionInstanceClazz.newInstance(); - BLAS_WRAPPER_INSTANCE = blasWrapperClazz.newInstance(); - DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.newInstance(); + CONVOLUTION_INSTANCE = convolutionInstanceClazz.getDeclaredConstructor().newInstance(); + BLAS_WRAPPER_INSTANCE = blasWrapperClazz.getDeclaredConstructor().newInstance(); + DATA_BUFFER_FACTORY_INSTANCE = dataBufferFactoryClazz.getDeclaredConstructor().newInstance(); - DISTRIBUTION_FACTORY = distributionFactoryClazz.newInstance(); + DISTRIBUTION_FACTORY = distributionFactoryClazz.getDeclaredConstructor().newInstance(); if (isFallback()) { fallbackMode.set(true); diff --git a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java index a16c7bac4..1a520c7cd 100644 --- a/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java +++ b/cavis-dnn/cavis-dnn-common/src/main/java/org/nd4j/common/config/ND4JClassLoading.java @@ -58,11 +58,13 @@ public final class ND4JClassLoading { @SuppressWarnings("unchecked") public static Class loadClassByName(String className, boolean initialize, ClassLoader classLoader) { + try { - log.info(String.format("Trying to load class [%s]", className)); - return (Class) Class.forName(className, initialize, classLoader); + Class clazz = (Class) Class.forName(className, initialize, classLoader); + log.info(String.format("Trying to load class [%s] - Success", className)); + return clazz; } catch (ClassNotFoundException classNotFoundException) { - log.error(String.format("Cannot find class [%s] of provided class-loader.", className)); + log.error(String.format("Trying to load class [%s] - Failure: Cannot find class with provided class-loader.", className)); return null; } } diff --git a/cavis-dnn/cavis-dnn-nn/build.gradle b/cavis-dnn/cavis-dnn-nn/build.gradle index 3ffdbee6a..e38b43f1d 100644 --- a/cavis-dnn/cavis-dnn-nn/build.gradle +++ b/cavis-dnn/cavis-dnn-nn/build.gradle @@ -21,6 +21,8 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" dependencies { + implementation platform(projects.cavisCommonPlatform) + implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators implementation 'org.lucee:oswego-concurrent:1.3.4' implementation projects.cavisDnn.cavisDnnCommon @@ -50,4 +52,9 @@ dependencies { implementation "com.fasterxml.jackson.core:jackson-databind" implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml" implementation "com.jakewharton.byteunits:byteunits:0.9.1" + + //Rest Client + // define any required OkHttp artifacts without version + implementation "com.squareup.okhttp3:okhttp" + implementation "com.squareup.okhttp3:logging-interceptor" } \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java index cccc3cb1b..a96ec6db7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java @@ -215,7 +215,7 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable { /** * Get the updater for the given parameter. Typically the same updater will be used for all - * updaters, but this is not necessarily the case + * parameters, but this is not necessarily the case * * @param paramName Parameter name * @return IUpdater for the parameter diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java index 77b030b4c..d2aa10406 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java @@ -30,27 +30,28 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; * @author Adam Gibson */ public class DenseLayer extends BaseLayer { - public DenseLayer(NeuralNetConfiguration conf, DataType dataType) { - super(conf, dataType); - } - @Override - public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) { - throw new UnsupportedOperationException("Not supported"); - } + public DenseLayer(NeuralNetConfiguration conf, DataType dataType) { + super(conf, dataType); + } - @Override - public boolean isPretrainLayer() { - return false; - } + @Override + public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) { + throw new UnsupportedOperationException("Not supported"); + } - @Override - public boolean hasBias(){ - return layerConf().hasBias(); - } + @Override + public boolean isPretrainLayer() { + return false; + } - @Override - public boolean hasLayerNorm(){ - return layerConf().hasLayerNorm(); - } + @Override + public boolean hasBias() { + return layerConf().hasBias(); + } + + @Override + public boolean hasLayerNorm() { + return layerConf().hasLayerNorm(); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreToChartListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreToChartListener.java new file mode 100644 index 000000000..2fc2999d6 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreToChartListener.java @@ -0,0 +1,62 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package org.deeplearning4j.optimize.listeners; + +import java.io.IOException; +import lombok.extern.slf4j.Slf4j; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import org.deeplearning4j.nn.api.Model; +import org.deeplearning4j.optimize.api.BaseTrainingListener; + +@Slf4j +public class ScoreToChartListener extends BaseTrainingListener { + + final String url = "http://bru5:8080/cavis-rest-1.0-SNAPSHOT.war/hello/hello-world?"; + final String seriesName; + + public ScoreToChartListener(String seriesName) { + this.seriesName = seriesName; + } + + @Override + public void iterationDone(Model model, int iteration, int epoch) { + double score = model.score(); + String nurl = url+"s="+score+"&n="+seriesName; + OkHttpClient client = new OkHttpClient(); + + Request request = new Request.Builder() + .url(nurl) + .build(); + + try { + Response response = client.newCall(request).execute(); + log.debug(String.format("Did send score to chart at '%s'.", nurl)); + response.body().close(); + } catch (IOException e) { + log.warn(String.format("Could not send score to chart at '%s' because %s", nurl, e.getMessage())); + } + //response.body().string(); + } + +} diff --git a/cavis-native/cavis-native-blas/src/main/java/org/nd4j/nativeblas/NativeOpsGPUInfoProvider.java b/cavis-native/cavis-native-blas/src/main/java/org/nd4j/nativeblas/NativeOpsGPUInfoProvider.java index 1a8d3950b..654825a4c 100644 --- a/cavis-native/cavis-native-blas/src/main/java/org/nd4j/nativeblas/NativeOpsGPUInfoProvider.java +++ b/cavis-native/cavis-native-blas/src/main/java/org/nd4j/nativeblas/NativeOpsGPUInfoProvider.java @@ -31,31 +31,30 @@ import java.util.List; @Slf4j public class NativeOpsGPUInfoProvider implements GPUInfoProvider { - @Override - public List getGPUs() { - NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + @Override + public List getGPUs() { + NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); - List gpus = new ArrayList<>(); + List gpus = new ArrayList<>(); + int nDevices = nativeOps.getAvailableDevices(); + if (nDevices > 0) { + for (int i = 0; i < nDevices; i++) { + try { + String name = nativeOps.getDeviceName(i); + long total = nativeOps.getDeviceTotalMemory(i); + long free = nativeOps.getDeviceFreeMemory(i); + int major = nativeOps.getDeviceMajor(i); + int minor = nativeOps.getDeviceMinor(i); - int nDevices = nativeOps.getAvailableDevices(); - if (nDevices > 0) { - for (int i = 0; i < nDevices; i++) { - try { - String name = nativeOps.getDeviceName(i); - long total = nativeOps.getDeviceTotalMemory(i); - long free = nativeOps.getDeviceFreeMemory(i); - int major = nativeOps.getDeviceMajor(i); - int minor = nativeOps.getDeviceMinor(i); - - gpus.add(new GPUInfo(name, total, free, major, minor)); - } catch (Exception e) { - log.info("Can't add GPU", e); - } - } + gpus.add(new GPUInfo(name, total, free, major, minor)); + } catch (Exception e) { + log.info("Can't add GPU", e); } - - return gpus; + } } + return gpus; + } + } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index e0102fcac..337cbc23e 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -83,7 +83,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; */ @Slf4j public class AtomicAllocator implements Allocator { - private static final AtomicAllocator INSTANCE = new AtomicAllocator(); + private static AtomicAllocator INSTANCE = new AtomicAllocator(); private Configuration configuration; @@ -122,6 +122,7 @@ public class AtomicAllocator implements Allocator { private final AtomicLong useTracker = new AtomicLong(System.currentTimeMillis()); public static AtomicAllocator getInstance() { + if(INSTANCE == null) INSTANCE = new AtomicAllocator(); if (INSTANCE == null) throw new RuntimeException("AtomicAllocator is NULL"); return INSTANCE; diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 6b4793704..59d775ec6 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -402,6 +402,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda val ctx = AtomicAllocator.getInstance().getDeviceContext(); val devicePtr = allocationPoint.getDevicePointer(); NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(devicePtr, 0, length * elementSize, 0, ctx.getSpecialStream()); + int ec = NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorCode(); + if(ec != 0) { + throw new RuntimeException(NativeOpsHolder.getInstance().getDeviceNativeOps().lastErrorMessage()); + } ctx.getSpecialStream().synchronize(); } diff --git a/cavis-native/cavis-native-jcublas/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java b/cavis-native/cavis-native-jcublas/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java new file mode 100644 index 000000000..39843d68f --- /dev/null +++ b/cavis-native/cavis-native-jcublas/src/test/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBufferTest.java @@ -0,0 +1,55 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package org.nd4j.linalg.jcublas.buffer; + +import static org.junit.jupiter.api.Assertions.*; + +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import org.junit.jupiter.api.Test; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.environment.Nd4jEnvironment; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.NativeOpsHolder; + +@Slf4j +class BaseCudaDataBufferTest { + + @Test + public void testMemoryAlloc() throws InterruptedException { + BaseCudaDataBuffer cuBuffer = new CudaLongDataBuffer(16l); + log.info( + "Allocation Status: " + cuBuffer.getAllocationPoint().getAllocationStatus().toString()); + Thread.sleep(3000); + cuBuffer.getAllocationPoint().tickDeviceWrite(); + DataBuffer buf = Nd4j.rand(8,1).shapeInfoDataBuffer(); + NativeOpsHolder.getInstance().getDeviceNativeOps().memcpySync(cuBuffer.pointer(), buf.pointer(), 8, 0, new Pointer() ); + + log.info( + "Allocation Status: " + cuBuffer.getAllocationPoint().getAllocationStatus().toString()); + + cuBuffer.release(); + } + + +} \ No newline at end of file From fec570ff98003d9d92b216321c088fc6146515db Mon Sep 17 00:00:00 2001 From: brian Date: Thu, 23 Mar 2023 17:39:00 +0100 Subject: [PATCH 119/126] Playing with some new code Signed-off-by: brian --- .../src/test/java/net/brutex/gan/App.java | 40 +- cavis-dnn/cavis-dnn-nn-api/build.gradle | 27 + .../java/net/brutex/ai/dnn/api/Layer.java | 40 + .../brutex/ai/dnn/api/LayerConfiguration.java | 42 + .../net/brutex/ai/dnn/api/NeuralNetwork.java | 69 + .../dnn/api/NeuralNetworkConfiguration.java | 43 + cavis-dnn/cavis-dnn-nn/build.gradle | 2 +- .../dnn/conf/NeuralNetworkConfiguration.java | 143 ++ .../layer/AbstractLayerConfiguration.java | 35 + .../net/brutex/ai/dnn/conf/layer/FFLayer.java | 52 + .../ai/dnn/conf/layer/LayerConfiguration.java | 28 + .../impl/network/AbstractNeuralNetwork.java | 72 + .../ai/dnn/impl/network/NeuralNetwork.java | 692 ++++++++ .../deeplearning4j/nn/api/NeuralNetwork.java | 121 +- .../nn/conf/MultiLayerConfiguration.java | 1388 +++++++++-------- .../layers/wrapper/BuildingBlockLayer.java | 97 ++ .../nn/multilayer/MultiLayerNetwork.java | 2 +- settings.gradle | 2 + 18 files changed, 2160 insertions(+), 735 deletions(-) create mode 100644 cavis-dnn/cavis-dnn-nn-api/build.gradle create mode 100644 cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/Layer.java create mode 100644 cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/LayerConfiguration.java create mode 100644 cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/NeuralNetwork.java create mode 100644 cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/NeuralNetworkConfiguration.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FFLayer.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/LayerConfiguration.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/AbstractNeuralNetwork.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/NeuralNetwork.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BuildingBlockLayer.java diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index bf4783145..f5b47031b 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -21,15 +21,10 @@ package net.brutex.gan; +import java.util.List; import java.util.Random; -import javax.ws.rs.client.ClientBuilder; import lombok.extern.slf4j.Slf4j; -import okhttp3.OkHttpClient; -import okhttp3.Request; -import okhttp3.Response; import org.apache.commons.lang3.ArrayUtils; -import org.datavec.api.Writable; -import org.datavec.api.records.reader.RecordReader; import org.datavec.api.split.FileSplit; import org.datavec.image.loader.NativeImageLoader; import org.datavec.image.recordreader.ImageRecordReader; @@ -37,34 +32,29 @@ import org.datavec.image.transform.ColorConversionTransform; import org.datavec.image.transform.ImageTransform; import org.datavec.image.transform.PipelineImageTransform; import org.datavec.image.transform.ResizeImageTransform; -import org.datavec.image.transform.ScaleImageTransform; import org.datavec.image.transform.ShowImageTransform; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; +import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; +import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.deeplearning4j.optimize.listeners.ScoreToChartListener; -import org.glassfish.jersey.client.JerseyClient; -import org.glassfish.jersey.client.JerseyClientBuilder; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.lossfunctions.LossFunctions; import javax.swing.*; @@ -106,6 +96,8 @@ public class App { new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) .build() }; + + } /** @@ -114,7 +106,7 @@ public class App { * @return config */ private static MultiLayerConfiguration generator() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder() .seed(42) .updater(UPDATER) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) @@ -123,9 +115,25 @@ public class App { .activation(Activation.IDENTITY) .list(genLayers()) .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) + // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS)) + .build(); + log.debug("Generator network: \n{}", confxx.toJson()); + + NeuralNetworkConfiguration conf2 = NeuralNetworkConfiguration.builder().build(); + + NeuralNetworkConfiguration confx = NeuralNetworkConfiguration.builder() + .cacheMode(CacheMode.HOST) + .layer( new DenseLayer.Builder().build()) + .layer( new DenseLayer.Builder().build()) + .layer( BuildingBlockLayer.builder().build()) + .layers( List.of(genLayers())) + .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .build(); - return conf; + + + + return confx; } private static Layer[] disLayers() { diff --git a/cavis-dnn/cavis-dnn-nn-api/build.gradle b/cavis-dnn/cavis-dnn-nn-api/build.gradle new file mode 100644 index 000000000..e41b96b8d --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn-api/build.gradle @@ -0,0 +1,27 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ +//apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" + +dependencies { + implementation platform(projects.cavisCommonPlatform) + implementation projects.cavisDnn.cavisDnnApi + implementation projects.cavisDnn.cavisDnnNn +} \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/Layer.java b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/Layer.java new file mode 100644 index 000000000..c4c81f8ad --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/Layer.java @@ -0,0 +1,40 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +/** + * This is an "executable" Layer, that is based on a {@link LayerConfiguration} + */ +public interface Layer { + + /** + * Get the underlying configuration for this Layer + * @return configuration + */ + LayerConfiguration getLayerConfiguration(); + + /** + * Set the underlying layer configuration + * @param conf The new configuration + */ + void setLayerConfiguration(LayerConfiguration conf); +} diff --git a/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/LayerConfiguration.java b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/LayerConfiguration.java new file mode 100644 index 000000000..0b274cb8c --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/LayerConfiguration.java @@ -0,0 +1,42 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +public interface LayerConfiguration { + + /** + * Create and return an instance of a LayerConfiguration. + * + * @param network the "holding" network for the instance + * @return the new layer instance + */ + Layer instantiate(NeuralNetwork network); + + + /** + * Defines the valid input type for this Layer + * + * @return InputType + */ + org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType(); + +} diff --git a/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/NeuralNetwork.java b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/NeuralNetwork.java new file mode 100644 index 000000000..93ef1263d --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/NeuralNetwork.java @@ -0,0 +1,69 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +/** + * A Neural Network is an instance of a {@link NeuralNetworkConfiguration}, that can be trained, + * evaluated, saved, exported, etc. Its configuration state is defined with the + * {@link #setConfiguration(NeuralNetworkConfiguration)} and {@link #getConfiguration()} methods. + * + */ +public interface NeuralNetwork { + + /** + * The configuration that defines this Neural Network + * + * @param conf the configuration to use for this network + */ + void setConfiguration(NeuralNetworkConfiguration conf); + NeuralNetworkConfiguration getConfiguration(); + + /** + * This method fits model with a given DataSet + * + * @param dataSet the dataset to use for training + */ + void fit(DataSet dataSet); + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet the multi dataset to use for training + */ + void fit(MultiDataSet dataSet); + + /** + * The name of the Neural Network + * @return the name + */ + String getName(); + + /** + * Set the name for this Neural Network + * @param name the name + */ + void setName(String name); + +} diff --git a/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/NeuralNetworkConfiguration.java b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/NeuralNetworkConfiguration.java new file mode 100644 index 000000000..f29dd4916 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/NeuralNetworkConfiguration.java @@ -0,0 +1,43 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +import java.util.List; + +public interface NeuralNetworkConfiguration { + + /** + * Provides a flat list of all embedded layer configurations, this + * can only be called after the layer is initialized or {@link #getLayerConfigurations()} is + * called. + * + * @return unstacked layer configurations + */ + List getLayerConfigurations(); + + + /** + * This uncollables any stacked layer configurations within building blocks like + * @link BuildingBlockLayer} + */ + void calculateInnerLayerConfigurations(); +} diff --git a/cavis-dnn/cavis-dnn-nn/build.gradle b/cavis-dnn/cavis-dnn-nn/build.gradle index e38b43f1d..e0f85570d 100644 --- a/cavis-dnn/cavis-dnn-nn/build.gradle +++ b/cavis-dnn/cavis-dnn-nn/build.gradle @@ -22,7 +22,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" dependencies { implementation platform(projects.cavisCommonPlatform) - + implementation projects.cavisDnn.cavisDnnNnApi implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators implementation 'org.lucee:oswego-concurrent:1.3.4' implementation projects.cavisDnn.cavisDnnCommon diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java new file mode 100644 index 000000000..e383ea9c7 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java @@ -0,0 +1,143 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.conf; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import lombok.Singular; +import lombok.extern.jackson.Jacksonized; +import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.LayerConfiguration; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.CacheMode; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer; + +/** + * The NeuralNetworkConfiguration is a sequential container for the different layers in your + * network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be + * stacked.

+ * It then “chains” outputs to inputs sequentially for each NeuralNetworkConfiguration, + * finally returning the output of the "top" configuration. Any settings made, are inherited and can + * be overridden on a "deeper" level. For this use case, you need to wrap the NeuralNetworkConfiguration + * into a BuildingBlockLayer + * + */ +@Jacksonized +@JsonIgnoreProperties(ignoreUnknown = true) +@lombok.Builder +@Slf4j +public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralNetworkConfiguration, Serializable, Cloneable { + + /** + * The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise. + * Valid values are
+ * CacheMode.NONE,
+ * CacheMode.HOST or
+ * CacheMode.DEVICE
+ */ + @NonNull + @lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE; + + @Getter @Setter @NonNull + protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; + + @Getter @Setter @NonNull + protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; + + @Getter @Setter @NonNull + protected BackpropType backpropType = BackpropType.Standard; + + @Getter + protected Map inputPreProcessors = new HashMap<>(); + + + @Getter @Setter protected int tbpttFwdLength = 20; + @Getter @Setter protected int tbpttBackLength = 20; + + + /** + * The list of layer configurations in this configuration. They will be indexed automatically + * as the layers get added starting with index 0. + */ + @Singular @Getter + private List layerConfigurations; + + /** + * The name for this configuration. Defaults to "Anonymous NeuralNetworkConfiguration" if + * it is not specified. + */ + @lombok.Builder.Default @Getter + private String name = "Anonymous NeuralNetworkConfiguration"; + + + /** + * The {@link InputType} of the data for this network configuration + */ + private InputType inputType; + + /** + * hidden list of layers, that "flattens" all the layers of this network and applies + * inheritance. + */ + @lombok.Builder.ObtainVia(method = "calculateInnerLayers") + private final List innerLayerConfigurations; + + @Override + public void calculateInnerLayerConfigurations() { + List list = new ArrayList<>(); + for( LayerConfiguration layer : this.layerConfigurations) { + if(layer instanceof BuildingBlockLayer) { + BuildingBlockLayer blayer = (BuildingBlockLayer) layer; + blayer.getConf().calculateInnerLayerConfigurations(); + list.addAll(blayer.getConf().getLayerConfigurations()); + } else { + list.add(layer); + } + } + this.layerConfigurations = list; + } + + /** + * Creates and returns a copy of this object. + * + * @return a clone of this instance. + * @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable} + * interface. Subclasses that override the {@code clone} method + * can also throw this exception to indicate that an instance + * cannot be cloned. + * @see Cloneable + */ + @Override + protected Object clone() throws CloneNotSupportedException { + return super.clone(); + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java new file mode 100644 index 000000000..951688e51 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java @@ -0,0 +1,35 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.conf.layer; + +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import net.brutex.ai.dnn.api.LayerConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; + +public abstract class AbstractLayerConfiguration implements LayerConfiguration { + + @Getter @Setter @NonNull + private InputType.Type inputType; + +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FFLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FFLayer.java new file mode 100644 index 000000000..d903e9002 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FFLayer.java @@ -0,0 +1,52 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.conf.layer; + +import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.Layer; +import net.brutex.ai.dnn.api.NeuralNetwork; +import net.brutex.ai.dnn.conf.layer.AbstractLayerConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InputType.Type; + +@Slf4j +public class FFLayer extends AbstractLayerConfiguration { + + + /** + * Create and return an instance of a LayerConfiguration. + * + * @param network the "holding" network for the instance + * @return the new layer instance + */ + @Override + public Layer instantiate(NeuralNetwork network) { + //Let's do some verifications first + if(getInputType() != Type.FF) { + log.error("The {} layer configuration must use an InputType of {}, but found {}", + this.getClass().getSimpleName(), + Type.FF.name(), + getInputType().name()); + } + return null; + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/LayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/LayerConfiguration.java new file mode 100644 index 000000000..16c67b491 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/LayerConfiguration.java @@ -0,0 +1,28 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.conf.layer; + +public abstract class LayerConfiguration { + + + +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/AbstractNeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/AbstractNeuralNetwork.java new file mode 100644 index 000000000..a1c36e988 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/AbstractNeuralNetwork.java @@ -0,0 +1,72 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.impl.network; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import net.brutex.ai.dnn.api.Layer; +import net.brutex.ai.dnn.api.NeuralNetwork; +import net.brutex.ai.dnn.api.LayerConfiguration; +import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +public abstract class AbstractNeuralNetwork implements NeuralNetwork { + + @Getter @Setter @NonNull + private String name; + + @Getter @NonNull + private NeuralNetworkConfiguration configuration; + + @Getter + private final Collection trainingListeners = new HashSet<>(); + + /** + * The neural network holds an instantiation of its configured + * layers. + * @return the actual runtime layers + */ + @Getter + private final List runtimeLayers = new ArrayList<>(); + + /** + * Sets the configuration to be used. Each time a configuration is set, the runtime layers + * of this NeuralNetwork are updated from the configuration. + * + * @param conf the configuration to use for this network + */ + public void setConfiguration(net.brutex.ai.dnn.api.NeuralNetworkConfiguration conf) { + List layers = conf.getLayerConfigurations(); + for(LayerConfiguration layer : layers) { + Layer initializedLayer = layer.instantiate(this); + this.getRuntimeLayers().add(initializedLayer); + } + this.configuration = configuration; + } + +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/NeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/NeuralNetwork.java new file mode 100644 index 000000000..198007baf --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/NeuralNetwork.java @@ -0,0 +1,692 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.impl.network; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import lombok.val; +import org.bytedeco.javacpp.Pointer; +import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; +import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.api.Classifier; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.api.Updater; +import org.deeplearning4j.nn.api.layers.IOutputLayer; +import org.deeplearning4j.nn.api.layers.RecurrentLayer; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; +import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop; +import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.updater.UpdaterCreator; +import org.deeplearning4j.nn.workspace.ArrayType; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.optimize.Solver; +import org.deeplearning4j.optimize.api.ConvexOptimizer; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.deeplearning4j.util.CrashReportingUtil; +import org.deeplearning4j.util.ModelSerializer; +import org.nd4j.common.base.Preconditions; +import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; +import org.nd4j.linalg.api.memory.enums.AllocationPolicy; +import org.nd4j.linalg.api.memory.enums.LearningPolicy; +import org.nd4j.linalg.api.memory.enums.ResetPolicy; +import org.nd4j.linalg.api.memory.enums.SpillPolicy; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.AsyncDataSetIterator; +import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.exception.ND4JArraySizeException; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.heartbeat.Heartbeat; +import org.nd4j.linalg.heartbeat.reports.Environment; +import org.nd4j.linalg.heartbeat.reports.Event; +import org.nd4j.linalg.heartbeat.reports.Task; +import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils; +import org.nd4j.linalg.heartbeat.utils.TaskUtils; +import org.nd4j.linalg.indexing.NDArrayIndex; + +@Slf4j +public class NeuralNetwork extends AbstractNeuralNetwork { + + + //the hidden neural network layers (including output layer) + protected Layer[] layers; + + protected transient ThreadLocal lastEtlTime = new ThreadLocal<>(); + + //Current training data: input features and labels + @Getter @Setter @NonNull + protected INDArray input; + @Getter @Setter + protected INDArray labels; + + //Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers + @Getter + protected transient Map helperWorkspaces = new HashMap<>(); + + /** + * Used to call optimizers during backprop + */ + @NonNull + protected transient Solver solver = new Solver.Builder().configure(getConfiguration()). + listeners(getTrainingListeners()).model(this).build(); + + + /** + * Create a new NeuralNetwork from the given configuration + * @param conf + */ + public NeuralNetwork(NeuralNetworkConfiguration conf) { + if(! validateConfiguration() ) { + log.error("Configuration '{}' has failed validation.", conf.getName()); + throw new RuntimeException(); + } + log.info("Configuration '{}' has been validated successfully.", conf.getName()); + this.conf = conf; + } + + private boolean validateConfiguration() { + + return true; + } + + private void logNotImplemented( ) { + // getStackTrace() method return + // current method name at 0th index + String method = new Throwable() + .getStackTrace()[1] + .getMethodName(); + log.trace("Method '{}}' is not implemented for {}", method, this.getClass().getSimpleName()); + } + + /** + * This method does initialization of model + *

+ * PLEASE NOTE: All implementations should track own state, to avoid double spending + */ + @Override + public void init() { + logNotImplemented(); + } + + /** + * This method returns model parameters as single INDArray + * + * @return + */ + @Override + public INDArray params() { + logNotImplemented(); + return null; + } + + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return getUpdater(true) != null ? getUpdater(true).getStateViewArray() : null; + } + + /** + * This method returns Optimizer used for training + * + * @return the optimizer + */ + @Override + public ConvexOptimizer getOptimizer() { + return solver.getOptimizer(); + } + + + + /** Get the updater for this NeuralNetwork from the Solver + * @return Updater for NeuralNetwork + */ + private Updater getUpdater(boolean initializeIfReq) { + if (solver == null && initializeIfReq) { + synchronized(this){ + if(solver == null) { //May have been created while waiting for lock + solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this).build(); + solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this)); + } + } + } + if(solver != null) { + return solver.getOptimizer().getUpdater(initializeIfReq); + } + return null; + } + + /** + * Set the updater for the NeuralNetwork in the Solver + * */ + public void setUpdater(@NonNull Updater updater) { + solver.getOptimizer().setUpdater(updater); + } + + + @Override + public void fit(MultiDataSet dataSet) { + if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) { + INDArray features = dataSet.getFeatures(0); + INDArray labels = dataSet.getLabels(0); + INDArray fMask = null; + INDArray lMask = null; + + if (dataSet.getFeaturesMaskArrays() != null) + fMask = dataSet.getFeaturesMaskArrays()[0]; + + if (dataSet.getFeaturesMaskArrays() != null) + lMask = dataSet.getLabelsMaskArrays()[0]; + + DataSet ds = new DataSet(features, labels, fMask, lMask); + fit(ds); + } else { + throw new DL4JInvalidInputException( + "MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." + + "Please consider use of ComputationGraph"); + } + } + + /** + * Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs. + * Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop + * + * @param iterator Training data (DataSetIterator). Iterator must support resetting + * @param numEpochs Number of training epochs, >= 1 + */ + public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){ + Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs); + Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" + + "iterator has does not support resetting (iterator.resetSupported() returned false)"); + + for(int i = 0; i < numEpochs; i++) { + fit(iterator); + } + } + + /** + * Perform minibatch training on all minibatches in the MultiDataSetIterator.
+ * Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as + * MultiLayerNetwork only supports 1 input and 1 output) + * + * @param iterator Training data (DataSetIterator). Iterator must support resetting + */ + @Override + public void fit(MultiDataSetIterator iterator) { + fit(new MultiDataSetWrapperIterator(iterator)); + } + + /** + * Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.
+ * Note that this method does not do layerwise pretraining.
+ * For pretraining use method pretrain.. #pretrain(DataSetIterator)
+ * @param iterator Training data (DataSetIterator) + */ + @Override + public void fit(DataSetIterator iterator) { + try{ + fitHelper(iterator); + } catch (OutOfMemoryError e){ + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } + + private synchronized void fitHelper(DataSetIterator iterator){ + // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate + DataSetIterator iter; + boolean destructable = false; + if (iterator.asyncSupported()) { + iter = new AsyncDataSetIterator(iterator, Math.min( + Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true); + destructable = true; + } else { + iter = iterator; + } + + for (TrainingListener tl : trainingListeners) { + tl.onEpochStart(this); + } + + LayerWorkspaceMgr workspaceMgr; + if(conf.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ + workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + workspaceMgr = LayerWorkspaceMgr.builder() + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) + //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM + // as these should be closed by the time updaters are executed + //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this + .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .build(); + } + workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); + + update(TaskUtils.buildTask(iter)); + if (!iter.hasNext() && iter.resetSupported()) { + iter.reset(); + } + long time1 = System.currentTimeMillis(); + while (iter.hasNext()) { + + DataSet next = iter.next(); + long time2 = System.currentTimeMillis(); + + lastEtlTime.set((time2 - time1)); + + if (next.getFeatures() == null || next.getLabels() == null) + break; + + // TODO: basically we want to wrap internals of this loop into workspace + + + boolean hasMaskArrays = next.hasMaskArrays(); + + if (conf.getBackpropType() == BackpropType.TruncatedBPTT) { + doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(), + next.getLabelsMaskArray(), workspaceMgr); + } else { + if (hasMaskArrays) + setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray()); + + setInput(next.getFeatures()); + setLabels(next.getLabels()); + + if (solver == null) { + try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this) + .build(); + } + } + + //TODO CACHE + solver.optimize(workspaceMgr); + } + + if (hasMaskArrays) + clearLayerMaskArrays(); + + time1 = System.currentTimeMillis(); + synchronizeIterEpochCounts(); + } + + if (!trainingListeners.isEmpty()) { + for (TrainingListener tl : trainingListeners) { + tl.onEpochEnd(this); + } + } + + clearLayersStates(); + + if (destructable) + ((AsyncDataSetIterator) iter).shutdown(); + + incrementEpochCount(); + } + + + /** + * Workspace for working memory for a single layer: forward pass and backward pass + * Note that this is opened/closed once per op (activate/backpropGradient call) + */ + protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM"; + /** + * Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop + * Not used for inference + */ + protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT"; + /** + * Next 2 workspaces: used for: + * (a) Inference: holds activations for one layer only + * (b) Backprop: holds activation gradients for one layer only + * In both cases, they are opened and closed on every second layer + */ + protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1"; + protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2"; + + /** + * Workspace for output methods that use OutputAdapter + */ + protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM"; + + /** + * Workspace for working memory in RNNs - opened and closed once per RNN time step + */ + protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM"; + + + protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG; + + protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder() + .initialSize(0) + .overallocationLimit(0.05) + .policyLearning(LearningPolicy.FIRST_LOOP) + .policyReset(ResetPolicy.BLOCK_LEFT) + .policySpill(SpillPolicy.REALLOCATE) + .policyAllocation(AllocationPolicy.OVERALLOCATE) + .build(); + + protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG; + + protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder() + .initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT) + .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) + .policyLearning(LearningPolicy.FIRST_LOOP).build(); + + + boolean initDone; + protected void update(Task task) { + if (!initDone) { + initDone = true; + Heartbeat heartbeat = Heartbeat.getInstance(); + task = ModelSerializer.taskByModel(this); + Environment env = EnvironmentUtils.buildEnvironment(); + heartbeat.reportEvent(Event.STANDALONE, env, task); + } + } + + protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, + INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) { + if (input.rank() != 3 || labels.rank() != 3) { + log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " + + Arrays.toString(input.shape()) + "\tand labels with shape " + + Arrays.toString(labels.shape())); + return; + } + if (input.size(2) != labels.size(2)) { + log.warn("Input and label time series have different lengths: {} input length, {} label length", + input.size(2), labels.size(2)); + return; + } + + int fwdLen = conf.getTbpttFwdLength(); + update(TaskUtils.buildTask(input, labels)); + val timeSeriesLength = input.size(2); + long nSubsets = timeSeriesLength / fwdLen; + if (timeSeriesLength % fwdLen != 0) + nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20) + + rnnClearPreviousState(); + + for (int i = 0; i < nSubsets; i++) { + long startTimeIdx = (long) i * fwdLen; + long endTimeIdx = startTimeIdx + fwdLen; + if (endTimeIdx > timeSeriesLength) + endTimeIdx = timeSeriesLength; + + if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); + INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels, + featuresMaskArray, labelsMaskArray); + + setInput(subsets[0]); + setLabels(subsets[1]); + setLayerMaskArrays(subsets[2], subsets[3]); + + if (solver == null) { + try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this) + .build(); + } + } + solver.optimize(workspaceMgr); + + //Finally, update the state of the RNN layers: + updateRnnStateWithTBPTTState(); + } + + rnnClearPreviousState(); + clearLayerMaskArrays(); + } + + private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels, + INDArray fMask, INDArray lMask ){ + INDArray[] out = new INDArray[4]; + out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + + if (fMask != null) { + out[2] = fMask.get(NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + } + if (lMask != null) { + out[3] = lMask.get(NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + } + + return out; + } + + /** + * Intended for internal/developer use + */ + public void updateRnnStateWithTBPTTState() { + Layer[] layers = conf.calculateInnerLayers().toArray(new Layer[]{}); + for (int i = 0; i < layers.length; i++) { + if (layers[i] instanceof RecurrentLayer) { + RecurrentLayer l = ((RecurrentLayer) layers[i]); + l.rnnSetPreviousState(l.rnnGetTBPTTState()); + } else if (layers[i] instanceof MultiLayerNetwork) { + ((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState(); + } + } + } + + /** Clear the previous state of the RNN layers (if any). + */ + public void rnnClearPreviousState() { + Layer[] layers = conf.getLayers().toArray(new Layer[]{}); + if (layers == null) + return; + for (int i = 0; i < layers.length; i++) { + if (layers[i] instanceof RecurrentLayer) + ((RecurrentLayer) layers[i]).rnnClearPreviousState(); + else if (layers[i] instanceof MultiLayerNetwork) { + ((MultiLayerNetwork) layers[i]).rnnClearPreviousState(); + } else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){ + ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()).rnnClearPreviousState(); + } + } + } + + + + /** Remove the mask arrays from all layers.
+ * See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays. + */ + public void clearLayerMaskArrays() { + Layer[] layers = conf.getLayers().toArray(new Layer[]{}); + for (Layer layer : layers) { + layer.setMaskArray(null); + } + } + + /** + * Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1). + * Note that this is done automatically when using iterator-based fitting methods, such as + * {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, INDArray/INDArray etc), + * the network has no way to know when one epoch ends and another starts. In such situations, this method + * can be used to increment the epoch counter.
+ * Note that the epoch counter is used for situations such as some learning rate schedules, and the like. + * + * The current epoch count can be obtained using {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()} + */ + public void incrementEpochCount(){ + conf.setEpochCount(conf.getEpochCount() + 1); + synchronizeIterEpochCounts(); + } + + protected void synchronizeIterEpochCounts() { + //TODO: this is necessary for some schedules - but the redundant values are a little ugly... + int currIter = conf.getIterationCount(); + int currEpoch = conf.getEpochCount(); + log.error("Something went wrong here. Code incomplete"); + /*for(Layer l : conf.getLayers()) { + l.setIterationCount(currIter); + l.setEpochCount(currEpoch); + } + */ + } + + /** + * This method just makes sure there's no state preserved within layers + */ + public void clearLayersStates() { + for (Layer layer : layers) { + layer.clear(); + layer.clearNoiseWeightParams(); + } + } + + + /**Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many + * and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths + * within the same minibatch.
+ * For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape + * [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength] + * and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding - + * at a given time step).
+ * NOTE: This method is not usually used directly. Instead, methods such as @link #feedForward(INDArray, INDArray, INDArray)} + * and @link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally. + * @param featuresMaskArray Mask array for features (input) + * @param labelsMaskArray Mask array for labels (output) + * @see #clearLayerMaskArrays() + */ + public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) { + if (featuresMaskArray != null) { + + if (featuresMaskArray.size(0) > Integer.MAX_VALUE) + throw new ND4JArraySizeException(); + //New approach: use feedForwardMaskArray method + feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0)); + + + /* + //feedforward layers below a RNN layer: need the input (features) mask array + //Reason: even if the time series input is zero padded, the output from the dense layers are + // non-zero (i.e., activationFunction(0*weights + bias) != 0 in general) + //This assumes that the time series input is masked - i.e., values are 0 at the padded time steps, + // so we don't need to do anything for the recurrent layer + + //Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order + // as is done for 3d -> 2d time series reshaping + INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray); + + for( int i=0; i feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, + int minibatchSize) { + if (maskArray == null) { + for (int i = 0; i < layers.length; i++) { + layers[i].feedForwardMaskArray(null, null, minibatchSize); + } + } else { + //Do a forward pass through each preprocessor and layer + for (int i = 0; i < layers.length; i++) { + InputPreProcessor preProcessor = conf.getInputPreProcessors().get(i); + + if (preProcessor != null) { + Pair p = + preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); + if (p != null) { + maskArray = p.getFirst(); + currentMaskState = p.getSecond(); + } else { + maskArray = null; + currentMaskState = null; + } + } + + Pair p = + layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); + if (p != null) { + maskArray = p.getFirst(); + currentMaskState = p.getSecond(); + } else { + maskArray = null; + currentMaskState = null; + } + } + } + + return new Pair<>(maskArray, currentMaskState); + } + + +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/NeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/NeuralNetwork.java index 30215e916..c9437b838 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/NeuralNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/NeuralNetwork.java @@ -33,72 +33,75 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; */ public interface NeuralNetwork { - /** - * This method does initialization of model - * - * PLEASE NOTE: All implementations should track own state, to avoid double spending - */ - void init(); + /** + * This method does initialization of model + *

+ * PLEASE NOTE: All implementations should track own state, to avoid double spending + */ + void init(); - /** - * This method returns model parameters as single INDArray - * - * @return - */ - INDArray params(); + /** + * This method returns model parameters as single INDArray + * + * @return + */ + INDArray params(); - /** - * This method returns updater state (if applicable), null otherwise - * @return - */ - INDArray updaterState(); + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + INDArray updaterState(); - /** - * This method returns Optimizer used for training - * - * @return - */ - ConvexOptimizer getOptimizer(); + /** + * This method returns Optimizer used for training + * + * @return + */ + ConvexOptimizer getOptimizer(); - /** - * This method fits model with a given DataSet - * - * @param dataSet - */ - void fit(DataSet dataSet); + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + void fit(DataSet dataSet); - /** - * This method fits model with a given MultiDataSet - * - * @param dataSet - */ - void fit(MultiDataSet dataSet); + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + void fit(MultiDataSet dataSet); - /** - * This method fits model with a given DataSetIterator - * - * @param iterator - */ - void fit(DataSetIterator iterator); + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + void fit(DataSetIterator iterator); - /** - * This method fits model with a given MultiDataSetIterator - * - * @param iterator - */ - void fit(MultiDataSetIterator iterator); + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + void fit(MultiDataSetIterator iterator); - /** - * This method executes evaluation of the model against given iterator and evaluation implementations - * - * @param iterator - */ - T[] doEvaluation(DataSetIterator iterator, T... evaluations); + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + */ + T[] doEvaluation(DataSetIterator iterator, T... evaluations); - /** - * This method executes evaluation of the model against given iterator and evaluation implementations - * - * @param iterator - */ - T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations); + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + */ + T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java index a48dc85ba..47baaebfd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java @@ -52,718 +52,790 @@ import java.io.IOException; import java.io.Serializable; import java.util.*; +/** + * Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of + * multiple layers. Everything starts with a MultiLayerConfiguration, which organizes those layers + * and their hyperparameters. Hyperparameters are variables that determine how a neural network + * learns. They include how many times to update the weights of the model, how to initialize those + * weights, which activation function to attach to the nodes, which optimization algorithm to use, + * and how fast the model should learn. This is what one configuration would look like: + *

+ * + * MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
+ * .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)
+ * .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ * .updater(new Sgd(0.05)) //... other hyperparameters
+ * .list() .backprop(true)
+ * .build();

+ * + * With Deeplearning4j, you add a layer + * by calling layer on the NeuralNetConfiguration.Builder(), specifying its place in the order of + * layers (the zero-indexed layer below is the input layer), the number of input and output nodes, + * nIn and nOut, as well as the type: DenseLayer.

+ * + * .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
+ * .build())

+ * + * Once you've configured your net, you train the + * model with model.fit. + */ @Data @AllArgsConstructor(access = AccessLevel.PRIVATE) @NoArgsConstructor @Slf4j public class MultiLayerConfiguration implements Serializable, Cloneable { - protected List confs; - protected Map inputPreProcessors = new HashMap<>(); - protected BackpropType backpropType = BackpropType.Standard; - protected int tbpttFwdLength = 20; - protected int tbpttBackLength = 20; - protected boolean validateOutputLayerConfig = true; //Default to legacy for pre 1.0.0-beta3 networks on deserialization + protected List confs; + protected Map inputPreProcessors = new HashMap<>(); + protected BackpropType backpropType = BackpropType.Standard; + protected int tbpttFwdLength = 20; + protected int tbpttBackLength = 20; + protected boolean validateOutputLayerConfig = true; //Default to legacy for pre 1.0.0-beta3 networks on deserialization - @Getter - @Setter - protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; + @Getter + @Setter + protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; - @Getter - @Setter - protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; + @Getter + @Setter + protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; - @Getter - @Setter - protected CacheMode cacheMode; + @Getter + @Setter + protected CacheMode cacheMode; - @Getter - @Setter - protected DataType dataType = DataType.FLOAT; //Default to float for deserialization of beta3 and earlier nets + @Getter + @Setter + protected DataType dataType = DataType.FLOAT; //Default to float for deserialization of beta3 and earlier nets - //Counter for the number of parameter updates so far - // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted - // for Spark and model serialization - protected int iterationCount = 0; + //Counter for the number of parameter updates so far + // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted + // for Spark and model serialization + protected int iterationCount = 0; - //Counter for the number of epochs completed so far. Used for per-epoch schedules - protected int epochCount = 0; + //Counter for the number of epochs completed so far. Used for per-epoch schedules + protected int epochCount = 0; - public int getEpochCount() { - return epochCount; + /** + * Create a neural net configuration from json + * + * @param json the neural net configuration from json + * @return {@link MultiLayerConfiguration} + */ + public static MultiLayerConfiguration fromYaml(String json) { + ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); + try { + return mapper.readValue(json, MultiLayerConfiguration.class); + } catch (IOException e) { + throw new RuntimeException(e); } + } - public void setEpochCount(int epochCount) { - this.epochCount = epochCount; - for (int i = 0; i < confs.size(); i++) { - getConf(i).setEpochCount(epochCount); - } - } - - /** - * @return JSON representation of NN configuration - */ - public String toYaml() { - ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); - synchronized (mapper) { - try { - return mapper.writeValueAsString(this); - } catch (com.fasterxml.jackson.core.JsonProcessingException e) { - throw new RuntimeException(e); - } - } - } - - /** - * Create a neural net configuration from json - * - * @param json the neural net configuration from json - * @return {@link MultiLayerConfiguration} - */ - public static MultiLayerConfiguration fromYaml(String json) { - ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); + /** + * Create a neural net configuration from json + * + * @param json the neural net configuration from json + * @return {@link MultiLayerConfiguration} + */ + public static MultiLayerConfiguration fromJson(String json) { + MultiLayerConfiguration conf; + ObjectMapper mapper = NeuralNetConfiguration.mapper(); + try { + conf = mapper.readValue(json, MultiLayerConfiguration.class); + } catch (InvalidTypeIdException e) { + if (e.getMessage().contains("@class")) { try { - return mapper.readValue(json, MultiLayerConfiguration.class); - } catch (IOException e) { - throw new RuntimeException(e); + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class); + } catch (InvalidTypeIdException e2) { + //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." + //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work + String msg = e2.getMessage(); + if (msg != null && msg.contains("Could not resolve type id")) { + throw new RuntimeException( + "Error deserializing MultiLayerConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" + + + " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", + e); + } + throw new RuntimeException(e2); + } catch (IOException e2) { + throw new RuntimeException(e2); } + } + throw new RuntimeException(e); + } catch (IOException e) { + //Check if this exception came from legacy deserializer... + String msg = e.getMessage(); + if (msg != null && msg.contains("legacy")) { + throw new RuntimeException( + "Error deserializing MultiLayerConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " + + + "deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", + e); + } + throw new RuntimeException(e); } + //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier) + // Previously: enumeration used for loss functions. Now: use classes + // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums + int layerCount = 0; + JsonNode confs = null; + for (NeuralNetConfiguration nnc : conf.getConfs()) { + Layer l = nnc.getLayer(); + if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) { + //lossFn field null -> may be an old config format, with lossFunction field being for the enum + //if so, try walking the JSON graph to extract out the appropriate enum value - /** - * @return JSON representation of NN configuration - */ - public String toJson() { - ObjectMapper mapper = NeuralNetConfiguration.mapper(); - synchronized (mapper) { - //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally - //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 - try { - return mapper.writeValueAsString(this); - } catch (com.fasterxml.jackson.core.JsonProcessingException e) { - throw new RuntimeException(e); - } - } - } - - /** - * Create a neural net configuration from json - * - * @param json the neural net configuration from json - * @return {@link MultiLayerConfiguration} - */ - public static MultiLayerConfiguration fromJson(String json) { - MultiLayerConfiguration conf; - ObjectMapper mapper = NeuralNetConfiguration.mapper(); + BaseOutputLayer ol = (BaseOutputLayer) l; try { - conf = mapper.readValue(json, MultiLayerConfiguration.class); - } catch (InvalidTypeIdException e){ - if(e.getMessage().contains("@class")){ - try { - //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format - return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class); - } catch (InvalidTypeIdException e2){ - //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." - //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work - String msg = e2.getMessage(); - if(msg != null && msg.contains("Could not resolve type id")){ - throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " + - "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" + - " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", e); - } - throw new RuntimeException(e2); - } catch (IOException e2){ - throw new RuntimeException(e2); - } + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + return conf; //Should never happen... } - throw new RuntimeException(e); - } catch (IOException e) { - //Check if this exception came from legacy deserializer... - String msg = e.getMessage(); - if (msg != null && msg.contains("legacy")) { - throw new RuntimeException("Error deserializing MultiLayerConfiguration - configuration may have a custom " + - "layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " + - "deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", e); + JsonNode outputLayerNode = outputLayerNNCNode.get("layer"); + + JsonNode lossFunctionNode = null; + if (outputLayerNode.has("output")) { + lossFunctionNode = outputLayerNode.get("output").get("lossFunction"); + } else if (outputLayerNode.has("rnnoutput")) { + lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction"); } - throw new RuntimeException(e); - } + if (lossFunctionNode != null) { + String lossFunctionEnumStr = lossFunctionNode.asText(); + LossFunctions.LossFunction lossFunction = null; + try { + lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr); + } catch (Exception e) { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", + e); + } - //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier) - // Previously: enumeration used for loss functions. Now: use classes - // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums - int layerCount = 0; - JsonNode confs = null; - for (NeuralNetConfiguration nnc : conf.getConfs()) { - Layer l = nnc.getLayer(); - if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) { - //lossFn field null -> may be an old config format, with lossFunction field being for the enum - //if so, try walking the JSON graph to extract out the appropriate enum value + if (lossFunction != null) { + switch (lossFunction) { + case MSE: + ol.setLossFn(new LossMSE()); + break; + case XENT: + ol.setLossFn(new LossBinaryXENT()); + break; + case NEGATIVELOGLIKELIHOOD: + ol.setLossFn(new LossNegativeLogLikelihood()); + break; + case MCXENT: + ol.setLossFn(new LossMCXENT()); + break; - BaseOutputLayer ol = (BaseOutputLayer) l; - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) - return conf; //Should never happen... - JsonNode outputLayerNode = outputLayerNNCNode.get("layer"); - - JsonNode lossFunctionNode = null; - if (outputLayerNode.has("output")) { - lossFunctionNode = outputLayerNode.get("output").get("lossFunction"); - } else if (outputLayerNode.has("rnnoutput")) { - lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction"); - } - - if (lossFunctionNode != null) { - String lossFunctionEnumStr = lossFunctionNode.asText(); - LossFunctions.LossFunction lossFunction = null; - try { - lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr); - } catch (Exception e) { - log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", - e); - } - - if (lossFunction != null) { - switch (lossFunction) { - case MSE: - ol.setLossFn(new LossMSE()); - break; - case XENT: - ol.setLossFn(new LossBinaryXENT()); - break; - case NEGATIVELOGLIKELIHOOD: - ol.setLossFn(new LossNegativeLogLikelihood()); - break; - case MCXENT: - ol.setLossFn(new LossMCXENT()); - break; - - //Remaining: TODO - case SQUARED_LOSS: - case RECONSTRUCTION_CROSSENTROPY: - default: - log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", - lossFunction); - break; - } - } - } - - } else { - log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", - (confs != null ? confs.getClass() : null)); - } - } catch (IOException e) { - log.warn("OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", - e); + //Remaining: TODO + case SQUARED_LOSS: + case RECONSTRUCTION_CROSSENTROPY: + default: + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", + lossFunction); break; } + } } - //Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn") - //Try to load the old format if necessary, and create the appropriate IActivation instance - if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) { - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) - return conf; //Should never happen... - JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); - - if (layerWrapperNode == null || layerWrapperNode.size() != 1) { - continue; - } - - JsonNode layerNode = layerWrapperNode.elements().next(); - JsonNode activationFunction = layerNode.get("activationFunction"); //Should only have 1 element: "dense", "output", etc - - if (activationFunction != null) { - IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction(); - ((BaseLayer) l).setActivationFn(ia); - } - } - - } catch (IOException e) { - log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", - e); - } - } - - if(!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) { - return conf; - } - - layerCount++; + } else { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", + (confs != null ? confs.getClass() : null)); + } + } catch (IOException e) { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", + e); + break; } - return conf; - } + } - /** - * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied from handling of {@link Activation} - * above. - * @return True if all is well and layer iteration shall continue. False else-wise. - */ - private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper, JsonNode confs, int layerCount) { - if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) - return false; //Should never happen... - JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); - - if (layerWrapperNode == null || layerWrapperNode.size() != 1) { - return true; - } - - JsonNode layerNode = layerWrapperNode.elements().next(); - JsonNode weightInit = layerNode.get("weightInit"); //Should only have 1 element: "dense", "output", etc - JsonNode distribution = layerNode.get("dist"); - - Distribution dist = null; - if(distribution != null) { - dist = mapper.treeToValue(distribution, Distribution.class); - } - - if (weightInit != null) { - final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist); - ((BaseLayer) l).setWeightInitFn(wi); - } - } - - } catch (IOException e) { - log.warn("Layer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON", - e); - } - } - return true; - - } - - @Override - public String toString() { - return toJson(); - } - - public NeuralNetConfiguration getConf(int i) { - return confs.get(i); - } - - @Override - public MultiLayerConfiguration clone() { + //Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn") + //Try to load the old format if necessary, and create the appropriate IActivation instance + if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) { try { - MultiLayerConfiguration clone = (MultiLayerConfiguration) super.clone(); + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + return conf; //Should never happen... + } + JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); - if (clone.confs != null) { - List list = new ArrayList<>(); - for (NeuralNetConfiguration conf : clone.confs) { - list.add(conf.clone()); - } - clone.confs = list; + if (layerWrapperNode == null || layerWrapperNode.size() != 1) { + continue; } - if (clone.inputPreProcessors != null) { - Map map = new HashMap<>(); - for (Map.Entry entry : clone.inputPreProcessors.entrySet()) { - map.put(entry.getKey(), entry.getValue().clone()); - } - clone.inputPreProcessors = map; + JsonNode layerNode = layerWrapperNode.elements().next(); + JsonNode activationFunction = layerNode.get( + "activationFunction"); //Should only have 1 element: "dense", "output", etc + + if (activationFunction != null) { + IActivation ia = Activation.fromString(activationFunction.asText()) + .getActivationFunction(); + ((BaseLayer) l).setActivationFn(ia); } + } - clone.inferenceWorkspaceMode = this.inferenceWorkspaceMode; - clone.trainingWorkspaceMode = this.trainingWorkspaceMode; - clone.cacheMode = this.cacheMode; - clone.validateOutputLayerConfig = this.validateOutputLayerConfig; - clone.dataType = this.dataType; - - return clone; - - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); + } catch (IOException e) { + log.warn( + "Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", + e); } + } + + if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) { + return conf; + } + + layerCount++; + } + return conf; + } + + /** + * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied + * from handling of {@link Activation} above. + * + * @return True if all is well and layer iteration shall continue. False else-wise. + */ + private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper, + JsonNode confs, int layerCount) { + if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { + try { + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + return false; //Should never happen... + } + JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); + + if (layerWrapperNode == null || layerWrapperNode.size() != 1) { + return true; + } + + JsonNode layerNode = layerWrapperNode.elements().next(); + JsonNode weightInit = layerNode.get( + "weightInit"); //Should only have 1 element: "dense", "output", etc + JsonNode distribution = layerNode.get("dist"); + + Distribution dist = null; + if (distribution != null) { + dist = mapper.treeToValue(distribution, Distribution.class); + } + + if (weightInit != null) { + final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) + .getWeightInitFunction(dist); + ((BaseLayer) l).setWeightInitFn(wi); + } + } + + } catch (IOException e) { + log.warn( + "Layer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON", + e); + } + } + return true; + + } + + public int getEpochCount() { + return epochCount; + } + + public void setEpochCount(int epochCount) { + this.epochCount = epochCount; + for (int i = 0; i < confs.size(); i++) { + getConf(i).setEpochCount(epochCount); + } + } + + /** + * @return JSON representation of NN configuration + */ + public String toYaml() { + ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); + synchronized (mapper) { + try { + return mapper.writeValueAsString(this); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + /** + * @return JSON representation of NN configuration + */ + public String toJson() { + ObjectMapper mapper = NeuralNetConfiguration.mapper(); + synchronized (mapper) { + //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally + //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 + try { + return mapper.writeValueAsString(this); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public String toString() { + return toJson(); + } + + public NeuralNetConfiguration getConf(int i) { + return confs.get(i); + } + + @Override + public MultiLayerConfiguration clone() { + try { + MultiLayerConfiguration clone = (MultiLayerConfiguration) super.clone(); + + if (clone.confs != null) { + List list = new ArrayList<>(); + for (NeuralNetConfiguration conf : clone.confs) { + list.add(conf.clone()); + } + clone.confs = list; + } + + if (clone.inputPreProcessors != null) { + Map map = new HashMap<>(); + for (Map.Entry entry : clone.inputPreProcessors.entrySet()) { + map.put(entry.getKey(), entry.getValue().clone()); + } + clone.inputPreProcessors = map; + } + + clone.inferenceWorkspaceMode = this.inferenceWorkspaceMode; + clone.trainingWorkspaceMode = this.trainingWorkspaceMode; + clone.cacheMode = this.cacheMode; + clone.validateOutputLayerConfig = this.validateOutputLayerConfig; + clone.dataType = this.dataType; + + return clone; + + } catch (CloneNotSupportedException e) { + throw new RuntimeException(e); + } + } + + public InputPreProcessor getInputPreProcess(int curr) { + return inputPreProcessors.get(curr); + } + + /** + * Get a {@link MemoryReport} for the given MultiLayerConfiguration. This is used to estimate the + * memory requirements for the given network configuration and input + * + * @param inputType Input types for the network + * @return Memory report for the network + */ + public NetworkMemoryReport getMemoryReport(InputType inputType) { + + Map memoryReportMap = new LinkedHashMap<>(); + int nLayers = confs.size(); + for (int i = 0; i < nLayers; i++) { + String layerName = confs.get(i).getLayer().getLayerName(); + if (layerName == null) { + layerName = String.valueOf(i); + } + + //Pass input type through preprocessor, if necessary + InputPreProcessor preproc = getInputPreProcess(i); + //TODO memory requirements for preprocessor + if (preproc != null) { + inputType = preproc.getOutputType(inputType); + } + + LayerMemoryReport report = confs.get(i).getLayer().getMemoryReport(inputType); + memoryReportMap.put(layerName, report); + + inputType = confs.get(i).getLayer().getOutputType(i, inputType); } - public InputPreProcessor getInputPreProcess(int curr) { - return inputPreProcessors.get(curr); + return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, + "MultiLayerNetwork", inputType); + } + + /** + * For the given input shape/type for the network, return a list of activation sizes for each + * layer in the network.
i.e., list.get(i) is the output activation sizes for layer i + * + * @param inputType Input type for the network + * @return A lits of activation types for the network, indexed by layer number + */ + public List getLayerActivationTypes(@NonNull InputType inputType) { + List out = new ArrayList<>(); + int nLayers = confs.size(); + for (int i = 0; i < nLayers; i++) { + InputPreProcessor preproc = getInputPreProcess(i); + if (preproc != null) { + inputType = preproc.getOutputType(inputType); + } + + inputType = confs.get(i).getLayer().getOutputType(i, inputType); + out.add(inputType); + } + return out; + } + + @Data + public static class Builder { + + private static final int DEFAULT_TBPTT_LENGTH = 20; + + protected List confs = new ArrayList<>(); + protected double dampingFactor = 100; + protected Map inputPreProcessors = new HashMap<>(); + protected BackpropType backpropType = BackpropType.Standard; + protected int tbpttFwdLength = DEFAULT_TBPTT_LENGTH; + protected int tbpttBackLength = DEFAULT_TBPTT_LENGTH; + protected InputType inputType; + + protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; + protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; + protected CacheMode cacheMode = CacheMode.NONE; + protected boolean validateOutputConfig = true; + protected boolean validateTbpttConfig = true; + protected DataType dataType; + protected boolean overrideNinUponBuild = true; + + + /** + * Whether to over ride the nIn configuration forcibly upon construction. Default value is true + * + * @param overrideNinUponBuild Whether to over ride the nIn configuration forcibly upon + * construction. + * @return builder pattern + */ + public Builder overrideNinUponBuild(boolean overrideNinUponBuild) { + this.overrideNinUponBuild = overrideNinUponBuild; + return this; } /** - * Get a {@link MemoryReport} for the given MultiLayerConfiguration. This is used to estimate the - * memory requirements for the given network configuration and input + * Specify the processors. These are used at each layer for doing things like normalization and + * shaping of input. * - * @param inputType Input types for the network - * @return Memory report for the network + * @param processor what to use to preProcess the data. + * @return builder pattern */ - public NetworkMemoryReport getMemoryReport(InputType inputType) { + public Builder inputPreProcessor(Integer layer, InputPreProcessor processor) { + inputPreProcessors.put(layer, processor); + return this; + } - Map memoryReportMap = new LinkedHashMap<>(); - int nLayers = confs.size(); - for (int i = 0; i < nLayers; i++) { - String layerName = confs.get(i).getLayer().getLayerName(); - if (layerName == null) { - layerName = String.valueOf(i); - } - - //Pass input type through preprocessor, if necessary - InputPreProcessor preproc = getInputPreProcess(i); - //TODO memory requirements for preprocessor - if (preproc != null) { - inputType = preproc.getOutputType(inputType); - } - - LayerMemoryReport report = confs.get(i).getLayer().getMemoryReport(inputType); - memoryReportMap.put(layerName, report); - - inputType = confs.get(i).getLayer().getOutputType(i, inputType); + public Builder inputPreProcessor(String layer, InputPreProcessor processor) { + int i = 0; + for (NeuralNetConfiguration conf : this.confs) { + if (conf.getLayer().getLayerName().equals(layer)) { + inputPreProcessors.put(i, processor); + log.trace("Assigned preProcessor to layer with name {} at index {}", layer, i); + break; } + i++; + } + if (i >= this.confs.size()) { + log.warn("Could not assign preprocessor to layer with name {} as layer was not found.", + layer); + } + return this; + } - return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, "MultiLayerNetwork", inputType); + public Builder inputPreProcessors(Map processors) { + this.inputPreProcessors = processors; + return this; } /** - * For the given input shape/type for the network, return a list of activation sizes for each layer in the network.
- * i.e., list.get(i) is the output activation sizes for layer i - * - * @param inputType Input type for the network - * @return A lits of activation types for the network, indexed by layer number + * @deprecated Use {@link NeuralNetConfiguration.Builder#trainingWorkspaceMode(WorkspaceMode)} */ - public List getLayerActivationTypes(@NonNull InputType inputType) { - List out = new ArrayList<>(); - int nLayers = confs.size(); - for (int i = 0; i < nLayers; i++) { - InputPreProcessor preproc = getInputPreProcess(i); - if (preproc != null) { - inputType = preproc.getOutputType(inputType); - } - - inputType = confs.get(i).getLayer().getOutputType(i, inputType); - out.add(inputType); - } - return out; + @Deprecated + public Builder trainingWorkspaceMode(@NonNull WorkspaceMode workspaceMode) { + this.trainingWorkspaceMode = workspaceMode; + return this; } - @Data - public static class Builder { - - private static final int DEFAULT_TBPTT_LENGTH = 20; - - protected List confs = new ArrayList<>(); - protected double dampingFactor = 100; - protected Map inputPreProcessors = new HashMap<>(); - protected BackpropType backpropType = BackpropType.Standard; - protected int tbpttFwdLength = DEFAULT_TBPTT_LENGTH; - protected int tbpttBackLength = DEFAULT_TBPTT_LENGTH; - protected InputType inputType; - - protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; - protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; - protected CacheMode cacheMode = CacheMode.NONE; - protected boolean validateOutputConfig = true; - protected boolean validateTbpttConfig = true; - protected DataType dataType; - protected boolean overrideNinUponBuild = true; - - - /** - * Whether to over ride the nIn - * configuration forcibly upon construction. - * Default value is true - * @param overrideNinUponBuild Whether to over ride the nIn - * configuration forcibly upon construction. - * @return builder pattern - */ - public Builder overrideNinUponBuild(boolean overrideNinUponBuild) { - this.overrideNinUponBuild = overrideNinUponBuild; - return this; - } - - /** - * Specify the processors. - * These are used at each layer for doing things like normalization and - * shaping of input. - * - * @param processor what to use to preProcess the data. - * @return builder pattern - */ - public Builder inputPreProcessor(Integer layer, InputPreProcessor processor) { - inputPreProcessors.put(layer, processor); - return this; - } - - public Builder inputPreProcessors(Map processors) { - this.inputPreProcessors = processors; - return this; - } - - /** - * @deprecated Use {@link NeuralNetConfiguration.Builder#trainingWorkspaceMode(WorkspaceMode)} - */ - @Deprecated - public Builder trainingWorkspaceMode(@NonNull WorkspaceMode workspaceMode) { - this.trainingWorkspaceMode = workspaceMode; - return this; - } - - /** - * @deprecated Use {@link NeuralNetConfiguration.Builder#inferenceWorkspaceMode(WorkspaceMode)} - */ - @Deprecated - public Builder inferenceWorkspaceMode(@NonNull WorkspaceMode workspaceMode) { - this.inferenceWorkspaceMode = workspaceMode; - return this; - } - - /** - * This method defines how/if preOutput cache is handled: - * NONE: cache disabled (default value) - * HOST: Host memory will be used - * DEVICE: GPU memory will be used (on CPU backends effect will be the same as for HOST) - * - * @param cacheMode - * @return - */ - public Builder cacheMode(@NonNull CacheMode cacheMode) { - this.cacheMode = cacheMode; - return this; - } - - /** - * The type of backprop. Default setting is used for most networks (MLP, CNN etc), - * but optionally truncated BPTT can be used for training recurrent neural networks. - * If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() - */ - public Builder backpropType(@NonNull BackpropType type) { - this.backpropType = type; - return this; - } - - /** - * When doing truncated BPTT: how many steps should we do?
- * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
- * See: http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param bpttLength length > 0 - */ - public Builder tBPTTLength(int bpttLength) { - tBPTTForwardLength(bpttLength); - return tBPTTBackwardLength(bpttLength); - } - - /** - * When doing truncated BPTT: how many steps of forward pass should we do - * before doing (truncated) backprop?
- * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
- * Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, - * but may be larger than it in some circumstances (but never smaller)
- * Ideally your training data time series length should be divisible by this - * This is the k1 parameter on pg23 of - * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param forwardLength Forward length > 0, >= backwardLength - */ - public Builder tBPTTForwardLength(int forwardLength) { - this.tbpttFwdLength = forwardLength; - return this; - } - - /** - * When doing truncated BPTT: how many steps of backward should we do?
- * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
- * This is the k2 parameter on pg23 of - * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param backwardLength <= forwardLength - */ - public Builder tBPTTBackwardLength(int backwardLength) { - this.tbpttBackLength = backwardLength; - return this; - } - - public Builder confs(List confs) { - this.confs = confs; - return this; - } - - public Builder setInputType(InputType inputType) { - this.inputType = inputType; - return this; - } - - /** - * Enabled by default. If enabled, the output layer configuration will be validated, to throw an exception on - * likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
- * If disabled (false) no output layer validation will be performed.
- * Disabling this validation is not recommended, as the configurations that fail validation usually will - * not be able to learn correctly. However, the option to disable this validation is provided for advanced users - * when creating non-standard architectures. - * - * @param validate If true: validate output layer configuration. False: don't validate - */ - public Builder validateOutputLayerConfig(boolean validate) { - this.validateOutputConfig = validate; - return this; - } - - /** - * Enabled by default. If enabled, an exception will be throw when using the (invalid) combination of truncated - * backpropagation through time (TBPTT) with either a GlobalPoolingLayer or LastTimeStepLayer.
- * It is possible to disable this validation to allow what is almost certainly an invalid configuration to be used, - * however this is not recommended. - * - * @param validate Whether TBPTT validation should be performed - */ - public Builder validateTbpttConfig(boolean validate){ - this.validateTbpttConfig = validate; - return this; - } - - /** - * Set the DataType for the network parameters and activations for all layers in the network. Default: Float - * @param dataType Datatype to use for parameters and activations - */ - public Builder dataType(@NonNull DataType dataType){ - this.dataType = dataType; - return this; - } - - - public MultiLayerConfiguration build() { - //Validate BackpropType setting - if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH) && backpropType != BackpropType.TruncatedBPTT) { - log.warn("Truncated backpropagation through time lengths have been configured with values " + tbpttFwdLength - + " and " + tbpttBackLength + " but backprop type is set to " + backpropType + ". TBPTT configuration" + - " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT"); - } - - if(backpropType == BackpropType.TruncatedBPTT && validateTbpttConfig) { - //Check for invalid combination - tbptt plus LastTimeStepLayer or - for( int i = 0; i < confs.size(); i++) { - Layer l = confs.get(i).getLayer(); - if(l instanceof LastTimeStep || l instanceof GlobalPoolingLayer){ - throw new IllegalStateException("Invalid network configuration detected: Truncated backpropagation through time (TBPTT)" + - " cannot be used with layer " + i + " of type " + l.getClass().getName() + ": TBPTT is incompatible with this layer type (which is designed " + - "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n" + - "This check can be disabled using validateTbpttConfig(false) but this is not recommended."); - } - } - } - - - if (inputType == null && inputPreProcessors.get(0) == null) { - //User hasn't set the InputType. Sometimes we can infer it... - // For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in - // standard feedforward or RNN data - //This isn't the most elegant implementation, but should avoid breaking backward compatibility here - //Can't infer InputType for CNN layers, however (don't know image dimensions/depth) - Layer firstLayer = confs.get(0).getLayer(); - if (firstLayer instanceof BaseRecurrentLayer) { - BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer; - val nIn = brl.getNIn(); - if (nIn > 0) { - inputType = InputType.recurrent(nIn, brl.getRnnDataFormat()); - } - } else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer - || firstLayer instanceof OutputLayer) { - //Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a FeedForwardLayer - FeedForwardLayer ffl = (FeedForwardLayer) firstLayer; - val nIn = ffl.getNIn(); - if (nIn > 0) { - inputType = InputType.feedForward(nIn); - } - } - } - - - //Add preprocessors and set nIns, if InputType has been set - // Builder.inputType field can be set in 1 of 4 ways: - // 1. User calls setInputType directly - // 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...)) - // 3. Via the above code: i.e., assume input is as expected by the RNN or dense layer -> sets the inputType field - if (inputType != null) { - InputType currentInputType = inputType; - for (int i = 0; i < confs.size(); i++) { - Layer l = confs.get(i).getLayer(); - if (inputPreProcessors.get(i) == null) { - //Don't override preprocessor setting, but set preprocessor if required... - InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType); - if (inputPreProcessor != null) { - inputPreProcessors.put(i, inputPreProcessor); - } - } - - InputPreProcessor inputPreProcessor = inputPreProcessors.get(i); - if (inputPreProcessor != null) { - currentInputType = inputPreProcessor.getOutputType(currentInputType); - } - if(i > 0) { - Layer layer = confs.get(i - 1).getLayer(); - //convolution 1d is an edge case where it has rnn input type but the filters - //should be the output - if(layer instanceof Convolution1DLayer) { - if(l instanceof DenseLayer && inputType instanceof InputType.InputTypeRecurrent) { - FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l; - if(inputType instanceof InputType.InputTypeRecurrent) { - InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType; - feedForwardLayer.setNIn(recurrent.getTimeSeriesLength()); - } - } - else - l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user - } - else - l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user - - } - else - l.setNIn(currentInputType, overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user - - - currentInputType = l.getOutputType(i, currentInputType); - } - - } - - MultiLayerConfiguration conf = new MultiLayerConfiguration(); - conf.confs = this.confs; - conf.inputPreProcessors = inputPreProcessors; - conf.backpropType = backpropType; - conf.tbpttFwdLength = tbpttFwdLength; - conf.tbpttBackLength = tbpttBackLength; - conf.trainingWorkspaceMode = trainingWorkspaceMode; - conf.inferenceWorkspaceMode = inferenceWorkspaceMode; - conf.cacheMode = cacheMode; - conf.dataType = dataType; - - Nd4j.getRandom().setSeed(conf.getConf(0).getSeed()); - - //Validate output layer configuration - if (validateOutputConfig) { - //Validate output layer configurations... - for (NeuralNetConfiguration n : conf.getConfs()) { - Layer l = n.getLayer(); - OutputLayerUtil.validateOutputLayer(l.getLayerName(), l); //No-op for non output/loss layers - } - } - - return conf; - - } + /** + * @deprecated Use {@link NeuralNetConfiguration.Builder#inferenceWorkspaceMode(WorkspaceMode)} + */ + @Deprecated + public Builder inferenceWorkspaceMode(@NonNull WorkspaceMode workspaceMode) { + this.inferenceWorkspaceMode = workspaceMode; + return this; } + + /** + * This method defines how/if preOutput cache is handled: NONE: cache disabled (default value) + * HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect will + * be the same as for HOST) + * + * @param cacheMode + * @return + */ + public Builder cacheMode(@NonNull CacheMode cacheMode) { + this.cacheMode = cacheMode; + return this; + } + + /** + * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but + * optionally truncated BPTT can be used for training recurrent neural networks. If using + * TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() + */ + public Builder backpropType(@NonNull BackpropType type) { + this.backpropType = type; + return this; + } + + /** + * When doing truncated BPTT: how many steps should we do?
Only applicable when doing + * backpropType(BackpropType.TruncatedBPTT)
See: http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param bpttLength length > 0 + */ + public Builder tBPTTLength(int bpttLength) { + tBPTTForwardLength(bpttLength); + return tBPTTBackwardLength(bpttLength); + } + + /** + * When doing truncated BPTT: how many steps of forward pass should we do before doing + * (truncated) backprop?
Only applicable when doing + * backpropType(BackpropType.TruncatedBPTT)
Typically tBPTTForwardLength parameter is same + * as the tBPTTBackwardLength parameter, but may be larger than it in some circumstances (but + * never smaller)
Ideally your training data time series length should be divisible by this + * This is the k1 parameter on pg23 of + * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param forwardLength Forward length > 0, >= backwardLength + */ + public Builder tBPTTForwardLength(int forwardLength) { + this.tbpttFwdLength = forwardLength; + return this; + } + + /** + * When doing truncated BPTT: how many steps of backward should we do?
Only applicable when + * doing backpropType(BackpropType.TruncatedBPTT)
This is the k2 parameter on pg23 of + * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param backwardLength <= forwardLength + */ + public Builder tBPTTBackwardLength(int backwardLength) { + this.tbpttBackLength = backwardLength; + return this; + } + + public Builder confs(List confs) { + this.confs = confs; + return this; + } + + public Builder setInputType(InputType inputType) { + this.inputType = inputType; + return this; + } + + /** + * Enabled by default. If enabled, the output layer configuration will be validated, to throw an + * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
If + * disabled (false) no output layer validation will be performed.
Disabling this validation + * is not recommended, as the configurations that fail validation usually will not be able to + * learn correctly. However, the option to disable this validation is provided for advanced + * users when creating non-standard architectures. + * + * @param validate If true: validate output layer configuration. False: don't validate + */ + public Builder validateOutputLayerConfig(boolean validate) { + this.validateOutputConfig = validate; + return this; + } + + /** + * Enabled by default. If enabled, an exception will be throw when using the (invalid) + * combination of truncated backpropagation through time (TBPTT) with either a + * GlobalPoolingLayer or LastTimeStepLayer.
It is possible to disable this validation to + * allow what is almost certainly an invalid configuration to be used, however this is not + * recommended. + * + * @param validate Whether TBPTT validation should be performed + */ + public Builder validateTbpttConfig(boolean validate) { + this.validateTbpttConfig = validate; + return this; + } + + /** + * Set the DataType for the network parameters and activations for all layers in the network. + * Default: Float + * + * @param dataType Datatype to use for parameters and activations + */ + public Builder dataType(@NonNull DataType dataType) { + this.dataType = dataType; + return this; + } + + + public MultiLayerConfiguration build() { + //Validate BackpropType setting + if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH) + && backpropType != BackpropType.TruncatedBPTT) { + log.warn("Truncated backpropagation through time lengths have been configured with values " + + tbpttFwdLength + + " and " + tbpttBackLength + " but backprop type is set to " + backpropType + + ". TBPTT configuration" + + " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT"); + } + + if (backpropType == BackpropType.TruncatedBPTT && validateTbpttConfig) { + //Check for invalid combination - tbptt plus LastTimeStepLayer or + for (int i = 0; i < confs.size(); i++) { + Layer l = confs.get(i).getLayer(); + if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) { + throw new IllegalStateException( + "Invalid network configuration detected: Truncated backpropagation through time (TBPTT)" + + + " cannot be used with layer " + i + " of type " + l.getClass().getName() + + ": TBPTT is incompatible with this layer type (which is designed " + + "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n" + + + "This check can be disabled using validateTbpttConfig(false) but this is not recommended."); + } + } + } + + if (inputType == null && inputPreProcessors.get(0) == null) { + //User hasn't set the InputType. Sometimes we can infer it... + // For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in + // standard feedforward or RNN data + //This isn't the most elegant implementation, but should avoid breaking backward compatibility here + //Can't infer InputType for CNN layers, however (don't know image dimensions/depth) + Layer firstLayer = confs.get(0).getLayer(); + if (firstLayer instanceof BaseRecurrentLayer) { + BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer; + val nIn = brl.getNIn(); + if (nIn > 0) { + inputType = InputType.recurrent(nIn, brl.getRnnDataFormat()); + } + } else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer + || firstLayer instanceof OutputLayer) { + //Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a FeedForwardLayer + FeedForwardLayer ffl = (FeedForwardLayer) firstLayer; + val nIn = ffl.getNIn(); + if (nIn > 0) { + inputType = InputType.feedForward(nIn); + } + } + } + + //Add preprocessors and set nIns, if InputType has been set + // Builder.inputType field can be set in 1 of 4 ways: + // 1. User calls setInputType directly + // 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...)) + // 3. Via the above code: i.e., assume input is as expected by the RNN or dense layer -> sets the inputType field + if (inputType != null) { + InputType currentInputType = inputType; + for (int i = 0; i < confs.size(); i++) { + Layer l = confs.get(i).getLayer(); + if (inputPreProcessors.get(i) == null) { + //Don't override preprocessor setting, but set preprocessor if required... + InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType); + if (inputPreProcessor != null) { + inputPreProcessors.put(i, inputPreProcessor); + } + } + + InputPreProcessor inputPreProcessor = inputPreProcessors.get(i); + if (inputPreProcessor != null) { + currentInputType = inputPreProcessor.getOutputType(currentInputType); + } + if (i > 0) { + Layer layer = confs.get(i - 1).getLayer(); + //convolution 1d is an edge case where it has rnn input type but the filters + //should be the output + if (layer instanceof Convolution1DLayer) { + if (l instanceof DenseLayer && inputType instanceof InputType.InputTypeRecurrent) { + FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l; + if (inputType instanceof InputType.InputTypeRecurrent) { + InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType; + feedForwardLayer.setNIn(recurrent.getTimeSeriesLength()); + } + } else { + l.setNIn(currentInputType, + overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user + } + } else { + l.setNIn(currentInputType, + overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user + } + + } else { + l.setNIn(currentInputType, + overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user + } + + currentInputType = l.getOutputType(i, currentInputType); + } + + } + + MultiLayerConfiguration conf = new MultiLayerConfiguration(); + conf.confs = this.confs; + conf.inputPreProcessors = inputPreProcessors; + conf.backpropType = backpropType; + conf.tbpttFwdLength = tbpttFwdLength; + conf.tbpttBackLength = tbpttBackLength; + conf.trainingWorkspaceMode = trainingWorkspaceMode; + conf.inferenceWorkspaceMode = inferenceWorkspaceMode; + conf.cacheMode = cacheMode; + conf.dataType = dataType; + + Nd4j.getRandom().setSeed(conf.getConf(0).getSeed()); + + //Validate output layer configuration + if (validateOutputConfig) { + //Validate output layer configurations... + for (NeuralNetConfiguration n : conf.getConfs()) { + Layer l = n.getLayer(); + OutputLayerUtil.validateOutputLayer(l.getLayerName(), + l); //No-op for non output/loss layers + } + } + + return conf; + + } + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BuildingBlockLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BuildingBlockLayer.java new file mode 100644 index 000000000..e150b850f --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BuildingBlockLayer.java @@ -0,0 +1,97 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package org.deeplearning4j.nn.conf.layers.wrapper; + +import java.util.Collection; +import lombok.AccessLevel; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import net.brutex.ai.dnn.api.LayerConfiguration; +import net.brutex.ai.dnn.api.NeuralNetwork; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; + +@Builder(builderClassName = "Builder", access = AccessLevel.PUBLIC) +public class BuildingBlockLayer extends BaseLayer implements LayerConfiguration { + + @NonNull + @Getter + private NeuralNetworkConfiguration conf; + + @Override + public Layer instantiate(NeuralNetConfiguration conf, + Collection trainingListeners, int layerIndex, INDArray layerParamsView, + boolean initializeParams, DataType networkDataType) { + return null; + } + + @Override + public ParamInitializer initializer() { + return null; + } + + @Override + public InputType getOutputType(int layerIndex, InputType inputType) { + return null; + } + + @Override + public void setNIn(InputType inputType, boolean override) { + + } + + @Override + public InputPreProcessor getPreProcessorForInputType(InputType inputType) { + return null; + } + + @Override + public boolean isPretrainParam(String paramName) { + return false; + } + + @Override + public LayerMemoryReport getMemoryReport(InputType inputType) { + return null; + } + + /** + * Create and return an instance of a LayerConfiguration. + * + * @param network the "holding" network for the instance + * @return the new layer instance + */ + @Override + public net.brutex.ai.dnn.api.Layer instantiate(NeuralNetwork network) { + return null; + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index f590a1caa..18397bd4d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -101,7 +101,7 @@ import java.util.*; @Slf4j -public class MultiLayerNetwork implements Serializable, Classifier, Layer, NeuralNetwork { +public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.deeplearning4j.nn.api.NeuralNetwork { //the hidden neural network layers (including output layer) protected Layer[] layers; diff --git a/settings.gradle b/settings.gradle index 0002d667d..80b29bef8 100644 --- a/settings.gradle +++ b/settings.gradle @@ -100,6 +100,7 @@ include ':cavis-dnn:cavis-dnn-data:cavis-dnn-data-utility-iterators' include ':cavis-dnn:cavis-dnn-modelimport' include ':cavis-dnn:cavis-dnn-nlp' include ':cavis-dnn:cavis-dnn-nn' +include ':cavis-dnn:cavis-dnn-nn-api' include ':cavis-dnn:cavis-dnn-nn-parent' include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-server' include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-client' @@ -154,3 +155,4 @@ include ':brutex-extended-tests' include ':cavis-full' + From 9af4f9f23a6577b73dd31018a06ad80d237df607 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 24 Mar 2023 15:04:06 +0100 Subject: [PATCH 120/126] Playing with some new code Signed-off-by: brian --- .../src/test/java/net/brutex/gan/App.java | 65 +- .../test/java/net/brutex/spark/BrianTest.java | 2 +- .../java/net/brutex/spark/BrianTest2.java | 2 +- .../java/net/brutex/spark/TestServer.java | 2 +- .../java/net/brutex/spark/TestServer2.java | 2 +- .../IntegrationTestBaselineGenerator.java | 4 +- .../integration/IntegrationTestRunner.java | 28 +- .../deeplearning4j/integration/TestUtils.java | 4 +- .../java/org/deeplearning4j/TestUtils.java | 4 +- .../org/deeplearning4j/eval/EvalTest.java | 2 +- .../gradientcheck/BNGradientCheckTest.java | 14 +- .../gradientcheck/CNN1DGradientCheckTest.java | 8 +- .../gradientcheck/CNN3DGradientCheckTest.java | 8 +- .../gradientcheck/CNNGradientCheckTest.java | 14 +- .../GlobalPoolingGradientCheckTests.java | 8 +- .../gradientcheck/GradientCheckTests.java | 16 +- .../GradientCheckTestsComputationGraph.java | 50 +- .../gradientcheck/LRNGradientCheckTests.java | 2 +- .../gradientcheck/LSTMGradientCheckTests.java | 6 +- .../NoBiasGradientCheckTests.java | 2 +- .../OutputLayerGradientChecks.java | 6 +- .../gradientcheck/VaeGradientCheckTests.java | 8 +- .../nn/conf/layers/LayerConfigTest.java | 36 +- .../deeplearning4j/nn/dtypes/DTypeTests.java | 4 +- .../nn/graph/ComputationGraphTestRNN.java | 10 +- .../nn/graph/TestCompGraphUnsupervised.java | 5 +- .../nn/graph/TestComputationGraphNetwork.java | 28 +- .../nn/layers/FrozenLayerTest.java | 2 +- .../deeplearning4j/nn/layers/TestDropout.java | 2 +- .../embedding/EmbeddingLayerTest.java | 4 +- .../nn/layers/ocnn/OCNNOutputLayerTest.java | 2 +- .../samediff/testlayers/SameDiffDense.java | 2 +- .../testlayers/SameDiffDenseVertex.java | 4 +- .../nn/misc/WorkspaceTests.java | 8 +- .../nn/multilayer/MultiLayerTest.java | 2 +- .../nn/multilayer/MultiLayerTestRNN.java | 2 +- .../rl/TestMultiModelGradientApplication.java | 4 +- .../TestTransferLearningModelSerializer.java | 2 +- .../TransferLearningCompGraphTest.java | 6 +- .../TransferLearningHelperTest.java | 2 +- .../optimize/solver/TestOptimizers.java | 4 +- .../regressiontest/RegressionTest060.java | 2 +- .../regressiontest/RegressionTest071.java | 2 +- .../regressiontest/RegressionTest080.java | 2 +- .../regressiontest/RegressionTest100a.java | 2 +- .../regressiontest/RegressionTest100b3.java | 2 +- .../regressiontest/RegressionTest100b4.java | 2 +- .../regressiontest/RegressionTest100b6.java | 2 +- .../customlayer100a/CustomLayer.java | 6 +- .../util/CrashReportingUtilTest.java | 12 +- .../util/ModelSerializerTest.java | 4 +- .../cuda/recurrent/CudnnLSTMHelper.java | 2 +- .../nn/modelimport/keras/KerasLayer.java | 8 +- .../nn/modelimport/keras/KerasModel.java | 4 +- .../keras/config/KerasLayerConfiguration.java | 2 +- .../keras/layers/core/KerasDense.java | 4 +- .../keras/layers/recurrent/KerasLSTM.java | 4 +- .../layers/recurrent/KerasSimpleRnn.java | 2 +- .../layers/wrappers/KerasBidirectional.java | 4 +- .../configurations/FullModelComparisons.java | 4 +- .../brutex/ai/dnn/api/LayerConfiguration.java | 9 + cavis-dnn/cavis-dnn-nn/build.gradle | 6 +- .../ILayer.java} | 22 +- .../ILayerConfiguration.java} | 57 +- .../java/net/brutex/ai/dnn/api/IModel.java | 86 + .../brutex/ai/dnn/api/INeuralNetwork.java} | 53 +- .../dnn/api/INeuralNetworkConfiguration.java | 52 + .../dnn/conf/NeuralNetworkConfiguration.java | 708 +- .../layer/AbstractLayerConfiguration.java | 10 +- .../conf/layer/DenseLayerConfiguration.java | 62 + .../layer/FeedForwardLayerConfiguration.java | 99 + .../impl/network/AbstractNeuralNetwork.java | 72 - .../ai/dnn/impl/network/NeuralNetwork.java | 692 -- .../dnn/networks/ArtificialNeuralNetwork.java | 53 + .../trainer/BaseEarlyStoppingTrainer.java | 2 +- .../gradientcheck/GradientCheckUtil.java | 6 +- .../java/org/deeplearning4j/nn/api/Layer.java | 377 +- .../deeplearning4j/nn/api/ModelAdapter.java | 2 +- .../nn/api/ParamInitializer.java | 10 +- .../deeplearning4j/nn/api/TrainingConfig.java | 2 +- .../org/deeplearning4j/nn/api/Updater.java | 2 +- .../nn/api/layers/LayerConstraint.java | 2 +- .../nn/api/layers/RecurrentLayer.java | 6 +- .../nn/conf/NeuralNetConfiguration.java | 5 +- .../nn/conf/constraint/MaxNormConstraint.java | 4 +- .../conf/constraint/MinMaxNormConstraint.java | 6 +- .../conf/constraint/UnitNormConstraint.java | 4 +- .../nn/conf/graph/LayerVertex.java | 7 +- .../nn/conf/layers/ActivationLayer.java | 2 +- .../nn/conf/layers/BaseLayer.java | 4 +- .../nn/conf/layers/CapsuleLayer.java | 4 +- .../nn/conf/layers/DenseLayer.java | 4 +- .../deeplearning4j/nn/conf/layers/Layer.java | 4 +- .../nn/conf/layers/LayerValidation.java | 4 +- .../layers/LocalResponseNormalization.java | 2 +- .../nn/conf/layers/PrimaryCapsules.java | 2 +- .../misc/ElementWiseMultiplicationLayer.java | 2 +- .../layers/recurrent/TimeDistributed.java | 2 +- .../layers/samediff/SameDiffLambdaLayer.java | 2 +- .../layers/samediff/SameDiffLambdaVertex.java | 2 +- .../layers/wrapper/BuildingBlockLayer.java | 97 - .../nn/conf/memory/NetworkMemoryReport.java | 2 +- .../nn/conf/weightnoise/IWeightNoise.java | 2 +- .../nn/graph/ComputationGraph.java | 230 +- .../nn/graph/vertex/BaseGraphVertex.java | 4 +- .../nn/graph/vertex/GraphVertex.java | 4 +- .../nn/graph/vertex/impl/LayerVertex.java | 6 +- .../impl/rnn/DuplicateToTimeSeriesVertex.java | 4 +- .../vertex/impl/rnn/LastTimeStepVertex.java | 4 +- .../impl/rnn/ReverseTimeSeriesVertex.java | 4 +- .../nn/layers/recurrent/LSTMHelpers.java | 2 +- .../nn/multilayer/MultiLayerNetwork.java | 8061 +++++++++-------- .../nn/transferlearning/TransferLearning.java | 2 +- .../TransferLearningHelper.java | 4 +- .../nn/updater/BaseMultiLayerUpdater.java | 4 +- .../optimize/api/TrainingListener.java | 4 +- .../listeners/CheckpointListener.java | 4 +- .../optimize/solvers/BaseOptimizer.java | 6 +- .../util/Convolution1DUtils.java | 2 +- .../util/CrashReportingUtil.java | 30 +- .../deeplearning4j/util/ModelSerializer.java | 2 +- .../org/deeplearning4j/util/NetworkUtils.java | 8 +- .../deeplearning4j/util/OutputLayerUtil.java | 2 +- .../deeplearning4j/util/TimeSeriesUtils.java | 2 +- .../java/net/brutex/ai/dnn/api/dnnTest.java | 127 + .../brutex/ai/dnn/conf/layer/FFLayerTest.java | 47 + .../nn/layers/HelperUtilsTest.java | 2 +- .../parallelism/InplaceParallelInference.java | 3 +- .../parallelism/ParallelInference.java | 2 +- .../parallelism/trainer/DefaultTrainer.java | 4 +- .../impl/graph/SparkComputationGraph.java | 2 +- ...VaeReconstructionErrorWithKeyFunction.java | 2 +- ...GVaeReconstructionProbWithKeyFunction.java | 2 +- ...VaeReconstructionErrorWithKeyFunction.java | 2 +- .../VaeReconstructionProbWithKeyFunction.java | 2 +- .../ParameterAveragingTrainingMaster.java | 4 +- .../spark/impl/misc/TestFrozenLayers.java | 4 +- ...TestSparkMultiLayerParameterAveraging.java | 10 +- .../pw/SharedTrainingWrapper.java | 6 +- .../training/SharedTrainingMaster.java | 2 +- .../ui/model/stats/BaseStatsListener.java | 5 +- .../ui/model/stats/impl/SbeStatsReport.java | 4 +- .../ui/module/train/TrainModuleUtils.java | 8 +- .../templates/TrainingModel.html.ftl | 6 +- .../org/deeplearning4j/zoo/TestUtils.java | 2 +- settings.gradle | 2 +- 146 files changed, 6151 insertions(+), 5493 deletions(-) rename cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/{conf/layer/LayerConfiguration.java => api/ILayer.java} (60%) rename cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/{conf/layer/FFLayer.java => api/ILayerConfiguration.java} (56%) create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java rename cavis-dnn/cavis-dnn-nn/src/main/java/{org/deeplearning4j/nn/api/NeuralNetwork.java => net/brutex/ai/dnn/api/INeuralNetwork.java} (58%) create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/DenseLayerConfiguration.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FeedForwardLayerConfiguration.java delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/AbstractNeuralNetwork.java delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/NeuralNetwork.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BuildingBlockLayer.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/conf/layer/FFLayerTest.java diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index f5b47031b..fca68610a 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -21,8 +21,19 @@ package net.brutex.gan; -import java.util.List; +import java.awt.BorderLayout; +import java.awt.Dimension; +import java.awt.GridLayout; +import java.awt.Image; +import java.awt.image.BufferedImage; +import java.io.File; +import java.util.Arrays; import java.util.Random; +import javax.swing.ImageIcon; +import javax.swing.JFrame; +import javax.swing.JLabel; +import javax.swing.JPanel; +import javax.swing.WindowConstants; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ArrayUtils; import org.datavec.api.split.FileSplit; @@ -34,20 +45,23 @@ import org.datavec.image.transform.PipelineImageTransform; import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ShowImageTransform; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; -import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.DropoutLayer; +import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; -import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -55,13 +69,6 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.IUpdater; - - -import javax.swing.*; -import java.awt.*; -import java.awt.image.BufferedImage; -import java.io.File; -import java.util.Arrays; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; @Slf4j @@ -106,7 +113,7 @@ public class App { * @return config */ private static MultiLayerConfiguration generator() { - MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder() + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(42) .updater(UPDATER) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) @@ -117,23 +124,8 @@ public class App { .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS)) .build(); - log.debug("Generator network: \n{}", confxx.toJson()); - NeuralNetworkConfiguration conf2 = NeuralNetworkConfiguration.builder().build(); - - NeuralNetworkConfiguration confx = NeuralNetworkConfiguration.builder() - .cacheMode(CacheMode.HOST) - .layer( new DenseLayer.Builder().build()) - .layer( new DenseLayer.Builder().build()) - .layer( BuildingBlockLayer.builder().build()) - .layers( List.of(genLayers())) - .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) - .build(); - - - - - return confx; + return conf; } private static Layer[] disLayers() { @@ -155,6 +147,7 @@ public class App { } private static MultiLayerConfiguration discriminator() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(42) .updater(UPDATER) @@ -183,13 +176,13 @@ public class App { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .seed(42) - .updater(UPDATER) - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .gradientNormalizationThreshold(GRADIENT_THRESHOLD) - .weightInit(WeightInit.XAVIER) - .activation(Activation.IDENTITY) - .list(layers) - .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) + .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) + .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold( 100 ) + .weightInit( new WeightInitXavier() ) + .activation( new ActivationIdentity()) + .list( layers ) + .setInputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .build(); return conf; diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java index efb54aa29..bc0aafa13 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java @@ -295,7 +295,7 @@ public class BrianTest extends BaseSparkSessionTest { .activation(Activation.RELU).l2(0.001).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER) .activation(Activation.RELU).build()) - //.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) + //.layer(2, new DenseLayerConfiguration.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4) .weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build()) .build(); diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java index 4e340c69a..f32c3c4de 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java @@ -301,7 +301,7 @@ public class BrianTest2 /*extends BaseDL4JTest*/ { .list() .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) - //.layer(2, new DenseLayer.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) + //.layer(2, new DenseLayerConfiguration.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(20).nOut(4).weightInit(WeightInit.XAVIER).activation(Activation.SIGMOID).build()) .build(); diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer.java b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer.java index 353195da4..b81f70fc8 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer.java @@ -95,7 +95,7 @@ public class TestServer { .list() //.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build()) //.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build()) - // .layer(1, new DenseLayer.Builder().nIn(10).nOut(64).activation(Activation.RELU).build()) + // .layer(1, new DenseLayerConfiguration.Builder().nIn(10).nOut(64).activation(Activation.RELU).build()) .layer(0, new DenseLayer.Builder().nIn(10).nOut(100).activation(Activation.RELU).l2(0.003).build()) .layer(1, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build()) .layer(2, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build()) diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java index d6ac22e11..ac625f2b6 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java @@ -131,7 +131,7 @@ public class TestServer2 { .list() //.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build()) //.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build()) - // .layer(1, new DenseLayer.Builder().nIn(10).nOut(64).activation(Activation.RELU).build()) + // .layer(1, new DenseLayerConfiguration.Builder().nIn(10).nOut(64).activation(Activation.RELU).build()) .layer(0, new DenseLayer.Builder().nIn(10).nOut(100).activation(Activation.RELU).l2(0.003).build()) .layer(1, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build()) .layer(2, new LSTM.Builder().nIn(100).nOut(100).activation(Activation.TANH).build()) diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java index 7c4bcc9ac..8111d2b7d 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java @@ -284,7 +284,7 @@ public class IntegrationTestBaselineGenerator { INDArray paramsPostTraining; if (modelType == ModelType.MLN) { int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN(); - Preconditions.checkState(layersToTrain != null, "Layer indices must not be null"); + Preconditions.checkState(layersToTrain != null, "ILayer indices must not be null"); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); for (int i : layersToTrain) { @@ -293,7 +293,7 @@ public class IntegrationTestBaselineGenerator { paramsPostTraining = mln.params(); } else if (modelType == ModelType.CG) { String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); - Preconditions.checkState(layersToTrain != null, "Layer names must not be null"); + Preconditions.checkState(layersToTrain != null, "ILayer names must not be null"); for (String i : layersToTrain) { cg.pretrainLayer(i, iter); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index fbc0d60a3..489c8021d 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -200,7 +200,7 @@ public class IntegrationTestRunner { m = cg; ComputationGraph loaded = ComputationGraph.load(savedModel, true); - assertEquals(loaded.getConfiguration(), cg.getConfiguration(), "Configs not equal" ); + assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" ); assertEquals( loaded.params(), cg.params(), "Params not equal"); assertEquals(loaded.paramTable(), cg.paramTable(), "Param table not equal"); } else if(config instanceof SameDiff){ @@ -383,7 +383,7 @@ public class IntegrationTestRunner { org.deeplearning4j.nn.api.Layer[] layers; if(modelType == ModelType.MLN){ int[] layersToTrain = tc.getUnsupervisedTrainLayersMLN(); - Preconditions.checkState(layersToTrain != null, "Layer indices must not be null"); + Preconditions.checkState(layersToTrain != null, "ILayer indices must not be null"); DataSetIterator dsi = new MultiDataSetWrapperIterator(iter); for( int i : layersToTrain){ @@ -393,7 +393,7 @@ public class IntegrationTestRunner { layers = mln.getLayers(); } else if(modelType == ModelType.CG) { String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); - Preconditions.checkState(layersToTrain != null, "Layer names must not be null"); + Preconditions.checkState(layersToTrain != null, "ILayer names must not be null"); for( String i : layersToTrain){ cg.pretrainLayer(i, iter); @@ -429,8 +429,8 @@ public class IntegrationTestRunner { isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT; tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength(); } else if(modelType == ModelType.CG) { - isTbptt = cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; - tbpttLength = cg.getConfiguration().getTbpttFwdLength(); + isTbptt = cg.getComputationGraphConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; + tbpttLength = cg.getComputationGraphConfiguration().getTbpttFwdLength(); } else { isTbptt = false; tbpttLength = 0; @@ -458,11 +458,11 @@ public class IntegrationTestRunner { epochAfter = mln.getEpochCount(); layers = mln.getLayers(); } else if(modelType == ModelType.CG){ - iterBefore = cg.getConfiguration().getIterationCount(); - epochBefore = cg.getConfiguration().getEpochCount(); + iterBefore = cg.getComputationGraphConfiguration().getIterationCount(); + epochBefore = cg.getComputationGraphConfiguration().getEpochCount(); cg.fit(countingIter); - iterAfter = cg.getConfiguration().getIterationCount(); - epochAfter = cg.getConfiguration().getEpochCount(); + iterAfter = cg.getComputationGraphConfiguration().getIterationCount(); + epochAfter = cg.getComputationGraphConfiguration().getEpochCount(); layers = cg.getLayers(); } else { iterBefore = sd.getTrainingConfig().getIterationCount(); @@ -611,7 +611,7 @@ public class IntegrationTestRunner { } else if(modelType == ModelType.CG){ ModelSerializer.writeModel(m, f, true); ComputationGraph restored = ComputationGraph.load(f, true); - assertEquals(cg.getConfiguration(), restored.getConfiguration()); + assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); assertEquals(cg.params(), restored.params()); } else { sd.save(f, true); @@ -745,7 +745,7 @@ public class IntegrationTestRunner { preProcessors = mln.getLayerWiseConfigurations().getInputPreProcessors().values(); } else { preProcessors = new ArrayList<>(); - for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getConfiguration().getVertices().values()) { + for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getComputationGraphConfiguration().getVertices().values()) { if (gv instanceof LayerVertex) { InputPreProcessor pp = ((LayerVertex) gv).getPreProcessor(); if (pp != null) { @@ -760,7 +760,7 @@ public class IntegrationTestRunner { //Collect vertex coverage information if (!isMLN) { - for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getConfiguration().getVertices().values()) { + for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getComputationGraphConfiguration().getVertices().values()) { vertexConfClassesSeen.put(gv.getClass(), vertexConfClassesSeen.getOrDefault(gv.getClass(), 0) + 1); } } @@ -872,14 +872,14 @@ public class IntegrationTestRunner { log.info("||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"); - log.info("Layer coverage - classes seen:"); + log.info("ILayer coverage - classes seen:"); for (Class c : layerClasses) { if (layerConfClassesSeen.containsKey(c)) { log.info("Class seen {} times in tests: {}", layerConfClassesSeen.get(c), c.getName()); } } - log.info("Layer classes NOT seen in any tests:"); + log.info("ILayer classes NOT seen in any tests:"); for (Class c : layerClasses) { if (!layerConfClassesSeen.containsKey(c)) { log.info("Class NOT seen in any tests: {}", c.getName()); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java index 5c16cc908..e03f2a523 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java @@ -73,7 +73,7 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); - assertEquals(net.getConfiguration(), restored.getConfiguration()); + assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen @@ -81,7 +81,7 @@ public class TestUtils { } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) - ComputationGraphConfiguration conf = net.getConfiguration(); + ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); serializeDeserializeJava(conf); return restored; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java index f1e12d123..cecc969ac 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -90,7 +90,7 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreComputationGraph(bais, true); - assertEquals(net.getConfiguration(), restored.getConfiguration()); + assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen @@ -98,7 +98,7 @@ public class TestUtils { } //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) - ComputationGraphConfiguration conf = net.getConfiguration(); + ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); serializeDeserializeJava(conf); return restored; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 30cb1e5ca..7b44d26c9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -626,7 +626,7 @@ public class EvalTest extends BaseDL4JTest { net.evaluate(iter); net.evaluateROCMultiClass(iter, 0); - cg.getConfiguration().setValidateOutputLayerConfig(false); + cg.getComputationGraphConfiguration().setValidateOutputLayerConfig(false); cg.evaluate(iter); cg.evaluateROCMultiClass(iter, 0); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index 65f8787d8..f45861f57 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -90,7 +90,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { mln.init(); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean @@ -135,7 +135,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { mln.init(); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean @@ -237,7 +237,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // for (int k = 0; k < mln.getnLayers(); k++) -// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); +// System.out.println("ILayer " + k + " # params: " + mln.getLayer(k).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean @@ -341,7 +341,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // for (int k = 0; k < mln.getnLayers(); k++) -// System.out.println("Layer " + k + " # params: " + mln.getLayer(k).numParams()); +// System.out.println("ILayer " + k + " # params: " + mln.getLayer(k).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean @@ -385,7 +385,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { mln.init(); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean @@ -430,7 +430,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { mln.init(); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean @@ -572,7 +572,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l1=" + l1vals[j] + ", l2=" + l2vals[j]); // for (int k = 0; k < net.getNumLayers(); k++) -// System.out.println("Layer " + k + " # params: " + net.getLayer(k).numParams()); +// System.out.println("ILayer " + k + " # params: " + net.getLayer(k).numParams()); //Mean and variance vars are not gradient checkable; mean/variance "gradient" is used to implement running mean/variance calc //i.e., runningMean = decay * runningMean + (1-decay) * batchMean diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index b61c1fe24..b9f461775 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -118,7 +118,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -198,7 +198,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -282,7 +282,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -359,7 +359,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index 4d3de0bfb..1f4a1ceec 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -149,7 +149,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); // for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// log.info("ILayer " + j + " # params: " + net.getLayer(j).numParams()); // } } @@ -252,7 +252,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); // for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// log.info("ILayer " + j + " # params: " + net.getLayer(j).numParams()); // } } @@ -431,7 +431,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); // for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// log.info("ILayer " + j + " # params: " + net.getLayer(j).numParams()); // } } @@ -530,7 +530,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { log.info(msg); // for (int j = 0; j < net.getnLayers(); j++) { -// log.info("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// log.info("ILayer " + j + " # params: " + net.getLayer(j).numParams()); // } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index b9536ee41..b737fcf79 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -137,7 +137,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { System.out.println(name + " - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -231,7 +231,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -293,7 +293,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); @@ -361,7 +361,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels); @@ -427,7 +427,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -500,7 +500,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -920,7 +920,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java index 7cb10f83b..36574096d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java @@ -95,7 +95,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -156,7 +156,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize + " - " + (nchw ? "NCHW" : "NHWC")); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -216,7 +216,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) @@ -299,7 +299,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { System.out.println("testCnnGlobalPoolingBasicMultiLayer() - " + pt + ", minibatch = " + miniBatchSize); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index cab80a69a..553477bd5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -123,7 +123,7 @@ public class GradientCheckTests extends BaseDL4JTest { + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -203,7 +203,7 @@ public class GradientCheckTests extends BaseDL4JTest { + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -297,7 +297,7 @@ public class GradientCheckTests extends BaseDL4JTest { + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", l2=" + l2 + ", l1=" + l1); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -342,7 +342,7 @@ public class GradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testEmbeddingLayerSimple"); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -382,7 +382,7 @@ public class GradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testEmbeddingLayerSimple"); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -472,7 +472,7 @@ public class GradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -714,7 +714,7 @@ public class GradientCheckTests extends BaseDL4JTest { // (a) activation function // (b) Whether to test at random initialization, or after some learning (i.e., 'characteristic mode of operation') // (c) Loss function (with specified output activations) - // (d) Layer Normalization enabled / disabled + // (d) ILayer Normalization enabled / disabled Activation[] activFns = {Activation.SIGMOID, Activation.TANH}; boolean[] characteristic = {true, false}; //If true: run some backprop steps first @@ -776,7 +776,7 @@ public class GradientCheckTests extends BaseDL4JTest { + lf + ", outputActivation=" + outputActivation + ", doLearningFirst=" + doLearningFirst + ", layerNorm=" + layerNorm); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index ec99f3852..7718078a6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -106,7 +106,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIris()"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -157,7 +157,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIrisWithMerging()"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -214,7 +214,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -274,7 +274,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIrisWithElementWiseVertex(op=" + op + ")"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -376,7 +376,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -439,7 +439,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -478,7 +478,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithSubset()"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -515,7 +515,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithLastTimeStepVertex()"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } //First: test with no input mask array @@ -579,7 +579,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithDuplicateToTimeSeries()"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input1, input2}) @@ -628,7 +628,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testLSTMWithReverseTimeSeriesVertex()"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -683,7 +683,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(inputs) @@ -723,7 +723,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -769,7 +769,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(input) @@ -820,7 +820,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{input}) @@ -888,7 +888,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testBasicIrisTripletStackingL2Loss()"); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{pos, anc, neg}) @@ -949,7 +949,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{example}) @@ -1014,7 +1014,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < net.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + net.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + net.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -1063,7 +1063,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) @@ -1121,7 +1121,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) @@ -1179,7 +1179,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) @@ -1242,7 +1242,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } graph.setLayerMaskArrays(new INDArray[] {inMask1, inMask2}, null); @@ -1301,7 +1301,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1, in2}) @@ -1347,7 +1347,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1}) @@ -1398,7 +1398,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < graph.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + graph.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + graph.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(graph).inputs(new INDArray[]{in1}) @@ -1436,7 +1436,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println("testGraphEmbeddingLayerSimple"); // for (int j = 0; j < cg.getNumLayers(); j++) -// System.out.println("Layer " + j + " # params: " + cg.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + cg.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(cg).inputs(new INDArray[]{input}) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java index 9d982818a..87ea20cf5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java @@ -84,7 +84,7 @@ public class LRNGradientCheckTests extends BaseDL4JTest { // if (PRINT_RESULTS) { // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); // } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java index c1e20d858..a2c7d7039 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java @@ -126,7 +126,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -215,7 +215,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.MLNConfig().net(mln).input(input) @@ -343,7 +343,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { + ", lossFn=" + lf + ", outputActivation=" + outputActivation + ", l2=" + l2 + ", l1=" + l1); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java index 5cfec0631..477199be0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java @@ -78,7 +78,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH) - .hasBias(true) //Layer 0: Always have a bias + .hasBias(true) //ILayer 0: Always have a bias .build()) .layer(1, new DenseLayer.Builder().nIn(layerSize).nOut(layerSize) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 1c1da4cee..0928b52de 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -137,7 +137,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } System.out.println("Starting test: " + testName); @@ -244,7 +244,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } System.out.println("Starting test: " + testName); @@ -393,7 +393,7 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(testName); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } System.out.println("Starting test: " + testName); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java index 92ddf8622..40041885e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java @@ -124,7 +124,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, @@ -195,7 +195,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int l = 0; l < mln.getnLayers(); l++) -// System.out.println("Layer " + l + " # params: " + mln.getLayer(l).numParams()); +// System.out.println("ILayer " + l + " # params: " + mln.getLayer(l).numParams()); } boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, @@ -283,7 +283,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, @@ -325,7 +325,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { if (PRINT_RESULTS) { System.out.println(msg); // for (int j = 0; j < mln.getnLayers(); j++) -// System.out.println("Layer " + j + " # params: " + mln.getLayer(j).numParams()); +// System.out.println("ILayer " + j + " # params: " + mln.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradientsPretrainLayer(layer, DEFAULT_EPS, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java index 60b549714..be25a0ccd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java @@ -133,8 +133,8 @@ public class LayerConfigTest extends BaseDL4JTest { //Learning rate without layerwise override: MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.3).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -143,8 +143,8 @@ public class LayerConfigTest extends BaseDL4JTest { //With: conf = new NeuralNetConfiguration.Builder().learningRate(0.3).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).learningRate(0.2).build()).build(); + .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).learningRate(0.2).build()).build(); net = new MultiLayerNetwork(conf); net.init(); @@ -154,8 +154,8 @@ public class LayerConfigTest extends BaseDL4JTest { //L1 and L2 without layerwise override: conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.2).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); net = new MultiLayerNetwork(conf); net.init(); @@ -166,8 +166,8 @@ public class LayerConfigTest extends BaseDL4JTest { //L1 and L2 with layerwise override: conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.2).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l1(0.9).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.8).build()).build(); + .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).l1(0.9).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).l2(0.8).build()).build(); net = new MultiLayerNetwork(conf); net.init(); @@ -326,8 +326,8 @@ public class LayerConfigTest extends BaseDL4JTest { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr) .updater(Updater.SGD) .learningRateDecayPolicy(LearningRatePolicy.Exponential).lrPolicyDecayRate(lrDecayRate).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -345,8 +345,8 @@ public class LayerConfigTest extends BaseDL4JTest { int iterations = 1; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) .learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(lrDecayRate) - .lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + .lrPolicyPower(power).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -367,8 +367,8 @@ public class LayerConfigTest extends BaseDL4JTest { int iterations = 1; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) .learningRateDecayPolicy(LearningRatePolicy.Step).lrPolicyDecayRate(lrDecayRate) - .lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + .lrPolicySteps(steps).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -388,8 +388,8 @@ public class LayerConfigTest extends BaseDL4JTest { int iterations = 1; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) .learningRateDecayPolicy(LearningRatePolicy.Poly).lrPolicyDecayRate(lrDecayRate) - .lrPolicyPower(power).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + .lrPolicyPower(power).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -409,8 +409,8 @@ public class LayerConfigTest extends BaseDL4JTest { int iterations = 1; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) .learningRateDecayPolicy(LearningRatePolicy.Sigmoid).lrPolicyDecayRate(lrDecayRate) - .lrPolicySteps(steps).list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) - .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); + .lrPolicySteps(steps).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) + .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index b3e625849..edad9fb7d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -229,7 +229,7 @@ public class DTypeTests extends BaseDL4JTest { if (seenLayers.size() < layerClasses.size()) { for (Class c : layerClasses) { if (!seenLayers.contains(c) && !ignoreClasses.contains(c)) { - log.warn("Layer class not tested for global vs. network datatypes: {}", c); + log.warn("ILayer class not tested for global vs. network datatypes: {}", c); fail = true; } } @@ -279,7 +279,7 @@ public class DTypeTests extends BaseDL4JTest { } public static void logUsedClasses(ComputationGraph net) { - ComputationGraphConfiguration conf = net.getConfiguration(); + ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); for (GraphVertex gv : conf.getVertices().values()) { seenVertices.add(gv.getClass()); if (gv instanceof LayerVertex) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java index eb8c1cbcc..2d2379fdb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java @@ -65,7 +65,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 12; - //4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors. + //4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors. ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(7) @@ -208,7 +208,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 12; - //4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors. + //4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors. //Network architecture: lstm0 -> Dense -> RnnOutputLayer0 // and lstm1 -> Dense -> RnnOutputLayer1 ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() @@ -391,9 +391,9 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { graphTBPTT.init(); graphTBPTT.clearTbpttState = false; - assertEquals(BackpropType.TruncatedBPTT, graphTBPTT.getConfiguration().getBackpropType()); - assertEquals(timeSeriesLength, graphTBPTT.getConfiguration().getTbpttFwdLength()); - assertEquals(timeSeriesLength, graphTBPTT.getConfiguration().getTbpttBackLength()); + assertEquals(BackpropType.TruncatedBPTT, graphTBPTT.getComputationGraphConfiguration().getBackpropType()); + assertEquals(timeSeriesLength, graphTBPTT.getComputationGraphConfiguration().getTbpttFwdLength()); + assertEquals(timeSeriesLength, graphTBPTT.getComputationGraphConfiguration().getTbpttBackLength()); INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java index a17979bf2..794538c36 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java @@ -42,7 +42,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.learning.config.Adam; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -168,8 +167,8 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest { net.init(); ComputationGraph cg = net.toComputationGraph(); - cg.getConfiguration().setInferenceWorkspaceMode(wsm); - cg.getConfiguration().setTrainingWorkspaceMode(wsm); + cg.getComputationGraphConfiguration().setInferenceWorkspaceMode(wsm); + cg.getComputationGraphConfiguration().setTrainingWorkspaceMode(wsm); DataSetIterator ds = new EarlyTerminationDataSetIterator(new MnistDataSetIterator(1, true, 12345), 1); Nd4j.getRandom().setSeed(12345); net.pretrainLayer(0, ds); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 7a918a674..a6373c6a9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -1033,15 +1033,15 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { DataSetIterator iter = new IrisDataSetIterator(50, 150); - assertEquals(0, network.getConfiguration().getIterationCount()); + assertEquals(0, network.getComputationGraphConfiguration().getIterationCount()); network.fit(iter); - assertEquals(3, network.getConfiguration().getIterationCount()); + assertEquals(3, network.getComputationGraphConfiguration().getIterationCount()); iter.reset(); network.fit(iter); - assertEquals(6, network.getConfiguration().getIterationCount()); + assertEquals(6, network.getComputationGraphConfiguration().getIterationCount()); iter.reset(); network.fit(iter.next()); - assertEquals(7, network.getConfiguration().getIterationCount()); + assertEquals(7, network.getComputationGraphConfiguration().getIterationCount()); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(network, baos, true); @@ -1049,7 +1049,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); ComputationGraph net = ModelSerializer.restoreComputationGraph(bais, true); - assertEquals(7, net.getConfiguration().getIterationCount()); + assertEquals(7, net.getComputationGraphConfiguration().getIterationCount()); } @Test @@ -1272,18 +1272,18 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); net.init(); - assertEquals(0, net.getConfiguration().getEpochCount()); + assertEquals(0, net.getComputationGraphConfiguration().getEpochCount()); DataSetIterator iter = new IrisDataSetIterator(150, 150); for( int i=0; i<4; i++ ){ - assertEquals(i, net.getConfiguration().getEpochCount()); + assertEquals(i, net.getComputationGraphConfiguration().getEpochCount()); net.fit(iter); - assertEquals(i+1, net.getConfiguration().getEpochCount()); + assertEquals(i+1, net.getComputationGraphConfiguration().getEpochCount()); } - assertEquals(4, net.getConfiguration().getEpochCount()); + assertEquals(4, net.getComputationGraphConfiguration().getEpochCount()); ByteArrayOutputStream baos = new ByteArrayOutputStream(); @@ -1293,7 +1293,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true); - assertEquals(4, restored.getConfiguration().getEpochCount()); + assertEquals(4, restored.getComputationGraphConfiguration().getEpochCount()); } @Test @@ -1619,13 +1619,13 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { GraphIndices indices = cg.calculateIndices(); int[] order = cg.topologicalSortOrder(); - List strOrder = cg.getConfiguration().getTopologicalOrderStr(); + List strOrder = cg.getComputationGraphConfiguration().getTopologicalOrderStr(); INDArray[] out1 = cg.output(in); //Check it's the same after loading: ComputationGraph cg2 = TestUtils.testModelSerialization(cg); int[] order2 = cg2.topologicalSortOrder(); - List strOrder2 = cg.getConfiguration().getTopologicalOrderStr(); + List strOrder2 = cg.getComputationGraphConfiguration().getTopologicalOrderStr(); assertArrayEquals(order, order2); assertEquals(strOrder, strOrder2); @@ -1633,7 +1633,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertArrayEquals(out1, out2); //Delete the topological order, ensure it gets recreated properly: - ComputationGraphConfiguration conf3 = cg2.getConfiguration().clone(); + ComputationGraphConfiguration conf3 = cg2.getComputationGraphConfiguration().clone(); conf3.setTopologicalOrder(null); conf3.setTopologicalOrderStr(null); ComputationGraph cg3 = new ComputationGraph(conf3); @@ -1641,7 +1641,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { cg3.setParams(cg2.params()); int[] order3 = cg3.topologicalSortOrder(); - List strOrder3 = cg.getConfiguration().getTopologicalOrderStr(); + List strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr(); INDArray[] out3 = cg3.output(in); assertArrayEquals(order, order3); assertEquals(strOrder, strOrder3); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java index c3543e167..0f506dbfe 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java @@ -235,7 +235,7 @@ public class FrozenLayerTest extends BaseDL4JTest { ComputationGraph clonedModel = modelNow.clone(); //Check json - assertEquals(clonedModel.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); + assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson()); //Check params assertEquals(modelNow.params(), clonedModel.params()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java index 67f66fb21..868f34ba7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java @@ -50,7 +50,7 @@ public class TestDropout extends BaseDL4JTest { @Test public void testDropoutSimple() throws Exception { //Testing dropout with a single layer - //Layer input: values should be set to either 0.0 or 2.0x original value + //ILayer input: values should be set to either 0.0 or 2.0x original value int nIn = 8; int nOut = 8; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index 259a38382..55c26b12b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -200,7 +200,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { @Test public void testEmbeddingForwardPass() { //With the same parameters, embedding layer should have same activations as the equivalent one-hot representation - // input with a DenseLayer + // input with a DenseLayerConfiguration int nClassesIn = 10; @@ -243,7 +243,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { @Test public void testEmbeddingBackwardPass() { //With the same parameters, embedding layer should have same activations as the equivalent one-hot representation - // input with a DenseLayer + // input with a DenseLayerConfiguration int nClassesIn = 10; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java index e9f76dfc2..0eaa156f1 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java @@ -104,7 +104,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { + "ocnn" + "sigmoid" + ", doLearningFirst=" + doLearningFirst); for (int j = 0; j < network.getnLayers(); j++) - System.out.println("Layer " + j + " # params: " + network.getLayer(j).numParams()); + System.out.println("ILayer " + j + " # params: " + network.getLayer(j).numParams()); } boolean gradOK = GradientCheckUtil.checkGradients(network, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java index e84390916..3595282c0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java @@ -98,7 +98,7 @@ public class SameDiffDense extends SameDiffLayer { if(DefaultParamInitializer.BIAS_KEY.equals(e.getKey())){ e.getValue().assign(0.0); } else { - //Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayer + //Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayerConfiguration WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', e.getValue()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java index da674ea7c..baa4cee7e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDenseVertex.java @@ -72,14 +72,14 @@ public class SameDiffDenseVertex extends SameDiffVertex { @Override public void initializeParameters(Map params) { - //Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayer + //Normally use 'c' order, but use 'f' for direct comparison to DL4J DenseLayerConfiguration WeightInitUtil.initWeights(nIn, nOut, new long[]{nIn, nOut}, weightInit, null, 'f', params.get("W")); params.get("b").assign(0.0); } @Override public char paramReshapeOrder(String paramName){ - return 'f'; //To match DL4J DenseLayer - for easy comparison + return 'f'; //To match DL4J DenseLayerConfiguration - for easy comparison } @Override diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java index cf7d31bd5..5b00685af 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java @@ -73,8 +73,8 @@ public class WorkspaceTests extends BaseDL4JTest { ComputationGraph c = createNet(); for (WorkspaceMode wm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { log.info("Starting test: {}", wm); - c.getConfiguration().setTrainingWorkspaceMode(wm); - c.getConfiguration().setInferenceWorkspaceMode(wm); + c.getComputationGraphConfiguration().setTrainingWorkspaceMode(wm); + c.getComputationGraphConfiguration().setInferenceWorkspaceMode(wm); INDArray f = Nd4j.rand(8, 1, 28, 28); INDArray l = Nd4j.rand(8, 10); @@ -666,8 +666,8 @@ public class WorkspaceTests extends BaseDL4JTest { ComputationGraph c = createNet(); for (WorkspaceMode wm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { log.info("Starting test: {}", wm); - c.getConfiguration().setTrainingWorkspaceMode(wm); - c.getConfiguration().setInferenceWorkspaceMode(wm); + c.getComputationGraphConfiguration().setTrainingWorkspaceMode(wm); + c.getComputationGraphConfiguration().setInferenceWorkspaceMode(wm); INDArray f = Nd4j.rand(8, 1, 28, 28); INDArray l = Nd4j.rand(8, 10); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index 056f4a43e..49d70647c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -995,7 +995,7 @@ public class MultiLayerTest extends BaseDL4JTest { @Test public void testCompareLayerMethods(){ - //Simple test: compare .layer(int, Layer) and .layer(Layer) are identical + //Simple test: compare .layer(int, ILayer) and .layer(ILayer) are identical MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java index 5064e44ab..a12bd88f9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java @@ -261,7 +261,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 12; - //4 layer network: 2 GravesLSTM + DenseLayer + RnnOutputLayer. Hence also tests preprocessors. + //4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors. MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() .layer(0, l0) .layer(1, l1) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java index 410abf970..92b8375dd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java @@ -216,8 +216,8 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { net2GradUpd.getUpdater().getStateViewArray()); //Remove the next 2 lines: fails - as net 1 is 1 iteration ahead - net1GradCalc.getConfiguration().setIterationCount(0); - net2GradUpd.getConfiguration().setIterationCount(0); + net1GradCalc.getComputationGraphConfiguration().setIterationCount(0); + net2GradUpd.getComputationGraphConfiguration().setIterationCount(0); for (int i = 0; i < 100; i++) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java index ad92a7c47..44c3bcb07 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java @@ -120,7 +120,7 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest { assertTrue(withFrozen.getLayer(0) instanceof FrozenLayer); assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer); - Map m = withFrozen.getConfiguration().getVertices(); + Map m = withFrozen.getComputationGraphConfiguration().getVertices(); Layer l0 = ((LayerVertex) m.get("0")).getLayerConf().getLayer(); Layer l1 = ((LayerVertex) m.get("1")).getLayerConf().getLayer(); assertTrue(l0 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java index a81d96838..efc821b6e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java @@ -102,7 +102,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { .build(); //Check json - assertEquals(expectedConf.toJson(), modelNow.getConfiguration().toJson()); + assertEquals(expectedConf.toJson(), modelNow.getComputationGraphConfiguration().toJson()); //Check params after fit modelNow.fit(randomData); @@ -382,7 +382,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { modelExpectedArch.getVertex("layer0").setLayerAsFrozen(); modelExpectedArch.getVertex("layer1").setLayerAsFrozen(); - assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); + assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson()); modelNow.setParams(modelExpectedArch.params()); int i = 0; @@ -445,7 +445,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { // assertEquals(confExpected, graph.getConfiguration()); - assertEquals(confExpected.toJson(), graph.getConfiguration().toJson()); + assertEquals(confExpected.toJson(), graph.getComputationGraphConfiguration().toJson()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java index 0e78a3d6c..d7e58be43 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java @@ -126,7 +126,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest { .setOutputs("outLeft", "outCentre", "outRight").build(); ComputationGraph expectedModel = new ComputationGraph(expectedConf); expectedModel.init(); - assertEquals(expectedConf.toJson(), modelSubset.getConfiguration().toJson()); + assertEquals(expectedConf.toJson(), modelSubset.getComputationGraphConfiguration().toJson()); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index 5b7bec134..73e1a7a56 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -764,7 +764,7 @@ public class TestOptimizers extends BaseDL4JTest { } - /** Simple abstract class to deal with the fact that we don't care about the majority of the Model/Layer + /** Simple abstract class to deal with the fact that we don't care about the majority of the Model/ILayer * methods here. Classes extending this model for optimizer tests need only implement the score() and * gradient() methods. */ @@ -907,7 +907,7 @@ public class TestOptimizers extends BaseDL4JTest { @Override public INDArray input() { - //Work-around for BaseUpdater.postApply(): Uses Layer.input().size(0) + //Work-around for BaseUpdater.postApply(): Uses ILayer.input().size(0) //in order to get mini-batch size. i.e., divide by 1 here. return Nd4j.zeros(1); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index 985f347d8..87a53e54a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -221,7 +221,7 @@ public class RegressionTest060 extends BaseDL4JTest { ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true); - ComputationGraphConfiguration conf = net.getConfiguration(); + ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); assertEquals(3, conf.getVertices().size()); GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index 2a75e7994..0dc3839bb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -221,7 +221,7 @@ public class RegressionTest071 extends BaseDL4JTest { ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true); - ComputationGraphConfiguration conf = net.getConfiguration(); + ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); assertEquals(3, conf.getVertices().size()); GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index 6566f03fe..6460582ba 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -237,7 +237,7 @@ public class RegressionTest080 extends BaseDL4JTest { ComputationGraph net = ModelSerializer.restoreComputationGraph(f, true); - ComputationGraphConfiguration conf = net.getConfiguration(); + ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); assertEquals(3, conf.getVertices().size()); GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index acee54871..f294e16a7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -171,7 +171,7 @@ public class RegressionTest100a extends BaseDL4JTest { int nBoxes = 5; int nClasses = 10; - ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer(); + ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getComputationGraphConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer(); assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index 8df2f258b..35fb7391b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -206,7 +206,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { int nBoxes = 5; int nClasses = 10; - ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer(); + ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getComputationGraphConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer(); assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index 5b4270a4e..00e46bf0c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -224,7 +224,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { int nBoxes = 5; int nClasses = 10; - ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices() + ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getComputationGraphConfiguration().getVertices() .get("convolution2d_9")).getLayerConf().getLayer(); assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index 40df45924..15a9c2bc3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -205,7 +205,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { int nBoxes = 5; int nClasses = 10; - ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getConfiguration().getVertices() + ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getComputationGraphConfiguration().getVertices() .get("convolution2d_9")).getLayerConf().getLayer(); assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java index 00a2b6242..acb3963b1 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java @@ -94,7 +94,7 @@ public class CustomLayer extends FeedForwardLayer { @Override public ParamInitializer initializer() { //This method returns the parameter initializer for this type of layer - //In this case, we can use the DefaultParamInitializer, which is the same one used for DenseLayer + //In this case, we can use the DefaultParamInitializer, which is the same one used for DenseLayerConfiguration //For more complex layers, you may need to implement a custom parameter initializer //See the various parameter initializers here: //https://github.com/deeplearning4j/deeplearning4j/tree/master/deeplearning4j-core/src/main/java/org/deeplearning4j/nn/params @@ -108,7 +108,7 @@ public class CustomLayer extends FeedForwardLayer { //If you don't need this functionality for your custom layer, you can return a LayerMemoryReport // with all 0s, or - //This implementation: based on DenseLayer implementation + //This implementation: based on DenseLayerConfiguration implementation InputType outputType = getOutputType(-1, inputType); val numParams = initializer().numParams(this); @@ -131,7 +131,7 @@ public class CustomLayer extends FeedForwardLayer { .workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, - MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer + MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayerConfiguration .build(); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java index 4da9883b8..8bfaa9eb2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java @@ -117,7 +117,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { String str = FileUtils.readFileToString(list[0]); // System.out.println(str); assertTrue(str.contains("Network Information")); - assertTrue(str.contains("Layer Helpers")); + assertTrue(str.contains("ILayer Helpers")); assertTrue(str.contains("JavaCPP")); assertTrue(str.contains("ScoreIterationListener")); @@ -134,7 +134,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { assertEquals(1, list.length); str = FileUtils.readFileToString(list[0]); assertTrue(str.contains("Network Information")); - assertTrue(str.contains("Layer Helpers")); + assertTrue(str.contains("ILayer Helpers")); assertTrue(str.contains("JavaCPP")); assertTrue(str.contains("ScoreIterationListener(1)")); @@ -150,7 +150,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { // System.out.println("///////////////////////////////////////////////////////////"); assertTrue(mlnMemoryInfo.contains("Network Information")); - assertTrue(mlnMemoryInfo.contains("Layer Helpers")); + assertTrue(mlnMemoryInfo.contains("ILayer Helpers")); assertTrue(mlnMemoryInfo.contains("JavaCPP")); assertTrue(mlnMemoryInfo.contains("ScoreIterationListener(1)")); @@ -172,7 +172,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { assertEquals(1, list.length); str = FileUtils.readFileToString(list[0]); assertTrue(str.contains("Network Information")); - assertTrue(str.contains("Layer Helpers")); + assertTrue(str.contains("ILayer Helpers")); assertTrue(str.contains("JavaCPP")); assertTrue(str.contains("ScoreIterationListener(1)")); @@ -187,7 +187,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { assertEquals(1, list.length); str = FileUtils.readFileToString(list[0]); assertTrue(str.contains("Network Information")); - assertTrue(str.contains("Layer Helpers")); + assertTrue(str.contains("ILayer Helpers")); assertTrue(str.contains("JavaCPP")); assertTrue(str.contains("ScoreIterationListener(1)")); @@ -203,7 +203,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { // System.out.println("///////////////////////////////////////////////////////////"); assertTrue(cgMemoryInfo.contains("Network Information")); - assertTrue(cgMemoryInfo.contains("Layer Helpers")); + assertTrue(cgMemoryInfo.contains("ILayer Helpers")); assertTrue(cgMemoryInfo.contains("JavaCPP")); assertTrue(cgMemoryInfo.contains("ScoreIterationListener(1)")); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java index 610cb0961..e01d42f01 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java @@ -151,7 +151,7 @@ public class ModelSerializerTest extends BaseDL4JTest { ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile); - assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson()); + assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson()); assertEquals(cg.params(), network.params()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -177,7 +177,7 @@ public class ModelSerializerTest extends BaseDL4JTest { ComputationGraph network = ModelSerializer.restoreComputationGraph(fis); - assertEquals(network.getConfiguration().toJson(), cg.getConfiguration().toJson()); + assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson()); assertEquals(cg.params(), network.params()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } diff --git a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java index 120078d07..2b71d920a 100644 --- a/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java +++ b/cavis-dnn/cavis-dnn-cudnn/src/main/java/org/deeplearning4j/cuda/recurrent/CudnnLSTMHelper.java @@ -198,7 +198,7 @@ public class CudnnLSTMHelper extends BaseCudnnHelper implements LSTMHelper { } if (!(activationFn instanceof ActivationTanH)) { supported = false; - log.warn("Not supported: Layer activation functions != ActivationTanH"); + log.warn("Not supported: ILayer activation functions != ActivationTanH"); } if (hasPeepholeConnections) { supported = false; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java index 5c8c829c4..601237b53 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java @@ -295,7 +295,7 @@ public class KerasLayer { } /** - * Copy Keras layer weights to DL4J Layer. + * Copy Keras layer weights to DL4J ILayer. * * @param layer DL4J layer * @throws InvalidKerasConfigurationException Invalid Keras configuration @@ -358,7 +358,7 @@ public class KerasLayer { } /** - * Whether this Keras layer maps to a DL4J Layer. + * Whether this Keras layer maps to a DL4J ILayer. * * @return true or false */ @@ -367,9 +367,9 @@ public class KerasLayer { } /** - * Gets corresponding DL4J Layer, if any. + * Gets corresponding DL4J ILayer, if any. * - * @return DL4J Layer + * @return DL4J ILayer * @see org.deeplearning4j.nn.api.Layer */ public Layer getLayer() { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java index d4bf6ba92..ea0b99f0c 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java @@ -583,8 +583,8 @@ public class KerasModel { graphBuilder.addVertex(layer.getLayerName(), layer.getVertex(), inboundLayerNamesArray); } else if (layer.isInputPreProcessor()) { if (preprocessor == null) - throw new UnsupportedKerasConfigurationException("Layer " + layer.getLayerName() - + " could not be mapped to Layer, Vertex, or InputPreProcessor"); + throw new UnsupportedKerasConfigurationException("ILayer " + layer.getLayerName() + + " could not be mapped to ILayer, Vertex, or InputPreProcessor"); graphBuilder.addVertex(layer.getLayerName(), new PreprocessorVertex(preprocessor), inboundLayerNamesArray); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java index a0082f4f1..d454d1e97 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/config/KerasLayerConfiguration.java @@ -246,7 +246,7 @@ public class KerasLayerConfiguration { private final String LAYER_FIELD_RATE = "rate"; private final String LAYER_FIELD_GAUSSIAN_VARIANCE = ""; // 1: sigma, 2: stddev - /* Layer wrappers */ + /* ILayer wrappers */ // Missing: TimeDistributed diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java index 9eae1f08e..f49599ccf 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDense.java @@ -115,9 +115,9 @@ public class KerasDense extends KerasLayer { } /** - * Get DL4J DenseLayer. + * Get DL4J DenseLayerConfiguration. * - * @return DenseLayer + * @return DenseLayerConfiguration */ public DenseLayer getDenseLayer() { return (DenseLayer) this.layer; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index 4e35a6867..e1c6be765 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -211,10 +211,10 @@ public class KerasLSTM extends KerasLayer { } /** - * Get DL4J Layer. If returnSequences is true, this can be casted to an "LSTM" layer, otherwise it can be casted + * Get DL4J ILayer. If returnSequences is true, this can be casted to an "LSTM" layer, otherwise it can be casted * to a "LastTimeStep" layer. * - * @return LSTM Layer + * @return LSTM ILayer */ public Layer getLSTMLayer() { return layer; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index ac2d4c234..ea71fc8d7 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -184,7 +184,7 @@ public class KerasSimpleRnn extends KerasLayer { /** * Get DL4J SimpleRnn layer. * - * @return SimpleRnn Layer + * @return SimpleRnn ILayer */ public Layer getSimpleRnnLayer() { return this.layer; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java index fa5f5b508..ccbbbd9d6 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java @@ -160,7 +160,7 @@ public class KerasBidirectional extends KerasLayer { /** * Return the underlying recurrent layer of this bidirectional layer * - * @return Layer, recurrent layer + * @return ILayer, recurrent layer */ public Layer getUnderlyingRecurrentLayer() { return kerasRnnlayer.getLayer(); @@ -169,7 +169,7 @@ public class KerasBidirectional extends KerasLayer { /** * Get DL4J Bidirectional layer. * - * @return Bidirectional Layer + * @return Bidirectional ILayer */ public Bidirectional getBidirectionalLayer() { return (Bidirectional) this.layer; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java index 1120dfbb8..f50df5084 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java @@ -85,7 +85,7 @@ public class FullModelComparisons extends BaseDL4JTest { System.out.println(model.summary()); - // 1. Layer + // 1. ILayer LSTM firstLstm = (LSTM) model.getLayer(0); org.deeplearning4j.nn.conf.layers.LSTM firstConf = (org.deeplearning4j.nn.conf.layers.LSTM) firstLstm.conf().getLayer(); @@ -123,7 +123,7 @@ public class FullModelComparisons extends BaseDL4JTest { Assertions.assertEquals(b.getDouble(0, 192), -0.13569744, 1e-7); // Keras O Assertions.assertEquals(b.getDouble(0, 0), -0.2587392, 1e-7); // Keras C - // 2. Layer + // 2. ILayer LSTM secondLstm = (LSTM) ((LastTimeStepLayer) model.getLayer(1)).getUnderlying(); org.deeplearning4j.nn.conf.layers.LSTM secondConf = (org.deeplearning4j.nn.conf.layers.LSTM) secondLstm.conf().getLayer(); diff --git a/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/LayerConfiguration.java b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/LayerConfiguration.java index 0b274cb8c..6b395a5b2 100644 --- a/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/LayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn-api/src/main/java/net/brutex/ai/dnn/api/LayerConfiguration.java @@ -39,4 +39,13 @@ public interface LayerConfiguration { */ org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType(); + + /** + * Defines the valid input type for this Layer + * + * @return InputType + */ + org.deeplearning4j.nn.conf.inputs.InputType.Type getOutputType(); + + } diff --git a/cavis-dnn/cavis-dnn-nn/build.gradle b/cavis-dnn/cavis-dnn-nn/build.gradle index e0f85570d..0e097093d 100644 --- a/cavis-dnn/cavis-dnn-nn/build.gradle +++ b/cavis-dnn/cavis-dnn-nn/build.gradle @@ -22,7 +22,7 @@ apply from: "${project.rootProject.projectDir}/createTestBackends.gradle" dependencies { implementation platform(projects.cavisCommonPlatform) - implementation projects.cavisDnn.cavisDnnNnApi +// implementation projects.cavisDnn.cavisDnnNnApi implementation projects.cavisDnn.cavisDnnData.cavisDnnDataUtilityIterators implementation 'org.lucee:oswego-concurrent:1.3.4' implementation projects.cavisDnn.cavisDnnCommon @@ -57,4 +57,6 @@ dependencies { // define any required OkHttp artifacts without version implementation "com.squareup.okhttp3:okhttp" implementation "com.squareup.okhttp3:logging-interceptor" -} \ No newline at end of file +} +sourceCompatibility = JavaVersion.VERSION_11 +targetCompatibility = JavaVersion.VERSION_11 diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/LayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayer.java similarity index 60% rename from cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/LayerConfiguration.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayer.java index 16c67b491..a43b94265 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/LayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayer.java @@ -19,10 +19,28 @@ * */ -package net.brutex.ai.dnn.conf.layer; +package net.brutex.ai.dnn.api; -public abstract class LayerConfiguration { +/** + * This is an "executable" ILayer, that is based on a {@link ILayerConfiguration} + */ +public interface ILayer { + /** + * Get the underlying configuration for this ILayer + * @return configuration + */ + ILayerConfiguration getLayerConfiguration(); + /** + * Set the underlying layer configuration + * @param conf The new configuration + */ + void setLayerConfiguration(ILayerConfiguration conf); + /** + * An implementation should provide a method to validate the network + * @return true if no errors found; false otherwise + */ + boolean isValid(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FFLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayerConfiguration.java similarity index 56% rename from cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FFLayer.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayerConfiguration.java index d903e9002..e0f5d856b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FFLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayerConfiguration.java @@ -19,34 +19,45 @@ * */ -package net.brutex.ai.dnn.conf.layer; - -import lombok.extern.slf4j.Slf4j; -import net.brutex.ai.dnn.api.Layer; -import net.brutex.ai.dnn.api.NeuralNetwork; -import net.brutex.ai.dnn.conf.layer.AbstractLayerConfiguration; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.inputs.InputType.Type; - -@Slf4j -public class FFLayer extends AbstractLayerConfiguration { +package net.brutex.ai.dnn.api; +public interface ILayerConfiguration { /** - * Create and return an instance of a LayerConfiguration. + * Create and return an instance of a ILayerConfiguration. * * @param network the "holding" network for the instance * @return the new layer instance */ - @Override - public Layer instantiate(NeuralNetwork network) { - //Let's do some verifications first - if(getInputType() != Type.FF) { - log.error("The {} layer configuration must use an InputType of {}, but found {}", - this.getClass().getSimpleName(), - Type.FF.name(), - getInputType().name()); - } - return null; - } + ILayer instantiate(IModel network); + + + /** + * Defines the valid input type for this ILayer + * + * @return InputType + */ + org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType(); + + + /** + * Defines the valid input type for this ILayer + * + * @return InputType + */ + org.deeplearning4j.nn.conf.inputs.InputType.Type getOutputType(); + + + /** + * Number of trainable parameter in this layer + * @return number of parameter + */ + long numParameters(); + + /** + * An implementation should provide a method to validate the network + * @return true if no errors found; false otherwise + */ + boolean isValid(); + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java new file mode 100644 index 000000000..f0c6a722a --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java @@ -0,0 +1,86 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; + +/** + * A Neural Network is an instance of a {@link INeuralNetworkConfiguration}, that can be trained, + * evaluated, saved, exported, etc. Its configuration state is defined with the + * {@link #setConfiguration(INeuralNetworkConfiguration)} and {@link #getConfiguration()} methods. + * + */ +public interface IModel { + + /** + * The configuration that defines this Neural Network + * + * @param conf the configuration to use for this network + */ + void setConfiguration(INeuralNetworkConfiguration conf); + INeuralNetworkConfiguration getConfiguration(); + + /** + * Fit the model for one iteration on the provided data + * + * @param features the examples to classify (one example in each row) + * @param labels the example labels(a binary outcome matrix) + * @param featuresMask The mask array for the features (used for variable length time series, etc). May be null. + * @param labelsMask The mask array for the labels (used for variable length time series, etc). May be null. + */ + void fit(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask); + + /** + * This method fits model with a given DataSet + * + * @param dataSet the dataset to use for training + */ + void fit(DataSet dataSet); + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet the multi dataset to use for training + */ + void fit(MultiDataSet dataSet); + + /** + * The name of the Neural Network + * @return the name + */ + String getName(); + + /** + * Set the name for this Neural Network + * @param name the name + */ + void setName(String name); + + /** + * An implementation should provide a method to validate the network + * @return true if no errors found; false otherwise + */ + boolean isValid(); + +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/NeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetwork.java similarity index 58% rename from cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/NeuralNetwork.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetwork.java index c9437b838..48d6c561b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/NeuralNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetwork.java @@ -1,25 +1,27 @@ /* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * */ -package org.deeplearning4j.nn.api; +package net.brutex.ai.dnn.api; +import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,7 +33,7 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; /** * @author raver119 */ -public interface NeuralNetwork { +public interface INeuralNetwork { /** * This method does initialization of model @@ -104,4 +106,17 @@ public interface NeuralNetwork { * @param iterator */ T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations); + + /** + * A neural network is created from a configuration. + * @param conf the configuration to create the network from + */ + void setConfiguration(NeuralNetworkConfiguration conf); + + /** + * Return the configuration for this configuration + * @return + */ + NeuralNetworkConfiguration getConfiguration(); + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java new file mode 100644 index 000000000..81d447fa3 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java @@ -0,0 +1,52 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +import java.util.List; + +public interface INeuralNetworkConfiguration { + +} +/** + /** + * Provides a flat list of all embedded layer configurations, this + * can only be called after the layer is initialized or {@link #getLayerConfigurations()} is + * called. + * + * @return unstacked layer configurations + + List getLayerConfigurations(); + + + /** + * This uncollables any stacked layer configurations within building blocks like + * @link BuildingBlockLayer} + + void calculateInnerLayerConfigurations(); + + /** + * An implementation should provide a method to validate the network + * @return true if no errors found; false otherwise + + boolean isValid(); +} +**/ \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java index e383ea9c7..51de9f873 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java @@ -22,32 +22,61 @@ package net.brutex.ai.dnn.conf; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; +import com.fasterxml.jackson.databind.node.ArrayNode; +import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Random; import lombok.Getter; import lombok.NonNull; import lombok.Setter; import lombok.Singular; import lombok.extern.jackson.Jacksonized; import lombok.extern.slf4j.Slf4j; -import net.brutex.ai.dnn.api.LayerConfiguration; +import net.brutex.ai.dnn.api.ILayerConfiguration; +import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.CacheMode; +import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; +import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; +import org.deeplearning4j.nn.conf.serde.JsonMappers; +import org.deeplearning4j.nn.weights.IWeightInit; +import org.deeplearning4j.nn.weights.WeightInit; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; +import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; +import org.nd4j.linalg.lossfunctions.impl.LossMSE; +import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; /** - * The NeuralNetworkConfiguration is a sequential container for the different layers in your + * The INeuralNetworkConfiguration is a sequential container for the different layers in your * network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be * stacked.

- * It then “chains” outputs to inputs sequentially for each NeuralNetworkConfiguration, + * It then “chains” outputs to inputs sequentially for each INeuralNetworkConfiguration, * finally returning the output of the "top" configuration. Any settings made, are inherited and can - * be overridden on a "deeper" level. For this use case, you need to wrap the NeuralNetworkConfiguration + * be overridden on a "deeper" level. For this use case, you need to wrap the INeuralNetworkConfiguration * into a BuildingBlockLayer * */ @@ -55,77 +84,54 @@ import org.deeplearning4j.nn.conf.layers.wrapper.BuildingBlockLayer; @JsonIgnoreProperties(ignoreUnknown = true) @lombok.Builder @Slf4j -public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralNetworkConfiguration, Serializable, Cloneable { - - /** - * The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise. - * Valid values are
- * CacheMode.NONE,
- * CacheMode.HOST or
- * CacheMode.DEVICE
- */ - @NonNull - @lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE; - - @Getter @Setter @NonNull - protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; - - @Getter @Setter @NonNull - protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; - - @Getter @Setter @NonNull - protected BackpropType backpropType = BackpropType.Standard; - - @Getter - protected Map inputPreProcessors = new HashMap<>(); - - - @Getter @Setter protected int tbpttFwdLength = 20; - @Getter @Setter protected int tbpttBackLength = 20; - - - /** - * The list of layer configurations in this configuration. They will be indexed automatically - * as the layers get added starting with index 0. - */ - @Singular @Getter - private List layerConfigurations; - - /** - * The name for this configuration. Defaults to "Anonymous NeuralNetworkConfiguration" if - * it is not specified. - */ - @lombok.Builder.Default @Getter - private String name = "Anonymous NeuralNetworkConfiguration"; - - - /** - * The {@link InputType} of the data for this network configuration - */ - private InputType inputType; +public class NeuralNetworkConfiguration extends NeuralNetConfiguration implements + INeuralNetworkConfiguration, Serializable, Cloneable { + private static final int DEFAULT_TBPTT_LENGTH = 20; + @Getter protected final List confs = new ArrayList<>(); /** * hidden list of layers, that "flattens" all the layers of this network and applies * inheritance. */ @lombok.Builder.ObtainVia(method = "calculateInnerLayers") - private final List innerLayerConfigurations; - - @Override - public void calculateInnerLayerConfigurations() { - List list = new ArrayList<>(); - for( LayerConfiguration layer : this.layerConfigurations) { - if(layer instanceof BuildingBlockLayer) { - BuildingBlockLayer blayer = (BuildingBlockLayer) layer; - blayer.getConf().calculateInnerLayerConfigurations(); - list.addAll(blayer.getConf().getLayerConfigurations()); - } else { - list.add(layer); - } - } - this.layerConfigurations = list; - } - + private final List innerLayerConfigurations; + @Getter @Setter @NonNull @Singular + protected List layers = new ArrayList<>(); + @Getter @Setter @NonNull @lombok.Builder.Default @Deprecated + protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; + @Getter @Setter @NonNull @lombok.Builder.Default @Deprecated + protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; + /** + * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but + * optionally truncated BPTT can be used for training recurrent neural networks. If using + * TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() + */ + @Getter @Setter @NonNull @lombok.Builder.Default + protected BackpropType backpropType = BackpropType.Standard; + @Getter + protected Map inputPreProcessors = new HashMap<>(); + /** + * When doing truncated BPTT: how many steps of forward pass should we do before doing + * (truncated) backprop?
Only applicable when doing + * backpropType(BackpropType.TruncatedBPTT)
Typically tBPTTForwardLength parameter is same + * as the tBPTTBackwardLength parameter, but may be larger than it in some circumstances (but + * never smaller)
Ideally your training data time series length should be divisible by this + * This is the k1 parameter on pg23 of + * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param forwardLength Forward length > 0, >= backwardLength + */ + @Getter @Setter protected int tbpttFwdLength = 20; + /** + * When doing truncated BPTT: how many steps of backward should we do?
Only applicable when + * doing backpropType(BackpropType.TruncatedBPTT)
This is the k2 parameter on pg23 of + * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param backwardLength <= forwardLength + */ + @Getter @Setter protected int tbpttBackLength = 20; /** * Creates and returns a copy of this object. * @@ -136,8 +142,564 @@ public class NeuralNetworkConfiguration implements net.brutex.ai.dnn.api.NeuralN * cannot be cloned. * @see Cloneable */ - @Override - protected Object clone() throws CloneNotSupportedException { - return super.clone(); + + //Nd4j.getRandom().setSeed(getConf(0).getSeed()); //TODO + //Counter for the number of parameter updates so far + // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted + // for Spark and model serialization + @Getter @Setter + protected int iterationCount = 0; + //Counter for the number of epochs completed so far. Used for per-epoch schedules + @Getter @Setter + protected int epochCount = 0; + protected double dampingFactor = 100; + @Getter @Setter //todo why? + private Layer layer; + /** + * A seed for this network, will be random if not specified. + */ + @Getter @Setter @NonNull @lombok.Builder.Default + private long seed = new Random().nextLong(); + /** + * The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise. + * This method defines how/if preOutput cache is handled: NONE: cache disabled (default value) + * HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect will + * be the same as for HOST) + * + * Valid values are
+ * CacheMode.NONE,
+ * CacheMode.HOST or
+ * CacheMode.DEVICE
+ * @param cacheMode + */ + @NonNull @Getter @Setter + @lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE; + /** + * The list of layer configurations in this configuration. They will be indexed automatically + * as the layers get added starting with index 0. + */ + @Singular @Getter + private List layerConfigurations; + /** + * The name for this configuration. Defaults to "Anonymous INeuralNetworkConfiguration" if + * it is not specified. + */ + @lombok.Builder.Default @Getter + private String name = "Anonymous INeuralNetworkConfiguration"; + /** + * The {@link InputType} of the data for this network configuration + */ + private InputType inputType; + /** + * Set the DataType for the network parameters and activations for all layers in the network. + * Default: Float + * + * @param dataType Datatype to use for parameters and activations + */ + @Getter @Setter @lombok.Builder.Default @NonNull + private DataType dataType = DataType.FLOAT; + /** + * Whether to override the nIn configuration forcibly upon construction. Default value is true. + * @return builder pattern + */ + @Getter @Setter + @lombok.Builder.Default + private boolean overrideNinUponBuild = true; + /** + * Enabled by default. If enabled, the output layer configuration will be validated, to throw an + * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
If + * disabled (false) no output layer validation will be performed.
Disabling this validation + * is not recommended, as the configurations that fail validation usually will not be able to + * learn correctly. However, the option to disable this validation is provided for advanced + * users when creating non-standard architectures. + * + * @param validate If true: validate output layer configuration. False: don't validate + */ + @Getter @Setter @lombok.Builder.Default + private boolean validateOutputLayerConfig=true; + /** + * Enabled by default. If enabled, an exception will be throw when using the (invalid) + * combination of truncated backpropagation through time (TBPTT) with either a + * GlobalPoolingLayer or LastTimeStepLayer.
It is possible to disable this validation to + * allow what is almost certainly an invalid configuration to be used, however this is not + * recommended. + * + * @param validate Whether TBPTT validation should be performed + */ + @Getter @Setter @lombok.Builder.Default + private boolean validateTbpttConfig=true; + + + + /** + * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} + * or {@link org.nd4j.linalg.learning.config.Nesterovs}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param updater Updater to use + */ + @Getter @Setter @NonNull + private IUpdater updater; + + /** + * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. + * See {@link GradientNormalization} for details
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param gradientNormalization Type of normalization to use. Defaults to None. + * @see GradientNormalization + */ + @Getter @Setter @NonNull @lombok.Builder.Default + private GradientNormalization gradientNormalization = GradientNormalization.None; + + /** + * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, + * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
+ * Not used otherwise.
+ * L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + */ + @Getter @Setter + private double gradientNormalizationThreshold; + + + /** + * Weight initialization scheme to use, for initial weight values + * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + */ + @Getter @Setter + private IWeightInit weightInit; + + /** + * Activation function / neuron non-linearity
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + */ + @Getter @Setter + private IActivation activation; + + + + /** + * Create a neural net configuration from json + * + * @param json the neural net configuration from json + * @return {@link NeuralNetworkConfiguration} + */ + public static NeuralNetworkConfiguration fromJson(String json) { + NeuralNetworkConfiguration conf; + ObjectMapper mapper = NeuralNetworkConfiguration.mapper(); + try { + conf = mapper.readValue(json, NeuralNetworkConfiguration.class); + } catch (InvalidTypeIdException e) { + if (e.getMessage().contains("@class")) { + try { + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, NeuralNetworkConfiguration.class); + } catch (InvalidTypeIdException e2) { + //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.ILayer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." + //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work + String msg = e2.getMessage(); + if (msg != null && msg.contains("Could not resolve type id")) { + throw new RuntimeException( + "Error deserializing MultiLayerConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" + + + " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", + e); + } + throw new RuntimeException(e2); + } catch (IOException e2) { + throw new RuntimeException(e2); + } + } + throw new RuntimeException(e); + } catch (IOException e) { + //Check if this exception came from legacy deserializer... + String msg = e.getMessage(); + if (msg != null && msg.contains("legacy")) { + throw new RuntimeException( + "Error deserializing MultiLayerConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " + + + "deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", + e); + } + throw new RuntimeException(e); + } + + //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier) + // Previously: enumeration used for loss functions. Now: use classes + // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums + int layerCount = 0; + JsonNode confs = null; + for (NeuralNetworkConfiguration nnc : conf.getConfs()) { + Layer l = nnc.getLayer(); + if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) { + //lossFn field null -> may be an old config format, with lossFunction field being for the enum + //if so, try walking the JSON graph to extract out the appropriate enum value + + BaseOutputLayer ol = (BaseOutputLayer) l; + try { + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + throw new RuntimeException("should never happen"); //return conf; //Should never happen... + } + JsonNode outputLayerNode = outputLayerNNCNode.get("layer"); + + JsonNode lossFunctionNode = null; + if (outputLayerNode.has("output")) { + lossFunctionNode = outputLayerNode.get("output").get("lossFunction"); + } else if (outputLayerNode.has("rnnoutput")) { + lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction"); + } + + if (lossFunctionNode != null) { + String lossFunctionEnumStr = lossFunctionNode.asText(); + LossFunctions.LossFunction lossFunction = null; + try { + lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr); + } catch (Exception e) { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", + e); + } + + if (lossFunction != null) { + switch (lossFunction) { + case MSE: + ol.setLossFn(new LossMSE()); + break; + case XENT: + ol.setLossFn(new LossBinaryXENT()); + break; + case NEGATIVELOGLIKELIHOOD: + ol.setLossFn(new LossNegativeLogLikelihood()); + break; + case MCXENT: + ol.setLossFn(new LossMCXENT()); + break; + + //Remaining: TODO + case SQUARED_LOSS: + case RECONSTRUCTION_CROSSENTROPY: + default: + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", + lossFunction); + break; + } + } + } + + } else { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", + (confs != null ? confs.getClass() : null)); + } + } catch (IOException e) { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", + e); + break; + } + } + + //Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn") + //Try to load the old format if necessary, and create the appropriate IActivation instance + if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) { + try { + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + throw new RuntimeException("Should never happen"); //return conf; //Should never happen... + } + JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); + + if (layerWrapperNode == null || layerWrapperNode.size() != 1) { + continue; + } + + JsonNode layerNode = layerWrapperNode.elements().next(); + JsonNode activationFunction = layerNode.get( + "activationFunction"); //Should only have 1 element: "dense", "output", etc + + if (activationFunction != null) { + IActivation ia = Activation.fromString(activationFunction.asText()) + .getActivationFunction(); + ((BaseLayer) l).setActivationFn(ia); + } + } + + } catch (IOException e) { + log.warn( + "ILayer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", + e); + } + } + + if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) { + return conf; + } + + layerCount++; + } + return conf; } + + /** + * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied + * from handling of {@link Activation} above. + * + * @return True if all is well and layer iteration shall continue. False else-wise. + */ + private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper, + JsonNode confs, int layerCount) { + if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { + try { + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + return false; //Should never happen... + } + JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); + + if (layerWrapperNode == null || layerWrapperNode.size() != 1) { + return true; + } + + JsonNode layerNode = layerWrapperNode.elements().next(); + JsonNode weightInit = layerNode.get( + "weightInit"); //Should only have 1 element: "dense", "output", etc + JsonNode distribution = layerNode.get("dist"); + + Distribution dist = null; + if (distribution != null) { + dist = mapper.treeToValue(distribution, Distribution.class); + } + + if (weightInit != null) { + final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) + .getWeightInitFunction(dist); + ((BaseLayer) l).setWeightInitFn(wi); + } + } + + } catch (IOException e) { + log.warn( + "ILayer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON", + e); + } + } + return true; + + } + + /** + * Object mapper for serialization of configurations + * + * @return + */ + public static ObjectMapper mapperYaml() { + return JsonMappers.getMapperYaml(); + } + + /** + * Object mapper for serialization of configurations + * + * @return + */ + public static ObjectMapper mapper() { + return JsonMappers.getMapper(); + } + + + + /** + * @return JSON representation of NN configuration + */ + public String toYaml() { + ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); + synchronized (mapper) { + try { + return mapper.writeValueAsString(this); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + /** + * @return JSON representation of NN configuration + */ + public String toJson() { + ObjectMapper mapper = NeuralNetConfiguration.mapper(); + synchronized (mapper) { + //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally + //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 + try { + return mapper.writeValueAsString(this); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + @Override + public String toString() { + return toJson(); + } + + public NeuralNetworkConfiguration getConf(int i) { + return confs.get(i); + } + + @Override + public NeuralNetworkConfiguration clone() { + + NeuralNetworkConfiguration clone = (NeuralNetworkConfiguration) super.clone(); + List confList = clone.getConfs(); + if (confList != null) { + List list = new ArrayList<>(); + for (NeuralNetworkConfiguration conf : confList) { + list.add(conf.clone()); + } + } + + if (clone.getInputPreProcessors() != null) { + Map map = new HashMap<>(); + for (Map.Entry entry : clone.getInputPreProcessors().entrySet()) { + map.put(entry.getKey(), entry.getValue().clone()); + } + clone.getInputPreProcessors().clear(); + clone.getInputPreProcessors().putAll(map); + } + + clone.setInferenceWorkspaceMode(this.inferenceWorkspaceMode); + clone.setTrainingWorkspaceMode(this.trainingWorkspaceMode); + clone.setCacheMode(this.cacheMode); + clone.setValidateOutputLayerConfig(this.validateOutputLayerConfig); + clone.setDataType(this.dataType); + + return clone; + + } + + public InputPreProcessor getInputPreProcess(int curr) { + return inputPreProcessors.get(curr); + } + + /** + * Get a {@link MemoryReport} for the given MultiLayerConfiguration. This is used to estimate the + * memory requirements for the given network configuration and input + * + * @param inputType Input types for the network + * @return Memory report for the network + */ + public NetworkMemoryReport getMemoryReport(InputType inputType) { + + Map memoryReportMap = new LinkedHashMap<>(); + int nLayers = confs.size(); + for (int i = 0; i < nLayers; i++) { + String layerName = confs.get(i).getLayer().getLayerName(); + if (layerName == null) { + layerName = String.valueOf(i); + } + + //Pass input type through preprocessor, if necessary + InputPreProcessor preproc = getInputPreProcess(i); + //TODO memory requirements for preprocessor + if (preproc != null) { + inputType = preproc.getOutputType(inputType); + } + + LayerMemoryReport report = confs.get(i).getLayer().getMemoryReport(inputType); + memoryReportMap.put(layerName, report); + + inputType = confs.get(i).getLayer().getOutputType(i, inputType); + } + + return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, + "MultiLayerNetwork", inputType); + } + + /** + * For the given input shape/type for the network, return a list of activation sizes for each + * layer in the network.
i.e., list.get(i) is the output activation sizes for layer i + * + * @param inputType Input type for the network + * @return A lits of activation types for the network, indexed by layer number + */ + public List getLayerActivationTypes(@NonNull InputType inputType) { + List out = new ArrayList<>(); + int nLayers = confs.size(); + for (int i = 0; i < nLayers; i++) { + InputPreProcessor preproc = getInputPreProcess(i); + if (preproc != null) { + inputType = preproc.getOutputType(inputType); + } + + inputType = confs.get(i).getLayer().getOutputType(i, inputType); + out.add(inputType); + } + return out; + } + + /** + * Defines some additional handy methods. Other than that, + * the builder is generated by lombok. + */ + public static class NeuralNetworkConfigurationBuilder { + + /** + * Specify the processors. These are used at each layer for doing things like normalization and + * shaping of input. + * + * @param processor what to use to preProcess the data. + * @return builder pattern + */ + public NeuralNetworkConfigurationBuilder inputPreProcessor(Integer layer, + InputPreProcessor processor) { + inputPreProcessors.put(layer, processor); + return this; + } + + /** + * Specify additional layer configurations + */ + @Deprecated + public NeuralNetworkConfigurationBuilder layersFromArray(Layer[] arrLayers) { + for(Layer l : arrLayers) { + layers.add( l ); + } + return this; + } + } + + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java index 951688e51..1ed923bda 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java @@ -24,12 +24,12 @@ package net.brutex.ai.dnn.conf.layer; import lombok.Getter; import lombok.NonNull; import lombok.Setter; -import net.brutex.ai.dnn.api.LayerConfiguration; -import org.deeplearning4j.nn.conf.inputs.InputType; +import lombok.experimental.SuperBuilder; +import net.brutex.ai.dnn.api.ILayerConfiguration; -public abstract class AbstractLayerConfiguration implements LayerConfiguration { +@SuperBuilder +public abstract class AbstractLayerConfiguration implements ILayerConfiguration { @Getter @Setter @NonNull - private InputType.Type inputType; - + private String name; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/DenseLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/DenseLayerConfiguration.java new file mode 100644 index 000000000..d472d99b2 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/DenseLayerConfiguration.java @@ -0,0 +1,62 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.conf.layer; + +import lombok.Builder; +import lombok.experimental.SuperBuilder; +import org.deeplearning4j.nn.conf.layers.LayerValidation; + +/** + * The dense layer is a neural network layer that is connected deeply, which means each neuron in + * the dense layer receives input from all neurons of its previous layer. The dense layer is found + * to be the most commonly used layer in the models. + *

+ * In the background, the dense layer performs a matrix-vector multiplication. The values used in + * the matrix are actually parameters that can be trained and updated with the help of + * backpropagation. + *

+ * The output generated by the dense layer is an ‘m’ dimensional vector. Thus, dense layer is + * basically used for changing the dimensions of the vector. Dense layers also applies operations + * like rotation, scaling, translation on the vector. + */ +@SuperBuilder +public class DenseLayerConfiguration extends FeedForwardLayerConfiguration { + + /** + * Decides whether we should include a bias vector for calculation purposes or not. + */ + @Builder.Default + boolean bias = true; + + + + /** + * An implementation to validate the network + * + * @return true if no errors found; false otherwise + */ + @Override + public boolean isValid() { + LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getName(), -99, getIn(), getOut()); + return super.isValid(); + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FeedForwardLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FeedForwardLayerConfiguration.java new file mode 100644 index 000000000..c86869d54 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FeedForwardLayerConfiguration.java @@ -0,0 +1,99 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.conf.layer; + +import lombok.Getter; +import lombok.experimental.SuperBuilder; +import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.ILayer; +import net.brutex.ai.dnn.api.ILayerConfiguration; +import net.brutex.ai.dnn.api.IModel; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.inputs.InputType.Type; + +/** + * A Feed Forward Layer Configuration + */ +@Slf4j +@SuperBuilder +public class FeedForwardLayerConfiguration extends AbstractLayerConfiguration implements ILayerConfiguration { + + @Getter private int in; + @Getter private int out; + + /** + * This Fast Forward ILayer will always output data as + * FF type. + * @return InputType for FF + **/ + @Getter + final InputType.Type outputType = InputType.Type.FF; + + @Getter + final InputType.Type inputType = InputType.Type.FF; + + /** + * Create and return an instance of a ILayerConfiguration. + * + * @param network the "holding" network for the instance + * @return the new layer instance + */ + //@Override + public ILayer instantiate(IModel network) { + //Let's do some verifications first + if(getInputType() != Type.FF) { + log.error("The {} layer configuration must use an InputType of {}, but found {}", + this.getClass().getSimpleName(), + Type.FF.name(), + getInputType().name()); + } + return null; + } + + /** + * Number of trainable parameter in this layer + * + * @return number of parameter + */ + @Override + public long numParameters() { + return in * out + out; //add one extra out for the bias + } + + /** + * An implementation should provide a method to validate the network + * + * @return true if no errors found; false otherwise + */ + @Override + public boolean isValid() { + boolean result = true; + if(getInputType() != Type.FF) { + log.error("The {} layer configuration must use an InputType of {}, but found {}", + this.getClass().getSimpleName(), + Type.FF.name(), + getInputType().name()); + result = false; + } + return result; + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/AbstractNeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/AbstractNeuralNetwork.java deleted file mode 100644 index a1c36e988..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/AbstractNeuralNetwork.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * - * ****************************************************************************** - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - * - */ - -package net.brutex.ai.dnn.impl.network; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; -import net.brutex.ai.dnn.api.Layer; -import net.brutex.ai.dnn.api.NeuralNetwork; -import net.brutex.ai.dnn.api.LayerConfiguration; -import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; -import org.deeplearning4j.optimize.api.TrainingListener; -import org.nd4j.linalg.dataset.api.MultiDataSet; - -public abstract class AbstractNeuralNetwork implements NeuralNetwork { - - @Getter @Setter @NonNull - private String name; - - @Getter @NonNull - private NeuralNetworkConfiguration configuration; - - @Getter - private final Collection trainingListeners = new HashSet<>(); - - /** - * The neural network holds an instantiation of its configured - * layers. - * @return the actual runtime layers - */ - @Getter - private final List runtimeLayers = new ArrayList<>(); - - /** - * Sets the configuration to be used. Each time a configuration is set, the runtime layers - * of this NeuralNetwork are updated from the configuration. - * - * @param conf the configuration to use for this network - */ - public void setConfiguration(net.brutex.ai.dnn.api.NeuralNetworkConfiguration conf) { - List layers = conf.getLayerConfigurations(); - for(LayerConfiguration layer : layers) { - Layer initializedLayer = layer.instantiate(this); - this.getRuntimeLayers().add(initializedLayer); - } - this.configuration = configuration; - } - -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/NeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/NeuralNetwork.java deleted file mode 100644 index 198007baf..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/impl/network/NeuralNetwork.java +++ /dev/null @@ -1,692 +0,0 @@ -/* - * - * ****************************************************************************** - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - * - */ - -package net.brutex.ai.dnn.impl.network; - -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.bytedeco.javacpp.Pointer; -import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; -import org.deeplearning4j.exception.DL4JInvalidInputException; -import org.deeplearning4j.nn.api.Classifier; -import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.api.Updater; -import org.deeplearning4j.nn.api.layers.IOutputLayer; -import org.deeplearning4j.nn.api.layers.RecurrentLayer; -import org.deeplearning4j.nn.conf.BackpropType; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop; -import org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.updater.UpdaterCreator; -import org.deeplearning4j.nn.workspace.ArrayType; -import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.deeplearning4j.optimize.Solver; -import org.deeplearning4j.optimize.api.ConvexOptimizer; -import org.deeplearning4j.optimize.api.TrainingListener; -import org.deeplearning4j.util.CrashReportingUtil; -import org.deeplearning4j.util.ModelSerializer; -import org.nd4j.common.base.Preconditions; -import org.nd4j.common.primitives.Pair; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.AllocationPolicy; -import org.nd4j.linalg.api.memory.enums.LearningPolicy; -import org.nd4j.linalg.api.memory.enums.ResetPolicy; -import org.nd4j.linalg.api.memory.enums.SpillPolicy; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.AsyncDataSetIterator; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.api.MultiDataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.exception.ND4JArraySizeException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.heartbeat.Heartbeat; -import org.nd4j.linalg.heartbeat.reports.Environment; -import org.nd4j.linalg.heartbeat.reports.Event; -import org.nd4j.linalg.heartbeat.reports.Task; -import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils; -import org.nd4j.linalg.heartbeat.utils.TaskUtils; -import org.nd4j.linalg.indexing.NDArrayIndex; - -@Slf4j -public class NeuralNetwork extends AbstractNeuralNetwork { - - - //the hidden neural network layers (including output layer) - protected Layer[] layers; - - protected transient ThreadLocal lastEtlTime = new ThreadLocal<>(); - - //Current training data: input features and labels - @Getter @Setter @NonNull - protected INDArray input; - @Getter @Setter - protected INDArray labels; - - //Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers - @Getter - protected transient Map helperWorkspaces = new HashMap<>(); - - /** - * Used to call optimizers during backprop - */ - @NonNull - protected transient Solver solver = new Solver.Builder().configure(getConfiguration()). - listeners(getTrainingListeners()).model(this).build(); - - - /** - * Create a new NeuralNetwork from the given configuration - * @param conf - */ - public NeuralNetwork(NeuralNetworkConfiguration conf) { - if(! validateConfiguration() ) { - log.error("Configuration '{}' has failed validation.", conf.getName()); - throw new RuntimeException(); - } - log.info("Configuration '{}' has been validated successfully.", conf.getName()); - this.conf = conf; - } - - private boolean validateConfiguration() { - - return true; - } - - private void logNotImplemented( ) { - // getStackTrace() method return - // current method name at 0th index - String method = new Throwable() - .getStackTrace()[1] - .getMethodName(); - log.trace("Method '{}}' is not implemented for {}", method, this.getClass().getSimpleName()); - } - - /** - * This method does initialization of model - *

- * PLEASE NOTE: All implementations should track own state, to avoid double spending - */ - @Override - public void init() { - logNotImplemented(); - } - - /** - * This method returns model parameters as single INDArray - * - * @return - */ - @Override - public INDArray params() { - logNotImplemented(); - return null; - } - - /** - * This method returns updater state (if applicable), null otherwise - * - * @return - */ - @Override - public INDArray updaterState() { - return getUpdater(true) != null ? getUpdater(true).getStateViewArray() : null; - } - - /** - * This method returns Optimizer used for training - * - * @return the optimizer - */ - @Override - public ConvexOptimizer getOptimizer() { - return solver.getOptimizer(); - } - - - - /** Get the updater for this NeuralNetwork from the Solver - * @return Updater for NeuralNetwork - */ - private Updater getUpdater(boolean initializeIfReq) { - if (solver == null && initializeIfReq) { - synchronized(this){ - if(solver == null) { //May have been created while waiting for lock - solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this).build(); - solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this)); - } - } - } - if(solver != null) { - return solver.getOptimizer().getUpdater(initializeIfReq); - } - return null; - } - - /** - * Set the updater for the NeuralNetwork in the Solver - * */ - public void setUpdater(@NonNull Updater updater) { - solver.getOptimizer().setUpdater(updater); - } - - - @Override - public void fit(MultiDataSet dataSet) { - if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) { - INDArray features = dataSet.getFeatures(0); - INDArray labels = dataSet.getLabels(0); - INDArray fMask = null; - INDArray lMask = null; - - if (dataSet.getFeaturesMaskArrays() != null) - fMask = dataSet.getFeaturesMaskArrays()[0]; - - if (dataSet.getFeaturesMaskArrays() != null) - lMask = dataSet.getLabelsMaskArrays()[0]; - - DataSet ds = new DataSet(features, labels, fMask, lMask); - fit(ds); - } else { - throw new DL4JInvalidInputException( - "MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." + - "Please consider use of ComputationGraph"); - } - } - - /** - * Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs. - * Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop - * - * @param iterator Training data (DataSetIterator). Iterator must support resetting - * @param numEpochs Number of training epochs, >= 1 - */ - public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){ - Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs); - Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" + - "iterator has does not support resetting (iterator.resetSupported() returned false)"); - - for(int i = 0; i < numEpochs; i++) { - fit(iterator); - } - } - - /** - * Perform minibatch training on all minibatches in the MultiDataSetIterator.
- * Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as - * MultiLayerNetwork only supports 1 input and 1 output) - * - * @param iterator Training data (DataSetIterator). Iterator must support resetting - */ - @Override - public void fit(MultiDataSetIterator iterator) { - fit(new MultiDataSetWrapperIterator(iterator)); - } - - /** - * Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.
- * Note that this method does not do layerwise pretraining.
- * For pretraining use method pretrain.. #pretrain(DataSetIterator)
- * @param iterator Training data (DataSetIterator) - */ - @Override - public void fit(DataSetIterator iterator) { - try{ - fitHelper(iterator); - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - private synchronized void fitHelper(DataSetIterator iterator){ - // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate - DataSetIterator iter; - boolean destructable = false; - if (iterator.asyncSupported()) { - iter = new AsyncDataSetIterator(iterator, Math.min( - Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true); - destructable = true; - } else { - iter = iterator; - } - - for (TrainingListener tl : trainingListeners) { - tl.onEpochStart(this); - } - - LayerWorkspaceMgr workspaceMgr; - if(conf.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ - workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); - } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM - // as these should be closed by the time updaters are executed - //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this - .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .build(); - } - workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - - update(TaskUtils.buildTask(iter)); - if (!iter.hasNext() && iter.resetSupported()) { - iter.reset(); - } - long time1 = System.currentTimeMillis(); - while (iter.hasNext()) { - - DataSet next = iter.next(); - long time2 = System.currentTimeMillis(); - - lastEtlTime.set((time2 - time1)); - - if (next.getFeatures() == null || next.getLabels() == null) - break; - - // TODO: basically we want to wrap internals of this loop into workspace - - - boolean hasMaskArrays = next.hasMaskArrays(); - - if (conf.getBackpropType() == BackpropType.TruncatedBPTT) { - doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(), - next.getLabelsMaskArray(), workspaceMgr); - } else { - if (hasMaskArrays) - setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray()); - - setInput(next.getFeatures()); - setLabels(next.getLabels()); - - if (solver == null) { - try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this) - .build(); - } - } - - //TODO CACHE - solver.optimize(workspaceMgr); - } - - if (hasMaskArrays) - clearLayerMaskArrays(); - - time1 = System.currentTimeMillis(); - synchronizeIterEpochCounts(); - } - - if (!trainingListeners.isEmpty()) { - for (TrainingListener tl : trainingListeners) { - tl.onEpochEnd(this); - } - } - - clearLayersStates(); - - if (destructable) - ((AsyncDataSetIterator) iter).shutdown(); - - incrementEpochCount(); - } - - - /** - * Workspace for working memory for a single layer: forward pass and backward pass - * Note that this is opened/closed once per op (activate/backpropGradient call) - */ - protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM"; - /** - * Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop - * Not used for inference - */ - protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT"; - /** - * Next 2 workspaces: used for: - * (a) Inference: holds activations for one layer only - * (b) Backprop: holds activation gradients for one layer only - * In both cases, they are opened and closed on every second layer - */ - protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1"; - protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2"; - - /** - * Workspace for output methods that use OutputAdapter - */ - protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM"; - - /** - * Workspace for working memory in RNNs - opened and closed once per RNN time step - */ - protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM"; - - - protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG; - - protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder() - .initialSize(0) - .overallocationLimit(0.05) - .policyLearning(LearningPolicy.FIRST_LOOP) - .policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE) - .policyAllocation(AllocationPolicy.OVERALLOCATE) - .build(); - - protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG; - - protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder() - .initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT) - .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) - .policyLearning(LearningPolicy.FIRST_LOOP).build(); - - - boolean initDone; - protected void update(Task task) { - if (!initDone) { - initDone = true; - Heartbeat heartbeat = Heartbeat.getInstance(); - task = ModelSerializer.taskByModel(this); - Environment env = EnvironmentUtils.buildEnvironment(); - heartbeat.reportEvent(Event.STANDALONE, env, task); - } - } - - protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, - INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) { - if (input.rank() != 3 || labels.rank() != 3) { - log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " - + Arrays.toString(input.shape()) + "\tand labels with shape " - + Arrays.toString(labels.shape())); - return; - } - if (input.size(2) != labels.size(2)) { - log.warn("Input and label time series have different lengths: {} input length, {} label length", - input.size(2), labels.size(2)); - return; - } - - int fwdLen = conf.getTbpttFwdLength(); - update(TaskUtils.buildTask(input, labels)); - val timeSeriesLength = input.size(2); - long nSubsets = timeSeriesLength / fwdLen; - if (timeSeriesLength % fwdLen != 0) - nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20) - - rnnClearPreviousState(); - - for (int i = 0; i < nSubsets; i++) { - long startTimeIdx = (long) i * fwdLen; - long endTimeIdx = startTimeIdx + fwdLen; - if (endTimeIdx > timeSeriesLength) - endTimeIdx = timeSeriesLength; - - if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels, - featuresMaskArray, labelsMaskArray); - - setInput(subsets[0]); - setLabels(subsets[1]); - setLayerMaskArrays(subsets[2], subsets[3]); - - if (solver == null) { - try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getTrainingListeners()).model(this) - .build(); - } - } - solver.optimize(workspaceMgr); - - //Finally, update the state of the RNN layers: - updateRnnStateWithTBPTTState(); - } - - rnnClearPreviousState(); - clearLayerMaskArrays(); - } - - private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels, - INDArray fMask, INDArray lMask ){ - INDArray[] out = new INDArray[4]; - out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - - if (fMask != null) { - out[2] = fMask.get(NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - } - if (lMask != null) { - out[3] = lMask.get(NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - } - - return out; - } - - /** - * Intended for internal/developer use - */ - public void updateRnnStateWithTBPTTState() { - Layer[] layers = conf.calculateInnerLayers().toArray(new Layer[]{}); - for (int i = 0; i < layers.length; i++) { - if (layers[i] instanceof RecurrentLayer) { - RecurrentLayer l = ((RecurrentLayer) layers[i]); - l.rnnSetPreviousState(l.rnnGetTBPTTState()); - } else if (layers[i] instanceof MultiLayerNetwork) { - ((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState(); - } - } - } - - /** Clear the previous state of the RNN layers (if any). - */ - public void rnnClearPreviousState() { - Layer[] layers = conf.getLayers().toArray(new Layer[]{}); - if (layers == null) - return; - for (int i = 0; i < layers.length; i++) { - if (layers[i] instanceof RecurrentLayer) - ((RecurrentLayer) layers[i]).rnnClearPreviousState(); - else if (layers[i] instanceof MultiLayerNetwork) { - ((MultiLayerNetwork) layers[i]).rnnClearPreviousState(); - } else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){ - ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()).rnnClearPreviousState(); - } - } - } - - - - /** Remove the mask arrays from all layers.
- * See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays. - */ - public void clearLayerMaskArrays() { - Layer[] layers = conf.getLayers().toArray(new Layer[]{}); - for (Layer layer : layers) { - layer.setMaskArray(null); - } - } - - /** - * Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1). - * Note that this is done automatically when using iterator-based fitting methods, such as - * {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, INDArray/INDArray etc), - * the network has no way to know when one epoch ends and another starts. In such situations, this method - * can be used to increment the epoch counter.
- * Note that the epoch counter is used for situations such as some learning rate schedules, and the like. - * - * The current epoch count can be obtained using {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()} - */ - public void incrementEpochCount(){ - conf.setEpochCount(conf.getEpochCount() + 1); - synchronizeIterEpochCounts(); - } - - protected void synchronizeIterEpochCounts() { - //TODO: this is necessary for some schedules - but the redundant values are a little ugly... - int currIter = conf.getIterationCount(); - int currEpoch = conf.getEpochCount(); - log.error("Something went wrong here. Code incomplete"); - /*for(Layer l : conf.getLayers()) { - l.setIterationCount(currIter); - l.setEpochCount(currEpoch); - } - */ - } - - /** - * This method just makes sure there's no state preserved within layers - */ - public void clearLayersStates() { - for (Layer layer : layers) { - layer.clear(); - layer.clearNoiseWeightParams(); - } - } - - - /**Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many - * and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths - * within the same minibatch.
- * For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape - * [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength] - * and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding - - * at a given time step).
- * NOTE: This method is not usually used directly. Instead, methods such as @link #feedForward(INDArray, INDArray, INDArray)} - * and @link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally. - * @param featuresMaskArray Mask array for features (input) - * @param labelsMaskArray Mask array for labels (output) - * @see #clearLayerMaskArrays() - */ - public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) { - if (featuresMaskArray != null) { - - if (featuresMaskArray.size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - //New approach: use feedForwardMaskArray method - feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0)); - - - /* - //feedforward layers below a RNN layer: need the input (features) mask array - //Reason: even if the time series input is zero padded, the output from the dense layers are - // non-zero (i.e., activationFunction(0*weights + bias) != 0 in general) - //This assumes that the time series input is masked - i.e., values are 0 at the padded time steps, - // so we don't need to do anything for the recurrent layer - - //Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order - // as is done for 3d -> 2d time series reshaping - INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray); - - for( int i=0; i feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, - int minibatchSize) { - if (maskArray == null) { - for (int i = 0; i < layers.length; i++) { - layers[i].feedForwardMaskArray(null, null, minibatchSize); - } - } else { - //Do a forward pass through each preprocessor and layer - for (int i = 0; i < layers.length; i++) { - InputPreProcessor preProcessor = conf.getInputPreProcessors().get(i); - - if (preProcessor != null) { - Pair p = - preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); - if (p != null) { - maskArray = p.getFirst(); - currentMaskState = p.getSecond(); - } else { - maskArray = null; - currentMaskState = null; - } - } - - Pair p = - layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); - if (p != null) { - maskArray = p.getFirst(); - currentMaskState = p.getSecond(); - } else { - maskArray = null; - currentMaskState = null; - } - } - } - - return new Pair<>(maskArray, currentMaskState); - } - - -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java new file mode 100644 index 000000000..0a605b94f --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java @@ -0,0 +1,53 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.networks; + +import lombok.Getter; +import lombok.Setter; +import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; +import net.brutex.ai.dnn.api.INeuralNetwork; + +/** + * Artificial Neural Network An artificial neural network (1) takes some input data, and (2) + * transforms this input data by calculating a weighted sum over the inputs and (3) applies a + * non-linear function to this transformation to calculate an intermediate state. The three steps + * above constitute what is known as a layer, and the transformative function is often referred to + * as a unit. The intermediate states—often termed features—are used as the input into another + * layer. + *

+ * Through repetition of these steps, the artificial neural network learns multiple layers of + * non-linear features, which it then combines in a final layer to create a prediction. + *

+ * The neural network learns by generating an error signal that measures the difference between the + * predictions of the network and the desired values and then using this error signal to change the + * weights (or parameters) so that predictions get more accurate. + */ +public abstract class ArtificialNeuralNetwork implements INeuralNetwork { + + /** + * A neural network is created from a configuration. + * @param conf The (new net.brutex.ai) configuration for the network + */ + @Getter + @Setter //TODO make this also final and @NonNull + private NeuralNetworkConfiguration configuration; +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java index a39a08d97..4d6ff7675 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java @@ -346,7 +346,7 @@ public abstract class BaseEarlyStoppingTrainer implements IEarl } else if(model instanceof ComputationGraph){ ComputationGraph cg = ((ComputationGraph) model); listeners = cg.getListeners(); - cg.getConfiguration().setEpochCount(epochNum); + cg.getComputationGraphConfiguration().setEpochCount(epochNum); } else { return; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java index 121102214..696e92bc2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java @@ -431,7 +431,7 @@ public class GradientCheckUtil { + "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil"); } - DataType netDataType = c.net.getConfiguration().getDataType(); + DataType netDataType = c.net.getComputationGraphConfiguration().getDataType(); if (netDataType != DataType.DOUBLE) { throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (" + "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil"); @@ -444,8 +444,8 @@ public class GradientCheckUtil { //Check configuration int layerCount = 0; - for (String vertexName : c.net.getConfiguration().getVertices().keySet()) { - GraphVertex gv = c.net.getConfiguration().getVertices().get(vertexName); + for (String vertexName : c.net.getComputationGraphConfiguration().getVertices().keySet()) { + GraphVertex gv = c.net.getComputationGraphConfiguration().getVertices().get(vertexName); if (!(gv instanceof LayerVertex)) continue; LayerVertex lv = (LayerVertex) gv; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java index 60780ab99..e7500055f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java @@ -32,194 +32,209 @@ import org.nd4j.common.primitives.Pair; import java.io.Serializable; import java.util.Collection; +/** + * A layer is the highest-level building block in deep learning. A layer is a container that usually + * receives weighted input, transforms it with a set of mostly non-linear functions and then passes + * these values as output to the next layer. A layer is usually uniform, that is it only contains + * one type of activation function, pooling, convolution etc. so that it can be easily compared to + * other parts of the network. The first and last layers in a network are called input and output + * layers, respectively, and all layers in between are called hidden layers. + * + * @see NVIDIA Deep Learning In A Nutshell + */ public interface Layer extends Serializable, Cloneable, Model, Trainable { - enum Type { - FEED_FORWARD, RECURRENT, CONVOLUTIONAL, CONVOLUTIONAL3D, - SUBSAMPLING, UPSAMPLING, RECURSIVE, MULTILAYER, NORMALIZATION - } + /** + * This method sets given CacheMode for current layer + * + * @param mode + */ + void setCacheMode(CacheMode mode); - enum TrainingMode { - TRAIN, TEST - } + /** + * Calculate the regularization component of the score, for the parameters in this layer
For + * example, the L1, L2 and/or weight decay components of the loss function
+ * + * @param backpropOnlyParams If true: calculate regularization score based on backprop params + * only. If false: calculate based on all params (including pretrain + * params, if any) + * @return the regularization score of + */ + double calcRegularizationScore(boolean backpropOnlyParams); - /** - * This method sets given CacheMode for current layer - * - * @param mode - */ - void setCacheMode(CacheMode mode); + /** + * Returns the layer type + * + * @return + */ + Type type(); - /** - * Calculate the regularization component of the score, for the parameters in this layer
- * For example, the L1, L2 and/or weight decay components of the loss function
- * - * @param backpropOnlyParams If true: calculate regularization score based on backprop params only. If false: calculate - * based on all params (including pretrain params, if any) - * @return the regularization score of - */ - double calcRegularizationScore(boolean backpropOnlyParams); + /** + * Calculate the gradient relative to the error in the next layer + * + * @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where + * C is cost function a=sigma(z) is activation. + * @param workspaceMgr Workspace manager + * @return Pair where Gradient is gradient for this layer, INDArray is + * epsilon (activation gradient) needed by next layer, but before element-wise multiply by + * sigmaPrime(z). So for standard feed-forward layer, if this layer is L, then return.getSecond() + * == dL/dIn = (w^(L)*(delta^(L))^T)^T. Note that the returned array should be placed in the + * {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATION_GRAD} workspace via the workspace + * manager + */ + Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr); - /** - * Returns the layer type - * - * @return - */ - Type type(); + /** + * Perform forward pass and return the activations array with the last set input + * + * @param training training or test mode + * @param workspaceMgr Workspace manager + * @return the activation (layer output) of the last specified input. Note that the returned array + * should be placed in the {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATIONS} workspace + * via the workspace manager + */ + INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr); + + /** + * Perform forward pass and return the activations array with the specified input + * + * @param input the input to use + * @param training train or test mode + * @param mgr Workspace manager. + * @return Activations array. Note that the returned array should be placed in the + * {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATIONS} workspace via the workspace + * manager + */ + INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr mgr); + + /** + * Get the iteration listeners for this layer. + */ + Collection getListeners(); + + /** + * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, + * they will be replaced by this method + */ + void setListeners(TrainingListener... listeners); + + /** + * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, + * they will be replaced by this method + */ + void setListeners(Collection listeners); + + /** + * Get the layer index. + */ + int getIndex(); + + /** + * Set the layer index. + */ + void setIndex(int index); + + /** + * @return The current iteration count (number of parameter updates) for the layer/network + */ + int getIterationCount(); + + /** + * Set the current iteration count (number of parameter updates) for the layer/network + */ + void setIterationCount(int iterationCount); + + /** + * @return The current epoch count (number of training epochs passed) for the layer/network + */ + int getEpochCount(); + + /** + * Set the current epoch count (number of epochs passed ) for the layer/network + */ + void setEpochCount(int epochCount); + + /** + * Set the layer input. + */ + void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr); + + /** + * Get current/last input mini-batch size, as set by setInputMiniBatchSize(int) + * + * @see Layer#setInputMiniBatchSize(int) + */ + int getInputMiniBatchSize(); + + /** + * Set current/last input mini-batch size.
Used for score and gradient calculations. Mini + * batch size may be different from getInput().size(0) due to reshaping operations - for example, + * when using RNNs with DenseLayerConfiguration and OutputLayer. Called automatically during + * forward pass. + */ + void setInputMiniBatchSize(int size); + + INDArray getMaskArray(); + + /** + * Set the mask array. Note: In general, {@link #feedForwardMaskArray(INDArray, MaskState, int)} + * should be used in preference to this. + * + * @param maskArray Mask array to set + */ + void setMaskArray(INDArray maskArray); + + /** + * Returns true if the layer can be trained in an unsupervised/pretrain manner (AE, VAE, etc) + * + * @return true if the layer can be pretrained (using fit(INDArray), false otherwise + */ + boolean isPretrainLayer(); + + void clearNoiseWeightParams(); + + /** + * A performance optimization: mark whether the layer is allowed to modify its input array + * in-place. In many cases, this is totally safe - in others, the input array will be shared by + * multiple layers, and hence it's not safe to modify the input array. This is usually used by ops + * such as dropout. + * + * @param allow If true: the input array is safe to modify. If false: the input array should be + * copied before it is modified (i.e., in-place modifications are un-safe) + */ + void allowInputModification(boolean allow); + + /** + * Feed forward the input mask array, setting in the layer as appropriate. This allows different + * layers to handle masks differently - for example, bidirectional RNNs and normal RNNs operate + * differently with masks (the former sets activations to 0 outside of the data present region + * (and keeps the mask active for future layers like dense layers), whereas normal RNNs don't zero + * out the activations/errors )instead relying on backpropagated error arrays to handle the + * variable length case.
This is also used for example for networks that contain global + * pooling layers, arbitrary preprocessors, etc. + * + * @param maskArray Mask array to set + * @param currentMaskState Current state of the mask - see {@link MaskState} + * @param minibatchSize Current minibatch size. Needs to be known as it cannot always be + * inferred from the activations array due to reshaping (such as a + * DenseLayerConfiguration within a recurrent neural network) + * @return New mask array after this layer, along with the new mask state. + */ + Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, + int minibatchSize); + + /** + * @return Get the layer helper, if any + */ + LayerHelper getHelper(); - /** - * Calculate the gradient relative to the error in the next layer - * - * @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where C - * is cost function a=sigma(z) is activation. - * @param workspaceMgr Workspace manager - * @return Pair where Gradient is gradient for this layer, INDArray is epsilon (activation gradient) - * needed by next layer, but before element-wise multiply by sigmaPrime(z). So for standard feed-forward layer, if this layer is - * L, then return.getSecond() == dL/dIn = (w^(L)*(delta^(L))^T)^T. Note that the returned array should be placed in the - * {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATION_GRAD} workspace via the workspace manager - */ - Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr); + enum Type { + FEED_FORWARD, RECURRENT, CONVOLUTIONAL, CONVOLUTIONAL3D, + SUBSAMPLING, UPSAMPLING, RECURSIVE, MULTILAYER, NORMALIZATION + } - - /** - * Perform forward pass and return the activations array with the last set input - * - * @param training training or test mode - * @param workspaceMgr Workspace manager - * @return the activation (layer output) of the last specified input. Note that the returned array should be placed - * in the {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATIONS} workspace via the workspace manager - */ - INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr); - - /** - * Perform forward pass and return the activations array with the specified input - * - * @param input the input to use - * @param training train or test mode - * @param mgr Workspace manager. - * @return Activations array. Note that the returned array should be placed in the - * {@link org.deeplearning4j.nn.workspace.ArrayType#ACTIVATIONS} workspace via the workspace manager - */ - INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr mgr); - - /** - * Get the iteration listeners for this layer. - */ - Collection getListeners(); - - /** - * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, they will be - * replaced by this method - */ - void setListeners(TrainingListener... listeners); - - /** - * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, they will be - * replaced by this method - */ - void setListeners(Collection listeners); - - /** - * Set the layer index. - */ - void setIndex(int index); - - /** - * Get the layer index. - */ - int getIndex(); - - /** - * @return The current iteration count (number of parameter updates) for the layer/network - */ - int getIterationCount(); - - /** - * @return The current epoch count (number of training epochs passed) for the layer/network - */ - int getEpochCount(); - - /** - * Set the current iteration count (number of parameter updates) for the layer/network - */ - void setIterationCount(int iterationCount); - - /** - * Set the current epoch count (number of epochs passed ) for the layer/network - */ - void setEpochCount(int epochCount); - - /** - * Set the layer input. - */ - void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr); - - /** - * Set current/last input mini-batch size.
- * Used for score and gradient calculations. Mini batch size may be different from - * getInput().size(0) due to reshaping operations - for example, when using RNNs with - * DenseLayer and OutputLayer. Called automatically during forward pass. - */ - void setInputMiniBatchSize(int size); - - /** - * Get current/last input mini-batch size, as set by setInputMiniBatchSize(int) - * - * @see Layer#setInputMiniBatchSize(int) - */ - int getInputMiniBatchSize(); - - /** - * Set the mask array. Note: In general, {@link #feedForwardMaskArray(INDArray, MaskState, int)} should be used in - * preference to this. - * - * @param maskArray Mask array to set - */ - void setMaskArray(INDArray maskArray); - - - INDArray getMaskArray(); - - /** - * Returns true if the layer can be trained in an unsupervised/pretrain manner (AE, VAE, etc) - * - * @return true if the layer can be pretrained (using fit(INDArray), false otherwise - */ - boolean isPretrainLayer(); - - - void clearNoiseWeightParams(); - - /** - * A performance optimization: mark whether the layer is allowed to modify its input array in-place. In many cases, - * this is totally safe - in others, the input array will be shared by multiple layers, and hence it's not safe to - * modify the input array. - * This is usually used by ops such as dropout. - * @param allow If true: the input array is safe to modify. If false: the input array should be copied before it - * is modified (i.e., in-place modifications are un-safe) - */ - void allowInputModification(boolean allow); - - - /** - * Feed forward the input mask array, setting in the layer as appropriate. This allows different layers to - * handle masks differently - for example, bidirectional RNNs and normal RNNs operate differently with masks (the - * former sets activations to 0 outside of the data present region (and keeps the mask active for future layers like - * dense layers), whereas normal RNNs don't zero out the activations/errors )instead relying on backpropagated error - * arrays to handle the variable length case.
- * This is also used for example for networks that contain global pooling layers, arbitrary preprocessors, etc. - * - * @param maskArray Mask array to set - * @param currentMaskState Current state of the mask - see {@link MaskState} - * @param minibatchSize Current minibatch size. Needs to be known as it cannot always be inferred from the activations - * array due to reshaping (such as a DenseLayer within a recurrent neural network) - * @return New mask array after this layer, along with the new mask state. - */ - Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize); - - /** - * @return Get the layer helper, if any - */ - LayerHelper getHelper(); + enum TrainingMode { + TRAIN, TEST + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java index 8b7d816d6..01a60b73e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java @@ -25,7 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; public interface ModelAdapter extends OutputAdapter { /** - * This method invokes model internally, and does convertion to T + * This method invokes model internally, and does conversion to T * @return */ T apply(Model model, INDArray[] inputs, INDArray[] inputMasks, INDArray[] labelsMasks); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ParamInitializer.java index 7170953e9..7b6483483 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ParamInitializer.java @@ -41,7 +41,7 @@ public interface ParamInitializer { /** * Get a list of all parameter keys given the layer configuration * - * @param layer Layer + * @param layer ILayer * @return All parameter keys */ List paramKeys(org.deeplearning4j.nn.conf.layers.Layer layer); @@ -49,7 +49,7 @@ public interface ParamInitializer { /** * Weight parameter keys given the layer configuration * - * @param layer Layer + * @param layer ILayer * @return Weight parameter keys */ List weightKeys(org.deeplearning4j.nn.conf.layers.Layer layer); @@ -57,7 +57,7 @@ public interface ParamInitializer { /** * Bias parameter keys given the layer configuration * - * @param layer Layer + * @param layer ILayer * @return Bias parameter keys */ List biasKeys(org.deeplearning4j.nn.conf.layers.Layer layer); @@ -65,7 +65,7 @@ public interface ParamInitializer { /** * Is the specified parameter a weight? * - * @param layer Layer + * @param layer ILayer * @param key Key to check * @return True if parameter is a weight */ @@ -74,7 +74,7 @@ public interface ParamInitializer { /** * Is the specified parameter a bias? * - * @param layer Layer + * @param layer ILayer * @param key Key to check * @return True if parameter is a bias */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/TrainingConfig.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/TrainingConfig.java index ae7601a6f..58f101260 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/TrainingConfig.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/TrainingConfig.java @@ -47,7 +47,7 @@ public interface TrainingConfig { * Is the specified parameter a layerwise pretraining only parameter?
* For example, visible bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't * used during supervised backprop.
- * Layers (like DenseLayer, etc) with no pretrainable parameters will return false for all (valid) inputs. + * Layers (like DenseLayerConfiguration, etc) with no pretrainable parameters will return false for all (valid) inputs. * * @param paramName Parameter name/key * @return True if the parameter is for layerwise pretraining only, false otherwise diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Updater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Updater.java index d63b57bb8..2c01298cb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Updater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Updater.java @@ -36,7 +36,7 @@ public interface Updater extends Serializable { /** * Set the internal (historical) state view array for this updater * - * @param layer Layer that this updater belongs to + * @param layer ILayer that this updater belongs to * @param viewArray View array * @param initialize Whether to initialize the array or not */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/LayerConstraint.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/LayerConstraint.java index fff8bd77d..cfa82b050 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/LayerConstraint.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/LayerConstraint.java @@ -33,7 +33,7 @@ public interface LayerConstraint extends Cloneable, Serializable { * Apply a given constraint to a layer at each iteration * in the provided epoch, after parameters have been updated. * - * @param layer org.deeplearning4j.nn.api.Layer + * @param layer org.deeplearning4j.nn.api.ILayer * @param iteration given iteration as integer * @param epoch current epoch as integer */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java index 62050b88e..a4f73d3b0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java @@ -66,10 +66,10 @@ public interface RecurrentLayer extends Layer { * (a) result in the same output
* (b) leave the state maps (both stateMap and tBpttStateMap) in an identical state * - * @param input Layer input + * @param input ILayer input * @param training if true: training. Otherwise: test * @param storeLastForTBPTT If true: store the final state in tBpttStateMap for use in truncated BPTT training - * @return Layer activations + * @return ILayer activations */ INDArray rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT, LayerWorkspaceMgr workspaceMg); @@ -92,7 +92,7 @@ public interface RecurrentLayer extends Layer { void rnnSetTBPTTState(Map state); /** - * Truncated BPTT equivalent of Layer.backpropGradient(). + * Truncated BPTT equivalent of ILayer.backpropGradient(). * Primary difference here is that forward pass in the context of BPTT is that we do * forward pass using stored state for truncated BPTT vs. from zero initialization * for standard BPTT. diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index 69ff898e2..f44a8f3ab 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -25,6 +25,7 @@ import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.distribution.Distribution; @@ -68,7 +69,9 @@ import java.util.*; @NoArgsConstructor @Slf4j @EqualsAndHashCode(exclude = {"iterationCount", "epochCount"}) -public class NeuralNetConfiguration implements Serializable, Cloneable { +public class NeuralNetConfiguration implements Serializable, Cloneable, + INeuralNetworkConfiguration { + protected Layer layer; //batch size: primarily used for conv nets. Will be reinforced if set. diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MaxNormConstraint.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MaxNormConstraint.java index 43fdc4254..a38e6dfcf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MaxNormConstraint.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MaxNormConstraint.java @@ -43,7 +43,7 @@ public class MaxNormConstraint extends BaseConstraint { /** * @param maxNorm Maximum L2 value * @param paramNames Which parameter names to apply constraint to - * @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should + * @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should * be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of * parameters which have order [depthOut, depthIn, kH, kW] */ @@ -56,7 +56,7 @@ public class MaxNormConstraint extends BaseConstraint { * Apply to weights but not biases by default * * @param maxNorm Maximum L2 value - * @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should + * @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should * be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of * parameters which have order [depthOut, depthIn, kH, kW] */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.java index 6449a9abd..ca43d4ca0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/MinMaxNormConstraint.java @@ -51,7 +51,7 @@ public class MinMaxNormConstraint extends BaseConstraint { * * @param max Maximum L2 value * @param min Minimum L2 value - * @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should + * @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should * be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of * parameters which have order [depthOut, depthIn, kH, kW] */ @@ -65,7 +65,7 @@ public class MinMaxNormConstraint extends BaseConstraint { * @param max Maximum L2 value * @param min Minimum L2 value * @param rate Constraint rate - * @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should + * @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should * be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of * parameters which have order [depthOut, depthIn, kH, kW] */ @@ -79,7 +79,7 @@ public class MinMaxNormConstraint extends BaseConstraint { * @param min Minimum L2 value * @param rate Constraint rate * @param paramNames Which parameter names to apply constraint to - * @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should + * @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should * be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of * parameters which have order [depthOut, depthIn, kH, kW] */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/UnitNormConstraint.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/UnitNormConstraint.java index a082056a7..3e80f341b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/UnitNormConstraint.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/UnitNormConstraint.java @@ -39,7 +39,7 @@ public class UnitNormConstraint extends BaseConstraint { /** * Apply to weights but not biases by default * - * @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should + * @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should * be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of * parameters which have order [depthOut, depthIn, kH, kW] */ @@ -49,7 +49,7 @@ public class UnitNormConstraint extends BaseConstraint { /** - * @param dimensions Dimensions to apply to. For DenseLayer, OutputLayer, RnnOutputLayer, LSTM, etc: this should + * @param dimensions Dimensions to apply to. For DenseLayerConfiguration, OutputLayer, RnnOutputLayer, LSTM, etc: this should * be dimension 1. For CNNs, this should be dimensions [1,2,3] corresponding to last 3 of * parameters which have order [depthOut, depthIn, kH, kW] */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java index b1734682d..0c7565db1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java @@ -21,7 +21,6 @@ package org.deeplearning4j.nn.conf.graph; import lombok.Data; -import lombok.EqualsAndHashCode; import lombok.NoArgsConstructor; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -40,8 +39,8 @@ public class LayerVertex extends GraphVertex { private NeuralNetConfiguration layerConf; private InputPreProcessor preProcessor; - //Set outputVertex to true when Layer is an OutputLayer, OR For use in specialized situations like reinforcement learning - // For RL situations, this Layer insn't an OutputLayer, but is the last layer in a graph, that gets its error/epsilon + //Set outputVertex to true when ILayer is an OutputLayer, OR For use in specialized situations like reinforcement learning + // For RL situations, this ILayer insn't an OutputLayer, but is the last layer in a graph, that gets its error/epsilon // passed in externally private boolean outputVertex; @@ -99,7 +98,7 @@ public class LayerVertex extends GraphVertex { public org.deeplearning4j.nn.graph.vertex.GraphVertex instantiate(ComputationGraph graph, String name, int idx, INDArray paramsView, boolean initializeParams, DataType networkDatatype) { //Now, we need to work out if this vertex is an output vertex or not... - boolean isOutput = graph.getConfiguration().getNetworkOutputs().contains(name); + boolean isOutput = graph.getComputationGraphConfiguration().getNetworkOutputs().contains(name); org.deeplearning4j.nn.api.Layer layer = layerConf.getLayer().instantiate(layerConf, null, idx, paramsView, initializeParams, networkDatatype); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java index 0fb559c74..0b10cedd4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java @@ -134,7 +134,7 @@ public class ActivationLayer extends NoParamLayer { private IActivation activationFn = null; /** - * Layer activation function. Typical values include:
"relu" (rectified linear), "tanh", "sigmoid", + * ILayer activation function. Typical values include:
"relu" (rectified linear), "tanh", "sigmoid", * "softmax", "hardtanh", "leakyrelu", "maxout", "softsign", "softplus" * * @deprecated Use {@link #activation(Activation)} or {@link @activation(IActivation)} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java index fc751e91b..6aad5b0ef 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java @@ -176,7 +176,7 @@ public abstract class BaseLayer extends Layer implements Serializable, Cloneable protected double biasInit = Double.NaN; /** - * Gain initialization value, for layers with Layer Normalization. Defaults to 1 + * Gain initialization value, for layers with ILayer Normalization. Defaults to 1 * */ protected double gainInit = Double.NaN; @@ -292,7 +292,7 @@ public abstract class BaseLayer extends Layer implements Serializable, Cloneable } /** - * Gain initialization value, for layers with Layer Normalization. Defaults to 1 + * Gain initialization value, for layers with ILayer Normalization. Defaults to 1 * * @param gainInit Value to use for initializing gain */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java index c6f31faf3..4081930c9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java @@ -63,14 +63,14 @@ public class CapsuleLayer extends SameDiffLayer { this.routings = builder.routings; if(capsules <= 0 || capsuleDimensions <= 0 || routings <= 0){ - throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \"" + layerName + "\"):" + " capsules, capsuleDimensions, and routings must be > 0. Got: " + capsules + ", " + capsuleDimensions + ", " + routings); } if(inputCapsules < 0 || inputCapsuleDimensions < 0){ - throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \"" + layerName + "\"):" + " inputCapsules and inputCapsuleDimensions must be >= 0 if set. Got: " + inputCapsules + ", " + inputCapsuleDimensions); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java index d77f13e5c..1a6ce905c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java @@ -55,7 +55,7 @@ public class DenseLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - LayerValidation.assertNInNOutSet("DenseLayer", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret = new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(conf, networkDataType); @@ -101,7 +101,7 @@ public class DenseLayer extends FeedForwardLayer { return new LayerMemoryReport.Builder(layerName, DenseLayer.class, inputType, outputType) .standardMemory(numParams, updaterStateSize) .workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference - .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer + .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayerConfiguration .build(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java index a96ec6db7..66f48dd14 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java @@ -205,7 +205,7 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable { /** * Is the specified parameter a layerwise pretraining only parameter?
For example, visible * bias params in an autoencoder (or, decoder params in a variational autoencoder) aren't used - * during supervised backprop.
Layers (like DenseLayer, etc) with no pretrainable parameters + * during supervised backprop.
Layers (like DenseLayerConfiguration, etc) with no pretrainable parameters * will return false for all (valid) inputs. * * @param paramName Parameter name/key @@ -255,7 +255,7 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable { protected IDropout iDropout; /** - * Layer name assigns layer string name. Allows easier differentiation between layers. + * ILayer name assigns layer string name. Allows easier differentiation between layers. */ public T name(String layerName) { this.setLayerName(layerName); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java index 2a5f16be6..571f884e3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java @@ -42,7 +42,7 @@ public class LayerValidation { /** * Asserts that the layer nIn and nOut values are set for the layer * - * @param layerType Type of layer ("DenseLayer", etc) + * @param layerType Type of layer ("DenseLayerConfiguration", etc) * @param layerName Name of the layer (may be null if not set) * @param layerIndex Index of the layer * @param nIn nIn value @@ -60,7 +60,7 @@ public class LayerValidation { /** * Asserts that the layer nOut value is set for the layer * - * @param layerType Type of layer ("DenseLayer", etc) + * @param layerType Type of layer ("DenseLayerConfiguration", etc) * @param layerName Name of the layer (may be null if not set) * @param layerIndex Index of the layer * @param nOut nOut value diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java index 8648a2814..98d7fa093 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java @@ -147,7 +147,7 @@ public class LocalResponseNormalization extends Layer { return new LayerMemoryReport.Builder(layerName, DenseLayer.class, inputType, inputType).standardMemory(0, 0) .workingMemory(0, 2 * actElementsPerEx, 0, 3 * actElementsPerEx) - .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer + .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayerConfiguration .build(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java index 2107bdede..4d3f56a84 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java @@ -87,7 +87,7 @@ public class PrimaryCapsules extends SameDiffLayer { } if(capsules < 0){ - throw new IllegalArgumentException("Invalid configuration for Capsule Layer (layer name = \"" + throw new IllegalArgumentException("Invalid configuration for Capsule ILayer (layer name = \"" + layerName + "\"):" + " capsules must be >= 0 if set. Got: " + capsules); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java index 79ab2ca54..9eea40cfc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java @@ -113,7 +113,7 @@ public class ElementWiseMultiplicationLayer extends org.deeplearning4j.nn.conf.l return new LayerMemoryReport.Builder(layerName, ElementWiseMultiplicationLayer.class, inputType, outputType) .standardMemory(numParams, updaterStateSize) .workingMemory(0, 0, trainSizeFixed, trainSizeVariable) //No additional memory (beyond activations) for inference - .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayer + .cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS) //No caching in DenseLayerConfiguration .build(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java index d6004f6bb..54a93b904 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java @@ -44,7 +44,7 @@ public class TimeDistributed extends BaseWrapperLayer { private RNNFormat rnnDataFormat = RNNFormat.NCW; /** - * @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayer + * @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayerConfiguration */ public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { super(underlying); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java index 51cdb3b6f..0b68bf649 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaLayer.java @@ -33,7 +33,7 @@ public abstract class SameDiffLambdaLayer extends SameDiffLayer { * The defineLayer method is used to define the forward pass for the layer * * @param sameDiff SameDiff instance to use to define the vertex - * @param layerInput Layer input variable + * @param layerInput ILayer input variable * @return The output variable (corresponding to the output activations for the layer) */ public abstract SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java index d3c10ec2f..7ec4fb2d5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLambdaVertex.java @@ -37,7 +37,7 @@ public abstract class SameDiffLambdaVertex extends SameDiffVertex { * The defineVertex method is used to define the foward pass for the vertex * * @param sameDiff SameDiff instance to use to define the vertex - * @param inputs Layer input variable + * @param inputs ILayer input variable * @return The output variable (orresponding to the output activations for the vertex) */ public abstract SDVariable defineVertex(SameDiff sameDiff, VertexInputs inputs); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BuildingBlockLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BuildingBlockLayer.java deleted file mode 100644 index e150b850f..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BuildingBlockLayer.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * - * ****************************************************************************** - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - * - */ - -package org.deeplearning4j.nn.conf.layers.wrapper; - -import java.util.Collection; -import lombok.AccessLevel; -import lombok.Builder; -import lombok.Getter; -import lombok.NonNull; -import net.brutex.ai.dnn.api.LayerConfiguration; -import net.brutex.ai.dnn.api.NeuralNetwork; -import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.optimize.api.TrainingListener; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; - -@Builder(builderClassName = "Builder", access = AccessLevel.PUBLIC) -public class BuildingBlockLayer extends BaseLayer implements LayerConfiguration { - - @NonNull - @Getter - private NeuralNetworkConfiguration conf; - - @Override - public Layer instantiate(NeuralNetConfiguration conf, - Collection trainingListeners, int layerIndex, INDArray layerParamsView, - boolean initializeParams, DataType networkDataType) { - return null; - } - - @Override - public ParamInitializer initializer() { - return null; - } - - @Override - public InputType getOutputType(int layerIndex, InputType inputType) { - return null; - } - - @Override - public void setNIn(InputType inputType, boolean override) { - - } - - @Override - public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return null; - } - - @Override - public boolean isPretrainParam(String paramName) { - return false; - } - - @Override - public LayerMemoryReport getMemoryReport(InputType inputType) { - return null; - } - - /** - * Create and return an instance of a LayerConfiguration. - * - * @param network the "holding" network for the instance - * @return the new layer instance - */ - @Override - public net.brutex.ai.dnn.api.Layer instantiate(NeuralNetwork network) { - return null; - } -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/NetworkMemoryReport.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/NetworkMemoryReport.java index 9182ccfb9..d3f7b1955 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/NetworkMemoryReport.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/memory/NetworkMemoryReport.java @@ -153,7 +153,7 @@ public class NetworkMemoryReport extends MemoryReport { .append(modelName).append("\n").append(" Network Input: ") .append(Arrays.toString(networkInputTypes)).append("\n") .append(" # Layers: ").append(layerAndVertexReports.size()) - .append("\n").append(" Layer Types: ").append(sbLayerCounts) + .append("\n").append(" ILayer Types: ").append(sbLayerCounts) .append("\n"); appendFixedPlusVariable(sb, " Inference Memory (FP32) ", fixedMemBytes, perEx); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/IWeightNoise.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/IWeightNoise.java index 4c45b762f..c6c77d3d2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/IWeightNoise.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/IWeightNoise.java @@ -33,7 +33,7 @@ public interface IWeightNoise extends Serializable, Cloneable{ /** * Get the parameter, after applying weight noise * - * @param layer Layer to get the parameter for + * @param layer ILayer to get the parameter for * @param paramKey Parameter key * @param iteration Iteration number * @param epoch Epoch number diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index ac8a05be4..4a080bb28 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -25,6 +25,8 @@ import lombok.NonNull; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import lombok.val; +import net.brutex.ai.dnn.api.INeuralNetwork; +import net.brutex.ai.dnn.networks.ArtificialNeuralNetwork; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; @@ -103,9 +105,16 @@ import java.util.*; import java.util.concurrent.atomic.AtomicLong; @Slf4j -public class ComputationGraph implements Serializable, Model, NeuralNetwork { +public class ComputationGraph extends ArtificialNeuralNetwork implements Serializable, Model, + INeuralNetwork { - protected ComputationGraphConfiguration configuration; + /** + * This method returns configuration of this ComputationGraph + * + * @return + */ + @Getter + protected ComputationGraphConfiguration computationGraphConfiguration; protected boolean initCalled = false; protected transient Solver solver; //Used to call optimizers during backprop protected INDArray flattenedParams; //Params for all layers are a view/subset of this array @@ -210,17 +219,17 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { private Collection trainingListeners = new ArrayList<>(); - public ComputationGraph(ComputationGraphConfiguration configuration) { - this.configuration = configuration; - this.numInputArrays = configuration.getNetworkInputs().size(); - this.numOutputArrays = configuration.getNetworkOutputs().size(); + public ComputationGraph(ComputationGraphConfiguration computationGraphConfiguration) { + this.computationGraphConfiguration = computationGraphConfiguration; + this.numInputArrays = computationGraphConfiguration.getNetworkInputs().size(); + this.numOutputArrays = computationGraphConfiguration.getNetworkOutputs().size(); this.inputs = new INDArray[numInputArrays]; this.labels = new INDArray[numOutputArrays]; - this.defaultConfiguration = configuration.getDefaultConfiguration(); + this.defaultConfiguration = computationGraphConfiguration.getDefaultConfiguration(); //Working memory: should learn over course of: (a) full forward pass, and (b) full backward pass //Working memory should be opened once per vertex, for each of forward and backward passes - int numWorkingMem = 2 * configuration.getVertices().size(); + int numWorkingMem = 2 * computationGraphConfiguration.getVertices().size(); WS_LAYER_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder() .initialSize(0) .overallocationLimit(0.02) @@ -238,7 +247,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { .initialSize(0) .overallocationLimit(0.02) .policyLearning(LearningPolicy.OVER_TIME) - .cyclesBeforeInitialization(configuration.getVertices().size()) + .cyclesBeforeInitialization(computationGraphConfiguration.getVertices().size()) .policyReset(ResetPolicy.BLOCK_LEFT) .policySpill(SpillPolicy.REALLOCATE) .policyAllocation(AllocationPolicy.OVERALLOCATE) @@ -278,14 +287,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } - /** - * This method returns configuration of this ComputationGraph - * - * @return - */ - public ComputationGraphConfiguration getConfiguration() { - return configuration; - } + /** * Returns the number of layers in the ComputationGraph @@ -313,7 +315,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * Get a given layer by name. */ public Layer getLayer(String name) { - Preconditions.checkState(verticesMap.containsKey(name), "Layer with name %s does not exist in the network", name); + Preconditions.checkState(verticesMap.containsKey(name), "ILayer with name %s does not exist in the network", name); return verticesMap.get(name).getLayer(); } @@ -449,7 +451,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if (initCalled) return; - DataType netDtype = getConfiguration().getDataType(); + DataType netDtype = this.getComputationGraphConfiguration().getDataType(); if(parameters != null && parameters.dataType() != netDtype){ Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters); if(cloneParametersArray){ @@ -463,31 +465,31 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } - if (configuration.getTrainingWorkspaceMode() == null) - configuration.setTrainingWorkspaceMode(WorkspaceMode.NONE); + if (computationGraphConfiguration.getTrainingWorkspaceMode() == null) + computationGraphConfiguration.setTrainingWorkspaceMode(WorkspaceMode.NONE); - if (configuration.getInferenceWorkspaceMode() == null) - configuration.setInferenceWorkspaceMode(WorkspaceMode.NONE); + if (computationGraphConfiguration.getInferenceWorkspaceMode() == null) + computationGraphConfiguration.setInferenceWorkspaceMode(WorkspaceMode.NONE); - if (configuration.getCacheMode() == null) - configuration.setCacheMode(CacheMode.NONE); + if (computationGraphConfiguration.getCacheMode() == null) + computationGraphConfiguration.setCacheMode(CacheMode.NONE); OneTimeLogger.info(log, "Starting ComputationGraph with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", - configuration.getTrainingWorkspaceMode(), configuration.getInferenceWorkspaceMode(), configuration.getCacheMode()); + computationGraphConfiguration.getTrainingWorkspaceMode(), computationGraphConfiguration.getInferenceWorkspaceMode(), computationGraphConfiguration.getCacheMode()); //First: build topological ordering, based on configuration. Used for forward pass, backprop and order of parameters/gradients GraphIndices indices = calculateIndices(); topologicalOrder = indices.getTopologicalSortOrder(); //Initialization: create the GraphVertex objects, based on configuration structure - Map configVertexMap = configuration.getVertices(); + Map configVertexMap = computationGraphConfiguration.getVertices(); //Names of all of the (data) inputs to the ComputationGraph - List networkInputNames = configuration.getNetworkInputs(); + List networkInputNames = computationGraphConfiguration.getNetworkInputs(); //Inputs for each layer and GraphNode: - Map> vertexInputs = configuration.getVertexInputs(); - this.vertices = new GraphVertex[networkInputNames.size() + configuration.getVertices().size()]; + Map> vertexInputs = computationGraphConfiguration.getVertexInputs(); + this.vertices = new GraphVertex[networkInputNames.size() + computationGraphConfiguration.getVertices().size()]; //All names: inputs, layers and graph nodes (index to name map) Map allNamesReverse = new HashMap<>(); @@ -504,7 +506,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { long numParams = 0; long[] numParamsForVertex = new long[topologicalOrder.length]; int i = 0; - for (; i < configuration.getNetworkInputs().size(); i++) { + for (; i < computationGraphConfiguration.getNetworkInputs().size(); i++) { numParamsForVertex[i] = 0; //No parameters for input vertices } for(; i < topologicalOrder.length; i++) { @@ -513,7 +515,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { n.setDataType(netDtype); numParamsForVertex[i] = n.numParams(true); if(numParamsForVertex[i] < 0) - throw new DL4JInvalidConfigException("Layer " + name + " had parameters < 0 " + numParamsForVertex[i]); + throw new DL4JInvalidConfigException("ILayer " + name + " had parameters < 0 " + numParamsForVertex[i]); numParams += numParamsForVertex[i]; } @@ -564,7 +566,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { List tempLayerList = new ArrayList<>(); defaultConfiguration.clearVariables(); List variables = defaultConfiguration.variables(false); - i = configuration.getNetworkInputs().size(); + i = computationGraphConfiguration.getNetworkInputs().size(); for(; i> seenAsInputTo = new HashMap<>(); - for(Map.Entry> entry : configuration.getVertexInputs().entrySet()){ + for(Map.Entry> entry : computationGraphConfiguration.getVertexInputs().entrySet()){ for(String s : entry.getValue() ){ if (!seenAsInputTo.containsKey(s)) { seenAsInputTo.put(s, new ArrayList()); @@ -709,10 +711,10 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { for(Layer l : layers){ String layerName = l.conf().getLayer().getLayerName(); - List inputs = configuration.getVertexInputs().get(layerName); + List inputs = computationGraphConfiguration.getVertexInputs().get(layerName); String in = inputs.get(0); //For now: layers should have exactly 1 input - if(configuration.getNetworkInputs().contains(in)){ + if(computationGraphConfiguration.getNetworkInputs().contains(in)){ //TODO When is it safe to NOT allow input modifucation? It's not always safe... // For example dropout + iterating over List that is used for multiple epochs... continue; @@ -761,10 +763,10 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { long numParams = 0; long[] numParamsForVertex = new long[topologicalOrder.length]; int i = 0; - for (; i < configuration.getNetworkInputs().size(); i++) { + for (; i < computationGraphConfiguration.getNetworkInputs().size(); i++) { numParamsForVertex[i] = 0; //No parameters for input vertices } - Map configVertexMap = configuration.getVertices(); + Map configVertexMap = computationGraphConfiguration.getVertices(); for (; i < topologicalOrder.length; i++) { String name = indices.getIdxToName().get(i); org.deeplearning4j.nn.conf.graph.GraphVertex n = configVertexMap.get(name); @@ -796,7 +798,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if(outputLayerIdxs == null) { outputLayerIdxs = new int[numOutputArrays]; int i = 0; - for (String s : configuration.getNetworkOutputs()) { + for (String s : computationGraphConfiguration.getNetworkOutputs()) { outputLayerIdxs[i++] = verticesMap.get(s).getVertexIndex(); } } @@ -875,7 +877,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { /** * Pretrain a specified layer with the given DataSetIterator * - * @param layerName Layer name + * @param layerName ILayer name * @param dataSetIterator Data */ public void pretrainLayer(String layerName, DataSetIterator dataSetIterator) { @@ -890,7 +892,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { /** * Pretrain a specified layer with the given MultiDataSetIterator * - * @param layerName Layer name + * @param layerName ILayer name * @param iter Training data */ public void pretrainLayer(String layerName, MultiDataSetIterator iter) { @@ -920,7 +922,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { int idx = toTrain.getVertexIndex(); LayerWorkspaceMgr workspaceMgr; - if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ + if(computationGraphConfiguration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { workspaceMgr = LayerWorkspaceMgr.builder() @@ -1133,7 +1135,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { update(TaskUtils.buildTask(inputs, labels)); LayerWorkspaceMgr workspaceMgr; - if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ + if(computationGraphConfiguration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { workspaceMgr = LayerWorkspaceMgr.builder() @@ -1151,7 +1153,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) { + if (computationGraphConfiguration.getBackpropType() == BackpropType.TruncatedBPTT) { doTruncatedBPTT(inputs, labels, featureMaskArrays, labelMaskArrays, workspaceMgr); } else { if (solver == null) { @@ -1202,9 +1204,9 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { //Get cached topological sort order from config, if present - if(configuration.getTopologicalOrder() != null && configuration.getTopologicalOrderStr() != null){ - int[] t = configuration.getTopologicalOrder(); - List s = configuration.getTopologicalOrderStr(); + if(computationGraphConfiguration.getTopologicalOrder() != null && computationGraphConfiguration.getTopologicalOrderStr() != null){ + int[] t = computationGraphConfiguration.getTopologicalOrder(); + List s = computationGraphConfiguration.getTopologicalOrderStr(); Map m1 = new HashMap<>(); Map m2 = new HashMap<>(); for( int i=0; i nodeMap = configuration.getVertices(); - List networkInputNames = configuration.getNetworkInputs(); - int numVertices = networkInputNames.size() + configuration.getVertices().size(); + Map nodeMap = computationGraphConfiguration.getVertices(); + List networkInputNames = computationGraphConfiguration.getNetworkInputs(); + int numVertices = networkInputNames.size() + computationGraphConfiguration.getVertices().size(); int[] out = new int[numVertices]; int outCounter = 0; @@ -1233,7 +1235,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { Map vertexNamesMap = new HashMap<>(); Map vertexNamesMap2 = new HashMap<>(); int i = 0; - for (String inputName : configuration.getNetworkInputs()) { + for (String inputName : computationGraphConfiguration.getNetworkInputs()) { vertexNamesMap.put(i, inputName); vertexNamesMap2.put(inputName, i); i++; @@ -1248,7 +1250,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { Map> inputEdges = new HashMap<>(); //key: vertex. Values: vertices that the key vertex receives input from Map> outputEdges = new HashMap<>(); //key: vertex. Values: vertices that the key vertex outputs to - for (String s : configuration.getNetworkInputs()) { + for (String s : computationGraphConfiguration.getNetworkInputs()) { int idx = vertexNamesMap2.get(s); inputEdges.put(idx, null); } @@ -1256,7 +1258,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { for (Map.Entry entry : nodeMap.entrySet()) { String thisVertexName = entry.getKey(); int idx = vertexNamesMap2.get(thisVertexName); - List inputsToThisVertex = configuration.getVertexInputs().get(thisVertexName); + List inputsToThisVertex = computationGraphConfiguration.getVertexInputs().get(thisVertexName); if (inputsToThisVertex == null || inputsToThisVertex.isEmpty()) { inputEdges.put(idx, null); @@ -1324,8 +1326,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { for( int idx : out){ s.add(vertexNamesMap.get(idx)); } - configuration.setTopologicalOrder(out); - configuration.setTopologicalOrderStr(s); + computationGraphConfiguration.setTopologicalOrder(out); + computationGraphConfiguration.setTopologicalOrderStr(s); graphIndices = GraphIndices.builder() .topologicalSortOrder(out) @@ -1344,7 +1346,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { synchronizeIterEpochCounts(); LayerWorkspaceMgr workspaceMgr; - if(configuration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ + if(computationGraphConfiguration.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { workspaceMgr = LayerWorkspaceMgr.builder() @@ -1362,7 +1364,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - boolean tbptt = configuration.getBackpropType() == BackpropType.TruncatedBPTT; + boolean tbptt = computationGraphConfiguration.getBackpropType() == BackpropType.TruncatedBPTT; FwdPassType fwdType = (tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD); synchronizeIterEpochCounts(); @@ -1386,7 +1388,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { score = 0.0; int outNum = 0; - for (String s : configuration.getNetworkOutputs()) { + for (String s : computationGraphConfiguration.getNetworkOutputs()) { GraphVertex gv = verticesMap.get(s); if(gv instanceof LayerVertex) { //At this point: the input to the output layer might not be set on the layer itself - just the vertex @@ -1863,7 +1865,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { int[] layerNums = new int[layers.size()]; for( int i=0; i freeWorkspaceManagers = new ArrayList<>(); //Basically used as a stack Map openActivationsWorkspaces = new IdentityHashMap<>(); - WorkspaceMode wsm = (train ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode()); + WorkspaceMode wsm = (train ? computationGraphConfiguration.getTrainingWorkspaceMode() : computationGraphConfiguration.getInferenceWorkspaceMode()); boolean noWS = wsm == WorkspaceMode.NONE; LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null; List[] closeAtEndIteraton = (List[])new List[topologicalOrder.length]; @@ -2438,7 +2440,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) { if (current.hasLayer()) { - //Layer + //ILayer INDArray input = current.getInputs()[0]; Layer l = current.getLayer(); if (l instanceof RecurrentLayer) { @@ -2562,7 +2564,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { try { - calcBackpropGradients(true, configuration.getBackpropType() == BackpropType.TruncatedBPTT, epsilons); + calcBackpropGradients(true, computationGraphConfiguration.getBackpropType() == BackpropType.TruncatedBPTT, epsilons); return gradient; } catch (OutOfMemoryError e){ CrashReportingUtil.writeMemoryCrashDump(this, e); @@ -2595,19 +2597,19 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { consumed by all layers */ - if(externalEpsilons == null || externalEpsilons.length == 0 && configuration.getTrainingWorkspaceMode() != WorkspaceMode.NONE){ + if(externalEpsilons == null || externalEpsilons.length == 0 && computationGraphConfiguration.getTrainingWorkspaceMode() != WorkspaceMode.NONE){ WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "Expected workspace WS_ALL_LAYERS_ACT to be active and open" + " in calcBackpropGradients when workspace mode is not set to NONE"); } //Validate the network configuration for external errors - no output layers if(externalEpsilons != null && externalEpsilons.length > 0){ - List outputLayers = configuration.getNetworkOutputs(); + List outputLayers = computationGraphConfiguration.getNetworkOutputs(); for(String s : outputLayers ){ GraphVertex gv = getVertex(s); if(gv instanceof LayerVertex && gv.getLayer() instanceof IOutputLayer){ throw new IllegalStateException("Cannot perform backprop with external errors in conjunction with an output layer:" + - " output layers cannot use external errors for backprop. Layer name: " + s); + " output layers cannot use external errors for backprop. ILayer name: " + s); } } @@ -2643,7 +2645,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } - boolean noWS = configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE; + boolean noWS = computationGraphConfiguration.getInferenceWorkspaceMode() == WorkspaceMode.NONE; LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null; List allWorkspaceManagers = new ArrayList<>(); @@ -2722,7 +2724,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { //(a) it's an output layer (i.e., instanceof IOutputLayer), or //(b) it's a normal layer, but it has been marked as an output layer for use in external errors - for reinforcement learning, for example - int thisOutputNumber = configuration.getNetworkOutputs().indexOf(current.getVertexName()); + int thisOutputNumber = computationGraphConfiguration.getNetworkOutputs().indexOf(current.getVertexName()); Layer currentLayer = current.getLayer(); if (currentLayer instanceof FrozenLayerWithBackprop) { currentLayer = ((FrozenLayerWithBackprop) currentLayer).getInsideLayer(); @@ -2735,7 +2737,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } else { if ((externalEpsilons == null || externalEpsilons.length == 0) && labels[thisOutputNumber] != null) { - throw new DL4JException("Layer \"" + current.getVertexName() + "\" of type " + throw new DL4JException("ILayer \"" + current.getVertexName() + "\" of type " + current.getLayer().getClass().getSimpleName() + " is set as network output " + "(but isn't an IOutputLayer). Only IOutputLayer layers can be fit via backprop with" @@ -2882,7 +2884,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { @Override public ComputationGraph clone() { - ComputationGraph cg = new ComputationGraph(configuration.clone()); + ComputationGraph cg = new ComputationGraph(computationGraphConfiguration.clone()); cg.init(params().dup(), false); if (solver != null) { //If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however @@ -3019,7 +3021,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if (outputLayerIdx >= numOutputArrays) throw new IllegalArgumentException("Invalid index: cannot get output layer " + outputLayerIdx + ", total number of network outputs = " + numOutputArrays); - return getLayer(configuration.getNetworkOutputs().get(outputLayerIdx)); + return getLayer(computationGraphConfiguration.getNetworkOutputs().get(outputLayerIdx)); } /** @@ -3086,7 +3088,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { private double scoreHelper(MultiDataSet dataSet, boolean training){ LayerWorkspaceMgr mgr; - WorkspaceMode wsm = (training ? configuration.getTrainingWorkspaceMode() : configuration.getInferenceWorkspaceMode()); + WorkspaceMode wsm = (training ? computationGraphConfiguration.getTrainingWorkspaceMode() : computationGraphConfiguration.getInferenceWorkspaceMode()); if(wsm == WorkspaceMode.NONE){ mgr = LayerWorkspaceMgr.noWorkspaces(); } else { @@ -3120,7 +3122,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { double r = calcRegularizationScore(true); int i = 0; - for (String s : configuration.getNetworkOutputs()) { + for (String s : computationGraphConfiguration.getNetworkOutputs()) { GraphVertex gv = verticesMap.get(s); Layer outLayer = gv.getLayer(); if (outLayer == null || !(outLayer instanceof IOutputLayer)) { @@ -3180,7 +3182,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { private INDArray scoreExamplesHelper(MultiDataSet dataSet, boolean addRegularizationTerms){ LayerWorkspaceMgr mgr; - if(configuration.getInferenceWorkspaceMode() == WorkspaceMode.NONE){ + if(computationGraphConfiguration.getInferenceWorkspaceMode() == WorkspaceMode.NONE){ mgr = LayerWorkspaceMgr.noWorkspaces(); } else { mgr = LayerWorkspaceMgr.builder() @@ -3212,7 +3214,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { double r = (addRegularizationTerms ? calcRegularizationScore(true) : 0.0); int i = 0; - for (String s : configuration.getNetworkOutputs()) { + for (String s : computationGraphConfiguration.getNetworkOutputs()) { GraphVertex gv = verticesMap.get(s); Layer outLayer = gv.getLayer(); if (outLayer == null || !(outLayer instanceof IOutputLayer)) { @@ -3640,7 +3642,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } if (l == null || !(l instanceof RecurrentLayer)) { throw new UnsupportedOperationException( - "Layer \"" + layerName + "\" is not a recurrent layer. Cannot set state"); + "ILayer \"" + layerName + "\" is not a recurrent layer. Cannot set state"); } ((RecurrentLayer) l).rnnSetPreviousState(state); } @@ -3704,7 +3706,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } - long fwdLen = configuration.getTbpttFwdLength(); + long fwdLen = computationGraphConfiguration.getTbpttFwdLength(); long nSubsets = timeSeriesLength / fwdLen; if (timeSeriesLength % fwdLen != 0) nSubsets++; @@ -3882,7 +3884,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { // This output doesn't have a mask, we can skip it. continue; } - String outputName = configuration.getNetworkOutputs().get(i); + String outputName = computationGraphConfiguration.getNetworkOutputs().get(i); GraphVertex v = verticesMap.get(outputName); Layer ol = v.getLayer(); ol.setMaskArray(labelMaskArrays[i]); @@ -3972,7 +3974,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { labelsList = iterator.getLabels(); Layer outputLayer = getOutputLayer(0); - if(getConfiguration().isValidateOutputLayerConfig()){ + if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class); } @@ -3990,7 +3992,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { */ public T evaluate(MultiDataSetIterator iterator, List labelsList, int topN) { Layer outputLayer = getOutputLayer(0); - if(getConfiguration().isValidateOutputLayerConfig()){ + if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation(labelsList, topN))[0]; @@ -4055,7 +4057,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { */ public T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(0); - if(getConfiguration().isValidateOutputLayerConfig()){ + if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0]; @@ -4078,7 +4080,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { */ public T evaluateROC(MultiDataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(0); - if(getConfiguration().isValidateOutputLayerConfig()){ + if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0]; @@ -4101,7 +4103,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { */ public T evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(0); - if(getConfiguration().isValidateOutputLayerConfig()){ + if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0]; @@ -4116,7 +4118,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { */ public T evaluateROCMultiClass(MultiDataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(0); - if(getConfiguration().isValidateOutputLayerConfig()){ + if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0]; @@ -4202,13 +4204,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { MultiDataSetIterator iter = iterator.asyncSupported() ? new AsyncMultiDataSetIterator(iterator, 2, true) : iterator; - WorkspaceMode cMode = configuration.getTrainingWorkspaceMode(); - configuration.setTrainingWorkspaceMode(configuration.getInferenceWorkspaceMode()); + WorkspaceMode cMode = computationGraphConfiguration.getTrainingWorkspaceMode(); + computationGraphConfiguration.setTrainingWorkspaceMode(computationGraphConfiguration.getInferenceWorkspaceMode()); - boolean useRnnSegments = (configuration.getBackpropType() == BackpropType.TruncatedBPTT); + boolean useRnnSegments = (computationGraphConfiguration.getBackpropType() == BackpropType.TruncatedBPTT); MemoryWorkspace outputWs; - if(getConfiguration().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED){ + if(this.getComputationGraphConfiguration().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED){ outputWs = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM); } else { outputWs = new DummyWorkspace(); @@ -4256,7 +4258,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } else { rnnClearPreviousState(); - int fwdLen = configuration.getTbpttFwdLength(); + int fwdLen = computationGraphConfiguration.getTbpttFwdLength(); long tsLength = -1; long nF = next.getFeatures().length; for (int i = 0; i < nF; i++) { @@ -4309,7 +4311,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if (iterator.asyncSupported()) ((AsyncMultiDataSetIterator) iter).shutdown(); - configuration.setTrainingWorkspaceMode(cMode); + computationGraphConfiguration.setTrainingWorkspaceMode(cMode); return evaluations; } @@ -4380,9 +4382,9 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { String out = "-"; String paramShape = "-"; if (currentVertex.isInputVertex()) { - if (inputTypes != null) vertexOutputs.put(currentVertexName, inputTypes[configuration.getNetworkInputs().indexOf(currentVertexName)]); //for input vertices the outputs are just the input types (only layer vertices have preprocessing?) + if (inputTypes != null) vertexOutputs.put(currentVertexName, inputTypes[computationGraphConfiguration.getNetworkInputs().indexOf(currentVertexName)]); //for input vertices the outputs are just the input types (only layer vertices have preprocessing?) } else { - connections = configuration.getVertexInputs().get(currentVertexName).toString(); + connections = computationGraphConfiguration.getVertexInputs().get(currentVertexName).toString(); List inputTypeList = new ArrayList<>(); if (currentVertex.hasLayer()) { Layer currentLayer = currentVertex.getLayer(); @@ -4425,7 +4427,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { inShape = currentInType.toString(); inputTypeList.add(currentInType); - InputPreProcessor layerVertexPreProcesor = ((org.deeplearning4j.nn.conf.graph.LayerVertex)configuration.getVertices().get(currentVertexName)).getPreProcessor(); + InputPreProcessor layerVertexPreProcesor = ((org.deeplearning4j.nn.conf.graph.LayerVertex) computationGraphConfiguration.getVertices().get(currentVertexName)).getPreProcessor(); if (layerVertexPreProcesor != null) { inShape += "-->" + layerVertexPreProcesor.getOutputType(currentInType); } @@ -4444,7 +4446,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { } } if (inputTypes != null) { - InputType currentVertexOutputType = configuration.getVertices().get(currentVertexName).getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()])); + InputType currentVertexOutputType = computationGraphConfiguration.getVertices().get(currentVertexName).getOutputType(currLayerIdx, inputTypeList.toArray(new InputType[inputTypeList.size()])); outShape = currentVertexOutputType.toString(); vertexOutputs.put(currentVertexName, currentVertexOutputType); } @@ -4546,14 +4548,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * The current epoch count can be obtained using {@code ComputationGraph.getConfiguration().getEpochCount()} */ public void incrementEpochCount(){ - configuration.setEpochCount(configuration.getEpochCount() + 1); + computationGraphConfiguration.setEpochCount(computationGraphConfiguration.getEpochCount() + 1); synchronizeIterEpochCounts(); } protected void synchronizeIterEpochCounts(){ //TODO: this is necessrry for some schedules - but the redundant values are a little ugly... - int currIter = getConfiguration().getIterationCount(); - int currEpoch = getConfiguration().getEpochCount(); + int currIter = this.getComputationGraphConfiguration().getIterationCount(); + int currEpoch = this.getComputationGraphConfiguration().getEpochCount(); for(Layer l : layers){ l.setIterationCount(currIter); l.setEpochCount(currEpoch); @@ -4565,7 +4567,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * @return Number of iterations */ public int getIterationCount(){ - return configuration.getIterationCount(); + return computationGraphConfiguration.getIterationCount(); } /** @@ -4576,7 +4578,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { * @return Number of epochs */ public int getEpochCount(){ - return configuration.getEpochCount(); + return computationGraphConfiguration.getEpochCount(); } /** @@ -4633,7 +4635,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { INDArray newParams = params().castTo(dataType); - String jsonConfig = getConfiguration().toJson(); + String jsonConfig = this.getComputationGraphConfiguration().toJson(); ComputationGraphConfiguration newConf = ComputationGraphConfiguration.fromJson(jsonConfig); newConf.setDataType(dataType); ComputationGraph newNet = new ComputationGraph(newConf); @@ -4714,7 +4716,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { /** * Get the current learning rate, for the specified layer, from the network. * Note: If the layer has no learning rate (no parameters, or an updater without a learning rate) then null is returned - * @param layerName Layer name + * @param layerName ILayer name * @return Learning rate for the specified layer, or null */ public Double getLearningRate(String layerName){ @@ -4724,7 +4726,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { /** * Return the layer size (number of units) for the specified layer. * Note that the meaning of the "layer size" can depend on the type of layer. For example:
- * - DenseLayer, OutputLayer, recurrent layers: number of units (nOut configuration option)
+ * - DenseLayerConfiguration, OutputLayer, recurrent layers: number of units (nOut configuration option)
* - ConvolutionLayer: the channels (number of channels)
* - Subsampling layers, global pooling layers, etc: size of 0 is always returned
* @@ -4733,7 +4735,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { */ public long layerSize(int layer) { if (layer < 0 || layer > layers.length) { - throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + throw new IllegalArgumentException("Invalid layer index: " + layer + ". ILayer index must be between 0 and " + (layers.length - 1) + " inclusive"); } return layerSize(layers[layer].conf().getLayer().getLayerName()); @@ -4742,7 +4744,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { /** * Return the input size (number of inputs) for the specified layer.
* Note that the meaning of the "input size" can depend on the type of layer. For example:
- * - DenseLayer, OutputLayer, etc: the feature vector size (nIn configuration option)
+ * - DenseLayerConfiguration, OutputLayer, etc: the feature vector size (nIn configuration option)
* - Recurrent layers: the feature vector size per time step (nIn configuration option)
* - ConvolutionLayer: the channels (number of channels)
* - Subsampling layers, global pooling layers, etc: size of 0 is always returned
@@ -4752,7 +4754,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { */ public long layerInputSize(int layer) { if (layer < 0 || layer > layers.length) { - throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " + throw new IllegalArgumentException("Invalid layer index: " + layer + ". ILayer index must be between 0 and " + (layers.length - 1) + " inclusive"); } return layerInputSize(layers[layer].conf().getLayer().getLayerName()); @@ -4761,7 +4763,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { /** * Return the layer size (number of units) for the specified layer.
* Note that the meaning of the "layer size" can depend on the type of layer. For example:
- * - DenseLayer, OutputLayer, recurrent layers: number of units (nOut configuration option)
+ * - DenseLayerConfiguration, OutputLayer, recurrent layers: number of units (nOut configuration option)
* - ConvolutionLayer: the channels (number of channels)
* - Subsampling layers, global pooling layers, etc: size of 0 is always returned
* @@ -4785,7 +4787,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { /** * Return the input size (number of inputs) for the specified layer.
* Note that the meaning of the "input size" can depend on the type of layer. For example:
- * - DenseLayer, OutputLayer, etc: the feature vector size (nIn configuration option)
+ * - DenseLayerConfiguration, OutputLayer, etc: the feature vector size (nIn configuration option)
* - Recurrent layers: the feature vector size per time step (nIn configuration option)
* - ConvolutionLayer: the channels (number of channels)
* - Subsampling layers, global pooling layers, etc: size of 0 is always returned
@@ -4860,7 +4862,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if (obj instanceof ComputationGraph) { ComputationGraph network = (ComputationGraph) obj; boolean paramsEquals = network.params().equals(params()); - boolean confEquals = getConfiguration().equals(network.getConfiguration()); + boolean confEquals = this.getComputationGraphConfiguration().equals(network.getComputationGraphConfiguration()); boolean updaterEquals = getUpdater().equals(network.getUpdater()); return paramsEquals && confEquals && updaterEquals; } @@ -4875,7 +4877,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { val cg = ModelSerializer.restoreComputationGraph(ois, true); this.defaultConfiguration = cg.defaultConfiguration.clone(); - this.configuration = cg.configuration.clone(); + this.computationGraphConfiguration = cg.computationGraphConfiguration.clone(); this.init(); this.flattenedParams.assign(cg.flattenedParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java index afffe99d4..cdb124d75 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java @@ -58,8 +58,8 @@ public abstract class BaseGraphVertex implements GraphVertex { protected INDArray[] inputs; protected INDArray epsilon; - //Set outputVertex to true when Layer is an OutputLayer, OR For use in specialized situations like reinforcement learning - // For RL situations, this Layer insn't an OutputLayer, but is the last layer in a graph, that gets its error/epsilon + //Set outputVertex to true when ILayer is an OutputLayer, OR For use in specialized situations like reinforcement learning + // For RL situations, this ILayer insn't an OutputLayer, but is the last layer in a graph, that gets its error/epsilon // passed in externally @Setter @Getter protected boolean outputVertex; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java index 73e4b2fc4..61136e0db 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java @@ -40,7 +40,7 @@ public interface GraphVertex extends Trainable, Serializable { /** Get the index of the GraphVertex */ int getVertexIndex(); - /** Get the number of input arrays. For example, a Layer may have only one input array, but in general a GraphVertex + /** Get the number of input arrays. For example, a ILayer may have only one input array, but in general a GraphVertex * may have an arbtrary (>=1) number of input arrays (for example, from multiple other layers) */ int getNumInputArrays(); @@ -85,7 +85,7 @@ public interface GraphVertex extends Trainable, Serializable { /** Set the GraphVertex to be an output vertex */ void setOutputVertex(boolean outputVertex); - /** Get the Layer (if any). Returns null if {@link #hasLayer()} == false */ + /** Get the ILayer (if any). Returns null if {@link #hasLayer()} == false */ Layer getLayer(); /** Set the input activations. diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java index fdd05c390..60f3dad0b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java @@ -124,10 +124,10 @@ public class LayerVertex extends BaseGraphVertex { public Pair doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { if (!canDoBackward()) { if(inputs == null || inputs[0] == null){ - throw new IllegalStateException("Cannot do backward pass: inputs not set. Layer: \"" + vertexName + throw new IllegalStateException("Cannot do backward pass: inputs not set. ILayer: \"" + vertexName + "\" (idx " + vertexIndex + "), numInputs: " + getNumInputArrays()); } else { - throw new IllegalStateException("Cannot do backward pass: all epsilons not set. Layer \"" + vertexName + throw new IllegalStateException("Cannot do backward pass: all epsilons not set. ILayer \"" + vertexName + "\" (idx " + vertexIndex + "), numInputs :" + getNumInputArrays() + "; numOutputs: " + getNumOutputConnections()); } @@ -142,7 +142,7 @@ public class LayerVertex extends BaseGraphVertex { if (tbptt && layer instanceof RecurrentLayer) { //Truncated BPTT for recurrent layers pair = ((RecurrentLayer) layer).tbpttBackpropGradient(epsilon, - graph.getConfiguration().getTbpttBackLength(), workspaceMgr); + graph.getComputationGraphConfiguration().getTbpttBackLength(), workspaceMgr); } else { //Normal backprop pair = layer.backpropGradient(epsilon, workspaceMgr); //epsTotal may be null for OutputLayers diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java index 2bfc6ee97..27eb238d3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/DuplicateToTimeSeriesVertex.java @@ -48,10 +48,10 @@ public class DuplicateToTimeSeriesVertex extends BaseGraphVertex { VertexIndices[] inputVertices, VertexIndices[] outputVertices, String inputName, DataType dataType) { super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.inputName = inputName; - this.inputVertexIndex = graph.getConfiguration().getNetworkInputs().indexOf(inputName); + this.inputVertexIndex = graph.getComputationGraphConfiguration().getNetworkInputs().indexOf(inputName); if (inputVertexIndex == -1) throw new IllegalArgumentException("Invalid input name: \"" + inputName + "\" not found in list " - + "of network inputs (" + graph.getConfiguration().getNetworkInputs() + ")"); + + "of network inputs (" + graph.getComputationGraphConfiguration().getNetworkInputs() + ")"); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java index 0475936d0..4402dc4c5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/LastTimeStepVertex.java @@ -54,10 +54,10 @@ public class LastTimeStepVertex extends BaseGraphVertex { VertexIndices[] outputVertices, String inputName, DataType dataType) { super(graph, name, vertexIndex, inputVertices, outputVertices, dataType); this.inputName = inputName; - this.inputIdx = graph.getConfiguration().getNetworkInputs().indexOf(inputName); + this.inputIdx = graph.getComputationGraphConfiguration().getNetworkInputs().indexOf(inputName); if (inputIdx == -1) throw new IllegalArgumentException("Invalid input name: \"" + inputName + "\" not found in list " - + "of network inputs (" + graph.getConfiguration().getNetworkInputs() + ")"); + + "of network inputs (" + graph.getComputationGraphConfiguration().getNetworkInputs() + ")"); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java index 359a576a3..86b5dcab3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/rnn/ReverseTimeSeriesVertex.java @@ -48,10 +48,10 @@ public class ReverseTimeSeriesVertex extends BaseGraphVertex { this.inputIdx = -1; } else { // Find the given input - this.inputIdx = graph.getConfiguration().getNetworkInputs().indexOf(inputName); + this.inputIdx = graph.getComputationGraphConfiguration().getNetworkInputs().indexOf(inputName); if (inputIdx == -1) throw new IllegalArgumentException("Invalid input name: \"" + inputName + "\" not found in list " - + "of network inputs (" + graph.getConfiguration().getNetworkInputs() + ")"); + + "of network inputs (" + graph.getComputationGraphConfiguration().getNetworkInputs() + ")"); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java index 3ad4f8b0a..fa03d3c51 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LSTMHelpers.java @@ -79,7 +79,7 @@ public class LSTMHelpers { ) { //Mini-batch data format: for mini-batch size m, nIn inputs, and T time series length - //Data has shape [m,nIn,T]. Layer activations/output has shape [m,nHiddenUnits,T] + //Data has shape [m,nIn,T]. ILayer activations/output has shape [m,nHiddenUnits,T] if (input == null || input.length() == 0) throw new IllegalArgumentException("Invalid input: not set or 0 length"); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 18397bd4d..0f81392f9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -26,6 +26,8 @@ import lombok.NonNull; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import lombok.val; +import net.brutex.ai.dnn.api.INeuralNetwork; +import net.brutex.ai.dnn.networks.ArtificialNeuralNetwork; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; @@ -38,9 +40,7 @@ import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; -import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -99,1097 +99,1235 @@ import org.nd4j.common.util.OneTimeLogger; import java.io.*; import java.util.*; - +/** + * Artificial Neural Network An artificial neural network (1) takes some input data, and (2) + * transforms this input data by calculating a weighted sum over the inputs and (3) applies a + * non-linear function to this transformation to calculate an intermediate state. The three steps + * above constitute what is known as a layer, and the transformative function is often referred to + * as a unit. The intermediate states—often termed features—are used as the input into another + * layer. + *

+ * Through repetition of these steps, the artificial neural network learns multiple layers of + * non-linear features, which it then combines in a final layer to create a prediction. + *

+ * The neural network learns by generating an error signal that measures the difference between the + * predictions of the network and the desired values and then using this error signal to change the + * weights (or parameters) so that predictions get more accurate. + */ @Slf4j -public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.deeplearning4j.nn.api.NeuralNetwork { +public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serializable, Classifier, Layer, + INeuralNetwork { - //the hidden neural network layers (including output layer) - protected Layer[] layers; - protected LinkedHashMap layerMap = new LinkedHashMap<>(); - - //Current training data: input features and labels - protected INDArray input, labels; - - protected boolean initCalled = false; - protected Collection trainingListeners = new ArrayList<>(); - - protected NeuralNetConfiguration defaultConfiguration; - protected MultiLayerConfiguration layerWiseConfigurations; - protected Gradient gradient; - protected double score; - @Setter - protected boolean initDone = false; - protected INDArray flattenedParams; //Params for all layers are a view/subset of this array - @Getter - protected transient INDArray flattenedGradients; //Gradients for all layers are a view/subset of this array - - protected boolean clearTbpttState = true; //Mainly for unit testing (should be enabled otherwise) - protected transient ThreadLocal lastEtlTime = new ThreadLocal<>(); - protected INDArray mask; - - protected int layerIndex; //For Layer.get/setIndex() - - protected transient Solver solver; //Used to call optimizers during backprop - //Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers - @Getter - protected transient Map helperWorkspaces = new HashMap<>(); + /** + * Workspace for working memory for a single layer: forward pass and backward pass Note that this + * is opened/closed once per op (activate/backpropGradient call) + */ + protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM"; + /** + * Workspace for storing all layers' activations - used only to store activations (layer inputs) + * as part of backprop Not used for inference + */ + protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT"; + /** + * Next 2 workspaces: used for: (a) Inference: holds activations for one layer only (b) Backprop: + * holds activation gradients for one layer only In both cases, they are opened and closed on + * every second layer + */ + protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1"; + protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2"; + /** + * Workspace for output methods that use OutputAdapter + */ + protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM"; + /** + * Workspace for working memory in RNNs - opened and closed once per RNN time step + */ + protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM"; + protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder() + .initialSize(0) + .overallocationLimit(0.05) + .policyLearning(LearningPolicy.FIRST_LOOP) + .policyReset(ResetPolicy.BLOCK_LEFT) + .policySpill(SpillPolicy.REALLOCATE) + .policyAllocation(AllocationPolicy.OVERALLOCATE) + .build(); + protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder() + .initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT) + .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) + .policyLearning(LearningPolicy.FIRST_LOOP).build(); + //the hidden neural network layers (including output layer) + protected Layer[] layers; + protected LinkedHashMap layerMap = new LinkedHashMap<>(); + //Current training data: input features and labels + protected INDArray input, labels; + protected boolean initCalled = false; + protected Collection trainingListeners = new ArrayList<>(); + protected NeuralNetConfiguration defaultConfiguration; + protected MultiLayerConfiguration layerWiseConfigurations; + protected Gradient gradient; + protected double score; + @Setter + protected boolean initDone = false; + protected INDArray flattenedParams; //Params for all layers are a view/subset of this array + @Getter + protected transient INDArray flattenedGradients; //Gradients for all layers are a view/subset of this array + protected boolean clearTbpttState = true; //Mainly for unit testing (should be enabled otherwise) + protected transient ThreadLocal lastEtlTime = new ThreadLocal<>(); + protected INDArray mask; + protected int layerIndex; //For Layer.get/setIndex() + protected transient Solver solver; //Used to call optimizers during backprop + //Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers + @Getter + protected transient Map helperWorkspaces = new HashMap<>(); + protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG; + protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG; - /** - * Workspace for working memory for a single layer: forward pass and backward pass - * Note that this is opened/closed once per op (activate/backpropGradient call) - */ - protected static final String WS_LAYER_WORKING_MEM = "WS_LAYER_WORKING_MEM"; - /** - * Workspace for storing all layers' activations - used only to store activations (layer inputs) as part of backprop - * Not used for inference - */ - protected static final String WS_ALL_LAYERS_ACT = "WS_ALL_LAYERS_ACT"; - /** - * Next 2 workspaces: used for: - * (a) Inference: holds activations for one layer only - * (b) Backprop: holds activation gradients for one layer only - * In both cases, they are opened and closed on every second layer - */ - protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1"; - protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2"; + public MultiLayerNetwork(MultiLayerConfiguration conf) { + this.layerWiseConfigurations = conf; + this.defaultConfiguration = conf.getConf(0).clone(); - /** - * Workspace for output methods that use OutputAdapter - */ - protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM"; + //Working memory: should learn over course of: (a) full forward pass, and (b) full backward pass + //Working memory should be opened once per layer and once per preprocessor, for each of forward and backward passes + int numWorkingMem = 2 * (layerWiseConfigurations.getConfs().size() + + layerWiseConfigurations.getInputPreProcessors().size()); + WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); + WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(layerWiseConfigurations.getConfs().size()); + } - /** - * Workspace for working memory in RNNs - opened and closed once per RNN time step - */ - protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM"; + /** + * Initialize the network based on the configuration (a MultiLayerConfiguration in JSON format) + * and parameters array + * + * @param conf the configuration json + * @param params the parameters for the network + */ + public MultiLayerNetwork(String conf, INDArray params) { + this(MultiLayerConfiguration.fromJson(conf)); + init(); + setParameters(params); + } + /** + * Initialize the network based on the configuration and parameters array + * + * @param conf the configuration + * @param params the parameters + */ + public MultiLayerNetwork(MultiLayerConfiguration conf, INDArray params) { + this(conf); + init(); + setParameters(params); + } - protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG; + protected static WorkspaceConfiguration getLayerWorkingMemWSConfig(int numWorkingMemCycles) { + return WorkspaceConfiguration.builder() + .initialSize(0) + .overallocationLimit(0.02) + .policyLearning(LearningPolicy.OVER_TIME) + .cyclesBeforeInitialization(numWorkingMemCycles) + .policyReset(ResetPolicy.BLOCK_LEFT) + .policySpill(SpillPolicy.REALLOCATE) + .policyAllocation(AllocationPolicy.OVERALLOCATE) + .build(); + } - protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder() - .initialSize(0) - .overallocationLimit(0.05) - .policyLearning(LearningPolicy.FIRST_LOOP) - .policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE) - .policyAllocation(AllocationPolicy.OVERALLOCATE) + protected static WorkspaceConfiguration getLayerActivationWSConfig(int numLayers) { + //Activations memory: opened once per layer - for every second layer (preprocessors are within the loop). + //Technically we could set learning to numLayers / 2, but will set to numLayers for simplicity, and also to + // account for a backward pass + return WorkspaceConfiguration.builder() + .initialSize(0) + .overallocationLimit(0.02) + .policyLearning(LearningPolicy.OVER_TIME) + .cyclesBeforeInitialization(numLayers) + .policyReset(ResetPolicy.BLOCK_LEFT) + .policySpill(SpillPolicy.REALLOCATE) + .policyAllocation(AllocationPolicy.OVERALLOCATE) + .build(); + } + + /** + * Restore a MultiLayerNetwork to a file, saved using {@link #save(File)} or + * {@link ModelSerializer} + * + * @param f File to load the network from + * @param loadUpdater If true: load the updater if it is available (i.e., the state array for + * momentum/Adam/rmsprop etc) - use false if no further training is + * required, or true if further training will be undertaken + * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams) + */ + public static MultiLayerNetwork load(File f, boolean loadUpdater) throws IOException { + return ModelSerializer.restoreMultiLayerNetwork(f, loadUpdater); + } + + /** + * This method sets specified CacheMode for all layers within network + * + * @param mode + */ + public void setCacheMode(CacheMode mode) { + if (mode == null) { + mode = CacheMode.NONE; + } + + for (Layer layer : layers) { + layer.setCacheMode(mode); + } + } + + /** + * Get the last ETL time. This in informational, and is the amount of time in milliseconds that + * was required to obtain the last DataSet/MultiDataSet during fitting. A value consistently above + * 0 may indicate a data feeding bottleneck, or no asynchronous data prefetching (async prefetch + * is enabled by default) + * + * @return The last ETL time in milliseconds, if avaliable (or 0 if not) + */ + public long getLastEtlTime() { + Long time = lastEtlTime.get(); + return time == null ? 0L : time; + } + + /** + * Set the last ETL time in milliseconds, for informational/reporting purposes. Generally used + * internally. + * + * @param time ETL time + */ + public void setLastEtlTime(long time) { + lastEtlTime.set(time); + } + + protected void intializeConfigurations() { + if (layerWiseConfigurations == null) { + layerWiseConfigurations = new MultiLayerConfiguration.Builder().build(); + } + + if (layers == null) { + layers = new Layer[getnLayers()]; + } + + if (defaultConfiguration == null) { + defaultConfiguration = new NeuralNetConfiguration.Builder().build(); + } + } + + /** + * Perform layerwise pretraining for one epoch - see {@link #pretrain(DataSetIterator, int)} + */ + public void pretrain(DataSetIterator iter) { + pretrain(iter, 1); + } + + /** + * Perform layerwise unsupervised training on all pre-trainable layers in the network (VAEs, + * Autoencoders, etc), for the specified number of epochs each. For example, if numEpochs=3, then + * layer 0 will be fit for 3 epochs, followed by layer 1 for 3 epochs, and so on.
Note that + * pretraining will be performed on one layer after the other. To perform unsupervised training on + * a single layer, use {@link #pretrainLayer(int, DataSetIterator)} + * + * @param iter Training data + */ + public void pretrain(DataSetIterator iter, int numEpochs) { + if (flattenedGradients == null) { + initGradientsView(); + } + + for (int i = 0; i < getnLayers(); i++) { + pretrainLayer(i, iter, numEpochs); + } + } + + /** + * Fit for one epoch - see {@link #pretrainLayer(int, DataSetIterator, int)} + */ + public void pretrainLayer(int layerIdx, DataSetIterator iter) { + pretrainLayer(layerIdx, iter, 1); + } + + /** + * Perform layerwise unsupervised training on a single pre-trainable layer in the network (VAEs, + * Autoencoders, etc) for the specified number of epochs
If the specified layer index (0 to + * numLayers - 1) is not a pretrainable layer, this is a no-op. + * + * @param layerIdx Index of the layer to train (0 to numLayers-1) + * @param iter Training data + * @param numEpochs Number of epochs to fit the specified layer for + */ + public void pretrainLayer(int layerIdx, DataSetIterator iter, int numEpochs) { + Preconditions.checkState(numEpochs > 0, "Number of epochs (%s) must be a positive number", + numEpochs); + + if (flattenedGradients == null) { + initGradientsView(); + } + if (layerIdx >= layers.length) { + throw new IllegalArgumentException( + "Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + layers.length + + ")"); + } + + Layer layer = layers[layerIdx]; + if (!layer.isPretrainLayer()) { + return; + } + + if (numEpochs > 1 && !iter.resetSupported()) { + throw new IllegalStateException("Cannot fit multiple epochs (" + numEpochs + + ") on an iterator that doesn't support resetting"); + } + + if (!iter.hasNext() && iter.resetSupported()) { + iter.reset(); + } + + log.info( + "Starting unsupervised training on layer " + layerIdx + " for " + numEpochs + " epochs"); + for (int i = 0; i < numEpochs; i++) { + if (i > 0) { + iter.reset(); + } + + while (iter.hasNext()) { + DataSet next = iter.next(); + input = next.getFeatures(); + pretrainLayer(layerIdx, input); + } + } + + int ec = getLayer(layerIdx).conf().getEpochCount() + 1; + getLayer(layerIdx).conf().setEpochCount(ec); + } + + /** + * Perform layerwise unsupervised training on a single pre-trainable layer in the network (VAEs, + * Autoencoders, etc)
If the specified layer index (0 to numLayers - 1) is not a pretrainable + * layer, this is a no-op. + * + * @param layerIdx Index of the layer to train (0 to numLayers-1) + * @param features Training data array + */ + public void pretrainLayer(int layerIdx, INDArray features) { + setInput(features); + setLayerMaskArrays(null, null); + + if (flattenedGradients == null) { + initGradientsView(); + } + if (layerIdx >= layers.length) { + throw new IllegalArgumentException( + "Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + layers.length + + ")"); + } + + LayerWorkspaceMgr workspaceMgr; + if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + workspaceMgr = LayerWorkspaceMgr.builder() + .defaultWorkspace(WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); + } + workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); + + Layer layer = layers[layerIdx]; + if (!layer.isPretrainLayer()) { + return; + } + + //Do forward pass to the layer to be pretrained + INDArray outputOfPrevLayer; + if (layerIdx == 0) { + outputOfPrevLayer = input; + } else { + //Yes, this part of training - but we'll do forward psas as inference mode when doing layerwise training + // to effectively freeze earlier layers and not apply dropout etc + outputOfPrevLayer = outputOfLayerDetached(false, FwdPassType.STANDARD, layerIndex - 1, + features, null, null, null); + } + + try (MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { + if (layerWiseConfigurations.getInputPreProcess(layerIdx) != null) { + + if (input.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + outputOfPrevLayer = layerWiseConfigurations.getInputPreProcess(layerIdx) + .preProcess(outputOfPrevLayer, (int) input.size(0), + LayerWorkspaceMgr.noWorkspaces(helperWorkspaces)); + } + + layer.fit(outputOfPrevLayer, workspaceMgr); + } + } + + @Override + public int batchSize() { + //In 99+% of cases, the input and labels dimension 0 size should be identical + //The only real exceptions: space to batch, and batch to space layers + //In those cases, we should base it on the labels size, as this impacts gradient calculation + if (input.size(0) > Integer.MAX_VALUE || labels.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + return labels == null ? (int) input.size(0) : (int) labels.size(0); + } + + @Override + public NeuralNetConfiguration conf() { + return defaultConfiguration; + } + + @Override + public void setConf(NeuralNetConfiguration conf) { + throw new UnsupportedOperationException(); + } + + @Override + public INDArray input() { + return input; + } + + @Override + public ConvexOptimizer getOptimizer() { + return solver.getOptimizer(); + } + + /** + * Get one parameter array for the network.
In MultiLayerNetwork, parameters are keyed like + * "0_W" and "0_b" to mean "weights of layer index 0" and "biases of layer index 0" respectively. + * Numbers increment sequentially, and the suffixes ("W", "b" etc) depend on the layer type, and + * are defined in the relevant parameter initializers for each layer.
Note that the returned + * INDArrays are views of the underlying network parameters, so modifications of the returned + * arrays will impact the parameters of the network. + * + * @param param the key of the parameter + * @return The specified parameter array for the network + * @see #paramTable() paramTable() method, for a map of all parameters + */ + @Override + public INDArray getParam(String param) { + //Get params for MultiLayerNetwork sub layers. + int idx = param.indexOf('_'); + if (idx == -1) { + throw new IllegalStateException( + "Invalid param key: does not have layer separator: \"" + param + "\""); + } + int layerIdx = Integer.parseInt(param.substring(0, idx)); + String newKey = param.substring(idx + 1); + + return layers[layerIdx].getParam(newKey); + } + + /** + * Return a map of all parameters in the network. Parameter names are as described in + * {@link #getParam(String)}. As per {@link #getParam(String)} the returned arrays are views - + * modifications to these will impact the underlying network parameters + * + * @return A map of all parameters in the network + */ + @Override + public Map paramTable() { + return paramTable(false); + } + + /** + * Returns a map of all parameters in the network as per {@link #paramTable()}.
Optionally + * (with backpropParamsOnly=true) only the 'backprop' parameters are returned - that is, any + * parameters involved only in unsupervised layerwise pretraining not standard inference/backprop + * are excluded from the returned list. + * + * @param backpropParamsOnly If true, return backprop params only. If false: return all params + * @return Parameters for the network + */ + public Map paramTable(boolean backpropParamsOnly) { + //Get all parameters from all layers + Map allParams = new LinkedHashMap<>(); + for (int i = 0; i < layers.length; i++) { + Map paramMap = layers[i].paramTable(backpropParamsOnly); + for (Map.Entry entry : paramMap.entrySet()) { + String newKey = i + "_" + entry.getKey(); + allParams.put(newKey, entry.getValue()); + } + } + return allParams; + } + + /** + * Intended for internal use + */ + @Override + public boolean updaterDivideByMinibatch(String paramName) { + int idx = paramName.indexOf('_'); + int layerIdx = Integer.parseInt(paramName.substring(0, idx)); + String subName = paramName.substring(idx + 1); + return getLayer(layerIdx).updaterDivideByMinibatch(subName); + } + + /** + * Set the parameters of the netowrk. Note that the parameter keys must match the format as + * described in {@link #getParam(String)} and {@link #paramTable()}. Note that the values of the + * parameters used as an argument to this method are copied - i.e., it is safe to later + * modify/reuse the values in the provided paramTable without this impacting the network. + * + * @param paramTable Parameters to set + */ + @Override + public void setParamTable(Map paramTable) { + Map currParamTable = paramTable(); + if (!currParamTable.keySet().equals(paramTable.keySet())) { + throw new IllegalArgumentException( + "Cannot set param table: parameter keys do not match.\n" + "Current: " + + currParamTable.keySet() + "\nTo set: " + paramTable.keySet()); + } + + for (String s : paramTable.keySet()) { + INDArray curr = currParamTable.get(s); + INDArray toSet = paramTable.get(s); + if (!Arrays.equals(curr.shape(), toSet.shape())) { + throw new IllegalArgumentException( + "Cannot set parameter table: parameter \"" + s + "\" shapes " + + "do not match. Current = " + Arrays.toString(curr.shape()) + ", to set = " + + Arrays.toString(toSet.shape())); + } + } + + //Now that we've checked ALL params (to avoid leaving net in half-modified state) + for (String s : paramTable.keySet()) { + INDArray curr = currParamTable.get(s); + INDArray toSet = paramTable.get(s); + curr.assign(toSet); + } + } + + /** + * Set the values of a single parameter. See {@link #setParamTable(Map)} and + * {@link #getParam(String)} for more details. + * + * @param key the key of the parameter to set + * @param val the new values for the parameter + */ + @Override + public void setParam(String key, INDArray val) { + //Set params for MultiLayerNetwork sub layers. + int idx = key.indexOf('_'); + if (idx == -1) { + throw new IllegalStateException( + "Invalid param key: not have layer separator: \"" + key + "\""); + } + int layerIdx = Integer.parseInt(key.substring(0, idx)); + String newKey = key.substring(idx + 1); + + layers[layerIdx].setParam(newKey, val); + } + + /** + * Get the configuration for the network + * + * @return Network configuration + */ + public MultiLayerConfiguration getLayerWiseConfigurations() { + return layerWiseConfigurations; + } + + /** + * This method is intended for internal/developer use only. + */ + public void setLayerWiseConfigurations(MultiLayerConfiguration layerWiseConfigurations) { + this.layerWiseConfigurations = layerWiseConfigurations; + } + + /** + * Initialize the MultiLayerNetwork. This should be called once before the network is used. This + * is functionally equivalent to calling {@code init(null, false)}. + * + * @see MultiLayerNetwork#init(INDArray, boolean) + */ + public void init() { + init(null, false); + } + + /** + * Initialize the MultiLayerNetwork, optionally with an existing parameters array. If an existing + * parameters array is specified, it will be used (and the values will not be modified) in the + * network; if no parameters array is specified, parameters will be initialized randomly according + * to the network configuration. + * + * @param parameters Network parameter. May be null. If null: randomly initialize. + * @param cloneParametersArray Whether the parameter array (if any) should be cloned, or used + * directly + */ + public void init(INDArray parameters, boolean cloneParametersArray) { + if (layerWiseConfigurations == null || layers == null) { + intializeConfigurations(); + } + if (initCalled) { + return; + } + + DataType netDtype = getLayerWiseConfigurations().getDataType(); + if (parameters != null && parameters.dataType() != netDtype) { + Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, + "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", + parameters); + if (cloneParametersArray) { + try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + parameters = parameters.castTo(netDtype); + } + } else { + throw new IllegalStateException( + "Error initializing network: Network datatype is set to " + netDtype + + " but provided array has datatype " + parameters.dataType() + + " with cloneParametersArray argument" + + " set to false. Cannot initialize net with specified datatype array if that array does not match network datatype"); + } + } + + if (layerMap == null) { + layerMap = new LinkedHashMap<>(); + } + + if (layerWiseConfigurations.getTrainingWorkspaceMode() == null) { + layerWiseConfigurations.setTrainingWorkspaceMode(WorkspaceMode.NONE); + } + + if (layerWiseConfigurations.getInferenceWorkspaceMode() == null) { + layerWiseConfigurations.setInferenceWorkspaceMode(WorkspaceMode.NONE); + } + + if (layerWiseConfigurations.getCacheMode() == null) { + layerWiseConfigurations.setCacheMode(CacheMode.NONE); + } + + OneTimeLogger.info(log, + "Starting MultiLayerNetwork with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", + layerWiseConfigurations.getTrainingWorkspaceMode(), + layerWiseConfigurations.getInferenceWorkspaceMode(), + layerWiseConfigurations.getCacheMode()); + + int nLayers = getnLayers(); + + if (nLayers < 1) { + throw new IllegalStateException("Unable to create network: number of layers is less than 1"); + } + + if (this.layers == null || this.layers[0] == null) { + if (this.layers == null) { + this.layers = new Layer[nLayers]; + } + + //First: Work out total length of params + long paramLength = 0; + val nParamsPerLayer = new long[nLayers]; + for (int i = 0; i < nLayers; i++) { + NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); + conf.getLayer().setDataType(netDtype); + nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf); + paramLength += nParamsPerLayer[i]; + } + + //Create parameters array, if required + boolean initializeParams; + if (parameters != null) { + if (!parameters.isRowVectorOrScalar()) { + throw new IllegalArgumentException("Invalid parameters: should be a row vector"); + } + if (parameters.length() != paramLength) { + throw new IllegalArgumentException("Invalid parameters: expected length " + paramLength + + ", got length " + parameters.length()); + } + + if (cloneParametersArray) { + flattenedParams = parameters.dup(); + } else { + flattenedParams = parameters; + } + + initializeParams = false; + } else if (paramLength > 0) { + flattenedParams = Nd4j.create(netDtype, 1, paramLength); + initializeParams = true; + } else { + //Edge case: 0 params in network + flattenedParams = null; + initializeParams = false; + } + + //Set RNG seed, for repeatability between initializations when set + if (initializeParams) { + Nd4j.getRandom().setSeed(getDefaultConfiguration().getSeed()); + } + + // construct multi-layer + long paramCountSoFar = 0; + for (int i = 0; i < nLayers; i++) { + INDArray paramsView; + if (nParamsPerLayer[i] > 0) { + paramsView = flattenedParams.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(paramCountSoFar, paramCountSoFar + nParamsPerLayer[i])); + } else { + paramsView = null; + } + paramCountSoFar += nParamsPerLayer[i]; + + NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); + layers[i] = conf.getLayer() + .instantiate(conf, trainingListeners, i, paramsView, initializeParams, netDtype); + layerMap.put(conf.getLayer().getLayerName(), layers[i]); + } + initCalled = true; + } + + //Set parameters in MultiLayerNetwork.defaultConfiguration for later use in BaseOptimizer.setupSearchState() etc + defaultConfiguration.clearVariables(); + List variables = defaultConfiguration.variables(false); + for (int i = 0; i < layers.length; i++) { + if (layers[i] == null) { + throw new IllegalStateException( + "Encountered null layer during initialization for layer " + i + + ": " + layerWiseConfigurations.getConf(i).getLayer().getClass().getSimpleName() + + " initialization " + + "returned null layer?"); + } + + for (String s : layers[i].conf().variables()) { + variables.add(i + "_" + s); + } + } + + // now we init solver & optimizer + if (solver == null) { + try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) .build(); - - protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG; - - protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder() - .initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT) - .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) - .policyLearning(LearningPolicy.FIRST_LOOP).build(); - - - public MultiLayerNetwork(MultiLayerConfiguration conf) { - this.layerWiseConfigurations = conf; - this.defaultConfiguration = conf.getConf(0).clone(); - - //Working memory: should learn over course of: (a) full forward pass, and (b) full backward pass - //Working memory should be opened once per layer and once per preprocessor, for each of forward and backward passes - int numWorkingMem = 2 * (layerWiseConfigurations.getConfs().size() + layerWiseConfigurations.getInputPreProcessors().size()); - WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); - WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(layerWiseConfigurations.getConfs().size()); + solver.initOptimizer(); + } } - protected static WorkspaceConfiguration getLayerWorkingMemWSConfig(int numWorkingMemCycles){ - return WorkspaceConfiguration.builder() - .initialSize(0) - .overallocationLimit(0.02) - .policyLearning(LearningPolicy.OVER_TIME) - .cyclesBeforeInitialization(numWorkingMemCycles) - .policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE) - .policyAllocation(AllocationPolicy.OVERALLOCATE) - .build(); + //Mark that input modification is allowed. + //TODO When is it safe to NOT skip the very first layer? It's not always safe... + // For example dropout + iterating over List that is used for multiple epochs... + for (int i = 1; i < layers.length; i++) { + layers[i].allowInputModification(true); } - protected static WorkspaceConfiguration getLayerActivationWSConfig(int numLayers){ - //Activations memory: opened once per layer - for every second layer (preprocessors are within the loop). - //Technically we could set learning to numLayers / 2, but will set to numLayers for simplicity, and also to - // account for a backward pass - return WorkspaceConfiguration.builder() - .initialSize(0) - .overallocationLimit(0.02) - .policyLearning(LearningPolicy.OVER_TIME) - .cyclesBeforeInitialization(numLayers) - .policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE) - .policyAllocation(AllocationPolicy.OVERALLOCATE) - .build(); + synchronizeIterEpochCounts(); + } + + /** + * This method allows you to specificy GradientsAccumulator instance to be used with this + * model
+ *
+ * PLEASE NOTE: Do not use this method unless you understand how to use GradientsAccumulator & + * updates sharing.
PLEASE NOTE: Do not use this method on standalone model + * + * @param accumulator Gradient accumulator to use for the network + */ + public void setGradientsAccumulator(GradientsAccumulator accumulator) { + if (!isInitCalled()) { + init(); } - /** - * This method sets specified CacheMode for all layers within network - * - * @param mode - */ - public void setCacheMode(CacheMode mode) { - if (mode == null) - mode = CacheMode.NONE; - - for (Layer layer : layers) { - layer.setCacheMode(mode); - } + if (solver == null) { + try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + .build(); + } } - /** - * Set the last ETL time in milliseconds, for informational/reporting purposes. Generally used internally. - * @param time ETL time - */ - public void setLastEtlTime(long time) { - lastEtlTime.set(time); - } + solver.getOptimizer().setGradientsAccumulator(accumulator); + } - /** - * Get the last ETL time. This in informational, and is the amount of time in milliseconds that was required - * to obtain the last DataSet/MultiDataSet during fitting. - * A value consistently above 0 may indicate a data feeding bottleneck, or no asynchronous data prefetching (async - * prefetch is enabled by default) - * @return The last ETL time in milliseconds, if avaliable (or 0 if not) - */ - public long getLastEtlTime() { - Long time = lastEtlTime.get(); - return time == null ? 0L : time; - } + public boolean isInitCalled() { + return initCalled; + } - /** - * Initialize the network based on the configuration (a MultiLayerConfiguration in JSON format) and parameters array - * - * @param conf the configuration json - * @param params the parameters for the network - */ - public MultiLayerNetwork(String conf, INDArray params) { - this(MultiLayerConfiguration.fromJson(conf)); + /** + * This method: initializes the flattened gradients array (used in backprop) and sets the + * appropriate subset in all layers. As a general rule, this shouldn't ever need to be called + * manually when doing training via fit(DataSet) or fit(DataSetIterator) + */ + public void initGradientsView() { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + if (layers == null) { init(); - setParameters(params); + } + + int nLayers = layers.length; + + //First: Work out total length of params + long paramLength = 0; + val nParamsPerLayer = new long[nLayers]; + for (int i = 0; i < nLayers; i++) { + NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); + nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf); + paramLength += nParamsPerLayer[i]; + } + + if (paramLength > 0) { + flattenedGradients = Nd4j.create(flattenedParams.dataType(), new long[]{1, paramLength}, + 'f'); //No need to initialize, as each layer will do it each iteration anyway + } + + long paramsSoFar = 0; + for (int i = 0; i < layers.length; i++) { + if (nParamsPerLayer[i] == 0) { + continue; //This layer doesn't have any parameters... + } + INDArray thisLayerGradView = flattenedGradients.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(paramsSoFar, paramsSoFar + nParamsPerLayer[i])); + layers[i].setBackpropGradientsViewArray(thisLayerGradView); + paramsSoFar += nParamsPerLayer[i]; + } + } + } + + protected INDArray activationFromPrevLayer(int curr, INDArray input, boolean training, + LayerWorkspaceMgr mgr) { + if (getLayerWiseConfigurations().getInputPreProcess(curr) != null) { + input = getLayerWiseConfigurations().getInputPreProcess(curr) + .preProcess(input, getInputMiniBatchSize(), mgr); } + INDArray ret = layers[curr].activate(input, training, mgr); + return ret; + } - /** - * Initialize the network based on the configuration and parameters array - * - * @param conf the configuration - * @param params the parameters - */ - public MultiLayerNetwork(MultiLayerConfiguration conf, INDArray params) { - this(conf); - init(); - setParameters(params); + /** + * Calculate activation for few layers at once. Suitable for autoencoder partial activation. + *

+ * In example: in 10-layer deep autoencoder, layers 0 - 4 inclusive are used for encoding part, + * and layers 5-9 inclusive are used for decoding part. + * + * @param from first layer to be activated, inclusive + * @param to last layer to be activated, inclusive + * @return the activation from the last layer + */ + public INDArray activateSelectedLayers(int from, int to, INDArray input) { + if (input == null) { + throw new IllegalStateException("Unable to perform activation; no input found"); + } + if (from < 0 || from >= layers.length || from >= to) { + throw new IllegalStateException("Unable to perform activation; FROM is out of layer space"); + } + if (to < 1 || to >= layers.length) { + throw new IllegalStateException("Unable to perform activation; TO is out of layer space"); } + try { + LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces(helperWorkspaces); //TODO - protected void intializeConfigurations() { - if (layerWiseConfigurations == null) - layerWiseConfigurations = new MultiLayerConfiguration.Builder().build(); - - if (layers == null) - layers = new Layer[getnLayers()]; - - if (defaultConfiguration == null) - defaultConfiguration = new NeuralNetConfiguration.Builder().build(); + INDArray res = input; + for (int l = from; l <= to; l++) { + res = this.activationFromPrevLayer(l, res, false, mgr); + } + return res; + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; } + } + /** + * Compute all layer activations, from input to output of the output layer. Note that the input is + * included in the list: thus feedForward(in,train).get(0) is the inputs, .get(1) is the + * activations of layer 0, and so on. + * + * @param train Training: if true, perform forward pass/inference at training time. Usually, + * inference is performed with train = false. This impacts whether dropout etc is + * applied or not. + * @return The list of activations for each layer, including the input + */ + public List feedForward(INDArray input, boolean train) { + setInput(input); + return feedForward(train); + } - /** - * Perform layerwise pretraining for one epoch - see {@link #pretrain(DataSetIterator, int)} - */ - public void pretrain(DataSetIterator iter) { - pretrain(iter, 1); + /** + * Compute activations from input to output of the output layer. As per + * {@link #feedForward(INDArray, boolean)} but using the inputs that have previously been set + * using {@link #setInput(INDArray)} + * + * @return the list of activations for each layer + */ + public List feedForward(boolean train) { + try { + return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layers.length - 1, + input, mask, null, true); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; } + } - /** - * Perform layerwise unsupervised training on all pre-trainable layers in the network (VAEs, Autoencoders, etc), for the specified - * number of epochs each. For example, if numEpochs=3, then layer 0 will be fit for 3 epochs, followed by layer 1 - * for 3 epochs, and so on.
- * Note that pretraining will be performed on one layer after the other. To perform unsupervised training on a single layer, - * use {@link #pretrainLayer(int, DataSetIterator)} - * - * @param iter Training data - */ - public void pretrain(DataSetIterator iter, int numEpochs){ - if (flattenedGradients == null) { - initGradientsView(); - } - - for (int i = 0; i < getnLayers(); i++) { - pretrainLayer(i, iter, numEpochs); - } + /** + * Perform feed-forward, optionally (not) clearing the layer input arrays.
Note: when using + * clearInputs=false, there can be some performance and memory overhead: this is because the + * arrays are defined outside of workspaces (which are enabled by default) - otherwise, + * old/invalidated arrays could still be accessed after calling this method. Consequently: Don't + * use clearInputs=false unless you have a use case that requires them to remain after + * feed-forward has been completed + * + * @param train training mode (true) or test mode (false) + * @param clearInputs If false: don't clear the layer inputs + * @return Activations from feed-forward + */ + public List feedForward(boolean train, boolean clearInputs) { + try { + return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layers.length - 1, + input, mask, null, clearInputs); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; } + } - /** - * Fit for one epoch - see {@link #pretrainLayer(int, DataSetIterator, int)} - */ - public void pretrainLayer(int layerIdx, DataSetIterator iter) { - pretrainLayer(layerIdx, iter, 1); + /** + * Compute the activations from the input to the specified layer.
To compute activations for + * all layers, use feedForward(...) methods
Note: output list includes the original input. So + * list.get(0) is always the original input, and list.get(i+1) is the activations of the ith + * layer. + * + * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. + * feedForwardToLayer(i,input) will return the activations for layers 0..i + * (inclusive) + * @param input Input to the network + * @return list of activations. + */ + public List feedForwardToLayer(int layerNum, INDArray input) { + try { + return ffToLayerActivationsDetached(false, FwdPassType.STANDARD, false, layerNum, input, mask, + null, true); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; } + } - /** - * Perform layerwise unsupervised training on a single pre-trainable layer in the network (VAEs, Autoencoders, etc) - * for the specified number of epochs
- * If the specified layer index (0 to numLayers - 1) is not a pretrainable layer, this is a no-op. - * - * @param layerIdx Index of the layer to train (0 to numLayers-1) - * @param iter Training data - * @param numEpochs Number of epochs to fit the specified layer for - */ - public void pretrainLayer(int layerIdx, DataSetIterator iter, int numEpochs) { - Preconditions.checkState(numEpochs > 0, "Number of epochs (%s) must be a positive number", numEpochs); - - if (flattenedGradients == null) { - initGradientsView(); - } - if (layerIdx >= layers.length) { - throw new IllegalArgumentException( - "Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + layers.length + ")"); - } - - Layer layer = layers[layerIdx]; - if (!layer.isPretrainLayer()) - return; - - if(numEpochs > 1 && !iter.resetSupported()) - throw new IllegalStateException("Cannot fit multiple epochs (" + numEpochs + ") on an iterator that doesn't support resetting"); - - if (!iter.hasNext() && iter.resetSupported()) { - iter.reset(); - } - - log.info("Starting unsupervised training on layer " + layerIdx + " for " + numEpochs + " epochs"); - for(int i=0; i 0) - iter.reset(); - - while (iter.hasNext()) { - DataSet next = iter.next(); - input = next.getFeatures(); - pretrainLayer(layerIdx, input); - } - } - - int ec = getLayer(layerIdx).conf().getEpochCount() + 1; - getLayer(layerIdx).conf().setEpochCount(ec); + /** + * Compute the activations from the input to the specified layer.
To compute activations for + * all layers, use feedForward(...) methods
Note: output list includes the original input. So + * list.get(0) is always the original input, and list.get(i+1) is the activations of the ith + * layer. + * + * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. + * feedForwardToLayer(i,input) will return the activations for layers 0..i + * (inclusive) + * @param input Input to the network + * @param train true for training, false for test (i.e., false if using network after + * training) + * @return list of activations. + */ + public List feedForwardToLayer(int layerNum, INDArray input, boolean train) { + try { + int layerVertexIdx = layers[layerNum].getIndex(); + return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerVertexIdx, input, + mask, null, true); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; } + } - /** - * Perform layerwise unsupervised training on a single pre-trainable layer in the network (VAEs, Autoencoders, etc)
- * If the specified layer index (0 to numLayers - 1) is not a pretrainable layer, this is a no-op. - * - * @param layerIdx Index of the layer to train (0 to numLayers-1) - * @param features Training data array - */ - public void pretrainLayer(int layerIdx, INDArray features) { - setInput(features); - setLayerMaskArrays(null, null); + /** + * Compute the activations from the input to the specified layer, using the currently set input + * for the network.
To compute activations for all layers, use feedForward(...) methods
+ * Note: output list includes the original input. So list.get(0) is always the original input, and + * list.get(i+1) is the activations of the ith layer. + * + * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. + * feedForwardToLayer(i,input) will return the activations for layers 0..i + * (inclusive) + * @param train true for training, false for test (i.e., false if using network after + * training) + * @return list of activations. + */ + public List feedForwardToLayer(int layerNum, boolean train) { + try { + return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerNum, input, mask, + null, true); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } - if (flattenedGradients == null) { - initGradientsView(); - } - if (layerIdx >= layers.length) { - throw new IllegalArgumentException( - "Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + layers.length + ")"); + protected void validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, + int layerIdx, + boolean isPreprocessor, String op) { + try { + mgr.validateArrayLocation(arrayType, array, false, layerIdx > 0); + } catch (ND4JWorkspaceException e) { + String layerName = layers[layerIdx].conf().getLayer().getLayerName(); + String clazz; + if (isPreprocessor) { + clazz = layerWiseConfigurations.getInputPreProcess(layerIdx).getClass().getName(); + } else { + clazz = layers[layerIdx].getClass().getName(); + } + throw new IllegalStateException( + op + ": array (" + arrayType + ") workspace validation failed (" + + (isPreprocessor ? "preprocessor" : "layer ") + layerIdx + (layerName != null ? + " - layer name \"" + + layerName + "\"" : "") + " - class: " + clazz + + ") - array is defined in incorrect workspace", e); + } + } + + /** + * Feed-forward through the network - returning all array activations in a list, detached from any + * workspace. Note that no workspace should be active externally when calling this method (an + * exception will be thrown if a workspace is open externally) + * + * @param train Training mode (true) or test/inference mode (false) + * @param fwdPassType Type of forward pass to perform (STANDARD or + * RNN_ACTIVATE_WITH_STORED_STATE only) + * @param storeLastForTBPTT ONLY used if fwdPassType == + * FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE + * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use + * numLayers-1 + * @param input Input to the network + * @param fMask Feature mask array. May be null. + * @param lMask Label mask array. May be null. + * @param clearInputs Whether the layer inputs should be cleared + * @return List of activations (including the input), detached from any workspace + */ + protected synchronized List ffToLayerActivationsDetached(boolean train, + @NonNull FwdPassType fwdPassType, + boolean storeLastForTBPTT, int layerIndex, @NonNull INDArray input, + INDArray fMask, INDArray lMask, boolean clearInputs) { + setInput(input); + setLayerMaskArrays(fMask, lMask); + + //Verify that no workspace is open externally + WorkspaceUtils.assertNoWorkspacesOpen( + "Expected no workspace active in ffToLayerActivationsDetached"); + + LayerWorkspaceMgr workspaceMgr; + WorkspaceMode wsm = (train ? layerWiseConfigurations.getTrainingWorkspaceMode() + : layerWiseConfigurations.getInferenceWorkspaceMode()); + if (wsm == WorkspaceMode.NONE) { + workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + workspaceMgr = LayerWorkspaceMgr.builder() + .noWorkspaceFor(ArrayType.ACTIVATIONS) + .with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); + + if (input.isAttached()) { + //Don't leverage out of async DataSetIterator workspaces + workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); + } + + if (!clearInputs) { + workspaceMgr.setScopedOutFor(ArrayType.INPUT); + } + } + workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); + + List out = new ArrayList<>(); + out.add(workspaceMgr.leverageTo(ArrayType.INPUT, + input)); //Should be unnecessary (and no op), if layer is implemented correctly + + for (int i = 0; i <= layerIndex; i++) { + try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered( + ArrayType.FF_WORKING_MEM)) { + if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { + input = getLayerWiseConfigurations().getInputPreProcess(i) + .preProcess(input, getInputMiniBatchSize(), workspaceMgr); + //Validation: Exception if invalid (bad preprocessor implementation) + validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, + "Feed forward to layer (inference)"); } - LayerWorkspaceMgr workspaceMgr; - if(layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ - workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); + if (fwdPassType == FwdPassType.STANDARD) { + input = layers[i].activate(input, train, workspaceMgr); + } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) { + if (layers[i] instanceof RecurrentLayer) { + input = ((RecurrentLayer) layers[i]).rnnActivateUsingStoredState(input, train, + storeLastForTBPTT, workspaceMgr); + } else if (layers[i] instanceof BaseWrapperLayer + && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) { + RecurrentLayer rl = (RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying(); + input = rl.rnnActivateUsingStoredState(input, train, storeLastForTBPTT, workspaceMgr); + } else if (layers[i] instanceof MultiLayerNetwork) { + List temp = ((MultiLayerNetwork) layers[i]).rnnActivateUsingStoredState(input, + train, storeLastForTBPTT); + input = temp.get(temp.size() - 1); + } else { + input = layers[i].activate(input, train, workspaceMgr); + } } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .defaultWorkspace(WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + throw new IllegalStateException( + "Forward pass type not supported for this method: " + fwdPassType); } - workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - Layer layer = layers[layerIdx]; - if (!layer.isPretrainLayer()) - return; + //Validation: Exception if invalid (bad layer implementation) + validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, + "Feed forward to layer (inference)"); - //Do forward pass to the layer to be pretrained - INDArray outputOfPrevLayer; - if(layerIdx == 0) { - outputOfPrevLayer = input; + out.add(input); + } + if (clearInputs) { + layers[i].clear(); + } + } + + return out; + } + + /** + * Feed-forward through the network at training time - returning a list of all activations in a + * workspace (WS_ALL_LAYERS_ACT) if workspaces are enabled for training; or detached if no + * workspaces are used.
Note: if using workspaces for training, this method requires that + * WS_ALL_LAYERS_ACT is open externally.
If using NO workspaces, requires that no external + * workspace is open
Note that this method does NOT clear the inputs to each layer - instead, + * they are in the WS_ALL_LAYERS_ACT workspace for use in later backprop. + * + * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use + * numLayers-1 + * @param fwdPassType Type of forward pass to perform (STANDARD or + * RNN_ACTIVATE_WITH_STORED_STATE only) + * @param storeLastForTBPTT ONLY used if fwdPassType == + * FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE + * @param input Input to network + * @param fMask Feature mask array. May be null + * @param lMask Label mask aray. May be null. + * @return + */ + protected synchronized List ffToLayerActivationsInWs(int layerIndex, + @NonNull FwdPassType fwdPassType, boolean storeLastForTBPTT, + @NonNull INDArray input, INDArray fMask, INDArray lMask) { + setInput(input); + setLayerMaskArrays(fMask, lMask); + + LayerWorkspaceMgr workspaceMgr; + if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + WorkspaceUtils.assertNoWorkspacesOpen( + "Expected no workspace active in ffToLayerActivationsInWs when training workspace is set to NONE"); + workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + workspaceMgr = LayerWorkspaceMgr.builder() + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); + + if (input.isAttached()) { + //Don't leverage out of async DataSetIterator workspaces + workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); + } + + if (layerWiseConfigurations.getCacheMode() != CacheMode.NONE) { + //For now: store cache mode activations in activations workspace + workspaceMgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); + workspaceMgr.setWorkspace(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, + WS_LAYER_WORKING_MEM_CONFIG); + } + + WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, + "ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open"); + } + workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); + + List out = new ArrayList<>(); + out.add(workspaceMgr.leverageTo(ArrayType.INPUT, input)); //Probably unnecessary usually + + boolean traceLog = log.isTraceEnabled(); + + for (int i = 0; i <= layerIndex; i++) { + try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered( + ArrayType.FF_WORKING_MEM)) { + if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { + input = getLayerWiseConfigurations().getInputPreProcess(i) + .preProcess(input, getInputMiniBatchSize(), workspaceMgr); + //Validation: Exception if invalid (bad preprocessor implementation) + validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, + "Feed forward to layer (training)"); + } + + if (traceLog) { + log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); + } + + if (fwdPassType == FwdPassType.STANDARD) { + input = layers[i].activate(input, true, workspaceMgr); + } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) { + if (layers[i] instanceof RecurrentLayer) { + input = ((RecurrentLayer) layers[i]).rnnActivateUsingStoredState(input, true, + storeLastForTBPTT, workspaceMgr); + } else if (layers[i] instanceof BaseWrapperLayer + && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) { + RecurrentLayer rl = (RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying(); + input = rl.rnnActivateUsingStoredState(input, true, storeLastForTBPTT, workspaceMgr); + } else if (layers[i] instanceof MultiLayerNetwork) { + List temp = ((MultiLayerNetwork) layers[i]).rnnActivateUsingStoredState(input, + true, storeLastForTBPTT); + input = temp.get(temp.size() - 1); + } else { + input = layers[i].activate(input, true, workspaceMgr); + } } else { - //Yes, this part of training - but we'll do forward psas as inference mode when doing layerwise training - // to effectively freeze earlier layers and not apply dropout etc - outputOfPrevLayer = outputOfLayerDetached(false, FwdPassType.STANDARD, layerIndex-1, features, null, null, null); + throw new IllegalStateException( + "FwdPassType not supported for this method: " + fwdPassType); } - try(MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { - if (layerWiseConfigurations.getInputPreProcess(layerIdx) != null) { - - if (input.size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - outputOfPrevLayer = layerWiseConfigurations.getInputPreProcess(layerIdx).preProcess(outputOfPrevLayer, (int) input.size(0), - LayerWorkspaceMgr.noWorkspaces(helperWorkspaces)); - } - - layer.fit(outputOfPrevLayer, workspaceMgr); - } - } - - @Override - public int batchSize() { - //In 99+% of cases, the input and labels dimension 0 size should be identical - //The only real exceptions: space to batch, and batch to space layers - //In those cases, we should base it on the labels size, as this impacts gradient calculation - if (input.size(0) > Integer.MAX_VALUE || labels.size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - return labels == null ? (int) input.size(0) : (int)labels.size(0); - } - - @Override - public NeuralNetConfiguration conf() { - return defaultConfiguration; - } - - @Override - public void setConf(NeuralNetConfiguration conf) { - throw new UnsupportedOperationException(); - } - - @Override - public INDArray input() { - return input; - } - - @Override - public ConvexOptimizer getOptimizer() { - return solver.getOptimizer(); - } - - /** - * Get one parameter array for the network.
- * In MultiLayerNetwork, parameters are keyed like "0_W" and "0_b" to mean "weights of layer index 0" and "biases - * of layer index 0" respectively. Numbers increment sequentially, and the suffixes ("W", "b" etc) depend on the - * layer type, and are defined in the relevant parameter initializers for each layer.
- * Note that the returned INDArrays are views of the underlying network parameters, so modifications of the returned - * arrays will impact the parameters of the network. - * - * @param param the key of the parameter - * @return The specified parameter array for the network - * @see #paramTable() paramTable() method, for a map of all parameters - */ - @Override - public INDArray getParam(String param) { - //Get params for MultiLayerNetwork sub layers. - int idx = param.indexOf('_'); - if (idx == -1) - throw new IllegalStateException("Invalid param key: does not have layer separator: \"" + param + "\""); - int layerIdx = Integer.parseInt(param.substring(0, idx)); - String newKey = param.substring(idx + 1); - - return layers[layerIdx].getParam(newKey); - } - - /** - * Return a map of all parameters in the network. Parameter names are as described in {@link #getParam(String)}. - * As per {@link #getParam(String)} the returned arrays are views - modifications to these will impact - * the underlying network parameters - * @return A map of all parameters in the network - */ - @Override - public Map paramTable() { - return paramTable(false); - } - - /** - * Returns a map of all parameters in the network as per {@link #paramTable()}.
- * Optionally (with backpropParamsOnly=true) only the 'backprop' parameters are returned - that is, any parameters - * involved only in unsupervised layerwise pretraining not standard inference/backprop are excluded from the returned list. - * @param backpropParamsOnly If true, return backprop params only. If false: return all params - * @return Parameters for the network - */ - public Map paramTable(boolean backpropParamsOnly) { - //Get all parameters from all layers - Map allParams = new LinkedHashMap<>(); - for (int i = 0; i < layers.length; i++) { - Map paramMap = layers[i].paramTable(backpropParamsOnly); - for (Map.Entry entry : paramMap.entrySet()) { - String newKey = i + "_" + entry.getKey(); - allParams.put(newKey, entry.getValue()); - } - } - return allParams; - } - - /** - * Intended for internal use - */ - @Override - public boolean updaterDivideByMinibatch(String paramName) { - int idx = paramName.indexOf('_'); - int layerIdx = Integer.parseInt(paramName.substring(0, idx)); - String subName = paramName.substring(idx+1); - return getLayer(layerIdx).updaterDivideByMinibatch(subName); - } - - /** - * Set the parameters of the netowrk. Note that the parameter keys must match the format as described in {@link #getParam(String)} - * and {@link #paramTable()}. Note that the values of the parameters used as an argument to this method are copied - - * i.e., it is safe to later modify/reuse the values in the provided paramTable without this impacting the network. - * - * @param paramTable Parameters to set - */ - @Override - public void setParamTable(Map paramTable) { - Map currParamTable = paramTable(); - if (!currParamTable.keySet().equals(paramTable.keySet())) { - throw new IllegalArgumentException("Cannot set param table: parameter keys do not match.\n" + "Current: " - + currParamTable.keySet() + "\nTo set: " + paramTable.keySet()); + if (input == null) { + throw new IllegalStateException("Layer " + i + " returned null activations"); } - for (String s : paramTable.keySet()) { - INDArray curr = currParamTable.get(s); - INDArray toSet = paramTable.get(s); - if (!Arrays.equals(curr.shape(), toSet.shape())) { - throw new IllegalArgumentException("Cannot set parameter table: parameter \"" + s + "\" shapes " - + "do not match. Current = " + Arrays.toString(curr.shape()) + ", to set = " - + Arrays.toString(toSet.shape())); - } - } + //Validation: Exception if invalid (bad layer implementation) + validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, + "Feed forward to layer (training)"); + validateArrayWorkspaces(workspaceMgr, layers[i].input(), ArrayType.INPUT, i, false, + "Feed forward to layer (training)"); - //Now that we've checked ALL params (to avoid leaving net in half-modified state) - for (String s : paramTable.keySet()) { - INDArray curr = currParamTable.get(s); - INDArray toSet = paramTable.get(s); - curr.assign(toSet); + out.add(input); + + if (traceLog) { + log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); } + } } - /** - * Set the values of a single parameter. See {@link #setParamTable(Map)} and {@link #getParam(String)} for more - * details. - * @param key the key of the parameter to set - * @param val the new values for the parameter - */ - @Override - public void setParam(String key, INDArray val) { - //Set params for MultiLayerNetwork sub layers. - int idx = key.indexOf('_'); - if (idx == -1) - throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\""); - int layerIdx = Integer.parseInt(key.substring(0, idx)); - String newKey = key.substring(idx + 1); - - layers[layerIdx].setParam(newKey, val); - } - - /** - * Get the configuration for the network - * @return Network configuration - */ - public MultiLayerConfiguration getLayerWiseConfigurations() { - return layerWiseConfigurations; - } - - /** - * This method is intended for internal/developer use only. - */ - public void setLayerWiseConfigurations(MultiLayerConfiguration layerWiseConfigurations) { - this.layerWiseConfigurations = layerWiseConfigurations; - } - - /** - * Initialize the MultiLayerNetwork. This should be called once before the network is used. - * This is functionally equivalent to calling {@code init(null, false)}. - * @see MultiLayerNetwork#init(INDArray, boolean) - */ - public void init() { - init(null, false); - } - - /** - * Initialize the MultiLayerNetwork, optionally with an existing parameters array. - * If an existing parameters array is specified, it will be used (and the values will not be modified) in the network; - * if no parameters array is specified, parameters will be initialized randomly according to the network configuration. - * - * @param parameters Network parameter. May be null. If null: randomly initialize. - * @param cloneParametersArray Whether the parameter array (if any) should be cloned, or used directly - */ - public void init(INDArray parameters, boolean cloneParametersArray) { - if (layerWiseConfigurations == null || layers == null) - intializeConfigurations(); - if (initCalled) - return; - - DataType netDtype = getLayerWiseConfigurations().getDataType(); - if(parameters != null && parameters.dataType() != netDtype){ - Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters); - if(cloneParametersArray){ - try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - parameters = parameters.castTo(netDtype); - } - } else { - throw new IllegalStateException("Error initializing network: Network datatype is set to " + netDtype - + " but provided array has datatype " + parameters.dataType() + " with cloneParametersArray argument" + - " set to false. Cannot initialize net with specified datatype array if that array does not match network datatype"); - } - } - - - if (layerMap == null) - layerMap = new LinkedHashMap<>(); - - if (layerWiseConfigurations.getTrainingWorkspaceMode() == null) - layerWiseConfigurations.setTrainingWorkspaceMode(WorkspaceMode.NONE); - - if (layerWiseConfigurations.getInferenceWorkspaceMode() == null) - layerWiseConfigurations.setInferenceWorkspaceMode(WorkspaceMode.NONE); - - if (layerWiseConfigurations.getCacheMode() == null) - layerWiseConfigurations.setCacheMode(CacheMode.NONE); - - OneTimeLogger.info(log, "Starting MultiLayerNetwork with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", - layerWiseConfigurations.getTrainingWorkspaceMode(), - layerWiseConfigurations.getInferenceWorkspaceMode(), - layerWiseConfigurations.getCacheMode()); - - int nLayers = getnLayers(); - - if (nLayers < 1) - throw new IllegalStateException("Unable to create network: number of layers is less than 1"); - - if (this.layers == null || this.layers[0] == null) { - if (this.layers == null) - this.layers = new Layer[nLayers]; - - //First: Work out total length of params - long paramLength = 0; - val nParamsPerLayer = new long[nLayers]; - for (int i = 0; i < nLayers; i++) { - NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); - conf.getLayer().setDataType(netDtype); - nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf); - paramLength += nParamsPerLayer[i]; - } - - //Create parameters array, if required - boolean initializeParams; - if (parameters != null) { - if (!parameters.isRowVectorOrScalar()) - throw new IllegalArgumentException("Invalid parameters: should be a row vector"); - if (parameters.length() != paramLength) - throw new IllegalArgumentException("Invalid parameters: expected length " + paramLength - + ", got length " + parameters.length()); - - if (cloneParametersArray) - flattenedParams = parameters.dup(); - else - flattenedParams = parameters; - - initializeParams = false; - } else if(paramLength > 0){ - flattenedParams = Nd4j.create(netDtype, 1, paramLength); - initializeParams = true; - } else { - //Edge case: 0 params in network - flattenedParams = null; - initializeParams = false; - } - - //Set RNG seed, for repeatability between initializations when set - if (initializeParams) { - Nd4j.getRandom().setSeed(getDefaultConfiguration().getSeed()); - } - - // construct multi-layer - long paramCountSoFar = 0; - for (int i = 0; i < nLayers; i++) { - INDArray paramsView; - if (nParamsPerLayer[i] > 0) { - paramsView = flattenedParams.get(NDArrayIndex.interval(0,0,true), - NDArrayIndex.interval(paramCountSoFar, paramCountSoFar + nParamsPerLayer[i])); - } else { - paramsView = null; - } - paramCountSoFar += nParamsPerLayer[i]; - - NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); - layers[i] = conf.getLayer().instantiate(conf, trainingListeners, i, paramsView, initializeParams, netDtype); - layerMap.put(conf.getLayer().getLayerName(), layers[i]); - } - initCalled = true; - } - - //Set parameters in MultiLayerNetwork.defaultConfiguration for later use in BaseOptimizer.setupSearchState() etc - defaultConfiguration.clearVariables(); - List variables = defaultConfiguration.variables(false); - for (int i = 0; i < layers.length; i++) { - if(layers[i] == null){ - throw new IllegalStateException("Encountered null layer during initialization for layer " + i + - ": " + layerWiseConfigurations.getConf(i).getLayer().getClass().getSimpleName() + " initialization " + - "returned null layer?"); - } - - for (String s : layers[i].conf().variables()) { - variables.add(i + "_" + s); - } - } - - // now we init solver & optimizer - if (solver == null) { - try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); - solver.initOptimizer(); - } - } - - //Mark that input modification is allowed. - //TODO When is it safe to NOT skip the very first layer? It's not always safe... - // For example dropout + iterating over List that is used for multiple epochs... - for( int i=1; i - *
- * PLEASE NOTE: Do not use this method unless you understand how to use GradientsAccumulator & updates sharing.
- * PLEASE NOTE: Do not use this method on standalone model - * - * @param accumulator Gradient accumulator to use for the network - */ - public void setGradientsAccumulator(GradientsAccumulator accumulator) { - if (!isInitCalled()) - init(); - - if (solver == null) { - try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) - .build(); - } - } - - solver.getOptimizer().setGradientsAccumulator(accumulator); - } - - public boolean isInitCalled() { - return initCalled; - } - - /** - * This method: initializes the flattened gradients array (used in backprop) and sets the appropriate subset in all layers. - * As a general rule, this shouldn't ever need to be called manually when doing training via fit(DataSet) or fit(DataSetIterator) - */ - public void initGradientsView() { - try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - if (layers == null) - init(); - - int nLayers = layers.length; - - //First: Work out total length of params - long paramLength = 0; - val nParamsPerLayer = new long[nLayers]; - for (int i = 0; i < nLayers; i++) { - NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); - nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf); - paramLength += nParamsPerLayer[i]; - } - - if(paramLength > 0) { - flattenedGradients = Nd4j.create(flattenedParams.dataType(), new long[]{1, paramLength}, 'f'); //No need to initialize, as each layer will do it each iteration anyway - } - - long paramsSoFar = 0; - for (int i = 0; i < layers.length; i++) { - if (nParamsPerLayer[i] == 0) - continue; //This layer doesn't have any parameters... - INDArray thisLayerGradView = flattenedGradients.get(NDArrayIndex.interval(0,0,true), - NDArrayIndex.interval(paramsSoFar, paramsSoFar + nParamsPerLayer[i])); - layers[i].setBackpropGradientsViewArray(thisLayerGradView); - paramsSoFar += nParamsPerLayer[i]; - } - } - } - - protected INDArray activationFromPrevLayer(int curr, INDArray input, boolean training, LayerWorkspaceMgr mgr) { - if (getLayerWiseConfigurations().getInputPreProcess(curr) != null) { - input = getLayerWiseConfigurations().getInputPreProcess(curr).preProcess(input, getInputMiniBatchSize(), mgr); - } - - INDArray ret = layers[curr].activate(input, training, mgr); - return ret; - } - - /** - * Calculate activation for few layers at once. Suitable for autoencoder partial activation. - * - * In example: in 10-layer deep autoencoder, layers 0 - 4 inclusive are used for encoding part, and layers 5-9 inclusive are used for decoding part. - * - * @param from first layer to be activated, inclusive - * @param to last layer to be activated, inclusive - * @return the activation from the last layer - */ - public INDArray activateSelectedLayers(int from, int to, INDArray input) { - if (input == null) - throw new IllegalStateException("Unable to perform activation; no input found"); - if (from < 0 || from >= layers.length || from >= to) - throw new IllegalStateException("Unable to perform activation; FROM is out of layer space"); - if (to < 1 || to >= layers.length) - throw new IllegalStateException("Unable to perform activation; TO is out of layer space"); - - try { - LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces(helperWorkspaces); //TODO - - INDArray res = input; - for (int l = from; l <= to; l++) { - res = this.activationFromPrevLayer(l, res, false, mgr); - } - return res; - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - /** - * Compute all layer activations, from input to output of the output layer. - * Note that the input is included in the list: thus feedForward(in,train).get(0) is the inputs, - * .get(1) is the activations of layer 0, and so on. - * - * @param train Training: if true, perform forward pass/inference at training time. Usually, inference is performed - * with train = false. This impacts whether dropout etc is applied or not. - * @return The list of activations for each layer, including the input - */ - public List feedForward(INDArray input, boolean train) { - setInput(input); - return feedForward(train); - } - - /** - * Compute activations from input to output of the output layer. - * As per {@link #feedForward(INDArray, boolean)} but using the inputs that have previously been set using {@link #setInput(INDArray)} - * - * @return the list of activations for each layer - */ - public List feedForward(boolean train) { - try { - return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layers.length-1, - input, mask, null, true); - } catch (OutOfMemoryError e) { - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - /** - * Perform feed-forward, optionally (not) clearing the layer input arrays.
- * Note: when using clearInputs=false, there can be some performance and memory overhead: this is because the arrays are - * defined outside of workspaces (which are enabled by default) - otherwise, old/invalidated arrays could still be - * accessed after calling this method. Consequently: Don't use clearInputs=false unless you have a use case that - * requires them to remain after feed-forward has been completed - * - * @param train training mode (true) or test mode (false) - * @param clearInputs If false: don't clear the layer inputs - * @return Activations from feed-forward - */ - public List feedForward(boolean train, boolean clearInputs){ - try{ - return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layers.length-1, input, mask, null, clearInputs); - } catch (OutOfMemoryError e) { - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - /** Compute the activations from the input to the specified layer.
- * To compute activations for all layers, use feedForward(...) methods
- * Note: output list includes the original input. So list.get(0) is always the original input, and - * list.get(i+1) is the activations of the ith layer. - * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. - * feedForwardToLayer(i,input) will return the activations for layers 0..i (inclusive) - * @param input Input to the network - * @return list of activations. - */ - public List feedForwardToLayer(int layerNum, INDArray input) { - try{ - return ffToLayerActivationsDetached(false, FwdPassType.STANDARD, false, layerNum, input, mask, null, true); - } catch (OutOfMemoryError e) { - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - /** Compute the activations from the input to the specified layer.
- * To compute activations for all layers, use feedForward(...) methods
- * Note: output list includes the original input. So list.get(0) is always the original input, and - * list.get(i+1) is the activations of the ith layer. - * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. - * feedForwardToLayer(i,input) will return the activations for layers 0..i (inclusive) - * @param input Input to the network - * @param train true for training, false for test (i.e., false if using network after training) - * @return list of activations. - */ - public List feedForwardToLayer(int layerNum, INDArray input, boolean train) { - try { - int layerVertexIdx = layers[layerNum].getIndex(); - return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerVertexIdx, input, mask, null, true); - } catch (OutOfMemoryError e) { - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - /** Compute the activations from the input to the specified layer, using the currently set input for the network.
- * To compute activations for all layers, use feedForward(...) methods
- * Note: output list includes the original input. So list.get(0) is always the original input, and - * list.get(i+1) is the activations of the ith layer. - * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. - * feedForwardToLayer(i,input) will return the activations for layers 0..i (inclusive) - * @param train true for training, false for test (i.e., false if using network after training) - * @return list of activations. - */ - public List feedForwardToLayer(int layerNum, boolean train) { - try { - return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerNum, input, mask, null, true); - } catch (OutOfMemoryError e) { - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - - protected void validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, int layerIdx, - boolean isPreprocessor, String op){ - try{ - mgr.validateArrayLocation(arrayType, array, false, layerIdx > 0); - } catch (ND4JWorkspaceException e){ - String layerName = layers[layerIdx].conf().getLayer().getLayerName(); - String clazz; - if(isPreprocessor){ - clazz = layerWiseConfigurations.getInputPreProcess(layerIdx).getClass().getName(); - } else { - clazz = layers[layerIdx].getClass().getName(); - } - throw new IllegalStateException(op + ": array (" + arrayType + ") workspace validation failed (" + - (isPreprocessor ? "preprocessor" : "layer ") + layerIdx + (layerName != null ? " - layer name \"" + - layerName + "\"" : "") + " - class: " + clazz + ") - array is defined in incorrect workspace", e); - } - } - - /** - * Feed-forward through the network - returning all array activations in a list, detached from any workspace. - * Note that no workspace should be active externally when calling this method (an exception will be thrown - * if a workspace is open externally) - * - * @param train Training mode (true) or test/inference mode (false) - * @param fwdPassType Type of forward pass to perform (STANDARD or RNN_ACTIVATE_WITH_STORED_STATE only) - * @param storeLastForTBPTT ONLY used if fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE - * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use numLayers-1 - * @param input Input to the network - * @param fMask Feature mask array. May be null. - * @param lMask Label mask array. May be null. - * @param clearInputs Whether the layer inputs should be cleared - * @return List of activations (including the input), detached from any workspace - */ - protected synchronized List ffToLayerActivationsDetached(boolean train, @NonNull FwdPassType fwdPassType, - boolean storeLastForTBPTT, int layerIndex, @NonNull INDArray input, - INDArray fMask, INDArray lMask, boolean clearInputs){ - setInput(input); - setLayerMaskArrays(fMask, lMask); - - //Verify that no workspace is open externally - WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in ffToLayerActivationsDetached"); - - LayerWorkspaceMgr workspaceMgr; - WorkspaceMode wsm = (train ? layerWiseConfigurations.getTrainingWorkspaceMode() : layerWiseConfigurations.getInferenceWorkspaceMode()); - if(wsm == WorkspaceMode.NONE){ - workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); - } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .noWorkspaceFor(ArrayType.ACTIVATIONS) - .with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); - - if(input.isAttached()){ - //Don't leverage out of async DataSetIterator workspaces - workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); - } - - if(!clearInputs){ - workspaceMgr.setScopedOutFor(ArrayType.INPUT); - } - } - workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - - List out = new ArrayList<>(); - out.add(workspaceMgr.leverageTo(ArrayType.INPUT, input)); //Should be unnecessary (and no op), if layer is implemented correctly - - for( int i=0; i<=layerIndex; i++ ){ - try(MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)){ - if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { - input = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, getInputMiniBatchSize(), workspaceMgr); - //Validation: Exception if invalid (bad preprocessor implementation) - validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, "Feed forward to layer (inference)"); - } - - if(fwdPassType == FwdPassType.STANDARD){ - input = layers[i].activate(input, train, workspaceMgr); - } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) { - if (layers[i] instanceof RecurrentLayer) { - input = ((RecurrentLayer) layers[i]).rnnActivateUsingStoredState(input, train, - storeLastForTBPTT, workspaceMgr); - } else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer) { - RecurrentLayer rl = (RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying(); - input = rl.rnnActivateUsingStoredState(input, train,storeLastForTBPTT, workspaceMgr); - } else if (layers[i] instanceof MultiLayerNetwork) { - List temp = ((MultiLayerNetwork) layers[i]).rnnActivateUsingStoredState(input, train, storeLastForTBPTT); - input = temp.get(temp.size() - 1); - } else { - input = layers[i].activate(input, train, workspaceMgr); - } - } else { - throw new IllegalStateException("Forward pass type not supported for this method: " + fwdPassType); - } - - //Validation: Exception if invalid (bad layer implementation) - validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, "Feed forward to layer (inference)"); - - out.add(input); - } - if(clearInputs) { - layers[i].clear(); - } - } - - return out; - } - - /** - * Feed-forward through the network at training time - returning a list of all activations in a workspace (WS_ALL_LAYERS_ACT) - * if workspaces are enabled for training; or detached if no workspaces are used.
- * Note: if using workspaces for training, this method requires that WS_ALL_LAYERS_ACT is open externally.
- * If using NO workspaces, requires that no external workspace is open
- * Note that this method does NOT clear the inputs to each layer - instead, they are in the WS_ALL_LAYERS_ACT workspace - * for use in later backprop. - * - * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use numLayers-1 - * @param fwdPassType Type of forward pass to perform (STANDARD or RNN_ACTIVATE_WITH_STORED_STATE only) - * @param storeLastForTBPTT ONLY used if fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE - * @param input Input to network - * @param fMask Feature mask array. May be null - * @param lMask Label mask aray. May be null. - * @return - */ - protected synchronized List ffToLayerActivationsInWs(int layerIndex, @NonNull FwdPassType fwdPassType, boolean storeLastForTBPTT, - @NonNull INDArray input, INDArray fMask, INDArray lMask){ - setInput(input); - setLayerMaskArrays(fMask, lMask); - - LayerWorkspaceMgr workspaceMgr; - if(layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ - WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in ffToLayerActivationsInWs when training workspace is set to NONE"); - workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); - } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); - - if(input.isAttached()){ - //Don't leverage out of async DataSetIterator workspaces - workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); - } - - if(layerWiseConfigurations.getCacheMode() != CacheMode.NONE){ - //For now: store cache mode activations in activations workspace - workspaceMgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); - workspaceMgr.setWorkspace(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG); - } - - WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, "ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open"); - } - workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - - List out = new ArrayList<>(); - out.add(workspaceMgr.leverageTo(ArrayType.INPUT, input)); //Probably unnecessary usually - - boolean traceLog = log.isTraceEnabled(); - - for( int i = 0; i <=layerIndex; i++) { - try(MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)){ - if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { - input = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, getInputMiniBatchSize(), workspaceMgr); - //Validation: Exception if invalid (bad preprocessor implementation) - validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, "Feed forward to layer (training)"); - } - - if(traceLog){ - log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); - } - - if(fwdPassType == FwdPassType.STANDARD){ - input = layers[i].activate(input, true, workspaceMgr); - } else if(fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE){ - if (layers[i] instanceof RecurrentLayer) { - input = ((RecurrentLayer) layers[i]).rnnActivateUsingStoredState(input, true, storeLastForTBPTT, workspaceMgr); - }else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer) { - RecurrentLayer rl = (RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying(); - input = rl.rnnActivateUsingStoredState(input, true, storeLastForTBPTT, workspaceMgr); - } else if (layers[i] instanceof MultiLayerNetwork) { - List temp = ((MultiLayerNetwork) layers[i]).rnnActivateUsingStoredState(input, true, storeLastForTBPTT); - input = temp.get(temp.size() - 1); - } else { - input = layers[i].activate(input, true, workspaceMgr); - } - } else { - throw new IllegalStateException("FwdPassType not supported for this method: " + fwdPassType); - } - - if(input == null){ - throw new IllegalStateException("Layer " + i + " returned null activations"); - } - - //Validation: Exception if invalid (bad layer implementation) - validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, "Feed forward to layer (training)"); - validateArrayWorkspaces(workspaceMgr, layers[i].input(), ArrayType.INPUT, i, false, "Feed forward to layer (training)"); - - out.add(input); - - if(traceLog){ - log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); - } - } - } - - return out; - } - - /** - * Provide the output of the specified layer, detached from any workspace. This is most commonly used at inference/test - * time, and is more memory efficient than {@link #ffToLayerActivationsDetached(boolean, FwdPassType, boolean, int, INDArray, INDArray, INDArray, boolean)} - * and {@link #ffToLayerActivationsInWs(int, FwdPassType, boolean, INDArray, INDArray, INDArray)}.
- * This method clears all layer inputs. - * - * NOTE: in general, no workspaces should be activated externally for this method! - * This method handles the workspace activation as required - * - * @param train Training mode (true) or test/inference mode (false) - * @param fwdPassType Type of forward pass to perform (STANDARD, RNN_TIMESTEP or RNN_ACTIVATE_WITH_STORED_STATE) - * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use numLayers-1 - * @param input Input to the network - * @param featureMask Input/feature mask array. May be null. - * @param labelsMask Labels mask array. May be null - * @param outputWorkspace Optional - if provided, outputs should be placed in this workspace. NOTE: this workspace - * must be open - * @return Output of the specified layer, detached from any workspace - */ - protected INDArray outputOfLayerDetached(boolean train, @NonNull FwdPassType fwdPassType, int layerIndex, @NonNull INDArray input, - INDArray featureMask, INDArray labelsMask, MemoryWorkspace outputWorkspace){ - setInput(input); - setLayerMaskArrays(featureMask, labelsMask); + return out; + } + + /** + * Provide the output of the specified layer, detached from any workspace. This is most commonly + * used at inference/test time, and is more memory efficient than + * {@link #ffToLayerActivationsDetached(boolean, FwdPassType, boolean, int, INDArray, INDArray, + * INDArray, boolean)} and + * {@link #ffToLayerActivationsInWs(int, FwdPassType, boolean, INDArray, INDArray, INDArray)}.
+ * This method clears all layer inputs. + *

+ * NOTE: in general, no workspaces should be activated externally for this method! This method + * handles the workspace activation as required + * + * @param train Training mode (true) or test/inference mode (false) + * @param fwdPassType Type of forward pass to perform (STANDARD, RNN_TIMESTEP or + * RNN_ACTIVATE_WITH_STORED_STATE) + * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use + * numLayers-1 + * @param input Input to the network + * @param featureMask Input/feature mask array. May be null. + * @param labelsMask Labels mask array. May be null + * @param outputWorkspace Optional - if provided, outputs should be placed in this workspace. + * NOTE: this workspace must be open + * @return Output of the specified layer, detached from any workspace + */ + protected INDArray outputOfLayerDetached(boolean train, @NonNull FwdPassType fwdPassType, + int layerIndex, @NonNull INDArray input, + INDArray featureMask, INDArray labelsMask, MemoryWorkspace outputWorkspace) { + setInput(input); + setLayerMaskArrays(featureMask, labelsMask); /* Idea here: we want to minimize memory, and return only the final array @@ -1203,672 +1341,731 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.d Additionally, we'll reconfigure the workspace manager for the *final* layer, so that we don't have to detach */ - if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { - WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in outputOfLayerDetached", true); - } else { - Preconditions.checkState(outputWorkspace.isScopeActive(), "Workspace \"" + outputWorkspace.getId() + - "\" was provided for the network/layer outputs. When provided, this workspace must be opened before " + - "calling the output method; furthermore, closing the workspace is the responsibility of the user"); + if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { + WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in outputOfLayerDetached", + true); + } else { + Preconditions.checkState(outputWorkspace.isScopeActive(), + "Workspace \"" + outputWorkspace.getId() + + "\" was provided for the network/layer outputs. When provided, this workspace must be opened before " + + + "calling the output method; furthermore, closing the workspace is the responsibility of the user"); + } + + LayerWorkspaceMgr mgrEven; + LayerWorkspaceMgr mgrOdd; + + WorkspaceMode wsm = train ? layerWiseConfigurations.getTrainingWorkspaceMode() + : layerWiseConfigurations.getInferenceWorkspaceMode(); + if (wsm == WorkspaceMode.NONE) { + mgrEven = LayerWorkspaceMgr.noWorkspaces(); + mgrOdd = mgrEven; + + //Check for external workspace - doesn't make sense to have one with workspace mode NONE + if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) { + throw new IllegalStateException("Workspace \"" + outputWorkspace.getId() + + "\" was provided for the network/layer outputs, however " + (train ? "training" + : "inference") + + " workspace mode is set to NONE. Cannot put output activations into the specified workspace if" + + + "workspaces are disabled for the network. use getConfiguration().setTraining/InferenceWorkspaceMode(WorkspaceMode.ENABLED)"); + } + } else { + mgrEven = LayerWorkspaceMgr.builder() + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) + .with(ArrayType.INPUT, WS_LAYER_ACT_2, + WS_LAYER_ACT_X_CONFIG) //Inputs should always be in the previous WS + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); + + mgrOdd = LayerWorkspaceMgr.builder() + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) + .with(ArrayType.INPUT, WS_LAYER_ACT_1, + WS_LAYER_ACT_X_CONFIG) //Inputs should always be in the previous WS + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); + } + mgrEven.setHelperWorkspacePointers(helperWorkspaces); + mgrOdd.setHelperWorkspacePointers(helperWorkspaces); + + MemoryWorkspace wsActCloseNext = null; + MemoryWorkspace temp = null; + MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace(); + + boolean traceLog = log.isTraceEnabled(); + + Throwable t = null; + try { + for (int i = 0; i <= layerIndex; i++) { + LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd); + + if (traceLog) { + log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); } - LayerWorkspaceMgr mgrEven; - LayerWorkspaceMgr mgrOdd; + //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) + //Hence: put inputs in working memory + if (i == 0 && wsm != WorkspaceMode.NONE) { + mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG); + } - WorkspaceMode wsm = train ? layerWiseConfigurations.getTrainingWorkspaceMode() : layerWiseConfigurations.getInferenceWorkspaceMode(); - if(wsm == WorkspaceMode.NONE){ - mgrEven = LayerWorkspaceMgr.noWorkspaces(); - mgrOdd = mgrEven; + try (MemoryWorkspace wsFFWorking = mgr.notifyScopeEntered( + ArrayType.FF_WORKING_MEM)) { //Working memory: opened/closed once per layer + //Activations workspaces: opened/closed every second layer. + //So mgrEven (WS_LAYER_ACT_1) open at start of 0, 2, 4, 8; closed at end of 1, 3, 5, 7 etc + //and mgrOdd (WS_LAYER_ACT_2) opened at start of 1, 3, 5, 7; closed at end of 2, 4, 6, 8 etc + temp = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS); - //Check for external workspace - doesn't make sense to have one with workspace mode NONE - if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){ - throw new IllegalStateException("Workspace \"" + outputWorkspace.getId() + - "\" was provided for the network/layer outputs, however " + (train ? "training" : "inference") + - " workspace mode is set to NONE. Cannot put output activations into the specified workspace if" + - "workspaces are disabled for the network. use getConfiguration().setTraining/InferenceWorkspaceMode(WorkspaceMode.ENABLED)"); + //Note that because we're opening activation workspaces not in a simple nested order, we'll manually + // override the previous workspace setting. Otherwise, when we close these workspaces, the "current" + // workspace may be set to the incorrect one + temp.setPreviousWorkspace(initialWorkspace); + + if (i == 0 && input.isAttached()) { + //Don't leverage out of async DataSetIterator workspaces + mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); + } + + if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { + input = getLayerWiseConfigurations().getInputPreProcess(i) + .preProcess(input, getInputMiniBatchSize(), mgr); + //Validation: Exception if invalid (bad preprocessor implementation) + validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, + "Output of layer (inference)"); + } + + if (i == layerIndex) { + if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) { + //Place activations in user-specified workspace + mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), + outputWorkspace.getWorkspaceConfiguration()); + } else { + //Final activations: should be detached + mgr.setScopedOutFor(ArrayType.ACTIVATIONS); } - } else { - mgrEven = LayerWorkspaceMgr.builder() - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) - .with(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) //Inputs should always be in the previous WS - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + } - mgrOdd = LayerWorkspaceMgr.builder() - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) - .with(ArrayType.INPUT, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) //Inputs should always be in the previous WS - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + if (fwdPassType == FwdPassType.STANDARD) { + //Standard feed-forward case + if (i > 0 && ConvolutionUtils.layerHasConvolutionLayout(layers[i - 1].conf().getLayer()) + && ConvolutionUtils.layerHasConvolutionLayout(layers[i].conf().getLayer())) { + + CNN2DFormat preLayerFormat = ConvolutionUtils.getFormatForLayer( + layers[i - 1].conf().getLayer()); + CNN2DFormat currLayerFormat = ConvolutionUtils.getFormatForLayer( + layers[i].conf().getLayer()); + if (preLayerFormat != currLayerFormat) { + //NHWC case + if (preLayerFormat == CNN2DFormat.NCHW) { + input = input.permute(0, 3, 1, 2); + } + //NCHW case + else if (preLayerFormat == CNN2DFormat.NHWC) { + input = input.permute(0, 2, 3, 1); + + } else { + throw new IllegalStateException( + "No CNN2DDataFormat type found for previous layer!"); + } + } + + input = layers[i].activate(input, train, mgr); + } else if (i > 0 && Convolution1DUtils.hasRnnDataFormat(layers[i - 1].conf().getLayer()) + && Convolution1DUtils.hasRnnDataFormat(layers[i].conf().getLayer())) { + RNNFormat preLayerFormat = Convolution1DUtils.getRnnFormatFromLayer( + layers[i - 1].conf().getLayer()); + RNNFormat currLayerFormat = Convolution1DUtils.getRnnFormatFromLayer( + layers[i].conf().getLayer()); + //permute for next layer + if (preLayerFormat != currLayerFormat) { + input = input.permute(0, 2, 1); + } + + input = layers[i].activate(input, train, mgr); + + + } else { + input = layers[i].activate(input, train, mgr); + } + } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) { + //rnnTimeStep case + if (layers[i] instanceof RecurrentLayer) { + input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr); + } else if (layers[i] instanceof BaseWrapperLayer + && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) { + RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying()); + input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr); + } else if (layers[i] instanceof MultiLayerNetwork) { + input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input)); + } else { + input = layers[i].activate(input, false, mgr); + } + } else { + throw new IllegalArgumentException( + "Unsupported forward pass type for this method: " + fwdPassType); + } + layers[i].clear(); + //Validation: Exception if invalid (bad layer implementation) + validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, + "Output of layer (inference)"); + + if (wsActCloseNext != null) { + wsActCloseNext.close(); + } + wsActCloseNext = temp; + temp = null; } - mgrEven.setHelperWorkspacePointers(helperWorkspaces); - mgrOdd.setHelperWorkspacePointers(helperWorkspaces); - MemoryWorkspace wsActCloseNext = null; - MemoryWorkspace temp = null; - MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace(); + if (traceLog) { + log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); + } - boolean traceLog = log.isTraceEnabled(); - - Throwable t = null; + //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) + //Hence: put inputs in working memory -> set back to default for next use of workspace mgr + if (i == 0 && wsm != WorkspaceMode.NONE) { + mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, + WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS + } + } + } catch (Throwable t2) { + t = t2; + } finally { + if (wsActCloseNext != null) { try { - for (int i = 0; i <= layerIndex; i++) { - LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd); - - if (traceLog) { - log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); - } - - //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) - //Hence: put inputs in working memory - if (i == 0 && wsm != WorkspaceMode.NONE) { - mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG); - } - - try (MemoryWorkspace wsFFWorking = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { //Working memory: opened/closed once per layer - //Activations workspaces: opened/closed every second layer. - //So mgrEven (WS_LAYER_ACT_1) open at start of 0, 2, 4, 8; closed at end of 1, 3, 5, 7 etc - //and mgrOdd (WS_LAYER_ACT_2) opened at start of 1, 3, 5, 7; closed at end of 2, 4, 6, 8 etc - temp = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS); - - //Note that because we're opening activation workspaces not in a simple nested order, we'll manually - // override the previous workspace setting. Otherwise, when we close these workspaces, the "current" - // workspace may be set to the incorrect one - temp.setPreviousWorkspace(initialWorkspace); - - - if (i == 0 && input.isAttached()) { - //Don't leverage out of async DataSetIterator workspaces - mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); - } - - if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { - input = getLayerWiseConfigurations().getInputPreProcess(i).preProcess(input, getInputMiniBatchSize(), mgr); - //Validation: Exception if invalid (bad preprocessor implementation) - validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)"); - } - - if (i == layerIndex) { - if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) { - //Place activations in user-specified workspace - mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration()); - } else { - //Final activations: should be detached - mgr.setScopedOutFor(ArrayType.ACTIVATIONS); - } - } - - if (fwdPassType == FwdPassType.STANDARD) { - //Standard feed-forward case - if(i > 0 && ConvolutionUtils.layerHasConvolutionLayout(layers[i - 1].conf().getLayer()) - && ConvolutionUtils.layerHasConvolutionLayout(layers[i].conf().getLayer())) { - - CNN2DFormat preLayerFormat = ConvolutionUtils.getFormatForLayer(layers[i - 1].conf().getLayer()); - CNN2DFormat currLayerFormat = ConvolutionUtils.getFormatForLayer(layers[i].conf().getLayer()); - if(preLayerFormat != currLayerFormat) { - //NHWC case - if(preLayerFormat == CNN2DFormat.NCHW) { - input = input.permute(0,3,1,2); - } - //NCHW case - else if(preLayerFormat == CNN2DFormat.NHWC) { - input = input.permute(0,2,3,1); - - } - else - throw new IllegalStateException("No CNN2DDataFormat type found for previous layer!"); - } - - input = layers[i].activate(input, train, mgr); - } else if(i > 0 && Convolution1DUtils.hasRnnDataFormat(layers[i - 1].conf().getLayer()) - && Convolution1DUtils.hasRnnDataFormat(layers[i].conf().getLayer())) { - RNNFormat preLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(layers[i - 1].conf().getLayer()); - RNNFormat currLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(layers[i].conf().getLayer()); - //permute for next layer - if(preLayerFormat != currLayerFormat) { - input = input.permute(0,2,1); - } - - input = layers[i].activate(input, train, mgr); - - - } else - input = layers[i].activate(input, train, mgr); - } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) { - //rnnTimeStep case - if (layers[i] instanceof RecurrentLayer) { - input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr); - } else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) { - RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying()); - input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr); - } else if (layers[i] instanceof MultiLayerNetwork) { - input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input)); - } else { - input = layers[i].activate(input, false, mgr); - } - } else { - throw new IllegalArgumentException("Unsupported forward pass type for this method: " + fwdPassType); - } - layers[i].clear(); - //Validation: Exception if invalid (bad layer implementation) - validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)"); - - if (wsActCloseNext != null) { - wsActCloseNext.close(); - } - wsActCloseNext = temp; - temp = null; - } - - if (traceLog) { - log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); - } - - //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) - //Hence: put inputs in working memory -> set back to default for next use of workspace mgr - if (i == 0 && wsm != WorkspaceMode.NONE) { - mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS - } - } - } catch (Throwable t2){ - t = t2; - } finally { - if(wsActCloseNext != null){ - try { - wsActCloseNext.close(); - } catch (Throwable t2){ - if(t != null){ - log.error("Encountered second exception while trying to close workspace after initial exception"); - log.error("Original exception:", t); - throw t2; - } - } - } - if(temp != null){ - //Should only be non-null on exception - while(temp.isScopeActive()){ - //For safety, should never occur in theory: a single close() call may not be sufficient, if - // workspace scope was borrowed and not properly closed when exception occurred - try{ - temp.close(); - } catch (Throwable t2){ - if(t != null){ - log.error("Encountered second exception while trying to close workspace after initial exception"); - log.error("Original exception:", t); - throw t2; - } - } - } + wsActCloseNext.close(); + } catch (Throwable t2) { + if (t != null) { + log.error( + "Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } + } + } + if (temp != null) { + //Should only be non-null on exception + while (temp.isScopeActive()) { + //For safety, should never occur in theory: a single close() call may not be sufficient, if + // workspace scope was borrowed and not properly closed when exception occurred + try { + temp.close(); + } catch (Throwable t2) { + if (t != null) { + log.error( + "Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; } + } + } + } - Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); + Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); - if(t != null){ - if(t instanceof RuntimeException){ - throw ((RuntimeException)t); - } - throw new RuntimeException("Error during neural network forward pass", t); - } + if (t != null) { + if (t instanceof RuntimeException) { + throw ((RuntimeException) t); + } + throw new RuntimeException("Error during neural network forward pass", t); + } - if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { - WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active at the end of outputOfLayerDetached", true); - } else { - Preconditions.checkState(outputWorkspace.isScopeActive(), "Expected output workspace to still be open" + - "at end of outputOfLayerDetached, but it is closed. This suggests an implementation or layer workspace problem"); - } + if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { + WorkspaceUtils.assertNoWorkspacesOpen( + "Expected no workspace active at the end of outputOfLayerDetached", true); + } else { + Preconditions.checkState(outputWorkspace.isScopeActive(), + "Expected output workspace to still be open" + + "at end of outputOfLayerDetached, but it is closed. This suggests an implementation or layer workspace problem"); + } + } + + return input; + } + + private INDArray reshapeTimeStepInput(INDArray input) { + if (input.rank() == 2) { // dynamically reshape to 3D input with one time-step. + long[] inShape = input.shape(); + input = input.reshape(inShape[0], inShape[1], 1); + } + return input; + } + + /** + * Compute activations of all layers from input (inclusive) to output of the final/output layer. + * Equivalent to calling {@link #feedForward(boolean)} with train=false + * + * @return the list of activations for each layer, including the input + */ + public List feedForward() { + return feedForward(false); + } + + /** + * Compute activations of all layers from input (inclusive) to output of the final/output layer. + * Equivalent to calling {@link #feedForward(INDArray, boolean)} with train = false + * + * @return the list of activations for each layer, including the input + */ + public List feedForward(INDArray input) { + if (input == null) { + throw new IllegalStateException("Unable to perform feed forward; no input found"); + } + setInput(input); + return feedForward(); + } + + /** + * Compute the activations from the input to the output layer, given mask arrays (that may be + * null) The masking arrays are used in situations such an one-to-many and many-to-one rucerrent + * neural network (RNN) designs, as well as for supporting time series of varying lengths within + * the same minibatch for RNNs. Other than mask arrays, this is equivalent to calling + * {@link #feedForward(INDArray, boolean)} with train = false + */ + public List feedForward(INDArray input, INDArray featuresMask, INDArray labelsMask) { + setLayerMaskArrays(featuresMask, labelsMask); + List list = feedForward(input); + clearLayerMaskArrays(); + return list; + } + + @Override + public Gradient gradient() { + return gradient; + } + + @Override + public Pair gradientAndScore() { + return new Pair<>(gradient(), score()); + } + + /** + * Clone the MultiLayerNetwork + * + * @return A cloned MultiLayerNetwork with a copy of the configuration, parameters and updater + * identical to the current network. + */ + @Override + public MultiLayerNetwork clone() { + if (!initCalled) { + init(); + } + MultiLayerConfiguration conf = this.layerWiseConfigurations.clone(); + MultiLayerNetwork ret = new MultiLayerNetwork(conf); + ret.init(this.params().dup(), false); + + if (solver != null) { + //If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however + Updater u = this.getUpdater(); + INDArray updaterState = u.getStateViewArray(); + if (updaterState != null) { + ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false); + } + } + + if (hasAFrozenLayer()) { + //correct layers to frozen layers + Layer[] clonedLayers = ret.getLayers(); + for (int i = 0; i < layers.length; i++) { + if (layers[i] instanceof FrozenLayer) { + clonedLayers[i] = new FrozenLayer(ret.getLayer(i)); + } + } + ret.setLayers(clonedLayers); + } + return ret; + } + + protected boolean hasAFrozenLayer() { + for (int i = 0; i < layers.length - 1; i++) { + if (layers[i] instanceof FrozenLayer) { + return true; + } + } + return false; + } + + /** + * @deprecated To be removed. Use {@link #params()} instead + */ + @Deprecated + public INDArray params(boolean backwardOnly) { + return params(); + } + + /** + * Returns a 1 x m vector where the vector is composed of a flattened vector of all of the + * parameters in the network.
See {@link #getParam(String)} and {@link #paramTable()} for a + * more useful/interpretable representation of the parameters.
Note that the parameter vector + * is not a copy, and changes to the returned INDArray will impact the network parameters. + * + * @return the parameters for this neural net + */ + @Override + public INDArray params() { + return flattenedParams; + } + + /** + * Set the parameters for this model. This expects a linear ndarray which then be unpacked + * internally relative to the expected ordering of the model.
See also: + * {@link #setParamTable(Map)} and {@link #setParam(String, INDArray)} + * + * @param params the parameters for the model + */ + @Override + public void setParams(INDArray params) { + if (flattenedParams == params) { + return; //No op + } + + if (flattenedParams != null && params.length() == flattenedParams.length()) { + if (params != flattenedParams) { + flattenedParams.assign(params); + } + } else { + if (flattenedParams == null) { + flattenedParams = params.dup(); + } + int idx = 0; + for (int i = 0; i < getLayers().length; i++) { + Layer layer = getLayer(i); + long range = layer.numParams(); + if (range <= 0) { + continue; //Some layers: no parameters (subsampling, etc) + } + INDArray get = params.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(idx, range + idx)); + layer.setParams(get); + idx += range; + } + } + } + + @Override + public void setParamsViewArray(INDArray params) { + throw new UnsupportedOperationException("Not yet implemented"); + } + + @Override + public INDArray getGradientsViewArray() { + return flattenedGradients; + } + + @Override + public void setBackpropGradientsViewArray(INDArray gradients) { + int paramsSoFar = 0; + for (Layer layer : layers) { + if (layer.numParams() == 0) { + continue; + } + layer.setBackpropGradientsViewArray(gradients.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(paramsSoFar, paramsSoFar + layer.numParams()))); + paramsSoFar += layer.numParams(); + } + } + + @Override + public TrainingConfig getConfig() { + throw new UnsupportedOperationException("Not supported"); + } + + /** + * Returns the number of parameters in the network + * + * @return The number of parameters + */ + @Override + public long numParams() { + if (!isInitCalled()) { + init(); + } + return flattenedParams == null ? 0 : flattenedParams.length(); //Maybe nul for 0 params net + } + + /** + * Returns the number of parameters in the network + * + * @param backwards If true: exclude any parameters uned only in unsupervised layerwise training + * (such as the decoder parameters in an autoencoder) + * @return The number of parameters + */ + @Override + public long numParams(boolean backwards) { + int length = 0; + for (int i = 0; i < layers.length; i++) { + length += layers[i].numParams(backwards); + } + + return length; + } + + /** + * Sets the input and labels and returns the F1 score for the prediction with respect to the true + * labels + * + * @param data the data to score + * @return the score for the given input,label pairs + */ + @Override + public double f1Score(org.nd4j.linalg.dataset.api.DataSet data) { + return f1Score(data.getFeatures(), data.getLabels()); + } + + /** + * Perform minibatch training on all minibatches in the DataSetIterator, for the specified number + * of epochs. Equvalent to calling {@link #fit(DataSetIterator)} numEpochs times in a loop + * + * @param iterator Training data (DataSetIterator). Iterator must support resetting + * @param numEpochs Number of training epochs, >= 1 + */ + public void fit(@NonNull DataSetIterator iterator, int numEpochs) { + Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", + numEpochs); + Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), + "Cannot perform multiple epochs training using" + + "iterator thas does not support resetting (iterator.resetSupported() returned false)"); + + for (int i = 0; i < numEpochs; i++) { + fit(iterator); + } + } + + /** + * Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.
Note that + * this method does not do layerwise pretraining.
For pretraining use method pretrain.. + * {@link #pretrain(DataSetIterator)}
+ * + * @param iterator Training data (DataSetIterator) + */ + @Override + public void fit(DataSetIterator iterator) { + try { + fitHelper(iterator); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } + + private synchronized void fitHelper(DataSetIterator iterator) { + // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate + DataSetIterator iter; + boolean destructable = false; + if (iterator.asyncSupported()) { + iter = new AsyncDataSetIterator(iterator, + Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true); + destructable = true; + } else { + iter = iterator; + } + + for (TrainingListener tl : trainingListeners) { + tl.onEpochStart(this); + } + + LayerWorkspaceMgr workspaceMgr; + if (getLayerWiseConfigurations().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + workspaceMgr = LayerWorkspaceMgr.builder() + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM + // as these should be closed by the time updaters are executed + //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this + .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .build(); + } + workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); + + update(TaskUtils.buildTask(iter)); + if (!iter.hasNext() && iter.resetSupported()) { + iter.reset(); + } + long time1 = System.currentTimeMillis(); + while (iter.hasNext()) { + + DataSet next = iter.next(); + long time2 = System.currentTimeMillis(); + + lastEtlTime.set((time2 - time1)); + + if (next.getFeatures() == null || next.getLabels() == null) { + break; + } + + // TODO: basically we want to wrap internals of this loop into workspace + + boolean hasMaskArrays = next.hasMaskArrays(); + + if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) { + doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(), + next.getLabelsMaskArray(), workspaceMgr); + } else { + if (hasMaskArrays) { + setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray()); } - return input; - } + setInput(next.getFeatures()); + setLabels(next.getLabels()); - private INDArray reshapeTimeStepInput(INDArray input) { - if (input.rank() == 2) { // dynamically reshape to 3D input with one time-step. - long[] inShape = input.shape(); - input = input.reshape(inShape[0], inShape[1], 1); + if (solver == null) { + try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + .build(); + } } - return input; - } - /** - * Compute activations of all layers from input (inclusive) to output of the final/output layer. - * Equivalent to calling {@link #feedForward(boolean)} with train=false - * - * @return the list of activations for each layer, including the input - */ - public List feedForward() { - return feedForward(false); - } + //TODO CACHE + solver.optimize(workspaceMgr); + } - /** - * Compute activations of all layers from input (inclusive) to output of the final/output layer. - * Equivalent to calling {@link #feedForward(INDArray, boolean)} with train = false - * - * @return the list of activations for each layer, including the input - */ - public List feedForward(INDArray input) { - if (input == null) - throw new IllegalStateException("Unable to perform feed forward; no input found"); - setInput(input); - return feedForward(); - } - - /** - * Compute the activations from the input to the output layer, given mask arrays (that may be null) - * The masking arrays are used in situations such an one-to-many and many-to-one rucerrent neural network (RNN) - * designs, as well as for supporting time series of varying lengths within the same minibatch for RNNs. - * Other than mask arrays, this is equivalent to calling {@link #feedForward(INDArray, boolean)} with train = false - */ - public List feedForward(INDArray input, INDArray featuresMask, INDArray labelsMask) { - setLayerMaskArrays(featuresMask, labelsMask); - List list = feedForward(input); + if (hasMaskArrays) { clearLayerMaskArrays(); - return list; + } + + time1 = System.currentTimeMillis(); + synchronizeIterEpochCounts(); } - - @Override - public Gradient gradient() { - return gradient; + if (!trainingListeners.isEmpty()) { + for (TrainingListener tl : trainingListeners) { + tl.onEpochEnd(this); + } } - @Override - public Pair gradientAndScore() { - return new Pair<>(gradient(), score()); + clearLayersStates(); + + if (destructable) { + ((AsyncDataSetIterator) iter).shutdown(); } + incrementEpochCount(); + } - /** - * Clone the MultiLayerNetwork - * @return A cloned MultiLayerNetwork with a copy of the configuration, parameters and updater identical to the current network. - */ - @Override - public MultiLayerNetwork clone() { - if(!initCalled) - init(); - MultiLayerConfiguration conf = this.layerWiseConfigurations.clone(); - MultiLayerNetwork ret = new MultiLayerNetwork(conf); - ret.init(this.params().dup(), false); - - if (solver != null) { - //If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however - Updater u = this.getUpdater(); - INDArray updaterState = u.getStateViewArray(); - if (updaterState != null) { - ret.getUpdater().setStateViewArray(ret, updaterState.dup(), false); - } - } - - if (hasAFrozenLayer()) { - //correct layers to frozen layers - Layer[] clonedLayers = ret.getLayers(); - for (int i = 0; i < layers.length; i++) { - if (layers[i] instanceof FrozenLayer) { - clonedLayers[i] = new FrozenLayer(ret.getLayer(i)); - } - } - ret.setLayers(clonedLayers); - } - return ret; + /** + * Calculate parameter gradients and input activation gradients given the input and labels, and + * optionally mask arrays + * + * @param features Features for gradient calculation + * @param label Labels for gradient + * @param fMask Features mask array (may be null) + * @param labelMask Label mask array (may be null) + * @return A pair of gradient arrays: parameter gradients (in Gradient object) and input + * activation gradients + */ + public Pair calculateGradients(@NonNull INDArray features, + @NonNull INDArray label, + INDArray fMask, INDArray labelMask) { + try { + return calculateGradientsHelper(features, label, fMask, labelMask); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; } + } - protected boolean hasAFrozenLayer() { - for (int i = 0; i < layers.length - 1; i++) { - if (layers[i] instanceof FrozenLayer) - return true; - } - return false; + private Pair calculateGradientsHelper(INDArray features, INDArray label, + INDArray fMask, + INDArray labelMask) { + setInput(features); + setLabels(label); + setLayerMaskArrays(fMask, labelMask); + + LayerWorkspaceMgr mgr; + if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + mgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + mgr = LayerWorkspaceMgr.builder() + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); + + if (layerWiseConfigurations.getCacheMode() != null) { + //For now: store cache mode activations in activations workspace + mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); + } } + mgr.setHelperWorkspacePointers(helperWorkspaces); - - /** - * @deprecated To be removed. Use {@link #params()} instead - */ - @Deprecated - public INDArray params(boolean backwardOnly) { - return params(); - } - - - /** - * Returns a 1 x m vector where the vector is composed of a flattened vector of all of the parameters in the network.
- * See {@link #getParam(String)} and {@link #paramTable()} for a more useful/interpretable representation of the parameters.
- * Note that the parameter vector is not a copy, and changes to the returned INDArray will impact the network parameters. - * - * @return the parameters for this neural net - */ - @Override - public INDArray params() { - return flattenedParams; - } - - /** - * Set the parameters for this model. - * This expects a linear ndarray which then be unpacked internally relative to the expected ordering of the model.
- * See also: {@link #setParamTable(Map)} and {@link #setParam(String, INDArray)} - * - * @param params the parameters for the model - */ - @Override - public void setParams(INDArray params) { - if (flattenedParams == params) { - return; //No op - } - - if (flattenedParams != null && params.length() == flattenedParams.length()) { - if (params != flattenedParams) { - flattenedParams.assign(params); - } - } else { - if (flattenedParams == null) - flattenedParams = params.dup(); - int idx = 0; - for (int i = 0; i < getLayers().length; i++) { - Layer layer = getLayer(i); - long range = layer.numParams(); - if (range <= 0) - continue; //Some layers: no parameters (subsampling, etc) - INDArray get = params.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(idx, range + idx)); - layer.setParams(get); - idx += range; - } - } - } - - @Override - public void setParamsViewArray(INDArray params) { - throw new UnsupportedOperationException("Not yet implemented"); - } - - @Override - public INDArray getGradientsViewArray() { - return flattenedGradients; - } - - @Override - public void setBackpropGradientsViewArray(INDArray gradients) { - int paramsSoFar = 0; - for (Layer layer : layers) { - if (layer.numParams() == 0) - continue; - layer.setBackpropGradientsViewArray(gradients.get(NDArrayIndex.interval(0,0,true), - NDArrayIndex.interval(paramsSoFar, paramsSoFar + layer.numParams()))); - paramsSoFar += layer.numParams(); - } - } - - @Override - public TrainingConfig getConfig() { - throw new UnsupportedOperationException("Not supported"); - } - - /** - * Returns the number of parameters in the network - * - * @return The number of parameters - */ - @Override - public long numParams() { - if(!isInitCalled()) - init(); - return flattenedParams == null ? 0 : flattenedParams.length(); //Maybe nul for 0 params net - } - - /** - * Returns the number of parameters in the network - * - * @param backwards If true: exclude any parameters uned only in unsupervised layerwise training (such as the decoder - * parameters in an autoencoder) - * @return The number of parameters - */ - @Override - public long numParams(boolean backwards) { - int length = 0; - for (int i = 0; i < layers.length; i++) - length += layers[i].numParams(backwards); - - return length; - } - - /** - * Sets the input and labels and returns the F1 score for the prediction with respect to the true labels - * - * @param data the data to score - * @return the score for the given input,label pairs - */ - @Override - public double f1Score(org.nd4j.linalg.dataset.api.DataSet data) { - return f1Score(data.getFeatures(), data.getLabels()); - } - - /** - * Perform minibatch training on all minibatches in the DataSetIterator, for the specified number of epochs. - * Equvalent to calling {@link #fit(DataSetIterator)} numEpochs times in a loop - * - * @param iterator Training data (DataSetIterator). Iterator must support resetting - * @param numEpochs Number of training epochs, >= 1 - */ - public void fit(@NonNull DataSetIterator iterator, int numEpochs){ - Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs); - Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" + - "iterator thas does not support resetting (iterator.resetSupported() returned false)"); - - for(int i=0; i - * Note that this method does not do layerwise pretraining.
- * For pretraining use method pretrain.. {@link #pretrain(DataSetIterator)}
- * @param iterator Training data (DataSetIterator) - */ - @Override - public void fit(DataSetIterator iterator) { - try{ - fitHelper(iterator); - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - private synchronized void fitHelper(DataSetIterator iterator){ - // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate - DataSetIterator iter; - boolean destructable = false; - if (iterator.asyncSupported()) { - iter = new AsyncDataSetIterator(iterator, Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true); - destructable = true; - } else { - iter = iterator; - } - + //Calculate activations (which are stored in each layer, and used in backprop) + try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) { + //First: do a feed-forward through the network + //Note that we don't actually need to do the full forward pass through the output layer right now; but we do + // need the input to the output layer to be set (such that backprop can be done) + List activations = ffToLayerActivationsInWs(layers.length - 2, FwdPassType.STANDARD, + false, input, mask, fMask); + if (!trainingListeners.isEmpty()) { + //TODO: We possibly do want output layer activations in some cases here... for (TrainingListener tl : trainingListeners) { - tl.onEpochStart(this); + tl.onForwardPass(this, activations); } + } + INDArray inputToOutputLayer = activations.get(activations.size() - 1); + if (layerWiseConfigurations.getInputPreProcess(layers.length - 1) != null) { + inputToOutputLayer = layerWiseConfigurations.getInputPreProcess(layers.length - 1) + .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); + //Validate activations location + } + getOutputLayer().setInput(inputToOutputLayer, mgr); - LayerWorkspaceMgr workspaceMgr; - if(getLayerWiseConfigurations().getTrainingWorkspaceMode() == WorkspaceMode.NONE){ - workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); - } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM - // as these should be closed by the time updaters are executed - //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this - .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .build(); - } - workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - - update(TaskUtils.buildTask(iter)); - if (!iter.hasNext() && iter.resetSupported()) { - iter.reset(); - } - long time1 = System.currentTimeMillis(); - while (iter.hasNext()) { - - DataSet next = iter.next(); - long time2 = System.currentTimeMillis(); - - lastEtlTime.set((time2 - time1)); - - if (next.getFeatures() == null || next.getLabels() == null) - break; - - // TODO: basically we want to wrap internals of this loop into workspace - - - boolean hasMaskArrays = next.hasMaskArrays(); - - if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) { - doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(), - next.getLabelsMaskArray(), workspaceMgr); - } else { - if (hasMaskArrays) - setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray()); - - setInput(next.getFeatures()); - setLabels(next.getLabels()); - - if (solver == null) { - try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) - .build(); - } - } - - //TODO CACHE - solver.optimize(workspaceMgr); - } - - if (hasMaskArrays) - clearLayerMaskArrays(); - - time1 = System.currentTimeMillis(); - synchronizeIterEpochCounts(); - } - - if (!trainingListeners.isEmpty()) { - for (TrainingListener tl : trainingListeners) { - tl.onEpochEnd(this); - } - } - - clearLayersStates(); - - if (destructable) - ((AsyncDataSetIterator) iter).shutdown(); - - incrementEpochCount(); + Pair p = calcBackpropGradients(null, true, false, true); + if (p.getSecond() != null) { + p.setSecond(p.getSecond().detach()); + } + return p; } + } - /** - * Calculate parameter gradients and input activation gradients given the input and labels, and optionally mask arrays - * - * @param features Features for gradient calculation - * @param label Labels for gradient - * @param fMask Features mask array (may be null) - * @param labelMask Label mask array (may be null) - * @return A pair of gradient arrays: parameter gradients (in Gradient object) and input activation gradients - */ - public Pair calculateGradients(@NonNull INDArray features, @NonNull INDArray label, - INDArray fMask, INDArray labelMask) { - try{ - return calculateGradientsHelper(features, label, fMask, labelMask); - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } + /** + * Calculate gradients and errors. Used in two places: (a) backprop (for standard multi layer + * network learning) (b) backpropGradient (layer method, for when MultiLayerNetwork is used as a + * layer) + * + * @param epsilon Errors (technically errors .* activations). Not used if + * withOutputLayer = true + * @param withOutputLayer if true: assume last layer is output layer, and calculate errors + * based on labels. In this case, the epsilon input is not used + * (may/should be null). If false: calculate backprop gradients + * @param returnInputActGrad If true: terun the input activation gradients (detached). False: + * don't return + * @return Gradients and the error (epsilon) at the input + */ + protected Pair calcBackpropGradients(INDArray epsilon, + boolean withOutputLayer, boolean tbptt, + boolean returnInputActGrad) { + if (flattenedGradients == null) { + initGradientsView(); } + String multiGradientKey; + Gradient gradient = new DefaultGradient(flattenedGradients); - private Pair calculateGradientsHelper(INDArray features, INDArray label, INDArray fMask, - INDArray labelMask){ - setInput(features); - setLabels(label); - setLayerMaskArrays(fMask, labelMask); + LayerWorkspaceMgr mgrEven; + LayerWorkspaceMgr mgrOdd; - LayerWorkspaceMgr mgr; - if(layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ - mgr = LayerWorkspaceMgr.noWorkspaces(); - } else { - mgr = LayerWorkspaceMgr.builder() - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); - - if(layerWiseConfigurations.getCacheMode() != null){ - //For now: store cache mode activations in activations workspace - mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); - } - } - mgr.setHelperWorkspacePointers(helperWorkspaces); - - //Calculate activations (which are stored in each layer, and used in backprop) - try(MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) { - //First: do a feed-forward through the network - //Note that we don't actually need to do the full forward pass through the output layer right now; but we do - // need the input to the output layer to be set (such that backprop can be done) - List activations = ffToLayerActivationsInWs(layers.length - 2, FwdPassType.STANDARD, false, input, mask, fMask); - if (!trainingListeners.isEmpty()) { - //TODO: We possibly do want output layer activations in some cases here... - for (TrainingListener tl : trainingListeners) { - tl.onForwardPass(this, activations); - } - } - INDArray inputToOutputLayer = activations.get(activations.size() - 1); - if (layerWiseConfigurations.getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = layerWiseConfigurations.getInputPreProcess(layers.length - 1) - .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); - //Validate activations location - } - getOutputLayer().setInput(inputToOutputLayer, mgr); - - Pair p = calcBackpropGradients(null, true, false, true); - if(p.getSecond() != null){ - p.setSecond( p.getSecond().detach()); - } - return p; - } - } - - /** Calculate gradients and errors. Used in two places: - * (a) backprop (for standard multi layer network learning) - * (b) backpropGradient (layer method, for when MultiLayerNetwork is used as a layer) - * @param epsilon Errors (technically errors .* activations). Not used if withOutputLayer = true - * @param withOutputLayer if true: assume last layer is output layer, and calculate errors based on labels. In this - * case, the epsilon input is not used (may/should be null). - * If false: calculate backprop gradients - * @param returnInputActGrad If true: terun the input activation gradients (detached). False: don't return - * @return Gradients and the error (epsilon) at the input - */ - protected Pair calcBackpropGradients(INDArray epsilon, boolean withOutputLayer, boolean tbptt, - boolean returnInputActGrad) { - if (flattenedGradients == null) { - initGradientsView(); - } - String multiGradientKey; - Gradient gradient = new DefaultGradient(flattenedGradients); - - LayerWorkspaceMgr mgrEven; - LayerWorkspaceMgr mgrOdd; - - if(layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ - mgrEven = LayerWorkspaceMgr.noWorkspaces(); - mgrOdd = mgrEven; - WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in calcBackpropGradients when " + - "training workspace is set to none"); - } else { + if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + mgrEven = LayerWorkspaceMgr.noWorkspaces(); + mgrOdd = mgrEven; + WorkspaceUtils.assertNoWorkspacesOpen( + "Expected no workspace active in calcBackpropGradients when " + + "training workspace is set to none"); + } else { /* Workspaces for backprop in MLN share some features with outputOfLayerDetached, in terms of the "two alternating workspaces" idea (but for activation gradients here, instead of activations there). @@ -1884,1422 +2081,1546 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.d */ - mgrEven = LayerWorkspaceMgr.builder() - //Activations in context of backprop (preOut methods etc) are not used outside of the layer itself - .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) //Usually not required here. Exception: OutputLayer dropout - .with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + mgrEven = LayerWorkspaceMgr.builder() + //Activations in context of backprop (preOut methods etc) are not used outside of the layer itself + .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, + WS_ALL_LAYERS_ACT_CONFIG) //Usually not required here. Exception: OutputLayer dropout + .with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); - mgrOdd = LayerWorkspaceMgr.builder() - //Activations in context of backprop (preOut methods etc) are not used outside of the layer itself - .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) //Usually not required here. Exception: OutputLayer dropout - .with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + mgrOdd = LayerWorkspaceMgr.builder() + //Activations in context of backprop (preOut methods etc) are not used outside of the layer itself + .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, + WS_ALL_LAYERS_ACT_CONFIG) //Usually not required here. Exception: OutputLayer dropout + .with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); - if(epsilon == null) { - //If epsilon is non-null: external errors use case -> inputs are already detached - WorkspaceUtils.assertOpenActiveAndCurrent(WS_ALL_LAYERS_ACT, "calcBackpropGradients method requires workspace WS_ALL_LAYERS_ACT" + - " to be open when workspaces are used"); - } - } - mgrEven.setHelperWorkspacePointers(helperWorkspaces); - mgrOdd.setHelperWorkspacePointers(helperWorkspaces); - - //calculate and apply the backward gradient for every layer - /* - * Skip the output layer for the indexing and just loop backwards updating the coefficients for each layer. - * (when withOutputLayer == true) - * - * Activate applies the activation function for each layer and sets that as the input for the following layer. - * - * Typical literature contains most trivial case for the error calculation: wT * weights - * This interpretation transpose a few things to get mini batch because ND4J is rows vs columns organization for params - */ - int numLayers = getnLayers(); - //Store gradients is a list; used to ensure iteration order in DefaultGradient linked hash map. i.e., layer 0 first instead of output layer - LinkedList> gradientList = new LinkedList<>(); - - - Pair currPair = null; - MemoryWorkspace wsActGradCloseNext = null; - MemoryWorkspace wsActGradTemp = null; - MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace(); - - boolean traceLog = log.isTraceEnabled(); - - Throwable t = null; - try { - for (int i = layers.length - 1; i >= 0; i--) { - if (layers[i] instanceof FrozenLayer) { - break; - } - - if (traceLog) { - log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName()); - } - - LayerWorkspaceMgr workspaceMgr = (i % 2 == 0 ? mgrEven : mgrOdd); - - if (withOutputLayer && i == layers.length - 1) { - if (!(getOutputLayer() instanceof IOutputLayer)) { - log.warn("Warning: final layer isn't output layer. You cannot use backprop without an output layer."); - return null; - } - - IOutputLayer outputLayer = (IOutputLayer) getOutputLayer(); - if (labels == null && outputLayer.needsLabels()) - throw new IllegalStateException("No labels found"); - outputLayer.setLabels(labels); - } - - //Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers - wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD); - try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) { - - //Note that because we're opening activation workspaces not in a simple nested order, we'll manually - // override the previous workspace setting. Otherwise, when we close these workspaces, the "current" - // workspace may be set to the incorrect one - wsActGradTemp.setPreviousWorkspace(initialWorkspace); - wsBPWorking.setPreviousWorkspace(initialWorkspace); - - INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer - - if (!tbptt) { - //Standard case - currPair = layers[i].backpropGradient(eps, workspaceMgr); - } else { - //TBPTT gradient - if (layers[i] instanceof RecurrentLayer) { - currPair = ((RecurrentLayer) layers[i]).tbpttBackpropGradient(currPair.getSecond(), - layerWiseConfigurations.getTbpttBackLength(), workspaceMgr); - } else { - currPair = layers[i].backpropGradient(currPair.getSecond(), workspaceMgr); - } - } - - if (currPair.getSecond() != null) { - //Edge case: may be null for Embedding layer, for example - validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, - false, "Backprop"); - } - - for (Map.Entry entry : currPair.getFirst().gradientForVariable().entrySet()) { - String origName = entry.getKey(); - multiGradientKey = i + "_" + origName; - gradientList.addLast(new Triple<>(multiGradientKey, entry.getValue(), - currPair.getFirst().flatteningOrderForVariable(origName))); - } - if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { - currPair = new Pair<>(currPair.getFirst(), - this.layerWiseConfigurations.getInputPreProcess(i) - .backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr)); - if (i > 0 && currPair.getSecond() != null) { - validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, - true, "Backprop"); - } - } - - if (i == 0) { - if (returnInputActGrad && currPair.getSecond() != null) { - currPair.setSecond(currPair.getSecond().detach()); - } else { - currPair.setSecond(null); - } - } - - if (wsActGradCloseNext != null) { - wsActGradCloseNext.close(); - } - wsActGradCloseNext = wsActGradTemp; - wsActGradTemp = null; - } - - if (traceLog) { - log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName()); - } - } - } catch (Throwable thr ){ - t = thr; - } finally { - if(wsActGradCloseNext != null){ - try { - wsActGradCloseNext.close(); - } catch (Throwable t2){ - if(t != null){ - log.error("Encountered second exception while trying to close workspace after initial exception"); - log.error("Original exception:", t); - throw t2; - } - } - } - if(wsActGradTemp != null) { - //Should only be non-null on exception - try { - wsActGradTemp.close(); - } catch (Throwable t2) { - if (t != null) { - log.error("Encountered second exception while trying to close workspace after initial exception"); - log.error("Original exception:", t); - throw t2; - } - } - } - Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); - - if(t != null){ - if(t instanceof RuntimeException){ - throw ((RuntimeException)t); - } - throw new RuntimeException("Error during neural network forward pass", t); - } - } - - if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { - WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in calcBackpropGradients when " + - "training workspace is set to none"); - } else { - if(epsilon == null) { - //If epsilon != null: external errors use case (inputs are detached instead) - WorkspaceUtils.assertOpenActiveAndCurrent(WS_ALL_LAYERS_ACT, "calcBackpropGradients: WS_ALL_LAYERS_ACT is no" + - " longer the currently open/active workspace"); - } - } - - //Add gradients to Gradients (map), in correct order - for (Triple triple : gradientList) { - gradient.setGradientFor(triple.getFirst(), triple.getSecond(), triple.getThird()); - } - - return new Pair<>(gradient, currPair.getSecond()); + if (epsilon == null) { + //If epsilon is non-null: external errors use case -> inputs are already detached + WorkspaceUtils.assertOpenActiveAndCurrent(WS_ALL_LAYERS_ACT, + "calcBackpropGradients method requires workspace WS_ALL_LAYERS_ACT" + + " to be open when workspaces are used"); + } } + mgrEven.setHelperWorkspacePointers(helperWorkspaces); + mgrOdd.setHelperWorkspacePointers(helperWorkspaces); - protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, - INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) { - if (input.rank() != 3 || labels.rank() != 3) { - log.warn("Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " - + Arrays.toString(input.shape()) + "\tand labels with shape " - + Arrays.toString(labels.shape())); - return; - } - if (input.size(2) != labels.size(2)) { - log.warn("Input and label time series have different lengths: {} input length, {} label length", - input.size(2), labels.size(2)); - return; - } - - int fwdLen = layerWiseConfigurations.getTbpttFwdLength(); - update(TaskUtils.buildTask(input, labels)); - val timeSeriesLength = input.size(2); - long nSubsets = timeSeriesLength / fwdLen; - if (timeSeriesLength % fwdLen != 0) - nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20) - - rnnClearPreviousState(); - - for (int i = 0; i < nSubsets; i++) { - long startTimeIdx = (long) i * fwdLen; - long endTimeIdx = startTimeIdx + fwdLen; - if (endTimeIdx > timeSeriesLength) - endTimeIdx = timeSeriesLength; - - if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels, - featuresMaskArray, labelsMaskArray); - - setInput(subsets[0]); - setLabels(subsets[1]); - setLayerMaskArrays(subsets[2], subsets[3]); - - if (solver == null) { - try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) - .build(); - } - } - solver.optimize(workspaceMgr); - - //Finally, update the state of the RNN layers: - updateRnnStateWithTBPTTState(); - } - - rnnClearPreviousState(); - clearLayerMaskArrays(); - } - - private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, INDArray labels, - INDArray fMask, INDArray lMask ){ - INDArray[] out = new INDArray[4]; - out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - - if (fMask != null) { - out[2] = fMask.get(NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - } - if (lMask != null) { - out[3] = lMask.get(NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - } - - return out; - } - - /** - * Intended for internal/developer use + //calculate and apply the backward gradient for every layer + /* + * Skip the output layer for the indexing and just loop backwards updating the coefficients for each layer. + * (when withOutputLayer == true) + * + * Activate applies the activation function for each layer and sets that as the input for the following layer. + * + * Typical literature contains most trivial case for the error calculation: wT * weights + * This interpretation transpose a few things to get mini batch because ND4J is rows vs columns organization for params */ - public void updateRnnStateWithTBPTTState() { - for (int i = 0; i < layers.length; i++) { + int numLayers = getnLayers(); + //Store gradients is a list; used to ensure iteration order in DefaultGradient linked hash map. i.e., layer 0 first instead of output layer + LinkedList> gradientList = new LinkedList<>(); + + Pair currPair = null; + MemoryWorkspace wsActGradCloseNext = null; + MemoryWorkspace wsActGradTemp = null; + MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace(); + + boolean traceLog = log.isTraceEnabled(); + + Throwable t = null; + try { + for (int i = layers.length - 1; i >= 0; i--) { + if (layers[i] instanceof FrozenLayer) { + break; + } + + if (traceLog) { + log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName()); + } + + LayerWorkspaceMgr workspaceMgr = (i % 2 == 0 ? mgrEven : mgrOdd); + + if (withOutputLayer && i == layers.length - 1) { + if (!(getOutputLayer() instanceof IOutputLayer)) { + log.warn( + "Warning: final layer isn't output layer. You cannot use backprop without an output layer."); + return null; + } + + IOutputLayer outputLayer = (IOutputLayer) getOutputLayer(); + if (labels == null && outputLayer.needsLabels()) { + throw new IllegalStateException("No labels found"); + } + outputLayer.setLabels(labels); + } + + //Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers + wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD); + try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered( + ArrayType.BP_WORKING_MEM)) { + + //Note that because we're opening activation workspaces not in a simple nested order, we'll manually + // override the previous workspace setting. Otherwise, when we close these workspaces, the "current" + // workspace may be set to the incorrect one + wsActGradTemp.setPreviousWorkspace(initialWorkspace); + wsBPWorking.setPreviousWorkspace(initialWorkspace); + + INDArray eps = (i == layers.length - 1 ? epsilon + : currPair.getRight()); //eps is null for OutputLayer + + if (!tbptt) { + //Standard case + currPair = layers[i].backpropGradient(eps, workspaceMgr); + } else { + //TBPTT gradient if (layers[i] instanceof RecurrentLayer) { - RecurrentLayer l = ((RecurrentLayer) layers[i]); - l.rnnSetPreviousState(l.rnnGetTBPTTState()); - } else if (layers[i] instanceof MultiLayerNetwork) { - ((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState(); - } - } - } - - /** - * Get the {@link TrainingListener}s set for this network, if any - * @return listeners set for this network - */ - public Collection getListeners() { - return trainingListeners; - } - - /** - * @deprecated Use {@link #getListeners()} - */ - @Deprecated - public Collection getTrainingListeners() { - return trainingListeners; - } - - @Override - public void setListeners(Collection listeners) { - if (layers == null) { - init(); - } - for (Layer layer : layers) { - layer.setListeners(listeners); - } - - if (solver != null) { - solver.setListeners(listeners); - } - - this.trainingListeners.clear(); - if (listeners != null) { - this.trainingListeners.addAll(listeners); - } - } - - /** - * This method ADDS additional TrainingListener to existing listeners - * - * @param listeners - */ - @Override - public void addListeners(TrainingListener... listeners) { - Collections.addAll(trainingListeners, listeners); - - // fixme this is wrong, since it removes existing listeners from the solver - if (solver != null) { - solver.setListeners(this.trainingListeners); - } - } - - @Override - public void setListeners(TrainingListener... listeners) { - Collection cListeners = new ArrayList<>(); - //Check: user might have done setListeners(null) thinking this would clear the current listeners. - //This results in an TrainingListener[1] with a single null value -> results in a NPE later - if (listeners != null && listeners.length > 0) { - for (TrainingListener i : listeners) { - if (i != null) - cListeners.add(i); - } - } - setListeners(cListeners); - } - - /** - * Usable only for classification networks in conjunction with OutputLayer. Cannot be used with RnnOutputLayer, - * CnnLossLayer, or networks used for regression.
- * To get the raw output activations of the output layer, use {@link #output(INDArray)} or similar.
- *
- * Equivalent to argmax(this.output(input)): Returns the predicted class indices corresponding to the predictions - * for each example in the features array. - * - * @param d The input features to perform inference on - * @return The predicted class index for each example - */ - @Override - public int[] predict(INDArray d) { - INDArray output = output(d, Layer.TrainingMode.TEST); - - if (d.size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - - Preconditions.checkState(output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank()); - return output.argMax(1).toIntVector(); - } - - /** - * As per {@link #predict(INDArray)} but the returned values are looked up from the list of label names - * in the provided DataSet - */ - @Override - public List predict(org.nd4j.linalg.dataset.api.DataSet dataSet) { - Preconditions.checkState(dataSet.getLabelNamesList() != null, "This method can only be used when the DataSet contains a label name list"); - int[] intRet = predict(dataSet.getFeatures()); - List ret = new ArrayList<>(); - for (int i = 0; i < intRet.length; i++) { - ret.add(i, dataSet.getLabelName(intRet[i])); - } - return ret; - } - - /** - * Fit the model for one iteration on the provided data - * - * @param data the examples to classify (one example in each row) - * @param labels the example labels(a binary outcome matrix) - */ - @Override - public void fit(INDArray data, INDArray labels) { - fit(data, labels, null, null); - } - - /** - * Fit the model for one iteration on the provided data - * - * @param features the examples to classify (one example in each row) - * @param labels the example labels(a binary outcome matrix) - * @param featuresMask The mask array for the features (used for variable length time series, etc). May be null. - * @param labelsMask The mask array for the labels (used for variable length time series, etc). May be null. - */ - public synchronized void fit(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) { - try{ - fitHelper(features, labels, featuresMask, labelsMask); - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - private void fitHelper(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask){ - if(numParams() == 0) { - //No op: can't fit a network with 0 parameters - return; - } - - setInput(features); - setLabels(labels); - this.setLayerMaskArrays(featuresMask, labelsMask); - update(TaskUtils.buildTask(features, labels)); - - LayerWorkspaceMgr workspaceMgr; - if(layerWiseConfigurations.getTrainingWorkspaceMode() == null){ - workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); - } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM - // these should be closed by the time updaters are executed - //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this - .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .build(); - } - workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - - if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) { - doTruncatedBPTT(features, labels, featuresMask, labelsMask, workspaceMgr); - } else { - if (solver == null) { - try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); - } - } - //TODO CACHE WORKSPACE, IF USED??? - solver.optimize(workspaceMgr); - } - - clearLayerMaskArrays(); - clearLayersStates(); - synchronizeIterEpochCounts(); - } - - @Override - public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr){ - throw new UnsupportedOperationException("Not supported: use pretrainLayer"); - } - - - /** - * Fit the model for one iteration on the provided data - * - * @param data the data to train on - */ - @Override - public void fit(org.nd4j.linalg.dataset.api.DataSet data) { - fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(), data.getLabelsMaskArray()); - } - - /** - * Fit the model for one iteration on the provided data - * - * @param examples the examples to classify (one example in each row) - * @param labels the labels for each example (the number of labels must match - */ - @Override - public void fit(INDArray examples, int[] labels) { - org.deeplearning4j.nn.conf.layers.OutputLayer layerConf = - (org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer(); - - if (layerConf.getNOut() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - fit(examples, FeatureUtil.toOutcomeMatrix(labels, (int) layerConf.getNOut())); - } - - - /** - * Perform inference on the provided input/features - i.e., perform forward pass using the provided input/features - * and return the output of the final layer. - * - * @param input Input to the network - * @param train whether the output is test or train. This mainly affect hyper parameters such as dropout and - * batch normalization, which have different behaviour for test vs. train - * @return The network predictions - i.e., the activations of the final layer - */ - public INDArray output(INDArray input, TrainingMode train) { - return output(input, train == TrainingMode.TRAIN); - } - - /** - * Perform inference on the provided input/features - i.e., perform forward pass using the provided input/features - * and return the output of the final layer. - * - * @param input Input to the network - * @param train whether the output is test or train. This mainly affect hyper parameters such as dropout and - * batch normalization, which have different behaviour for test vs. train - * @return The network predictions - i.e., the activations of the final layer - */ - public INDArray output(INDArray input, boolean train) { - return output(input, train, null, null); - } - - /** - * Calculate the output of the network, with masking arrays. The masking arrays are used in situations such - * as one-to-many and many-to-one recurrent neural network (RNN) designs, as well as for supporting time series - * of varying lengths within the same minibatch. - */ - public INDArray output(INDArray input, boolean train, INDArray featuresMask, INDArray labelsMask) { - return output(input, train, featuresMask, labelsMask, null); - } - - /** - * Get the network output, which is optionally placed in the specified memory workspace.
- * If no memory workspace is provided, the output will be detached (not in any workspace).
- * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method) - * will be placed in the specified workspace. This workspace must be opened by the user before calling this method - - * and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out - * of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either - * an exception when used, or a crash). - * - * @param input Input to the network - * @param train True for train, false otherwise - * @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling this method. - * @return The output/activations from the network (either detached or in the specified workspace if provided) - */ - public INDArray output(INDArray input, boolean train, MemoryWorkspace outputWorkspace) { - return output(input, train, null, null, outputWorkspace); - } - - /** - * Get the network output, which is optionally placed in the specified memory workspace.
- * If no memory workspace is provided, the output will be detached (not in any workspace).
- * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method) - * will be placed in the specified workspace. This workspace must be opened by the user before calling this method - - * and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out - * of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either - * an exception when used, or a crash). - * - * @param input Input to the network - * @param train True for train, false otherwise - * @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling this method. - * @return The output/activations from the network (either detached or in the specified workspace if provided) - */ - public synchronized INDArray output(INDArray input, boolean train, INDArray featuresMask, INDArray labelsMask, MemoryWorkspace outputWorkspace) { - try { - return outputOfLayerDetached(train, FwdPassType.STANDARD, layers.length - 1, input, featuresMask, labelsMask, outputWorkspace); - } catch (OutOfMemoryError e) { - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - /** - * This method uses provided OutputAdapter to return custom object built from INDArray - * - * PLEASE NOTE: This method uses dedicated Workspace for output generation to avoid redundant allocations - * - * @param inputs Input arrays to the netwonk - * @param inputMasks Optional input mask arrays (may be null) - * @param labelMasks Optional label mask arrays (may be null - * @param outputAdapter OutputAdapter instance - * @param T extends Object - * @return T instance produced by OutputAdapter - */ - public synchronized T output(@NonNull INDArray inputs, INDArray inputMasks, INDArray labelMasks, @NonNull OutputAdapter outputAdapter) { - try (val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM)) { - if (outputAdapter instanceof ModelAdapter) - return ((ModelAdapter) outputAdapter).apply(this, new INDArray[]{inputs}, new INDArray[]{ inputMasks}, new INDArray[]{labelMasks}); - else - return outputAdapter.apply(output(inputs, false, inputMasks, labelMasks, ws)); - } - } - - /** - * Perform inference on the provided input/features - i.e., perform forward pass using the provided input/features - * and return the output of the final layer. Equivalent to {@link #output(INDArray, boolean)} with train=false - i.e., - * this method is used for inference. - * - * @param input Input to the network - * @return The network predictions - i.e., the activations of the final layer - */ - public INDArray output(INDArray input) { - return output(input, TrainingMode.TEST); - } - - /** - * Generate the output for all examples/batches in the input iterator, and concatenate them into a single array. - * See {@link #output(INDArray)}
- * NOTE 1: The output array can require a considerable amount of memory for iterators with a large number of examples
- * NOTE 2: This method cannot be used for variable length time series outputs, as this would require padding arrays - * for some outputs, or returning a mask array (which cannot be done with this method). For variable length time - * series applications, use one of the other output methods. This method also cannot be used with fully convolutional - * networks with different output sizes (for example, segmentation on different input image sizes). - * - * - * @param iterator Data to pass through the network - * @return output for all examples in the iterator, concatenated into a - */ - public INDArray output(DataSetIterator iterator, boolean train) { - List outList = new ArrayList<>(); - long[] firstOutputShape = null; - while (iterator.hasNext()) { - DataSet next = iterator.next(); - INDArray features = next.getFeatures(); - - if (features == null) - continue; - - INDArray fMask = next.getFeaturesMaskArray(); - INDArray lMask = next.getLabelsMaskArray(); - INDArray output = this.output(features, train, fMask, lMask); - outList.add(output); - if(firstOutputShape == null){ - firstOutputShape = output.shape(); + currPair = ((RecurrentLayer) layers[i]).tbpttBackpropGradient(currPair.getSecond(), + layerWiseConfigurations.getTbpttBackLength(), workspaceMgr); } else { - //Validate that shapes are the same (may not be, for some RNN variable length time series applications) - long[] currShape = output.shape(); - Preconditions.checkState(firstOutputShape.length == currShape.length, "Error during forward pass:" + - "different minibatches have different output array ranks - first minibatch shape %s, last minibatch shape %s", firstOutputShape, currShape); - for( int i=1; i - * This is equivalent to {@link #score(DataSet, boolean)} with training==false. - * @param data the data to score - * @return the score for the given input,label pairs - * @see #score(DataSet, boolean) - */ - public double score(DataSet data) { - return score(data, false); - } - - /** - * Sets the input and labels and calculates the score (value of the output layer loss function plus l1/l2 if applicable) - * for the prediction with respect to the true labels
- * @param data data to calculate score for - * @param training If true: score during training. If false: score at test time. This can affect the application of - * certain features, such as dropout and dropconnect (which are applied at training time only) - * @return the score (value of the loss function) - */ - public double score(DataSet data, boolean training) { - try{ - return scoreHelper(data, training); - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - private double scoreHelper(DataSet data, boolean training){ - boolean hasMaskArray = data.hasMaskArrays(); - if (hasMaskArray) - setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray()); - - if (!(getOutputLayer() instanceof IOutputLayer)) { - throw new IllegalStateException("Cannot calculate score if final layer is not an instance of IOutputLayer. " + - "Final layer is of type: " + getOutputLayer().getClass()); - } - - WorkspaceMode wsm = (training ? layerWiseConfigurations.getTrainingWorkspaceMode() : layerWiseConfigurations.getInferenceWorkspaceMode()); - LayerWorkspaceMgr mgr; - if(wsm == WorkspaceMode.NONE){ - mgr = LayerWorkspaceMgr.noWorkspaces(); - } else { - mgr = LayerWorkspaceMgr.builder() - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - //TODO we can probably optimize this - .noWorkspaceFor(ArrayType.ACTIVATIONS) - .noWorkspaceFor(ArrayType.INPUT) - .build(); - } - mgr.setHelperWorkspacePointers(helperWorkspaces); - - INDArray inputToOutputLayer = outputOfLayerDetached(training, FwdPassType.STANDARD,layers.length-2, data.getFeatures(), - data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null); - - if (data.getFeatures().size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - IOutputLayer ol = (IOutputLayer) getOutputLayer(); - if (getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) - .preProcess(inputToOutputLayer, (int) data.getFeatures().size(0), mgr); - } - ol.setInput(inputToOutputLayer, mgr); //Feedforward doesn't include output layer for efficiency - ol.setLabels(data.getLabels()); - double score; - try(MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { - score = ol.computeScore(calcRegularizationScore(true), training, mgr); - } - - if (hasMaskArray) - clearLayerMaskArrays(); - clearLayersStates(); - - return score; - } - - /** - * As per {@link #scoreExamples(DataSet, boolean)} - the outputs (example scores) for all DataSets in the iterator are concatenated - */ - public INDArray scoreExamples(DataSetIterator iter, boolean addRegularizationTerms) { - List out = new ArrayList<>(); - - while (iter.hasNext()) { - out.add(scoreExamples(iter.next(), addRegularizationTerms)); - } - return Nd4j.toFlattened('f', out); - } - - /**Calculate the score for each example in a DataSet individually. Unlike {@link #score(DataSet)} and {@link #score(DataSet, boolean)} - * this method does not average/sum over examples. This method allows for examples to be scored individually (at test time only), which - * may be useful for example for autoencoder architectures and the like.
- * Each row of the output (assuming addRegularizationTerms == true) is equivalent to calling score(DataSet) with a single example. - * @param data The data to score - * @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If false: don't add regularization terms - * @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss value) of the ith example - */ - public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) { - try{ - return scoreExamplesHelper(data, addRegularizationTerms); - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - private INDArray scoreExamplesHelper(DataSet data, boolean addRegularizationTerms){ - INDArray inputLast = outputOfLayerDetached(false, FwdPassType.STANDARD,layers.length-2, data.getFeatures(), - data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null); - setLabels(data.getLabels()); - setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray()); - - //TODO we might want workspaces here? - LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces(); - - INDArray out; - if (getOutputLayer() instanceof IOutputLayer) { - IOutputLayer ol = (IOutputLayer) getOutputLayer(); - if(layerWiseConfigurations.getInputPreProcess(layers.length-1) != null){ - - if (data.getFeatures().size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - inputLast = layerWiseConfigurations.getInputPreProcess(layers.length-1).preProcess(inputLast, - (int) data.getFeatures().size(0), mgr); + for (Map.Entry entry : currPair.getFirst().gradientForVariable() + .entrySet()) { + String origName = entry.getKey(); + multiGradientKey = i + "_" + origName; + gradientList.addLast(new Triple<>(multiGradientKey, entry.getValue(), + currPair.getFirst().flatteningOrderForVariable(origName))); + } + if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { + currPair = new Pair<>(currPair.getFirst(), + this.layerWiseConfigurations.getInputPreProcess(i) + .backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr)); + if (i > 0 && currPair.getSecond() != null) { + validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, + i, + true, "Backprop"); } - ol.setLabels(data.getLabels()); - ol.setInput(inputLast, mgr); - double r = (addRegularizationTerms ? calcRegularizationScore(true) : 0); - out = ol.computeScoreForExamples(r, mgr); - } else { - throw new UnsupportedOperationException( - "Cannot calculate score with respect to labels without an OutputLayer"); - } + } - clearLayersStates(); - clearLayerMaskArrays(); - return out; - } - - - @Override - public void fit() { - fit(input, labels); - } - - @Override - public void update(INDArray gradient, String paramType) { - throw new UnsupportedOperationException("Not implemented"); - } - - - /** - * Score of the model (relative to the objective function) - previously calculated on the last minibatch - * - * @return the score of the model (relative to the objective function) - */ - @Override - public double score() { - return score; - } - - /** - * Intended for developer/internal use - */ - public void setScore(double score) { - this.score = score; - } - - @Override - public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr){ - computeGradientAndScore(); - } - - public void computeGradientAndScore() { - - if (!(getOutputLayer() instanceof IOutputLayer)) { - throw new DL4JException( - "Cannot calculate gradient and score with respect to labels: final layer is not an IOutputLayer. " + - "Final layer class: " + getOutputLayer().getClass() + ". To calculate gradients and fit a network " + - "using backpropagation, the final layer must be an output layer"); - } - - //Note: Workspace manager is only ose here for score calculation... other workspace managers are used in the - // various FF/backprop methds - LayerWorkspaceMgr mgr; - if(layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE){ - mgr = LayerWorkspaceMgr.noWorkspaces(); - } else { - mgr = LayerWorkspaceMgr.builder() - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); - - if(layerWiseConfigurations.getCacheMode() != null){ - //For now: store cache mode activations in activations workspace - mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); + if (i == 0) { + if (returnInputActGrad && currPair.getSecond() != null) { + currPair.setSecond(currPair.getSecond().detach()); + } else { + currPair.setSecond(null); } + } + + if (wsActGradCloseNext != null) { + wsActGradCloseNext.close(); + } + wsActGradCloseNext = wsActGradTemp; + wsActGradTemp = null; } - boolean tbptt = layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT; - FwdPassType fwdType = (tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD); - synchronizeIterEpochCounts(); - - //Calculate activations (which are stored in each layer, and used in backprop) - try(MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) { - //First: do a feed-forward through the network - //Note that we don't actually need to do the full forward pass through the output layer right now; but we do - // need the input to the output layer to be set (such that backprop can be done) - List activations = ffToLayerActivationsInWs(layers.length - 2, fwdType, tbptt, input, mask, null); - if (!trainingListeners.isEmpty()) { - //TODO: We possibly do want output layer activations in some cases here... - for (TrainingListener tl : trainingListeners) { - tl.onForwardPass(this, activations); - } - } - INDArray inputToOutputLayer = activations.get(activations.size() - 1); - if (layerWiseConfigurations.getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = layerWiseConfigurations.getInputPreProcess(layers.length - 1) - .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); - //Validate activations location - } - getOutputLayer().setInput(inputToOutputLayer, mgr); - //Then: compute gradients - Pair pair = calcBackpropGradients(null, true, false, false); - this.gradient = (pair == null ? null : pair.getFirst()); - - //Calculate score - try(MemoryWorkspace wsFF = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { - double r = calcRegularizationScore(true); - score = ((IOutputLayer) getOutputLayer()).computeScore(r, true, mgr); - } - - //Listeners - if (!trainingListeners.isEmpty()) { - try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - for (TrainingListener tl : trainingListeners) { - tl.onBackwardPass(this); - } - } - } + if (traceLog) { + log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName()); } - - //Clear the post noise/dropconnect parameters on the output layer - getOutputLayer().clearNoiseWeightParams(); - } - - /** - * Clear the inputs. Clears optimizer state. - */ - public void clear() { - for (Layer layer : layers) - layer.clear(); - - input = null; - labels = null; - solver = null; - } - - @Override - public void applyConstraints(int iteration, int epoch) { - for(Layer l : layers){ - l.applyConstraints(iteration, epoch); - } - } - - - /** - * Set the input array for the network - * - * @param input Input array to set - */ - public void setInput(INDArray input) { - this.input = input; - if (this.layers == null) { - init(); - } - if (input != null) { - if (input.length() == 0) - throw new IllegalArgumentException( - "Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")"); - - if (input.size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - setInputMiniBatchSize((int) input.size(0)); - } - } - - @Override - public void setInput(INDArray input, LayerWorkspaceMgr mgr){ - throw new UnsupportedOperationException("Not supported"); - } - - /** - * Get the output layer - i.e., the last layer in the netwok - * - * @return - */ - public Layer getOutputLayer() { - Layer ret = getLayers()[getLayers().length - 1]; - if (ret instanceof FrozenLayerWithBackprop) { - ret = ((FrozenLayerWithBackprop) ret).getInsideLayer(); - } - return ret; - } - - - /** - * See {@link #setParams(INDArray)} - */ - public void setParameters(INDArray params) { - setParams(params); - } - - /** - * Intended for internal/developer use - */ - public NeuralNetConfiguration getDefaultConfiguration() { - return defaultConfiguration; - } - - public INDArray getLabels() { - return labels; - } - - public INDArray getInput() { - return input; - } - - - /** - * @param labels Labels to set - */ - public void setLabels(INDArray labels) { - this.labels = labels; - } - - /** - * Get the number of layers in the network - * - * @return the number of layers in the network - */ - public int getnLayers() { - return layerWiseConfigurations.getConfs().size(); - } - - /** - * @return The layers in the network - */ - public synchronized Layer[] getLayers() { - return layers; - } - - public Layer getLayer(int i) { - Preconditions.checkArgument(i >= 0 && i < layers.length, "Invalid layer index: layer index must be 0" + - " to %s (inclusive), got index %s", layers.length-1, i); - return layers[i]; - } - - public Layer getLayer(String name) { - return layerMap.get(name); - } - - public List getLayerNames() { - return new ArrayList<>(layerMap.keySet()); - } - - public void setLayers(Layer[] layers) { - this.layers = layers; - } - - public INDArray getMask() { - return mask; - } - - public void setMask(INDArray mask) { - this.mask = mask; - } - - public INDArray getMaskArray() { - return mask; - } - - @Override - public boolean isPretrainLayer() { - return false; - } - - @Override - public void clearNoiseWeightParams() { - for(Layer l : layers){ - l.clearNoiseWeightParams(); - } - } - - @Override - public void allowInputModification(boolean allow) { - throw new UnsupportedOperationException("Not supported"); - } - - @Override - public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, - int minibatchSize) { - if (maskArray == null) { - for (int i = 0; i < layers.length; i++) { - layers[i].feedForwardMaskArray(null, null, minibatchSize); - } - } else { - //Do a forward pass through each preprocessor and layer - for (int i = 0; i < layers.length; i++) { - InputPreProcessor preProcessor = getLayerWiseConfigurations().getInputPreProcess(i); - - if (preProcessor != null) { - Pair p = - preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); - if (p != null) { - maskArray = p.getFirst(); - currentMaskState = p.getSecond(); - } else { - maskArray = null; - currentMaskState = null; - } - } - - Pair p = - layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); - if (p != null) { - maskArray = p.getFirst(); - currentMaskState = p.getSecond(); - } else { - maskArray = null; - currentMaskState = null; - } - } - } - - return new Pair<>(maskArray, currentMaskState); - } - - @Override - public LayerHelper getHelper() { - throw new UnsupportedOperationException("Not supported"); - } - - //========== - //Layer methods - - @Override - public Type type() { - return Type.MULTILAYER; - } - - - /** - * Equivalent to {@link #output(INDArray)} using the input set via {@link #setInput(INDArray)} - */ - public INDArray activate(TrainingMode training) { - return output(input, training == TrainingMode.TRAIN); - } - - /** - * Equivalent to {@link #output(INDArray, TrainingMode)} - */ - public INDArray activate(INDArray input, TrainingMode training) { - return output(input, training == TrainingMode.TRAIN); - } - - @Override - public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { - if (getOutputLayer() instanceof IOutputLayer) - throw new UnsupportedOperationException("Cannot calculate gradients based on epsilon with OutputLayer"); - - return calcBackpropGradients(epsilon, false, false, true); - } - - @Override - public void setIndex(int index) { - layerIndex = index; - } - - @Override - public int getIndex() { - return layerIndex; - } - - @Override - public int getIterationCount() { - return getLayerWiseConfigurations().getIterationCount(); - } - - @Override - public int getEpochCount() { - return getLayerWiseConfigurations().getEpochCount(); - } - - @Override - public void setIterationCount(int iterationCount) { - getLayerWiseConfigurations().setIterationCount(iterationCount); - } - - @Override - public void setEpochCount(int epochCount) { - getLayerWiseConfigurations().setEpochCount(epochCount); - } - - @Override - public double calcRegularizationScore(boolean backpropParamsOnly){ - double scoreSum = 0.0; - for (int i = 0; i < layers.length; i++) { - scoreSum += layers[i].calcRegularizationScore(backpropParamsOnly); - } - return scoreSum; - } - - @Override - public void update(Gradient gradient) { - if (gradient.gradient().length() != numParams(true)) - throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true)); - for (Map.Entry entry : gradient.gradientForVariable().entrySet()) { - String key = entry.getKey(); - INDArray val = entry.getValue(); - int idx = key.indexOf('_'); - if (idx == -1) - throw new IllegalStateException("Invalid param key: not have layer separator: \"" + key + "\""); - Integer layerId = Integer.parseInt(key.substring(0, idx)); - String paramType = key.substring(idx + 1); - // Update MLN gradient - this.gradient.gradientForVariable().put(key, val); - // Update layer params - layers[layerId].update(val, paramType); - } - // Update layerwise gradient view - setBackpropGradientsViewArray(gradient.gradient()); - - } - - @Override - public INDArray activate(boolean training, LayerWorkspaceMgr mgr) { - throw new UnsupportedOperationException(); - } - - @Override - public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr mgr) { - throw new UnsupportedOperationException(); - } - - @Override - public void setInputMiniBatchSize(int size) { - if (layers != null) - for (Layer l : layers) - l.setInputMiniBatchSize(size); - } - - @Override - public int getInputMiniBatchSize() { - if(!conf().isMiniBatch()) - return 1; - - if (input.size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - return (int) input.size(0); - } - - @Override - public void setMaskArray(INDArray maskArray) { - throw new UnsupportedOperationException(); - } - - /** - * - * If this MultiLayerNetwork contains one or more RNN layers: conduct forward pass (prediction) - * but using previous stored state for any RNN layers. The activations for the final step are - * also stored in the RNN layers for use next time rnnTimeStep() is called.
- * This method can be used to generate output one or more steps at a time instead of always having to do - * forward pass from t=0. Example uses are for streaming data, and for generating samples from network output - * one step at a time (where samples are then fed back into the network as input)
- * If no previous state is present in RNN layers (i.e., initially or after calling rnnClearPreviousState()), - * the default initialization (usually 0) is used.
- * Supports mini-batch (i.e., multiple predictions/forward pass in parallel) as well as for single examples.
- * @param input Input to network. May be for one or multiple time steps. For single time step: - * input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. miniBatchSize=1 for single example.
- * For multiple time steps: [miniBatchSize,inputSize,inputTimeSeriesLength] - * @return Output activations. If output is RNN layer (such as RnnOutputLayer): if input has shape [miniBatchSize,inputSize] - * i.e., is 2d, output has shape [miniBatchSize,outputSize] (i.e., also 2d).
- * Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using RnnOutputLayer. - * @see #rnnTimeStep(INDArray, MemoryWorkspace) For outputting the activations in the specified workspace - */ - public INDArray rnnTimeStep(INDArray input) { - return rnnTimeStep(input, null); - } - - /** - * See {@link #rnnTimeStep(INDArray)} for details
- * If no memory workspace is provided, the output will be detached (not in any workspace).
- * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by this method) - * will be placed in the specified workspace. This workspace must be opened by the user before calling this method - - * and the user is responsible for (a) closing this workspace, and (b) ensuring the output array is not used out - * of scope (i.e., not used after closing the workspace to which it belongs - as this is likely to cause either - * an exception when used, or a crash). - * - * @param input Input activations - * @param outputWorkspace Output workspace. May be null - * @return The output/activations from the network (either detached or in the specified workspace if provided) - */ - public INDArray rnnTimeStep(INDArray input, MemoryWorkspace outputWorkspace ) { + } + } catch (Throwable thr) { + t = thr; + } finally { + if (wsActGradCloseNext != null) { try { - boolean inputIs2d = input.rank() == 2; - INDArray out = outputOfLayerDetached(false, FwdPassType.RNN_TIMESTEP, layers.length - 1, input, null, null, outputWorkspace); - if (inputIs2d && out.rank() == 3 && layers[layers.length - 1].type() == Type.RECURRENT) { - //Return 2d output with shape [miniBatchSize,nOut] - // instead of 3d output with shape [miniBatchSize,nOut,1] - return out.tensorAlongDimension(0, 1, 0); - } - return out; - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; + wsActGradCloseNext.close(); + } catch (Throwable t2) { + if (t != null) { + log.error( + "Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } } - } - - /**Get the state of the RNN layer, as used in rnnTimeStep(). - * @param layer Number/index of the layer. - * @return Hidden state, or null if layer is not an RNN layer - */ - public Map rnnGetPreviousState(int layer) { - if (layer < 0 || layer >= layers.length) - throw new IllegalArgumentException("Invalid layer number"); - Layer l = layers[layer]; - if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer){ - l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying(); + } + if (wsActGradTemp != null) { + //Should only be non-null on exception + try { + wsActGradTemp.close(); + } catch (Throwable t2) { + if (t != null) { + log.error( + "Encountered second exception while trying to close workspace after initial exception"); + log.error("Original exception:", t); + throw t2; + } } - if (!(l instanceof RecurrentLayer)) - throw new IllegalArgumentException("Layer is not an RNN layer"); - return ((RecurrentLayer) l).rnnGetPreviousState(); - } + } + Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); - /**Set the state of the RNN layer. - * @param layer The number/index of the layer. - * @param state The state to set the specified layer to - */ - public void rnnSetPreviousState(int layer, Map state) { - if (layer < 0 || layer >= layers.length) - throw new IllegalArgumentException("Invalid layer number"); - Layer l = layers[layer]; - if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer){ - l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying(); + if (t != null) { + if (t instanceof RuntimeException) { + throw ((RuntimeException) t); } - if (!(l instanceof RecurrentLayer)) - throw new IllegalArgumentException("Layer is not an RNN layer"); - RecurrentLayer r = (RecurrentLayer) l; - r.rnnSetPreviousState(state); + throw new RuntimeException("Error during neural network forward pass", t); + } } - /** Clear the previous state of the RNN layers (if any). - */ - public void rnnClearPreviousState() { - if (layers == null) - return; - for (int i = 0; i < layers.length; i++) { - if (layers[i] instanceof RecurrentLayer) - ((RecurrentLayer) layers[i]).rnnClearPreviousState(); - else if (layers[i] instanceof MultiLayerNetwork) { - ((MultiLayerNetwork) layers[i]).rnnClearPreviousState(); - } else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){ - ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()).rnnClearPreviousState(); - } + if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + WorkspaceUtils.assertNoWorkspacesOpen( + "Expected no workspace active in calcBackpropGradients when " + + "training workspace is set to none"); + } else { + if (epsilon == null) { + //If epsilon != null: external errors use case (inputs are detached instead) + WorkspaceUtils.assertOpenActiveAndCurrent(WS_ALL_LAYERS_ACT, + "calcBackpropGradients: WS_ALL_LAYERS_ACT is no" + + " longer the currently open/active workspace"); + } + } + + //Add gradients to Gradients (map), in correct order + for (Triple triple : gradientList) { + gradient.setGradientFor(triple.getFirst(), triple.getSecond(), triple.getThird()); + } + + return new Pair<>(gradient, currPair.getSecond()); + } + + protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, + INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) { + if (input.rank() != 3 || labels.rank() != 3) { + log.warn( + "Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " + + Arrays.toString(input.shape()) + "\tand labels with shape " + + Arrays.toString(labels.shape())); + return; + } + if (input.size(2) != labels.size(2)) { + log.warn( + "Input and label time series have different lengths: {} input length, {} label length", + input.size(2), labels.size(2)); + return; + } + + int fwdLen = layerWiseConfigurations.getTbpttFwdLength(); + update(TaskUtils.buildTask(input, labels)); + val timeSeriesLength = input.size(2); + long nSubsets = timeSeriesLength / fwdLen; + if (timeSeriesLength % fwdLen != 0) { + nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20) + } + + rnnClearPreviousState(); + + for (int i = 0; i < nSubsets; i++) { + long startTimeIdx = (long) i * fwdLen; + long endTimeIdx = startTimeIdx + fwdLen; + if (endTimeIdx > timeSeriesLength) { + endTimeIdx = timeSeriesLength; + } + + if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels, + featuresMaskArray, labelsMaskArray); + + setInput(subsets[0]); + setLabels(subsets[1]); + setLayerMaskArrays(subsets[2], subsets[3]); + + if (solver == null) { + try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + .build(); } + } + solver.optimize(workspaceMgr); + + //Finally, update the state of the RNN layers: + updateRnnStateWithTBPTTState(); } - /** Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:
- * (a) like rnnTimeStep does forward pass using stored state for RNN layers, and
- * (b) unlike rnnTimeStep does not modify the RNN layer state
- * Therefore multiple calls to this method with the same input should have the same output.
- * Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time. - * @param input Input to network - * @param training Whether training or not - * @param storeLastForTBPTT set to true if used as part of truncated BPTT training - * @return Activations for each layer (including input, as per feedforward() etc) - */ - public List rnnActivateUsingStoredState(INDArray input, boolean training, boolean storeLastForTBPTT) { - return ffToLayerActivationsDetached(training, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, storeLastForTBPTT, layers.length-1, input, mask, null, false); + rnnClearPreviousState(); + clearLayerMaskArrays(); + } + + private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, + INDArray labels, + INDArray fMask, INDArray lMask) { + INDArray[] out = new INDArray[4]; + out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + + if (fMask != null) { + out[2] = fMask.get(NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + } + if (lMask != null) { + out[3] = lMask.get(NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); } - /** Get the updater for this MultiLayerNetwork - * @return Updater for MultiLayerNetwork - */ - public Updater getUpdater() { - return getUpdater(true); + return out; + } + + /** + * Intended for internal/developer use + */ + public void updateRnnStateWithTBPTTState() { + for (int i = 0; i < layers.length; i++) { + if (layers[i] instanceof RecurrentLayer) { + RecurrentLayer l = ((RecurrentLayer) layers[i]); + l.rnnSetPreviousState(l.rnnGetTBPTTState()); + } else if (layers[i] instanceof MultiLayerNetwork) { + ((MultiLayerNetwork) layers[i]).updateRnnStateWithTBPTTState(); + } + } + } + + /** + * Get the {@link TrainingListener}s set for this network, if any + * + * @return listeners set for this network + */ + public Collection getListeners() { + return trainingListeners; + } + + @Override + public void setListeners(Collection listeners) { + if (layers == null) { + init(); + } + for (Layer layer : layers) { + layer.setListeners(listeners); } - public Updater getUpdater(boolean initializeIfReq) { - if (solver == null && initializeIfReq) { - synchronized(this){ - if(solver == null) { //May have been created while waiting for lock - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); - solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this)); - } - } + if (solver != null) { + solver.setListeners(listeners); + } + + this.trainingListeners.clear(); + if (listeners != null) { + this.trainingListeners.addAll(listeners); + } + } + + @Override + public void setListeners(TrainingListener... listeners) { + Collection cListeners = new ArrayList<>(); + //Check: user might have done setListeners(null) thinking this would clear the current listeners. + //This results in an TrainingListener[1] with a single null value -> results in a NPE later + if (listeners != null && listeners.length > 0) { + for (TrainingListener i : listeners) { + if (i != null) { + cListeners.add(i); } - if(solver != null) { - return solver.getOptimizer().getUpdater(initializeIfReq); - } - return null; + } + } + setListeners(cListeners); + } + + /** + * @deprecated Use {@link #getListeners()} + */ + @Deprecated + public Collection getTrainingListeners() { + return trainingListeners; + } + + /** + * This method ADDS additional TrainingListener to existing listeners + * + * @param listeners + */ + @Override + public void addListeners(TrainingListener... listeners) { + Collections.addAll(trainingListeners, listeners); + + // fixme this is wrong, since it removes existing listeners from the solver + if (solver != null) { + solver.setListeners(this.trainingListeners); + } + } + + /** + * Usable only for classification networks in conjunction with OutputLayer. Cannot be used with + * RnnOutputLayer, CnnLossLayer, or networks used for regression.
To get the raw output + * activations of the output layer, use {@link #output(INDArray)} or similar.
+ *
+ * Equivalent to argmax(this.output(input)): Returns the predicted class indices corresponding to + * the predictions for each example in the features array. + * + * @param d The input features to perform inference on + * @return The predicted class index for each example + */ + @Override + public int[] predict(INDArray d) { + INDArray output = output(d, Layer.TrainingMode.TEST); + + if (d.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); } - /** Set the updater for the MultiLayerNetwork */ - public void setUpdater(Updater updater) { - if (solver == null) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); - } - solver.getOptimizer().setUpdater(updater); + Preconditions.checkState(output.rank() == 2, + "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", + output.rank()); + return output.argMax(1).toIntVector(); + } + + /** + * As per {@link #predict(INDArray)} but the returned values are looked up from the list of label + * names in the provided DataSet + */ + @Override + public List predict(org.nd4j.linalg.dataset.api.DataSet dataSet) { + Preconditions.checkState(dataSet.getLabelNamesList() != null, + "This method can only be used when the DataSet contains a label name list"); + int[] intRet = predict(dataSet.getFeatures()); + List ret = new ArrayList<>(); + for (int i = 0; i < intRet.length; i++) { + ret.add(i, dataSet.getLabelName(intRet[i])); + } + return ret; + } + + /** + * Fit the model for one iteration on the provided data + * + * @param data the examples to classify (one example in each row) + * @param labels the example labels(a binary outcome matrix) + */ + @Override + public void fit(INDArray data, INDArray labels) { + fit(data, labels, null, null); + } + + /** + * Fit the model for one iteration on the provided data + * + * @param features the examples to classify (one example in each row) + * @param labels the example labels(a binary outcome matrix) + * @param featuresMask The mask array for the features (used for variable length time series, + * etc). May be null. + * @param labelsMask The mask array for the labels (used for variable length time series, etc). + * May be null. + */ + public synchronized void fit(INDArray features, INDArray labels, INDArray featuresMask, + INDArray labelsMask) { + try { + fitHelper(features, labels, featuresMask, labelsMask); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } + + private void fitHelper(INDArray features, INDArray labels, INDArray featuresMask, + INDArray labelsMask) { + if (numParams() == 0) { + //No op: can't fit a network with 0 parameters + return; } - /**Set the mask arrays for features and labels. Mask arrays are typically used in situations such as one-to-many - * and many-to-one learning with recurrent neural networks, as well as for supporting time series of varying lengths - * within the same minibatch.
- * For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape - * [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape [miniBatchSize,timeSeriesLength] - * and contain values 0 or 1 at each element (to specify whether a given input/example is present - or merely padding - - * at a given time step).
- * NOTE: This method is not usually used directly. Instead, methods such as {@link #feedForward(INDArray, INDArray, INDArray)} - * and {@link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally. - * @param featuresMaskArray Mask array for features (input) - * @param labelsMaskArray Mask array for labels (output) - * @see #clearLayerMaskArrays() - */ - public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) { - if (featuresMaskArray != null) { + setInput(features); + setLabels(labels); + this.setLayerMaskArrays(featuresMask, labelsMask); + update(TaskUtils.buildTask(features, labels)); - if (featuresMaskArray.size(0) > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - //New approach: use feedForwardMaskArray method - feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0)); + LayerWorkspaceMgr workspaceMgr; + if (layerWiseConfigurations.getTrainingWorkspaceMode() == null) { + workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + workspaceMgr = LayerWorkspaceMgr.builder() + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM + // these should be closed by the time updaters are executed + //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this + .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .build(); + } + workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); + + if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) { + doTruncatedBPTT(features, labels, featuresMask, labelsMask, workspaceMgr); + } else { + if (solver == null) { + try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + .build(); + } + } + //TODO CACHE WORKSPACE, IF USED??? + solver.optimize(workspaceMgr); + } + + clearLayerMaskArrays(); + clearLayersStates(); + synchronizeIterEpochCounts(); + } + + @Override + public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) { + throw new UnsupportedOperationException("Not supported: use pretrainLayer"); + } + + /** + * Fit the model for one iteration on the provided data + * + * @param data the data to train on + */ + @Override + public void fit(org.nd4j.linalg.dataset.api.DataSet data) { + fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(), + data.getLabelsMaskArray()); + } + + /** + * Fit the model for one iteration on the provided data + * + * @param examples the examples to classify (one example in each row) + * @param labels the labels for each example (the number of labels must match + */ + @Override + public void fit(INDArray examples, int[] labels) { + org.deeplearning4j.nn.conf.layers.OutputLayer layerConf = + (org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer(); + + if (layerConf.getNOut() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + fit(examples, FeatureUtil.toOutcomeMatrix(labels, (int) layerConf.getNOut())); + } + + /** + * Perform inference on the provided input/features - i.e., perform forward pass using the + * provided input/features and return the output of the final layer. + * + * @param input Input to the network + * @param train whether the output is test or train. This mainly affect hyper parameters such as + * dropout and batch normalization, which have different behaviour for test vs. + * train + * @return The network predictions - i.e., the activations of the final layer + */ + public INDArray output(INDArray input, TrainingMode train) { + return output(input, train == TrainingMode.TRAIN); + } + + /** + * Perform inference on the provided input/features - i.e., perform forward pass using the + * provided input/features and return the output of the final layer. + * + * @param input Input to the network + * @param train whether the output is test or train. This mainly affect hyper parameters such as + * dropout and batch normalization, which have different behaviour for test vs. + * train + * @return The network predictions - i.e., the activations of the final layer + */ + public INDArray output(INDArray input, boolean train) { + return output(input, train, null, null); + } + + /** + * Calculate the output of the network, with masking arrays. The masking arrays are used in + * situations such as one-to-many and many-to-one recurrent neural network (RNN) designs, as well + * as for supporting time series of varying lengths within the same minibatch. + */ + public INDArray output(INDArray input, boolean train, INDArray featuresMask, + INDArray labelsMask) { + return output(input, train, featuresMask, labelsMask, null); + } + + /** + * Get the network output, which is optionally placed in the specified memory workspace.
If no + * memory workspace is provided, the output will be detached (not in any workspace).
If a + * memory workspace is provided, the output activation array (i.e., the INDArray returned by this + * method) will be placed in the specified workspace. This workspace must be opened by the user + * before calling this method - and the user is responsible for (a) closing this workspace, and + * (b) ensuring the output array is not used out of scope (i.e., not used after closing the + * workspace to which it belongs - as this is likely to cause either an exception when used, or a + * crash). + * + * @param input Input to the network + * @param train True for train, false otherwise + * @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling + * this method. + * @return The output/activations from the network (either detached or in the specified workspace + * if provided) + */ + public INDArray output(INDArray input, boolean train, MemoryWorkspace outputWorkspace) { + return output(input, train, null, null, outputWorkspace); + } + + /** + * Get the network output, which is optionally placed in the specified memory workspace.
If no + * memory workspace is provided, the output will be detached (not in any workspace).
If a + * memory workspace is provided, the output activation array (i.e., the INDArray returned by this + * method) will be placed in the specified workspace. This workspace must be opened by the user + * before calling this method - and the user is responsible for (a) closing this workspace, and + * (b) ensuring the output array is not used out of scope (i.e., not used after closing the + * workspace to which it belongs - as this is likely to cause either an exception when used, or a + * crash). + * + * @param input Input to the network + * @param train True for train, false otherwise + * @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling + * this method. + * @return The output/activations from the network (either detached or in the specified workspace + * if provided) + */ + public synchronized INDArray output(INDArray input, boolean train, INDArray featuresMask, + INDArray labelsMask, MemoryWorkspace outputWorkspace) { + try { + return outputOfLayerDetached(train, FwdPassType.STANDARD, layers.length - 1, input, + featuresMask, labelsMask, outputWorkspace); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } + + /** + * This method uses provided OutputAdapter to return custom object built from INDArray + *

+ * PLEASE NOTE: This method uses dedicated Workspace for output generation to avoid redundant + * allocations + * + * @param inputs Input arrays to the netwonk + * @param inputMasks Optional input mask arrays (may be null) + * @param labelMasks Optional label mask arrays (may be null + * @param outputAdapter OutputAdapter instance + * @param T extends Object + * @return T instance produced by OutputAdapter + */ + public synchronized T output(@NonNull INDArray inputs, INDArray inputMasks, + INDArray labelMasks, @NonNull OutputAdapter outputAdapter) { + try (val ws = Nd4j.getWorkspaceManager() + .getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM)) { + if (outputAdapter instanceof ModelAdapter) { + return ((ModelAdapter) outputAdapter).apply(this, new INDArray[]{inputs}, + new INDArray[]{inputMasks}, new INDArray[]{labelMasks}); + } else { + return outputAdapter.apply(output(inputs, false, inputMasks, labelMasks, ws)); + } + } + } + + /** + * Perform inference on the provided input/features - i.e., perform forward pass using the + * provided input/features and return the output of the final layer. Equivalent to + * {@link #output(INDArray, boolean)} with train=false - i.e., this method is used for inference. + * + * @param input Input to the network + * @return The network predictions - i.e., the activations of the final layer + */ + public INDArray output(INDArray input) { + return output(input, TrainingMode.TEST); + } + + /** + * Generate the output for all examples/batches in the input iterator, and concatenate them into a + * single array. See {@link #output(INDArray)}
NOTE 1: The output array can require a + * considerable amount of memory for iterators with a large number of examples
NOTE 2: This + * method cannot be used for variable length time series outputs, as this would require padding + * arrays for some outputs, or returning a mask array (which cannot be done with this method). For + * variable length time series applications, use one of the other output methods. This method also + * cannot be used with fully convolutional networks with different output sizes (for example, + * segmentation on different input image sizes). + * + * @param iterator Data to pass through the network + * @return output for all examples in the iterator, concatenated into a + */ + public INDArray output(DataSetIterator iterator, boolean train) { + List outList = new ArrayList<>(); + long[] firstOutputShape = null; + while (iterator.hasNext()) { + DataSet next = iterator.next(); + INDArray features = next.getFeatures(); + + if (features == null) { + continue; + } + + INDArray fMask = next.getFeaturesMaskArray(); + INDArray lMask = next.getLabelsMaskArray(); + INDArray output = this.output(features, train, fMask, lMask); + outList.add(output); + if (firstOutputShape == null) { + firstOutputShape = output.shape(); + } else { + //Validate that shapes are the same (may not be, for some RNN variable length time series applications) + long[] currShape = output.shape(); + Preconditions.checkState(firstOutputShape.length == currShape.length, + "Error during forward pass:" + + "different minibatches have different output array ranks - first minibatch shape %s, last minibatch shape %s", + firstOutputShape, currShape); + for (int i = 1; i < currShape.length; + i++) { //Skip checking minibatch dimension, fine if this varies + Preconditions.checkState(firstOutputShape[i] == currShape[i], + "Current output shape does not match first" + + " output array shape at position %s: all dimensions must match other than the first dimension.\n" + + + " For variable length output size/length use cases such as for RNNs with multiple sequence lengths," + + + " use one of the other (non iterator) output methods. First batch output shape: %s, current batch output shape: %s", + i, firstOutputShape, currShape); + } + } + } + return Nd4j.concat(0, outList.toArray(new INDArray[outList.size()])); + } + + /** + * Equivalent to {@link #output(DataSetIterator, boolean)} with train=false + */ + public INDArray output(DataSetIterator iterator) { + return output(iterator, false); + } + + /** + * Perform inference and then calculate the F1 score of the output(input) vs. the labels. + * + * @param input the input to perform inference with + * @param labels the true labels + * @return the score for the given input,label pairs + */ + @Override + public double f1Score(INDArray input, INDArray labels) { + feedForward(input); + setLabels(labels); + Evaluation eval = new Evaluation(); + eval.eval(labels, output(input)); + return eval.f1(); + } + + /** + * @deprecated Will be removed in a future release + */ + @Deprecated + @Override + public int numLabels() { + return (int) labels.size(1); + } + + /** + * Sets the input and labels and calculates the score (value of the output layer loss function + * plus l1/l2 if applicable) for the prediction with respect to the true labels
This is + * equivalent to {@link #score(DataSet, boolean)} with training==false. + * + * @param data the data to score + * @return the score for the given input,label pairs + * @see #score(DataSet, boolean) + */ + public double score(DataSet data) { + return score(data, false); + } + + /** + * Sets the input and labels and calculates the score (value of the output layer loss function + * plus l1/l2 if applicable) for the prediction with respect to the true labels
+ * + * @param data data to calculate score for + * @param training If true: score during training. If false: score at test time. This can affect + * the application of certain features, such as dropout and dropconnect (which are + * applied at training time only) + * @return the score (value of the loss function) + */ + public double score(DataSet data, boolean training) { + try { + return scoreHelper(data, training); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } + + private double scoreHelper(DataSet data, boolean training) { + boolean hasMaskArray = data.hasMaskArrays(); + if (hasMaskArray) { + setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray()); + } + + if (!(getOutputLayer() instanceof IOutputLayer)) { + throw new IllegalStateException( + "Cannot calculate score if final layer is not an instance of IOutputLayer. " + + "Final layer is of type: " + getOutputLayer().getClass()); + } + + WorkspaceMode wsm = (training ? layerWiseConfigurations.getTrainingWorkspaceMode() + : layerWiseConfigurations.getInferenceWorkspaceMode()); + LayerWorkspaceMgr mgr; + if (wsm == WorkspaceMode.NONE) { + mgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + mgr = LayerWorkspaceMgr.builder() + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + //TODO we can probably optimize this + .noWorkspaceFor(ArrayType.ACTIVATIONS) + .noWorkspaceFor(ArrayType.INPUT) + .build(); + } + mgr.setHelperWorkspacePointers(helperWorkspaces); + + INDArray inputToOutputLayer = outputOfLayerDetached(training, FwdPassType.STANDARD, + layers.length - 2, data.getFeatures(), + data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null); + + if (data.getFeatures().size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + IOutputLayer ol = (IOutputLayer) getOutputLayer(); + if (getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) != null) { + inputToOutputLayer = getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) + .preProcess(inputToOutputLayer, (int) data.getFeatures().size(0), mgr); + } + ol.setInput(inputToOutputLayer, mgr); //Feedforward doesn't include output layer for efficiency + ol.setLabels(data.getLabels()); + double score; + try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { + score = ol.computeScore(calcRegularizationScore(true), training, mgr); + } + + if (hasMaskArray) { + clearLayerMaskArrays(); + } + clearLayersStates(); + + return score; + } + + /** + * As per {@link #scoreExamples(DataSet, boolean)} - the outputs (example scores) for all DataSets + * in the iterator are concatenated + */ + public INDArray scoreExamples(DataSetIterator iter, boolean addRegularizationTerms) { + List out = new ArrayList<>(); + + while (iter.hasNext()) { + out.add(scoreExamples(iter.next(), addRegularizationTerms)); + } + return Nd4j.toFlattened('f', out); + } + + /** + * Calculate the score for each example in a DataSet individually. Unlike {@link #score(DataSet)} + * and {@link #score(DataSet, boolean)} this method does not average/sum over examples. This + * method allows for examples to be scored individually (at test time only), which may be useful + * for example for autoencoder architectures and the like.
Each row of the output (assuming + * addRegularizationTerms == true) is equivalent to calling score(DataSet) with a single example. + * + * @param data The data to score + * @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If + * false: don't add regularization terms + * @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss + * value) of the ith example + */ + public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) { + try { + return scoreExamplesHelper(data, addRegularizationTerms); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } + + private INDArray scoreExamplesHelper(DataSet data, boolean addRegularizationTerms) { + INDArray inputLast = outputOfLayerDetached(false, FwdPassType.STANDARD, layers.length - 2, + data.getFeatures(), + data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null); + setLabels(data.getLabels()); + setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray()); + + //TODO we might want workspaces here? + LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces(); + + INDArray out; + if (getOutputLayer() instanceof IOutputLayer) { + IOutputLayer ol = (IOutputLayer) getOutputLayer(); + if (layerWiseConfigurations.getInputPreProcess(layers.length - 1) != null) { + + if (data.getFeatures().size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + inputLast = layerWiseConfigurations.getInputPreProcess(layers.length - 1) + .preProcess(inputLast, + (int) data.getFeatures().size(0), mgr); + } + ol.setLabels(data.getLabels()); + ol.setInput(inputLast, mgr); + double r = (addRegularizationTerms ? calcRegularizationScore(true) : 0); + out = ol.computeScoreForExamples(r, mgr); + } else { + throw new UnsupportedOperationException( + "Cannot calculate score with respect to labels without an OutputLayer"); + } + + clearLayersStates(); + clearLayerMaskArrays(); + return out; + } + + @Override + public void fit() { + fit(input, labels); + } + + @Override + public void update(INDArray gradient, String paramType) { + throw new UnsupportedOperationException("Not implemented"); + } + + /** + * Score of the model (relative to the objective function) - previously calculated on the last + * minibatch + * + * @return the score of the model (relative to the objective function) + */ + @Override + public double score() { + return score; + } + + /** + * Intended for developer/internal use + */ + public void setScore(double score) { + this.score = score; + } + + @Override + public void computeGradientAndScore(LayerWorkspaceMgr layerWorkspaceMgr) { + computeGradientAndScore(); + } + + public void computeGradientAndScore() { + + if (!(getOutputLayer() instanceof IOutputLayer)) { + throw new DL4JException( + "Cannot calculate gradient and score with respect to labels: final layer is not an IOutputLayer. " + + + "Final layer class: " + getOutputLayer().getClass() + + ". To calculate gradients and fit a network " + + "using backpropagation, the final layer must be an output layer"); + } + + //Note: Workspace manager is only ose here for score calculation... other workspace managers are used in the + // various FF/backprop methds + LayerWorkspaceMgr mgr; + if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + mgr = LayerWorkspaceMgr.noWorkspaces(); + } else { + mgr = LayerWorkspaceMgr.builder() + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); + + if (layerWiseConfigurations.getCacheMode() != null) { + //For now: store cache mode activations in activations workspace + mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); + } + } + + boolean tbptt = layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT; + FwdPassType fwdType = (tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE + : FwdPassType.STANDARD); + synchronizeIterEpochCounts(); + + //Calculate activations (which are stored in each layer, and used in backprop) + try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) { + //First: do a feed-forward through the network + //Note that we don't actually need to do the full forward pass through the output layer right now; but we do + // need the input to the output layer to be set (such that backprop can be done) + List activations = ffToLayerActivationsInWs(layers.length - 2, fwdType, tbptt, + input, mask, null); + if (!trainingListeners.isEmpty()) { + //TODO: We possibly do want output layer activations in some cases here... + for (TrainingListener tl : trainingListeners) { + tl.onForwardPass(this, activations); + } + } + INDArray inputToOutputLayer = activations.get(activations.size() - 1); + if (layerWiseConfigurations.getInputPreProcess(layers.length - 1) != null) { + inputToOutputLayer = layerWiseConfigurations.getInputPreProcess(layers.length - 1) + .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); + //Validate activations location + } + getOutputLayer().setInput(inputToOutputLayer, mgr); + //Then: compute gradients + Pair pair = calcBackpropGradients(null, true, false, false); + this.gradient = (pair == null ? null : pair.getFirst()); + + //Calculate score + try (MemoryWorkspace wsFF = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { + double r = calcRegularizationScore(true); + score = ((IOutputLayer) getOutputLayer()).computeScore(r, true, mgr); + } + + //Listeners + if (!trainingListeners.isEmpty()) { + try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + for (TrainingListener tl : trainingListeners) { + tl.onBackwardPass(this); + } + } + } + } + + //Clear the post noise/dropconnect parameters on the output layer + getOutputLayer().clearNoiseWeightParams(); + } + + /** + * Clear the inputs. Clears optimizer state. + */ + public void clear() { + for (Layer layer : layers) { + layer.clear(); + } + + input = null; + labels = null; + solver = null; + } + + @Override + public void applyConstraints(int iteration, int epoch) { + for (Layer l : layers) { + l.applyConstraints(iteration, epoch); + } + } + + @Override + public void setInput(INDArray input, LayerWorkspaceMgr mgr) { + throw new UnsupportedOperationException("Not supported"); + } + + /** + * Get the output layer - i.e., the last layer in the netwok + * + * @return + */ + public Layer getOutputLayer() { + Layer ret = getLayers()[getLayers().length - 1]; + if (ret instanceof FrozenLayerWithBackprop) { + ret = ((FrozenLayerWithBackprop) ret).getInsideLayer(); + } + return ret; + } + + + /** + * See {@link #setParams(INDArray)} + */ + public void setParameters(INDArray params) { + setParams(params); + } + + /** + * Intended for internal/developer use + */ + public NeuralNetConfiguration getDefaultConfiguration() { + return defaultConfiguration; + } + + public INDArray getLabels() { + return labels; + } + + /** + * @param labels Labels to set + */ + public void setLabels(INDArray labels) { + this.labels = labels; + } + + public INDArray getInput() { + return input; + } + + /** + * Set the input array for the network + * + * @param input Input array to set + */ + public void setInput(INDArray input) { + this.input = input; + if (this.layers == null) { + init(); + } + if (input != null) { + if (input.length() == 0) { + throw new IllegalArgumentException( + "Invalid input: length 0 (shape: " + Arrays.toString(input.shape()) + ")"); + } + + if (input.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + setInputMiniBatchSize((int) input.size(0)); + } + } + + /** + * Get the number of layers in the network + * + * @return the number of layers in the network + */ + public int getnLayers() { + return layerWiseConfigurations.getConfs().size(); + } + + /** + * @return The layers in the network + */ + public synchronized Layer[] getLayers() { + return layers; + } + + public void setLayers(Layer[] layers) { + this.layers = layers; + } + + public Layer getLayer(int i) { + Preconditions.checkArgument(i >= 0 && i < layers.length, + "Invalid layer index: layer index must be 0" + + " to %s (inclusive), got index %s", layers.length - 1, i); + return layers[i]; + } + + public Layer getLayer(String name) { + return layerMap.get(name); + } + + public List getLayerNames() { + return new ArrayList<>(layerMap.keySet()); + } + + public INDArray getMask() { + return mask; + } + + public void setMask(INDArray mask) { + this.mask = mask; + } + + public INDArray getMaskArray() { + return mask; + } + + @Override + public void setMaskArray(INDArray maskArray) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isPretrainLayer() { + return false; + } + + @Override + public void clearNoiseWeightParams() { + for (Layer l : layers) { + l.clearNoiseWeightParams(); + } + } + + @Override + public void allowInputModification(boolean allow) { + throw new UnsupportedOperationException("Not supported"); + } + + //========== + //Layer methods + + @Override + public Pair feedForwardMaskArray(INDArray maskArray, + MaskState currentMaskState, + int minibatchSize) { + if (maskArray == null) { + for (int i = 0; i < layers.length; i++) { + layers[i].feedForwardMaskArray(null, null, minibatchSize); + } + } else { + //Do a forward pass through each preprocessor and layer + for (int i = 0; i < layers.length; i++) { + InputPreProcessor preProcessor = getLayerWiseConfigurations().getInputPreProcess(i); + + if (preProcessor != null) { + Pair p = + preProcessor.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); + if (p != null) { + maskArray = p.getFirst(); + currentMaskState = p.getSecond(); + } else { + maskArray = null; + currentMaskState = null; + } + } + + Pair p = + layers[i].feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); + if (p != null) { + maskArray = p.getFirst(); + currentMaskState = p.getSecond(); + } else { + maskArray = null; + currentMaskState = null; + } + } + } + + return new Pair<>(maskArray, currentMaskState); + } + + @Override + public LayerHelper getHelper() { + throw new UnsupportedOperationException("Not supported"); + } + + @Override + public Type type() { + return Type.MULTILAYER; + } + + /** + * Equivalent to {@link #output(INDArray)} using the input set via {@link #setInput(INDArray)} + */ + public INDArray activate(TrainingMode training) { + return output(input, training == TrainingMode.TRAIN); + } + + /** + * Equivalent to {@link #output(INDArray, TrainingMode)} + */ + public INDArray activate(INDArray input, TrainingMode training) { + return output(input, training == TrainingMode.TRAIN); + } + + @Override + public Pair backpropGradient(INDArray epsilon, + LayerWorkspaceMgr workspaceMgr) { + if (getOutputLayer() instanceof IOutputLayer) { + throw new UnsupportedOperationException( + "Cannot calculate gradients based on epsilon with OutputLayer"); + } + + return calcBackpropGradients(epsilon, false, false, true); + } + + @Override + public int getIndex() { + return layerIndex; + } + + @Override + public void setIndex(int index) { + layerIndex = index; + } + + @Override + public int getIterationCount() { + return getLayerWiseConfigurations().getIterationCount(); + } + + @Override + public void setIterationCount(int iterationCount) { + getLayerWiseConfigurations().setIterationCount(iterationCount); + } + + @Override + public int getEpochCount() { + return getLayerWiseConfigurations().getEpochCount(); + } + + @Override + public void setEpochCount(int epochCount) { + getLayerWiseConfigurations().setEpochCount(epochCount); + } + + @Override + public double calcRegularizationScore(boolean backpropParamsOnly) { + double scoreSum = 0.0; + for (int i = 0; i < layers.length; i++) { + scoreSum += layers[i].calcRegularizationScore(backpropParamsOnly); + } + return scoreSum; + } + + @Override + public void update(Gradient gradient) { + if (gradient.gradient().length() != numParams(true)) { + throw new IllegalArgumentException( + "Invalid input: expect gradients array of length " + numParams(true)); + } + for (Map.Entry entry : gradient.gradientForVariable().entrySet()) { + String key = entry.getKey(); + INDArray val = entry.getValue(); + int idx = key.indexOf('_'); + if (idx == -1) { + throw new IllegalStateException( + "Invalid param key: not have layer separator: \"" + key + "\""); + } + Integer layerId = Integer.parseInt(key.substring(0, idx)); + String paramType = key.substring(idx + 1); + // Update MLN gradient + this.gradient.gradientForVariable().put(key, val); + // Update layer params + layers[layerId].update(val, paramType); + } + // Update layerwise gradient view + setBackpropGradientsViewArray(gradient.gradient()); + + } + + @Override + public INDArray activate(boolean training, LayerWorkspaceMgr mgr) { + throw new UnsupportedOperationException(); + } + + @Override + public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr mgr) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInputMiniBatchSize() { + if (!conf().isMiniBatch()) { + return 1; + } + + if (input.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + return (int) input.size(0); + } + + @Override + public void setInputMiniBatchSize(int size) { + if (layers != null) { + for (Layer l : layers) { + l.setInputMiniBatchSize(size); + } + } + } + + /** + * If this MultiLayerNetwork contains one or more RNN layers: conduct forward pass (prediction) + * but using previous stored state for any RNN layers. The activations for the final step are also + * stored in the RNN layers for use next time rnnTimeStep() is called.
This method can be used + * to generate output one or more steps at a time instead of always having to do forward pass from + * t=0. Example uses are for streaming data, and for generating samples from network output one + * step at a time (where samples are then fed back into the network as input)
If no previous + * state is present in RNN layers (i.e., initially or after calling rnnClearPreviousState()), the + * default initialization (usually 0) is used.
Supports mini-batch (i.e., multiple + * predictions/forward pass in parallel) as well as for single examples.
+ * + * @param input Input to network. May be for one or multiple time steps. For single time step: + * input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. + * miniBatchSize=1 for single example.
For multiple time steps: + * [miniBatchSize,inputSize,inputTimeSeriesLength] + * @return Output activations. If output is RNN layer (such as RnnOutputLayer): if input has shape + * [miniBatchSize,inputSize] i.e., is 2d, output has shape [miniBatchSize,outputSize] (i.e., also + * 2d).
Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using + * RnnOutputLayer. + * @see #rnnTimeStep(INDArray, MemoryWorkspace) For outputting the activations in the specified + * workspace + */ + public INDArray rnnTimeStep(INDArray input) { + return rnnTimeStep(input, null); + } + + /** + * See {@link #rnnTimeStep(INDArray)} for details
If no memory workspace is provided, the + * output will be detached (not in any workspace).
If a memory workspace is provided, the + * output activation array (i.e., the INDArray returned by this method) will be placed in the + * specified workspace. This workspace must be opened by the user before calling this method - and + * the user is responsible for (a) closing this workspace, and (b) ensuring the output array is + * not used out of scope (i.e., not used after closing the workspace to which it belongs - as this + * is likely to cause either an exception when used, or a crash). + * + * @param input Input activations + * @param outputWorkspace Output workspace. May be null + * @return The output/activations from the network (either detached or in the specified workspace + * if provided) + */ + public INDArray rnnTimeStep(INDArray input, MemoryWorkspace outputWorkspace) { + try { + boolean inputIs2d = input.rank() == 2; + INDArray out = outputOfLayerDetached(false, FwdPassType.RNN_TIMESTEP, layers.length - 1, + input, null, null, outputWorkspace); + if (inputIs2d && out.rank() == 3 && layers[layers.length - 1].type() == Type.RECURRENT) { + //Return 2d output with shape [miniBatchSize,nOut] + // instead of 3d output with shape [miniBatchSize,nOut,1] + return out.tensorAlongDimension(0, 1, 0); + } + return out; + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } + + /** + * Get the state of the RNN layer, as used in rnnTimeStep(). + * + * @param layer Number/index of the layer. + * @return Hidden state, or null if layer is not an RNN layer + */ + public Map rnnGetPreviousState(int layer) { + if (layer < 0 || layer >= layers.length) { + throw new IllegalArgumentException("Invalid layer number"); + } + Layer l = layers[layer]; + if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) { + l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying(); + } + if (!(l instanceof RecurrentLayer)) { + throw new IllegalArgumentException("Layer is not an RNN layer"); + } + return ((RecurrentLayer) l).rnnGetPreviousState(); + } + + /** + * Set the state of the RNN layer. + * + * @param layer The number/index of the layer. + * @param state The state to set the specified layer to + */ + public void rnnSetPreviousState(int layer, Map state) { + if (layer < 0 || layer >= layers.length) { + throw new IllegalArgumentException("Invalid layer number"); + } + Layer l = layers[layer]; + if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) { + l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying(); + } + if (!(l instanceof RecurrentLayer)) { + throw new IllegalArgumentException("Layer is not an RNN layer"); + } + RecurrentLayer r = (RecurrentLayer) l; + r.rnnSetPreviousState(state); + } + + /** + * Clear the previous state of the RNN layers (if any). + */ + public void rnnClearPreviousState() { + if (layers == null) { + return; + } + for (int i = 0; i < layers.length; i++) { + if (layers[i] instanceof RecurrentLayer) { + ((RecurrentLayer) layers[i]).rnnClearPreviousState(); + } else if (layers[i] instanceof MultiLayerNetwork) { + ((MultiLayerNetwork) layers[i]).rnnClearPreviousState(); + } else if (layers[i] instanceof BaseWrapperLayer + && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) { + ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying()).rnnClearPreviousState(); + } + } + } + + /** + * Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:
(a) + * like rnnTimeStep does forward pass using stored state for RNN layers, and
(b) unlike + * rnnTimeStep does not modify the RNN layer state
Therefore multiple calls to this method + * with the same input should have the same output.
Typically used during training only. Use + * rnnTimeStep for prediction/forward pass at test time. + * + * @param input Input to network + * @param training Whether training or not + * @param storeLastForTBPTT set to true if used as part of truncated BPTT training + * @return Activations for each layer (including input, as per feedforward() etc) + */ + public List rnnActivateUsingStoredState(INDArray input, boolean training, + boolean storeLastForTBPTT) { + return ffToLayerActivationsDetached(training, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, + storeLastForTBPTT, layers.length - 1, input, mask, null, false); + } + + /** + * Get the updater for this MultiLayerNetwork + * + * @return Updater for MultiLayerNetwork + */ + public Updater getUpdater() { + return getUpdater(true); + } + + /** + * Set the updater for the MultiLayerNetwork + */ + public void setUpdater(Updater updater) { + if (solver == null) { + solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); + } + solver.getOptimizer().setUpdater(updater); + } + + public Updater getUpdater(boolean initializeIfReq) { + if (solver == null && initializeIfReq) { + synchronized (this) { + if (solver == null) { //May have been created while waiting for lock + solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + .build(); + solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this)); + } + } + } + if (solver != null) { + return solver.getOptimizer().getUpdater(initializeIfReq); + } + return null; + } + + /** + * Set the mask arrays for features and labels. Mask arrays are typically used in situations such + * as one-to-many and many-to-one learning with recurrent neural networks, as well as for + * supporting time series of varying lengths within the same minibatch.
For example, with RNN + * data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape + * [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape + * [miniBatchSize,timeSeriesLength] and contain values 0 or 1 at each element (to specify whether + * a given input/example is present - or merely padding - at a given time step).
+ * NOTE: This method is not usually used directly. Instead, methods such as + * {@link #feedForward(INDArray, INDArray, INDArray)} + * and {@link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking + * internally. + * + * @param featuresMaskArray Mask array for features (input) + * @param labelsMaskArray Mask array for labels (output) + * @see #clearLayerMaskArrays() + */ + public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) { + if (featuresMaskArray != null) { + + if (featuresMaskArray.size(0) > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + //New approach: use feedForwardMaskArray method + feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0)); /* @@ -3308,837 +3629,883 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, org.d // non-zero (i.e., activationFunction(0*weights + bias) != 0 in general) //This assumes that the time series input is masked - i.e., values are 0 at the padded time steps, // so we don't need to do anything for the recurrent layer - + //Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order // as is done for 3d -> 2d time series reshaping INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray); - + for( int i=0; i See {@link #setLayerMaskArrays(INDArray, INDArray)} + * for details on mask arrays. + */ + public void clearLayerMaskArrays() { + for (Layer layer : layers) { + layer.setMaskArray(null); + } + } + + /** + * Evaluate the network (classification performance) + * + * @param iterator Iterator to evaluate on + * @return Evaluation object; results of evaluation on all examples in the data set + */ + public T evaluate(@NonNull DataSetIterator iterator) { + return (T) evaluate(iterator, null); + } + + /** + * Evaluate the network (classification performance). Can only be used with MultiDataSetIterator + * instances with a single input/output array + * + * @param iterator Iterator to evaluate on + * @return Evaluation object; results of evaluation on all examples in the data set + */ + public Evaluation evaluate(@NonNull MultiDataSetIterator iterator) { + return evaluate(new MultiDataSetWrapperIterator(iterator)); + } + + /** + * Evaluate the network for regression performance + * + * @param iterator Data to evaluate on + * @return Regression evaluation + */ + public T evaluateRegression(DataSetIterator iterator) { + return (T) doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0]; + } + + /** + * Evaluate the network for regression performance Can only be used with MultiDataSetIterator + * instances with a single input/output array + * + * @param iterator Data to evaluate on + */ + public org.nd4j.evaluation.regression.RegressionEvaluation evaluateRegression( + MultiDataSetIterator iterator) { + return evaluateRegression(new MultiDataSetWrapperIterator(iterator)); + } + + /** + * @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection + * of appropriate ROC/threshold configuration + */ + @Deprecated + public T evaluateROC(DataSetIterator iterator) { + return evaluateROC(iterator, 0); + } + + /** + * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} + * class + * + * @param iterator Data to evaluate on + * @param rocThresholdSteps Number of threshold steps to use with {@link ROC} - see that class for + * details. + * @return ROC evaluation on the given dataset + */ + public T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) { + Layer outputLayer = getOutputLayer(); + if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) { + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), + ROC.class); + } + return (T) doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0]; + } + + /** + * @deprecated To be removed - use {@link #evaluateROCMultiClass(DataSetIterator, int)} to enforce + * selection of appropriate ROC/threshold configuration + */ + @Deprecated + public T evaluateROCMultiClass(DataSetIterator iterator) { + return evaluateROCMultiClass(iterator, 0); + } + + /** + * Evaluate the network on the specified data, using the {@link ROCMultiClass} class + * + * @param iterator Data to evaluate on + * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass} + * @return Multi-class ROC evaluation on the given dataset + */ + public T evaluateROCMultiClass(DataSetIterator iterator, + int rocThresholdSteps) { + Layer outputLayer = getOutputLayer(); + if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) { + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), + ROCMultiClass.class); + } + return (T) doEvaluation(iterator, + new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0]; + } + + /** + * Perform evaluation using an arbitrary IEvaluation instance. + * + * @param iterator data to evaluate on + */ + public T[] doEvaluation(DataSetIterator iterator, T... evaluations) { + try { + return doEvaluationHelper(iterator, evaluations); + } catch (OutOfMemoryError e) { + CrashReportingUtil.writeMemoryCrashDump(this, e); + throw e; + } + } + + public T[] doEvaluationHelper(DataSetIterator iterator, + T... evaluations) { + if (!iterator.hasNext() && iterator.resetSupported()) { + iterator.reset(); } - /** Remove the mask arrays from all layers.
- * See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays. - */ - public void clearLayerMaskArrays() { - for (Layer layer : layers) { - layer.setMaskArray(null); - } + DataSetIterator iter = + iterator.asyncSupported() ? new AsyncDataSetIterator(iterator, 2, true) : iterator; + + WorkspaceMode cMode = layerWiseConfigurations.getTrainingWorkspaceMode(); + layerWiseConfigurations.setTrainingWorkspaceMode( + layerWiseConfigurations.getInferenceWorkspaceMode()); + + //First: let's determine if we should do 'split feed forward' for long time series + //The idea: RNN 20k time steps. Train using TBPTT length 100 -> 200 segments of length 100. If we naively + // just use .output(INDArray) here, then our memory requirements are 200x larger than if we did the same + // evaluation in segments... + //Only do this if TBPTT is enabled - if not, it means we can train without TBPTT and hence should be able + // to test without splitting also + boolean useRnnSegments = (layerWiseConfigurations.getBackpropType() + == BackpropType.TruncatedBPTT); + + MemoryWorkspace outputWs; + if (getLayerWiseConfigurations().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED) { + outputWs = Nd4j.getWorkspaceManager() + .getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM); + } else { + outputWs = new DummyWorkspace(); } - /** - * Evaluate the network (classification performance) - * - * @param iterator Iterator to evaluate on - * @return Evaluation object; results of evaluation on all examples in the data set - */ - public T evaluate(@NonNull DataSetIterator iterator) { - return (T)evaluate(iterator, null); - } + while (iter.hasNext()) { + DataSet next = iter.next(); - /** - * Evaluate the network (classification performance). - * Can only be used with MultiDataSetIterator instances with a single input/output array - * - * @param iterator Iterator to evaluate on - * @return Evaluation object; results of evaluation on all examples in the data set - */ - public Evaluation evaluate(@NonNull MultiDataSetIterator iterator) { - return evaluate(new MultiDataSetWrapperIterator(iterator)); - } + if (next.getFeatures() == null || next.getLabels() == null) { + continue; + } - /** - * Evaluate the network for regression performance - * @param iterator Data to evaluate on - * @return Regression evaluation - */ - public T evaluateRegression(DataSetIterator iterator) { - return (T)doEvaluation(iterator, new RegressionEvaluation(iterator.totalOutcomes()))[0]; - } + INDArray features = next.getFeatures(); + INDArray labels = next.getLabels(); + INDArray fMask = next.getFeaturesMaskArray(); + INDArray lMask = next.getLabelsMaskArray(); + List meta = next.getExampleMetaData(); - /** - * Evaluate the network for regression performance - * Can only be used with MultiDataSetIterator instances with a single input/output array - * @param iterator Data to evaluate on - */ - public org.nd4j.evaluation.regression.RegressionEvaluation evaluateRegression(MultiDataSetIterator iterator) { - return evaluateRegression(new MultiDataSetWrapperIterator(iterator)); - } + if (!useRnnSegments) { + //Standard/non-RNN case: + try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) { + INDArray out = outputOfLayerDetached(false, FwdPassType.STANDARD, layers.length - 1, + features, fMask, lMask, ws); - /** - * @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration - */ - @Deprecated - public T evaluateROC(DataSetIterator iterator){ - return evaluateROC(iterator, 0); - } - - /** - * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} class - * - * @param iterator Data to evaluate on - * @param rocThresholdSteps Number of threshold steps to use with {@link ROC} - see that class for details. - * @return ROC evaluation on the given dataset - */ - public T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) { - Layer outputLayer = getOutputLayer(); - if(getLayerWiseConfigurations().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class); - } - return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0]; - } - - /** - * @deprecated To be removed - use {@link #evaluateROCMultiClass(DataSetIterator, int)} to enforce selection of appropriate ROC/threshold configuration - */ - @Deprecated - public T evaluateROCMultiClass(DataSetIterator iterator) { - return evaluateROCMultiClass(iterator, 0); - } - - /** - * Evaluate the network on the specified data, using the {@link ROCMultiClass} class - * - * @param iterator Data to evaluate on - * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass} - * @return Multi-class ROC evaluation on the given dataset - */ - public T evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) { - Layer outputLayer = getOutputLayer(); - if(getLayerWiseConfigurations().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class); - } - return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0]; - } - - /** - * Perform evaluation using an arbitrary IEvaluation instance. - * - * @param iterator data to evaluate on - */ - public T[] doEvaluation(DataSetIterator iterator, T... evaluations) { - try{ - return doEvaluationHelper(iterator, evaluations); - } catch (OutOfMemoryError e){ - CrashReportingUtil.writeMemoryCrashDump(this, e); - throw e; - } - } - - public T[] doEvaluationHelper(DataSetIterator iterator, T... evaluations) { - if (!iterator.hasNext() && iterator.resetSupported()) { - iterator.reset(); - } - - DataSetIterator iter = iterator.asyncSupported() ? new AsyncDataSetIterator(iterator, 2, true) : iterator; - - WorkspaceMode cMode = layerWiseConfigurations.getTrainingWorkspaceMode(); - layerWiseConfigurations.setTrainingWorkspaceMode(layerWiseConfigurations.getInferenceWorkspaceMode()); - - //First: let's determine if we should do 'split feed forward' for long time series - //The idea: RNN 20k time steps. Train using TBPTT length 100 -> 200 segments of length 100. If we naively - // just use .output(INDArray) here, then our memory requirements are 200x larger than if we did the same - // evaluation in segments... - //Only do this if TBPTT is enabled - if not, it means we can train without TBPTT and hence should be able - // to test without splitting also - boolean useRnnSegments = (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT); - - MemoryWorkspace outputWs; - if(getLayerWiseConfigurations().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED){ - outputWs = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM); - } else { - outputWs = new DummyWorkspace(); - } - - while (iter.hasNext()) { - DataSet next = iter.next(); - - if (next.getFeatures() == null || next.getLabels() == null) - continue; - - - INDArray features = next.getFeatures(); - INDArray labels = next.getLabels(); - INDArray fMask = next.getFeaturesMaskArray(); - INDArray lMask = next.getLabelsMaskArray(); - List meta = next.getExampleMetaData(); - - - if (!useRnnSegments) { - //Standard/non-RNN case: - try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) { - INDArray out = outputOfLayerDetached(false, FwdPassType.STANDARD, layers.length - 1, features, fMask, lMask, ws); - - try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - for (T evaluation : evaluations) - evaluation.eval(labels, out, lMask, meta); - } - } - } else { - rnnClearPreviousState(); - - - //Get subset of features and labels: - val fwdLen = layerWiseConfigurations.getTbpttFwdLength(); - val tsLength = features.size(2); - long nSubsets = tsLength / fwdLen; - if (tsLength % fwdLen != 0) - nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20) - for (int i = 0; i < nSubsets; i++) { - val startTimeIdx = i * fwdLen; - val endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength); - - if (endTimeIdx > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - INDArray[] subsets = getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, fMask, lMask); - - setLayerMaskArrays(subsets[2], subsets[3]); - - try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) { - INDArray outSub = rnnTimeStep(subsets[0], ws); - try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - for (T evaluation : evaluations) - evaluation.eval(subsets[1], outSub, subsets[3]); - } - } - } + try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + for (T evaluation : evaluations) { + evaluation.eval(labels, out, lMask, meta); } - - //Clear inputs, masks etc. Important to avoid leaking invalidated/out of scope arrays between iterations - clearLayersStates(); + } } + } else { + rnnClearPreviousState(); - if (iterator.asyncSupported()) - ((AsyncDataSetIterator) iter).shutdown(); - - layerWiseConfigurations.setTrainingWorkspaceMode(cMode); - - return evaluations; - } - - /** - * Evaluate the network on the provided data set. Used for evaluating the performance of classifiers - * - * @param iterator Data to undertake evaluation on - * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator - */ - public Evaluation evaluate(DataSetIterator iterator, List labelsList) { - return evaluate(iterator, labelsList, 1); - } - - @Override - public INDArray updaterState() { - return getUpdater() != null ? getUpdater().getStateViewArray() : null; - } - - @Override - public void fit(MultiDataSet dataSet) { - if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) { - INDArray features = dataSet.getFeatures(0); - INDArray labels = dataSet.getLabels(0); - INDArray fMask = null; - INDArray lMask = null; - - if (dataSet.getFeaturesMaskArrays() != null) - fMask = dataSet.getFeaturesMaskArrays()[0]; - - if (dataSet.getFeaturesMaskArrays() != null) - lMask = dataSet.getLabelsMaskArrays()[0]; - - DataSet ds = new DataSet(features, labels, fMask, lMask); - fit(ds); - } else { - throw new DL4JInvalidInputException( - "MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." + - "Please consider use of ComputationGraph"); + //Get subset of features and labels: + val fwdLen = layerWiseConfigurations.getTbpttFwdLength(); + val tsLength = features.size(2); + long nSubsets = tsLength / fwdLen; + if (tsLength % fwdLen != 0) { + nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20) } - } + for (int i = 0; i < nSubsets; i++) { + val startTimeIdx = i * fwdLen; + val endTimeIdx = Math.min(startTimeIdx + fwdLen, tsLength); - /** - * Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified number of epochs. - * Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a loop - * - * @param iterator Training data (DataSetIterator). Iterator must support resetting - * @param numEpochs Number of training epochs, >= 1 - */ - public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs){ - Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs); - Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), "Cannot perform multiple epochs training using" + - "iterator has does not support resetting (iterator.resetSupported() returned false)"); - - for(int i = 0; i < numEpochs; i++) { - fit(iterator); - } - } - - /** - * Perform minibatch training on all minibatches in the MultiDataSetIterator.
- * Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as - * MultiLayerNetwork only supports 1 input and 1 output) - * - * @param iterator Training data (DataSetIterator). Iterator must support resetting - */ - @Override - public void fit(MultiDataSetIterator iterator) { - fit(new MultiDataSetWrapperIterator(iterator)); - } - - @Override - public T[] doEvaluation(MultiDataSetIterator iterator, T[] evaluations) { - return doEvaluation(new MultiDataSetWrapperIterator(iterator), evaluations); - } - - /** - * Evaluate the network (for classification) on the provided data set, with top N accuracy in addition to standard accuracy. - * For 'standard' accuracy evaluation only, use topN = 1 - * - * @param iterator Iterator (data) to evaluate on - * @param labelsList List of labels. May be null. - * @param topN N value for top N accuracy evaluation - * @return Evaluation object, summarizing the results of the evaluation on the provided DataSetIterator - */ - public Evaluation evaluate(DataSetIterator iterator, List labelsList, int topN) { - if (layers == null || !(getOutputLayer() instanceof IOutputLayer)) { - throw new IllegalStateException("Cannot evaluate network with no output layer"); - } - if (labelsList == null) { - try { - labelsList = iterator.getLabels(); - } catch (Throwable t){ } //Ignore, maybe UnsupportedOperationException etc - } - - Layer outputLayer = getOutputLayer(); - if(getLayerWiseConfigurations().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class); - } - - Evaluation e = new org.deeplearning4j.eval.Evaluation(labelsList, topN); - doEvaluation(iterator, e); - - return e; - } - - protected void update(Task task) { - if (!initDone) { - initDone = true; - Heartbeat heartbeat = Heartbeat.getInstance(); - task = ModelSerializer.taskByModel(this); - Environment env = EnvironmentUtils.buildEnvironment(); - heartbeat.reportEvent(Event.STANDALONE, env, task); - } - } - - /** - * String detailing the architecture of the multilayernetwork. - * Columns are LayerIndex with layer type, nIn, nOut, Total number of parameters and the Shapes of the parameters - * Will also give information about frozen layers, if any. - * @return Summary as a string - * @see #memoryInfo(int, InputType) - */ - public String summary() { - return summary(null); - } - - /** - * String detailing the architecture of the multilayernetwork. - * Will also display activation size when given an input type. - * Columns are LayerIndex with layer type, nIn, nOut, Total number of parameters, Shapes of the parameters, Input activation shape, Output activation shape - * Will also give information about frozen layers, if any. - * @return Summary as a string - * @see #memoryInfo(int, InputType) - */ - public String summary(InputType inputType) { - StringBuilder ret = new StringBuilder(); - ret.append("\n"); - - List lines = new ArrayList<>(); - if(inputType == null){ - lines.add(new String[]{"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape"}); - } else { - lines.add(new String[]{"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape", "InputShape", "OutputShape"}); - } - int[] maxLength = new int[inputType == null ? 4 : 6]; - String[] header = lines.get(0); - for( int i=0; i 0) { - paramShape = ""; - if (currentLayer instanceof BidirectionalLayer) { // Bidirectional layer is not an FFL - BidirectionalLayer bi = (BidirectionalLayer) currentLayer; - in = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNIn()); - out = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNOut()); - } else { - try { - in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn()); - out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut()); - } - catch (Exception e) { // Some layers, like PReLU, are just BaseLayers (but have parameters) - } - } - Set paraNames = currentLayer.paramTable().keySet(); - for (String aP : paraNames) { - String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape()); - paramShape += aP + ":" + paramS + ", "; - } - paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString(); - } - if (currentLayer instanceof FrozenLayer) { - frozenParams += currentLayer.numParams(); - classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\."); - className = "Frozen " + classNameArr[classNameArr.length - 1]; - } - - String[] line; - if (inputType == null) { - line = new String[]{name + " (" + className + ")", in + "," + out, paramCount, paramShape}; - } else { - line = new String[]{name + " (" + className + ")", in + "," + out, paramCount,paramShape,inShape,outShape}; - } - for( int i=0; iautomatically when using iterator-based fitting methods, such as - * {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, INDArray/INDArray etc), - * the network has no way to know when one epoch ends and another starts. In such situations, this method - * can be used to increment the epoch counter.
- * Note that the epoch counter is used for situations such as some learning rate schedules, and the like. - * - * The current epoch count can be obtained using {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()} - */ - public void incrementEpochCount(){ - layerWiseConfigurations.setEpochCount(layerWiseConfigurations.getEpochCount() + 1); - synchronizeIterEpochCounts(); - } - - - protected void synchronizeIterEpochCounts() { - //TODO: this is necessary for some schedules - but the redundant values are a little ugly... - int currIter = getIterationCount(); - int currEpoch = getEpochCount(); - for(Layer l : layers) { - l.setIterationCount(currIter); - l.setEpochCount(currEpoch); - } - } - - /** - * Save the MultiLayerNetwork to a file. Restore using {@link #load(File, boolean)}. - * Note that this saves the updater (i.e., the state array for momentum/Adam/rmsprop etc), which is desirable - * if further training will be undertaken. - * - * @param f File to save the network to - * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams) - * @see #save(File, boolean) - */ - public void save( File f ) throws IOException { - save(f, true); - } - - /** - * Save the MultiLayerNetwork to a file. Restore using {@link #load(File, boolean)}. - * - * @param f File to save the network to - * @param saveUpdater If true: save the updater (i.e., the state array for momentum/Adam/rmsprop etc), which should - * usually be saved if further training is required - * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams) - * @see #save(File, boolean) - */ - public void save(File f, boolean saveUpdater) throws IOException{ - ModelSerializer.writeModel(this, f, saveUpdater); - } - - /** - * Restore a MultiLayerNetwork to a file, saved using {@link #save(File)} or {@link ModelSerializer} - * @param f File to load the network from - * @param loadUpdater If true: load the updater if it is available (i.e., the state array for momentum/Adam/rmsprop - * etc) - use false if no further training is required, or true if further training - * will be undertaken - * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams) - */ - public static MultiLayerNetwork load(File f, boolean loadUpdater) throws IOException { - return ModelSerializer.restoreMultiLayerNetwork(f, loadUpdater); - } - - /** - * Convert this MultiLayerNetwork to a ComputationGraph - * - * @return ComputationGraph equivalent to this network (including parameters and updater state) - */ - public ComputationGraph toComputationGraph(){ - return NetworkUtils.toComputationGraph(this); - } - - /** - * Return a copy of the network with the parameters and activations set to use the specified (floating point) data type. - * If the existing datatype is the same as the requested dataype, the original network will be returned unchanged. - * Only floating point datatypes (DOUBLE, FLOAT, HALF) may be used. - * - * @param dataType Datatype to convert the network to - * @return The network, set to use the specified datatype for the parameters and activations - */ - public MultiLayerNetwork convertDataType(@NonNull DataType dataType){ - Preconditions.checkState(dataType.isFPType(), "Invalid DataType: %s. Can only convert network to a floating point type", dataType); - if(dataType == params().dataType()){ - return this; - } - - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - INDArray newParams = params().castTo(dataType); - String jsonConfig = getLayerWiseConfigurations().toJson(); - MultiLayerConfiguration newConf = MultiLayerConfiguration.fromJson(jsonConfig); - newConf.setDataType(dataType); - MultiLayerNetwork newNet = new MultiLayerNetwork(newConf); - newNet.init(newParams, false); - - Updater u = getUpdater(false); - if(u != null && u.getStateViewArray() != null){ - INDArray oldUpdaterState = u.getStateViewArray(); - newNet.getUpdater(true).getStateViewArray().assign(oldUpdaterState); - } - return newNet; - } - } - - /** - * Set the learning rate for all layers in the network to the specified value. Note that if any learning rate - * schedules are currently present, these will be removed in favor of the new (fixed) learning rate.
- *
- * Note: This method not free from a performance point of view: a proper learning rate schedule - * should be used in preference to calling this method at every iteration. - * - * @param newLr New learning rate for all layers - * @see #setLearningRate(ISchedule) - * @see #setLearningRate(int, double) - */ - public void setLearningRate(double newLr){ - NetworkUtils.setLearningRate(this, newLr); - } - - /** - * Set the learning rate schedule for all layers in the network to the specified schedule. - * This schedule will replace any/all existing schedules, and also any fixed learning rate values.
- * Note that the iteration/epoch counts will not be reset. Use {@link MultiLayerConfiguration#setIterationCount(int)} - * and {@link MultiLayerConfiguration#setEpochCount(int)} if this is required - * - * @param newLr New learning rate schedule for all layers - * @see #setLearningRate(ISchedule) - * @see #setLearningRate(int, double) - */ - public void setLearningRate(ISchedule newLr){ - NetworkUtils.setLearningRate(this, newLr); - } - - /** - * Set the learning rate for a single layer in the network to the specified value. Note that if any learning rate - * schedules are currently present, these will be removed in favor of the new (fixed) learning rate.
- *
- * Note: This method not free from a performance point of view: a proper learning rate schedule - * should be used in preference to calling this method at every iteration. Note also that - * {@link #setLearningRate(double)} should also be used in preference, when all layers need to be set to a new LR - * - * @param layerNumber Number of the layer to set the LR for - * @param newLr New learning rate for a single layer - * @see #setLearningRate(ISchedule) - * @see #setLearningRate(int, double) - */ - public void setLearningRate(int layerNumber, double newLr){ - NetworkUtils.setLearningRate(this, layerNumber, newLr); - } - - /** - * Set the learning rate schedule for a single layer in the network to the specified value.
- * Note also that {@link #setLearningRate(ISchedule)} should also be used in preference, when all layers need - * to be set to a new LR schedule.
- * This schedule will replace any/all existing schedules, and also any fixed learning rate values.
- * Note also that the iteration/epoch counts will not be reset. Use {@link MultiLayerConfiguration#setIterationCount(int)} - * and {@link MultiLayerConfiguration#setEpochCount(int)} if this is required - * - * @param layerNumber Number of the layer to set the LR schedule for - * @param newLr New learning rate for a single layer - * @see #setLearningRate(ISchedule) - * @see #setLearningRate(int, double) - */ - public void setLearningRate(int layerNumber, ISchedule newLr){ - NetworkUtils.setLearningRate(this, layerNumber, newLr); - } - - /** - * Get the current learning rate, for the specified layer, from the network. - * Note: If the layer has no learning rate (no parameters, or an updater without a learning rate) then null is returned - * @param layerNumber Layer number to get the learning rate for - * @return Learning rate for the specified layer, or null - */ - public Double getLearningRate(int layerNumber){ - return NetworkUtils.getLearningRate(this, layerNumber); - } - - /** - * Return the layer size (number of units) for the specified layer.
- * Note that the meaning of the "layer size" can depend on the type of layer. For example:
- * - DenseLayer, OutputLayer, recurrent layers: number of units (nOut configuration option)
- * - ConvolutionLayer: the channels (number of channels)
- * - Subsampling layers, global pooling layers, etc: size of 0 is always returned
- * - * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive - * @return Size of the layer - */ - public int layerSize(int layer) { - if (layer < 0 || layer > layers.length) { - throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " - + (layers.length - 1) + " inclusive"); - } - org.deeplearning4j.nn.conf.layers.Layer conf = layers[layer].conf().getLayer(); - if (conf == null || !(conf instanceof FeedForwardLayer)) { - return 0; - } - FeedForwardLayer ffl = (FeedForwardLayer) conf; - - if (ffl.getNOut() > Integer.MAX_VALUE) + if (endTimeIdx > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); - return (int) ffl.getNOut(); - } + } + INDArray[] subsets = getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, + fMask, lMask); - /** - * Return the input size (number of inputs) for the specified layer.
- * Note that the meaning of the "input size" can depend on the type of layer. For example:
- * - DenseLayer, OutputLayer, etc: the feature vector size (nIn configuration option)
- * - Recurrent layers: the feature vector size per time step (nIn configuration option)
- * - ConvolutionLayer: the channels (number of channels)
- * - Subsampling layers, global pooling layers, etc: size of 0 is always returned
- * - * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive - * @return Size of the layer - */ - public int layerInputSize(int layer) { - if (layer < 0 || layer > layers.length) { - throw new IllegalArgumentException("Invalid layer index: " + layer + ". Layer index must be between 0 and " - + (layers.length - 1) + " inclusive"); + setLayerMaskArrays(subsets[2], subsets[3]); + + try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) { + INDArray outSub = rnnTimeStep(subsets[0], ws); + try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + for (T evaluation : evaluations) { + evaluation.eval(subsets[1], outSub, subsets[3]); + } + } + } } - org.deeplearning4j.nn.conf.layers.Layer conf = layers[layer].conf().getLayer(); - if (conf == null || !(conf instanceof FeedForwardLayer)) { - return 0; + } + + //Clear inputs, masks etc. Important to avoid leaking invalidated/out of scope arrays between iterations + clearLayersStates(); + } + + if (iterator.asyncSupported()) { + ((AsyncDataSetIterator) iter).shutdown(); + } + + layerWiseConfigurations.setTrainingWorkspaceMode(cMode); + + return evaluations; + } + + /** + * Evaluate the network on the provided data set. Used for evaluating the performance of + * classifiers + * + * @param iterator Data to undertake evaluation on + * @return Evaluation object, summarizing the results of the evaluation on the provided + * DataSetIterator + */ + public Evaluation evaluate(DataSetIterator iterator, List labelsList) { + return evaluate(iterator, labelsList, 1); + } + + @Override + public INDArray updaterState() { + return getUpdater() != null ? getUpdater().getStateViewArray() : null; + } + + @Override + public void fit(MultiDataSet dataSet) { + if (dataSet.getFeatures().length == 1 && dataSet.getLabels().length == 1) { + INDArray features = dataSet.getFeatures(0); + INDArray labels = dataSet.getLabels(0); + INDArray fMask = null; + INDArray lMask = null; + + if (dataSet.getFeaturesMaskArrays() != null) { + fMask = dataSet.getFeaturesMaskArrays()[0]; + } + + if (dataSet.getFeaturesMaskArrays() != null) { + lMask = dataSet.getLabelsMaskArrays()[0]; + } + + DataSet ds = new DataSet(features, labels, fMask, lMask); + fit(ds); + } else { + throw new DL4JInvalidInputException( + "MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." + + "Please consider use of ComputationGraph"); + } + } + + /** + * Perform minibatch training on all minibatches in the MultiDataSetIterator, for the specified + * number of epochs. Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a + * loop + * + * @param iterator Training data (DataSetIterator). Iterator must support resetting + * @param numEpochs Number of training epochs, >= 1 + */ + public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs) { + Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", + numEpochs); + Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), + "Cannot perform multiple epochs training using" + + "iterator has does not support resetting (iterator.resetSupported() returned false)"); + + for (int i = 0; i < numEpochs; i++) { + fit(iterator); + } + } + + /** + * Perform minibatch training on all minibatches in the MultiDataSetIterator.
Note: The + * MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as + * MultiLayerNetwork only supports 1 input and 1 output) + * + * @param iterator Training data (DataSetIterator). Iterator must support resetting + */ + @Override + public void fit(MultiDataSetIterator iterator) { + fit(new MultiDataSetWrapperIterator(iterator)); + } + + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, T[] evaluations) { + return doEvaluation(new MultiDataSetWrapperIterator(iterator), evaluations); + } + + /** + * Evaluate the network (for classification) on the provided data set, with top N accuracy in + * addition to standard accuracy. For 'standard' accuracy evaluation only, use topN = 1 + * + * @param iterator Iterator (data) to evaluate on + * @param labelsList List of labels. May be null. + * @param topN N value for top N accuracy evaluation + * @return Evaluation object, summarizing the results of the evaluation on the provided + * DataSetIterator + */ + public Evaluation evaluate(DataSetIterator iterator, List labelsList, int topN) { + if (layers == null || !(getOutputLayer() instanceof IOutputLayer)) { + throw new IllegalStateException("Cannot evaluate network with no output layer"); + } + if (labelsList == null) { + try { + labelsList = iterator.getLabels(); + } catch (Throwable t) { + } //Ignore, maybe UnsupportedOperationException etc + } + + Layer outputLayer = getOutputLayer(); + if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) { + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), + Evaluation.class); + } + + Evaluation e = new org.deeplearning4j.eval.Evaluation(labelsList, topN); + doEvaluation(iterator, e); + + return e; + } + + protected void update(Task task) { + if (!initDone) { + initDone = true; + Heartbeat heartbeat = Heartbeat.getInstance(); + task = ModelSerializer.taskByModel(this); + Environment env = EnvironmentUtils.buildEnvironment(); + heartbeat.reportEvent(Event.STANDALONE, env, task); + } + } + + /** + * String detailing the architecture of the multilayernetwork. Columns are LayerIndex with layer + * type, nIn, nOut, Total number of parameters and the Shapes of the parameters Will also give + * information about frozen layers, if any. + * + * @return Summary as a string + * @see #memoryInfo(int, InputType) + */ + public String summary() { + return summary(null); + } + + /** + * String detailing the architecture of the multilayernetwork. Will also display activation size + * when given an input type. Columns are LayerIndex with layer type, nIn, nOut, Total number of + * parameters, Shapes of the parameters, Input activation shape, Output activation shape Will also + * give information about frozen layers, if any. + * + * @return Summary as a string + * @see #memoryInfo(int, InputType) + */ + public String summary(InputType inputType) { + StringBuilder ret = new StringBuilder(); + ret.append("\n"); + + List lines = new ArrayList<>(); + if (inputType == null) { + lines.add(new String[]{"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape"}); + } else { + lines.add(new String[]{"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape", + "InputShape", "OutputShape"}); + } + int[] maxLength = new int[inputType == null ? 4 : 6]; + String[] header = lines.get(0); + for (int i = 0; i < header.length; i++) { + maxLength[i] = header[i].length(); + } + + int frozenParams = 0; + for (org.deeplearning4j.nn.api.Layer currentLayer : getLayers()) { + String name = currentLayer.conf().getLayer().getLayerName(); + if (name == null) { + name = String.valueOf(currentLayer.getIndex()); + } + String paramShape = "-"; + String in = "-"; + String out = "-"; + String[] classNameArr = currentLayer.getClass().getName().split("\\."); + String className = classNameArr[classNameArr.length - 1]; + String paramCount = String.format("%,d", currentLayer.numParams()); + String inShape = ""; + String outShape = ""; + InputPreProcessor preProcessor; + InputType outType; + if (inputType != null) { + preProcessor = getLayerWiseConfigurations().getInputPreProcess(currentLayer.getIndex()); + inShape = inputType.toString(); + if (preProcessor != null) { + inputType = preProcessor.getOutputType(inputType); + inShape += "--> " + inputType.toString(); } - FeedForwardLayer ffl = (FeedForwardLayer) conf; - - if (ffl.getNIn() > Integer.MAX_VALUE) - throw new ND4JArraySizeException(); - return (int) ffl.getNIn(); - } - - /** - * Indicates whether some other object is "equal to" this one. - *

- * The {@code equals} method implements an equivalence relation - * on non-null object references: - *

    - *
  • It is reflexive: for any non-null reference value - * {@code x}, {@code x.equals(x)} should return - * {@code true}. - *
  • It is symmetric: for any non-null reference values - * {@code x} and {@code y}, {@code x.equals(y)} - * should return {@code true} if and only if - * {@code y.equals(x)} returns {@code true}. - *
  • It is transitive: for any non-null reference values - * {@code x}, {@code y}, and {@code z}, if - * {@code x.equals(y)} returns {@code true} and - * {@code y.equals(z)} returns {@code true}, then - * {@code x.equals(z)} should return {@code true}. - *
  • It is consistent: for any non-null reference values - * {@code x} and {@code y}, multiple invocations of - * {@code x.equals(y)} consistently return {@code true} - * or consistently return {@code false}, provided no - * information used in {@code equals} comparisons on the - * objects is modified. - *
  • For any non-null reference value {@code x}, - * {@code x.equals(null)} should return {@code false}. - *
- *

- * The {@code equals} method for class {@code Object} implements - * the most discriminating possible equivalence relation on objects; - * that is, for any non-null reference values {@code x} and - * {@code y}, this method returns {@code true} if and only - * if {@code x} and {@code y} refer to the same object - * ({@code x == y} has the value {@code true}). - *

- * Note that it is generally necessary to override the {@code hashCode} - * method whenever this method is overridden, so as to maintain the - * general contract for the {@code hashCode} method, which states - * that equal objects must have equal hash codes. - * - * @param obj the reference object with which to compare. - * @return {@code true} if this object is the same as the obj - * argument; {@code false} otherwise. - * @see #hashCode() - * @see HashMap - */ - @Override - public boolean equals(Object obj) { - if (obj == null) - return false; - if (obj instanceof MultiLayerNetwork) { - MultiLayerNetwork network = (MultiLayerNetwork) obj; - boolean paramsEquals = network.params().equals(params()); - boolean confEquals = getLayerWiseConfigurations().equals(network.getLayerWiseConfigurations()); - boolean updaterEquals = getUpdater().equals(network.getUpdater()); - return paramsEquals && confEquals && updaterEquals; + outType = currentLayer.conf().getLayer().getOutputType(currentLayer.getIndex(), inputType); + outShape = outType.toString(); + inputType = outType; + } + if (currentLayer.numParams() > 0) { + paramShape = ""; + if (currentLayer instanceof BidirectionalLayer) { // Bidirectional layer is not an FFL + BidirectionalLayer bi = (BidirectionalLayer) currentLayer; + in = String.valueOf(((Bidirectional) bi.conf().getLayer()).getNIn()); + out = String.valueOf(((Bidirectional) bi.conf().getLayer()).getNOut()); + } else { + try { + in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn()); + out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut()); + } catch ( + Exception e) { // Some layers, like PReLU, are just BaseLayers (but have parameters) + } } - return false; - } - - private void writeObject(ObjectOutputStream oos) throws IOException { - ModelSerializer.writeModel(this, oos, true); - } - - private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { - val mln = ModelSerializer.restoreMultiLayerNetwork(ois, true); - - this.defaultConfiguration = mln.defaultConfiguration.clone(); - this.layerWiseConfigurations = mln.layerWiseConfigurations.clone(); - this.init(); - this.flattenedParams.assign(mln.flattenedParams); - - int numWorkingMem = 2 * (layerWiseConfigurations.getConfs().size() + layerWiseConfigurations.getInputPreProcessors().size()); - WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); - WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(layerWiseConfigurations.getConfs().size()); - - if (mln.getUpdater() != null && mln.getUpdater(false).getStateViewArray() != null) - this.getUpdater(true).getStateViewArray().assign(mln.getUpdater(false).getStateViewArray()); - } - - /** - * Close the network and deallocate all native memory, including: parameters, gradients, updater memory and workspaces - * Note that the network should not be used again for any purpose after it has been closed - */ - @Override - public void close(){ - //Close the INDArray and dealloc - if(flattenedParams.closeable()) - flattenedParams.close(); - - if(flattenedGradients != null && flattenedGradients.closeable()) - flattenedGradients.close(); - - Updater u = getUpdater(false); - if(u != null && u.getStateViewArray() != null) { - INDArray state = u.getStateViewArray(); - if(state.closeable()) - state.close(); + Set paraNames = currentLayer.paramTable().keySet(); + for (String aP : paraNames) { + String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape()); + paramShape += aP + ":" + paramS + ", "; } + paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString(); + } + if (currentLayer instanceof FrozenLayer) { + frozenParams += currentLayer.numParams(); + classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName() + .split("\\."); + className = "Frozen " + classNameArr[classNameArr.length - 1]; + } - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - System.gc(); + String[] line; + if (inputType == null) { + line = new String[]{name + " (" + className + ")", in + "," + out, paramCount, paramShape}; + } else { + line = new String[]{name + " (" + className + ")", in + "," + out, paramCount, paramShape, + inShape, outShape}; + } + for (int i = 0; i < line.length; i++) { + maxLength[i] = Math.max(maxLength[i], line[i] == null ? 0 : line[i].length()); + } + lines.add(line); } + + StringBuilder sbFormat = new StringBuilder(); + int totalLength = 0; + int pos = 0; + for (int length : maxLength) { + int currLength; + if (pos++ == maxLength.length - 1) { + currLength = length; + } else { + currLength = length + 3; + } + sbFormat.append("%-").append(currLength).append("s"); + totalLength += currLength; + } + sbFormat.append("\n"); + String format = sbFormat.toString(); + + ret.append(StringUtils.repeat("=", totalLength)) + .append("\n"); + + boolean first = true; + for (String[] line : lines) { + String formatted = String.format(format, (Object[]) line); + ret.append(formatted); + if (first) { + ret.append(StringUtils.repeat("=", totalLength)).append("\n"); + first = false; + } + } + + ret.append(StringUtils.repeat("-", totalLength)); + ret.append(String.format("\n%30s %,d", "Total Parameters: ", params().length())); + ret.append( + String.format("\n%30s %,d", "Trainable Parameters: ", params().length() - frozenParams)); + ret.append(String.format("\n%30s %,d", "Frozen Parameters: ", frozenParams)); + ret.append("\n"); + ret.append(StringUtils.repeat("=", totalLength)); + ret.append("\n"); + return ret.toString(); + } + + /** + * Generate information regarding memory use for the network, for the given input type and + * minibatch size. Note that when using workspaces or CuDNN, the network should be trained for + * some iterations so that the memory workspaces have time to initialize. Without this, the memory + * requirements during training may be underestimated. + *

+ * Note also that this is the same information that is generated during an OOM crash when training + * or performing inference. + * + * @param minibatch Minibatch size to estimate memory for + * @param inputType Input type to the network + * @return A String with information about network memory use information + */ + public String memoryInfo(int minibatch, InputType inputType) { + return CrashReportingUtil.generateMemoryStatus(this, minibatch, inputType); + } + + /** + * This method just makes sure there's no state preserved within layers + */ + public void clearLayersStates() { + for (Layer layer : layers) { + layer.clear(); + layer.clearNoiseWeightParams(); + } + } + + /** + * Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1). Note that + * this is done automatically when using iterator-based fitting methods, such as + * {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, + * INDArray/INDArray etc), the network has no way to know when one epoch ends and another starts. + * In such situations, this method can be used to increment the epoch counter.
Note that the + * epoch counter is used for situations such as some learning rate schedules, and the like. + *

+ * The current epoch count can be obtained using + * {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()} + */ + public void incrementEpochCount() { + layerWiseConfigurations.setEpochCount(layerWiseConfigurations.getEpochCount() + 1); + synchronizeIterEpochCounts(); + } + + protected void synchronizeIterEpochCounts() { + //TODO: this is necessary for some schedules - but the redundant values are a little ugly... + int currIter = getIterationCount(); + int currEpoch = getEpochCount(); + for (Layer l : layers) { + l.setIterationCount(currIter); + l.setEpochCount(currEpoch); + } + } + + /** + * Save the MultiLayerNetwork to a file. Restore using {@link #load(File, boolean)}. Note that + * this saves the updater (i.e., the state array for momentum/Adam/rmsprop etc), which is + * desirable if further training will be undertaken. + * + * @param f File to save the network to + * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams) + * @see #save(File, boolean) + */ + public void save(File f) throws IOException { + save(f, true); + } + + /** + * Save the MultiLayerNetwork to a file. Restore using {@link #load(File, boolean)}. + * + * @param f File to save the network to + * @param saveUpdater If true: save the updater (i.e., the state array for momentum/Adam/rmsprop + * etc), which should usually be saved if further training is required + * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams) + * @see #save(File, boolean) + */ + public void save(File f, boolean saveUpdater) throws IOException { + ModelSerializer.writeModel(this, f, saveUpdater); + } + + /** + * Convert this MultiLayerNetwork to a ComputationGraph + * + * @return ComputationGraph equivalent to this network (including parameters and updater state) + */ + public ComputationGraph toComputationGraph() { + return NetworkUtils.toComputationGraph(this); + } + + /** + * Return a copy of the network with the parameters and activations set to use the specified + * (floating point) data type. If the existing datatype is the same as the requested dataype, the + * original network will be returned unchanged. Only floating point datatypes (DOUBLE, FLOAT, + * HALF) may be used. + * + * @param dataType Datatype to convert the network to + * @return The network, set to use the specified datatype for the parameters and activations + */ + public MultiLayerNetwork convertDataType(@NonNull DataType dataType) { + Preconditions.checkState(dataType.isFPType(), + "Invalid DataType: %s. Can only convert network to a floating point type", dataType); + if (dataType == params().dataType()) { + return this; + } + + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + INDArray newParams = params().castTo(dataType); + String jsonConfig = getLayerWiseConfigurations().toJson(); + MultiLayerConfiguration newConf = MultiLayerConfiguration.fromJson(jsonConfig); + newConf.setDataType(dataType); + MultiLayerNetwork newNet = new MultiLayerNetwork(newConf); + newNet.init(newParams, false); + + Updater u = getUpdater(false); + if (u != null && u.getStateViewArray() != null) { + INDArray oldUpdaterState = u.getStateViewArray(); + newNet.getUpdater(true).getStateViewArray().assign(oldUpdaterState); + } + return newNet; + } + } + + /** + * Set the learning rate for all layers in the network to the specified value. Note that if any + * learning rate schedules are currently present, these will be removed in favor of the new + * (fixed) learning rate.
+ *
+ * Note: This method not free from a performance point of view: a proper learning + * rate schedule + * should be used in preference to calling this method at every iteration. + * + * @param newLr New learning rate for all layers + * @see #setLearningRate(ISchedule) + * @see #setLearningRate(int, double) + */ + public void setLearningRate(double newLr) { + NetworkUtils.setLearningRate(this, newLr); + } + + /** + * Set the learning rate schedule for all layers in the network to the specified schedule. This + * schedule will replace any/all existing schedules, and also any fixed learning rate values.
+ * Note that the iteration/epoch counts will not be reset. Use + * {@link MultiLayerConfiguration#setIterationCount(int)} and + * {@link MultiLayerConfiguration#setEpochCount(int)} if this is required + * + * @param newLr New learning rate schedule for all layers + * @see #setLearningRate(ISchedule) + * @see #setLearningRate(int, double) + */ + public void setLearningRate(ISchedule newLr) { + NetworkUtils.setLearningRate(this, newLr); + } + + /** + * Set the learning rate for a single layer in the network to the specified value. Note that if + * any learning rate schedules are currently present, these will be removed in favor of the new + * (fixed) learning rate.
+ *
+ * Note: This method not free from a performance point of view: a proper learning + * rate schedule + * should be used in preference to calling this method at every iteration. Note also that + * {@link #setLearningRate(double)} should also be used in preference, when all layers need to be + * set to a new LR + * + * @param layerNumber Number of the layer to set the LR for + * @param newLr New learning rate for a single layer + * @see #setLearningRate(ISchedule) + * @see #setLearningRate(int, double) + */ + public void setLearningRate(int layerNumber, double newLr) { + NetworkUtils.setLearningRate(this, layerNumber, newLr); + } + + /** + * Set the learning rate schedule for a single layer in the network to the specified value.
+ * Note also that {@link #setLearningRate(ISchedule)} should also be used in preference, when all + * layers need to be set to a new LR schedule.
This schedule will replace any/all existing + * schedules, and also any fixed learning rate values.
Note also that the iteration/epoch + * counts will not be reset. Use {@link MultiLayerConfiguration#setIterationCount(int)} and + * {@link MultiLayerConfiguration#setEpochCount(int)} if this is required + * + * @param layerNumber Number of the layer to set the LR schedule for + * @param newLr New learning rate for a single layer + * @see #setLearningRate(ISchedule) + * @see #setLearningRate(int, double) + */ + public void setLearningRate(int layerNumber, ISchedule newLr) { + NetworkUtils.setLearningRate(this, layerNumber, newLr); + } + + /** + * Get the current learning rate, for the specified layer, from the network. Note: If the layer + * has no learning rate (no parameters, or an updater without a learning rate) then null is + * returned + * + * @param layerNumber Layer number to get the learning rate for + * @return Learning rate for the specified layer, or null + */ + public Double getLearningRate(int layerNumber) { + return NetworkUtils.getLearningRate(this, layerNumber); + } + + /** + * Return the layer size (number of units) for the specified layer.
Note that the meaning of + * the "layer size" can depend on the type of layer. For example:
- DenseLayer, OutputLayer, + * recurrent layers: number of units (nOut configuration option)
- ConvolutionLayer: the + * channels (number of channels)
- Subsampling layers, global pooling layers, etc: size of 0 + * is always returned
+ * + * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive + * @return Size of the layer + */ + public int layerSize(int layer) { + if (layer < 0 || layer > layers.length) { + throw new IllegalArgumentException( + "Invalid layer index: " + layer + ". Layer index must be between 0 and " + + (layers.length - 1) + " inclusive"); + } + org.deeplearning4j.nn.conf.layers.Layer conf = layers[layer].conf().getLayer(); + if (conf == null || !(conf instanceof FeedForwardLayer)) { + return 0; + } + FeedForwardLayer ffl = (FeedForwardLayer) conf; + + if (ffl.getNOut() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + return (int) ffl.getNOut(); + } + + /** + * Return the input size (number of inputs) for the specified layer.
Note that the meaning of + * the "input size" can depend on the type of layer. For example:
- DenseLayer, OutputLayer, + * etc: the feature vector size (nIn configuration option)
- Recurrent layers: the feature + * vector size per time step (nIn configuration option)
- ConvolutionLayer: the + * channels (number of channels)
- Subsampling layers, global pooling layers, etc: size of 0 + * is always returned
+ * + * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive + * @return Size of the layer + */ + public int layerInputSize(int layer) { + if (layer < 0 || layer > layers.length) { + throw new IllegalArgumentException( + "Invalid layer index: " + layer + ". Layer index must be between 0 and " + + (layers.length - 1) + " inclusive"); + } + org.deeplearning4j.nn.conf.layers.Layer conf = layers[layer].conf().getLayer(); + if (conf == null || !(conf instanceof FeedForwardLayer)) { + return 0; + } + FeedForwardLayer ffl = (FeedForwardLayer) conf; + + if (ffl.getNIn() > Integer.MAX_VALUE) { + throw new ND4JArraySizeException(); + } + return (int) ffl.getNIn(); + } + + /** + * Indicates whether some other object is "equal to" this one. + *

+ * The {@code equals} method implements an equivalence relation on non-null object references: + *

    + *
  • It is reflexive: for any non-null reference value + * {@code x}, {@code x.equals(x)} should return + * {@code true}. + *
  • It is symmetric: for any non-null reference values + * {@code x} and {@code y}, {@code x.equals(y)} + * should return {@code true} if and only if + * {@code y.equals(x)} returns {@code true}. + *
  • It is transitive: for any non-null reference values + * {@code x}, {@code y}, and {@code z}, if + * {@code x.equals(y)} returns {@code true} and + * {@code y.equals(z)} returns {@code true}, then + * {@code x.equals(z)} should return {@code true}. + *
  • It is consistent: for any non-null reference values + * {@code x} and {@code y}, multiple invocations of + * {@code x.equals(y)} consistently return {@code true} + * or consistently return {@code false}, provided no + * information used in {@code equals} comparisons on the + * objects is modified. + *
  • For any non-null reference value {@code x}, + * {@code x.equals(null)} should return {@code false}. + *
+ *

+ * The {@code equals} method for class {@code Object} implements + * the most discriminating possible equivalence relation on objects; + * that is, for any non-null reference values {@code x} and + * {@code y}, this method returns {@code true} if and only + * if {@code x} and {@code y} refer to the same object + * ({@code x == y} has the value {@code true}). + *

+ * Note that it is generally necessary to override the {@code hashCode} + * method whenever this method is overridden, so as to maintain the + * general contract for the {@code hashCode} method, which states + * that equal objects must have equal hash codes. + * + * @param obj the reference object with which to compare. + * @return {@code true} if this object is the same as the obj argument; {@code false} otherwise. + * @see #hashCode() + * @see HashMap + */ + @Override + public boolean equals(Object obj) { + if (obj == null) { + return false; + } + if (obj instanceof MultiLayerNetwork) { + MultiLayerNetwork network = (MultiLayerNetwork) obj; + boolean paramsEquals = network.params().equals(params()); + boolean confEquals = getLayerWiseConfigurations().equals( + network.getLayerWiseConfigurations()); + boolean updaterEquals = getUpdater().equals(network.getUpdater()); + return paramsEquals && confEquals && updaterEquals; + } + return false; + } + + private void writeObject(ObjectOutputStream oos) throws IOException { + ModelSerializer.writeModel(this, oos, true); + } + + private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { + val mln = ModelSerializer.restoreMultiLayerNetwork(ois, true); + + this.defaultConfiguration = mln.defaultConfiguration.clone(); + this.layerWiseConfigurations = mln.layerWiseConfigurations.clone(); + this.init(); + this.flattenedParams.assign(mln.flattenedParams); + + int numWorkingMem = 2 * (layerWiseConfigurations.getConfs().size() + + layerWiseConfigurations.getInputPreProcessors().size()); + WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); + WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(layerWiseConfigurations.getConfs().size()); + + if (mln.getUpdater() != null && mln.getUpdater(false).getStateViewArray() != null) { + this.getUpdater(true).getStateViewArray().assign(mln.getUpdater(false).getStateViewArray()); + } + } + + /** + * Close the network and deallocate all native memory, including: parameters, gradients, updater + * memory and workspaces Note that the network should not be used again for any purpose after it + * has been closed + */ + @Override + public void close() { + //Close the INDArray and dealloc + if (flattenedParams.closeable()) { + flattenedParams.close(); + } + + if (flattenedGradients != null && flattenedGradients.closeable()) { + flattenedGradients.close(); + } + + Updater u = getUpdater(false); + if (u != null && u.getStateViewArray() != null) { + INDArray state = u.getStateViewArray(); + if (state.closeable()) { + state.close(); + } + } + + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + System.gc(); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java index 52ae7c891..b941cf636 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java @@ -572,7 +572,7 @@ public class TransferLearning { */ public GraphBuilder(ComputationGraph origGraph) { this.origGraph = origGraph; - this.origConfig = origGraph.getConfiguration().clone(); + this.origConfig = origGraph.getComputationGraphConfiguration().clone(); } /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java index f6f3a35c1..a6f7d6c4f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java @@ -242,7 +242,7 @@ public class TransferLearningHelper { } Set frozenInputVerticesSorted = new HashSet<>(); - frozenInputVerticesSorted.addAll(origGraph.getConfiguration().getNetworkInputs()); + frozenInputVerticesSorted.addAll(origGraph.getComputationGraphConfiguration().getNetworkInputs()); frozenInputVerticesSorted.removeAll(allFrozen); //remove input vertices - just to add back in a predictable order for (String existingInput : frozenInputVerticesSorted) { @@ -328,7 +328,7 @@ public class TransferLearningHelper { String anInput = graphInputs.get(i); if (origGraph.getVertex(anInput).isInputVertex()) { //was an original input to the graph - int inputIndex = origGraph.getConfiguration().getNetworkInputs().indexOf(anInput); + int inputIndex = origGraph.getComputationGraphConfiguration().getNetworkInputs().indexOf(anInput); featuresNow[i] = origGraph.getInput(inputIndex); } else { //needs to be grabbed from the internal activations diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java index 4f4d1690f..91d24de46 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java @@ -385,14 +385,14 @@ public abstract class BaseMultiLayerUpdater implements Updater /** * Pre-apply: Apply gradient normalization/clipping * - * @param layer Layer to apply gradient normalization/clipping for + * @param layer ILayer to apply gradient normalization/clipping for * @param gradient Gradient to update * @param iteration The current iteration (i.e., number of parameter updates so far) */ public void preApply(Trainable layer, Gradient gradient, int iteration) { if (layer.getConfig() == null || layer.numParams() == 0) { - //Layer does not have parameters -> no gradient + //ILayer does not have parameters -> no gradient return; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/TrainingListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/TrainingListener.java index 7c96fd750..81a2d8465 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/TrainingListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/TrainingListener.java @@ -54,7 +54,7 @@ public interface TrainingListener { * only at training time * * @param model Model - * @param activations Layer activations (including input) + * @param activations ILayer activations (including input) */ void onForwardPass(Model model, List activations); @@ -63,7 +63,7 @@ public interface TrainingListener { * only at training time * * @param model Model - * @param activations Layer activations (including input) + * @param activations ILayer activations (including input) */ void onForwardPass(Model model, Map activations); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java index 4ebf2e050..550e4425b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java @@ -247,7 +247,7 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ if (model instanceof MultiLayerNetwork) { return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getIterationCount(); } else if (model instanceof ComputationGraph) { - return ((ComputationGraph) model).getConfiguration().getIterationCount(); + return ((ComputationGraph) model).getComputationGraphConfiguration().getIterationCount(); } else { return model.conf().getIterationCount(); } @@ -257,7 +257,7 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ if (model instanceof MultiLayerNetwork) { return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount(); } else if (model instanceof ComputationGraph) { - return ((ComputationGraph) model).getConfiguration().getEpochCount(); + return ((ComputationGraph) model).getComputationGraphConfiguration().getEpochCount(); } else { return model.conf().getEpochCount(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BaseOptimizer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BaseOptimizer.java index 3a8bfee10..42ce490e5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BaseOptimizer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BaseOptimizer.java @@ -336,7 +336,7 @@ public abstract class BaseOptimizer implements ConvexOptimizer { if (model instanceof MultiLayerNetwork) { return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getIterationCount(); } else if (model instanceof ComputationGraph) { - return ((ComputationGraph) model).getConfiguration().getIterationCount(); + return ((ComputationGraph) model).getComputationGraphConfiguration().getIterationCount(); } else { return model.conf().getIterationCount(); } @@ -347,7 +347,7 @@ public abstract class BaseOptimizer implements ConvexOptimizer { MultiLayerConfiguration conf = ((MultiLayerNetwork) model).getLayerWiseConfigurations(); conf.setIterationCount(conf.getIterationCount() + incrementBy); } else if (model instanceof ComputationGraph) { - ComputationGraphConfiguration conf = ((ComputationGraph) model).getConfiguration(); + ComputationGraphConfiguration conf = ((ComputationGraph) model).getComputationGraphConfiguration(); conf.setIterationCount(conf.getIterationCount() + incrementBy); } else { model.conf().setIterationCount(model.conf().getIterationCount() + incrementBy); @@ -358,7 +358,7 @@ public abstract class BaseOptimizer implements ConvexOptimizer { if (model instanceof MultiLayerNetwork) { return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount(); } else if (model instanceof ComputationGraph) { - return ((ComputationGraph) model).getConfiguration().getEpochCount(); + return ((ComputationGraph) model).getComputationGraphConfiguration().getEpochCount(); } else { return model.conf().getEpochCount(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java index 32c40bdfc..53bed93a2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java @@ -79,7 +79,7 @@ public class Convolution1DUtils { * @return the format for the layer */ public static RNNFormat getRnnFormatFromLayer(Layer layer) { - Preconditions.checkState(hasRnnDataFormat(layer),"Layer of type " + layer.getClass().getName() + " and name " + layer.getLayerName() + " does not have an RNNFormat"); + Preconditions.checkState(hasRnnDataFormat(layer),"ILayer of type " + layer.getClass().getName() + " and name " + layer.getLayerName() + " does not have an RNNFormat"); if(layer instanceof SimpleRnn) { SimpleRnn simpleRnn = (SimpleRnn) layer; return simpleRnn.getRnnDataFormat(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java index ac28ced80..5227ad77f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java @@ -320,12 +320,12 @@ public class CrashReportingUtil { appendHelperInformation(sb, mln.getLayers()); appendActivationShapes(mln, (inputTypes == null || inputTypes.length == 0 ? null : inputTypes[0]), minibatch, sb, bytesPerElement); } else { - sb.append(f("Backprop Type", cg.getConfiguration().getBackpropType())); - if(cg.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT){ - sb.append(f("TBPTT Length", cg.getConfiguration().getTbpttFwdLength() + "/" + cg.getConfiguration().getTbpttBackLength())); + sb.append(f("Backprop Type", cg.getComputationGraphConfiguration().getBackpropType())); + if(cg.getComputationGraphConfiguration().getBackpropType() == BackpropType.TruncatedBPTT){ + sb.append(f("TBPTT Length", cg.getComputationGraphConfiguration().getTbpttFwdLength() + "/" + cg.getComputationGraphConfiguration().getTbpttBackLength())); } - sb.append(f("Workspace Mode: Training", cg.getConfiguration().getTrainingWorkspaceMode())); - sb.append(f("Workspace Mode: Inference", cg.getConfiguration().getInferenceWorkspaceMode())); + sb.append(f("Workspace Mode: Training", cg.getComputationGraphConfiguration().getTrainingWorkspaceMode())); + sb.append(f("Workspace Mode: Inference", cg.getComputationGraphConfiguration().getInferenceWorkspaceMode())); appendLayerInformation(sb, cg.getLayers(), bytesPerElement); appendHelperInformation(sb, cg.getLayers()); appendActivationShapes(cg, sb, bytesPerElement); @@ -461,13 +461,13 @@ public class CrashReportingUtil { List l = new ArrayList<>(layerClasses.keySet()); Collections.sort(l); sb.append(f("Number of Layers", layers.length)); - sb.append("Layer Counts\n"); + sb.append("ILayer Counts\n"); for(String s : l){ sb.append(" ").append(f(s, layerClasses.get(s))); } - sb.append("Layer Parameter Breakdown\n"); + sb.append("ILayer Parameter Breakdown\n"); String format = " %-3s %-20s %-20s %-20s %-20s"; - sb.append(String.format(format, "Idx", "Name", "Layer Type", "Layer # Parameters", "Layer Parameter Memory")).append("\n"); + sb.append(String.format(format, "Idx", "Name", "ILayer Type", "ILayer # Parameters", "ILayer Parameter Memory")).append("\n"); for(Layer layer : layers){ long numParams = layer.numParams(); sb.append(String.format(format, layer.getIndex(), layer.conf().getLayer().getLayerName(), @@ -477,13 +477,13 @@ public class CrashReportingUtil { } private static void appendHelperInformation(StringBuilder sb, org.deeplearning4j.nn.api.Layer[] layers){ - sb.append("\n----- Layer Helpers - Memory Use -----\n"); + sb.append("\n----- ILayer Helpers - Memory Use -----\n"); int helperCount = 0; long helperWithMemCount = 0L; long totalHelperMem = 0L; - //Layer index, layer name, layer class, helper class, total memory, breakdown + //ILayer index, layer name, layer class, helper class, total memory, breakdown String format = "%-3s %-20s %-25s %-30s %-12s %s"; boolean header = false; for(Layer l : layers){ @@ -509,7 +509,7 @@ public class CrashReportingUtil { if(!header){ - sb.append(String.format(format, "#", "Layer Name", "Layer Class", "Helper Class", "Total Memory", "Memory Breakdown")) + sb.append(String.format(format, "#", "ILayer Name", "ILayer Class", "Helper Class", "Total Memory", "Memory Breakdown")) .append("\n"); header = true; } @@ -551,7 +551,7 @@ public class CrashReportingUtil { sb.append(f("Input Shape", Arrays.toString(inputShape))); List inputTypes = net.getLayerWiseConfigurations().getLayerActivationTypes(inputType); String format = "%-3s %-20s %-20s %-42s %-20s %-12s %-12s"; - sb.append(String.format(format, "Idx", "Name", "Layer Type", "Activations Type", "Activations Shape", + sb.append(String.format(format, "Idx", "Name", "ILayer Type", "Activations Type", "Activations Shape", "# Elements", "Memory")).append("\n"); org.deeplearning4j.nn.api.Layer[] layers = net.getLayers(); long totalActivationBytes = 0; @@ -598,11 +598,11 @@ public class CrashReportingUtil { for( int i=0; i inputTypes = net.getConfiguration().getLayerActivationTypes(inputType); + Map inputTypes = net.getComputationGraphConfiguration().getLayerActivationTypes(inputType); GraphIndices indices = net.calculateIndices(); String format = "%-3s %-20s %-20s %-42s %-20s %-12s %-12s"; - sb.append(String.format(format, "Idx", "Name", "Layer Type", "Activations Type", "Activations Shape", + sb.append(String.format(format, "Idx", "Name", "ILayer Type", "Activations Type", "Activations Shape", "# Elements", "Memory")).append("\n"); org.deeplearning4j.nn.api.Layer[] layers = net.getLayers(); long totalActivationBytes = 0; @@ -633,7 +633,7 @@ public class CrashReportingUtil { sb.append(String.format(format, i, layerName, className, it, Arrays.toString(shape), (numElements < 0 ? "" : String.valueOf(numElements)), fBytes(bytes))).append("\n"); - if(!net.getConfiguration().getNetworkOutputs().contains(layerName)){ + if(!net.getComputationGraphConfiguration().getNetworkOutputs().contains(layerName)){ totalExOutput += bytes; } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java index ae7e2e2df..e636334fd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java @@ -141,7 +141,7 @@ public class ModelSerializer { if (model instanceof MultiLayerNetwork) { json = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson(); } else if (model instanceof ComputationGraph) { - json = ((ComputationGraph) model).getConfiguration().toJson(); + json = ((ComputationGraph) model).getComputationGraphConfiguration().toJson(); } ZipEntry config = new ZipEntry(CONFIGURATION_JSON); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java index 7ed0a4bcb..4348be74a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java @@ -199,7 +199,7 @@ public class NetworkUtils { * Note: If the layer has no learning rate (no parameters, or an updater without a learning rate) then null is returned * * @param net Network - * @param layerNumber Layer number to get the learning rate for + * @param layerNumber ILayer number to get the learning rate for * @return Learning rate for the specified layer, or null */ public static Double getLearningRate(MultiLayerNetwork net, int layerNumber) { @@ -321,13 +321,13 @@ public class NetworkUtils { * Note: If the layer has no learning rate (no parameters, or an updater without a learning rate) then null is returned * * @param net Network - * @param layerName Layer name to get the learning rate for + * @param layerName ILayer name to get the learning rate for * @return Learning rate for the specified layer, or null */ public static Double getLearningRate(ComputationGraph net, String layerName) { Layer l = net.getLayer(layerName).conf().getLayer(); - int iter = net.getConfiguration().getIterationCount(); - int epoch = net.getConfiguration().getEpochCount(); + int iter = net.getComputationGraphConfiguration().getIterationCount(); + int epoch = net.getComputationGraphConfiguration().getEpochCount(); if (l instanceof BaseLayer) { BaseLayer bl = (BaseLayer) l; IUpdater u = bl.getIUpdater(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java index 08a3d086a..fb3d9ea64 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java @@ -68,7 +68,7 @@ public class OutputLayerUtil { * * If the specified layer is not an output layer, this is a no-op * @param layerName Name of the layer - * @param layer Layer + * @param layer ILayer */ public static void validateOutputLayer(String layerName, Layer layer){ IActivation activation; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java index df4583cd8..eb5814b49 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java @@ -440,7 +440,7 @@ public class TimeSeriesUtils { /** * Get the {@link RNNFormat} from the RNN layer, accounting for the presence of wrapper layers like Bidirectional, * LastTimeStep, etc - * @param layer Layer to get the RNNFormat from + * @param layer ILayer to get the RNNFormat from */ public static RNNFormat getFormatFromRnnLayer(Layer layer){ if(layer instanceof BaseRecurrentLayer){ diff --git a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java new file mode 100644 index 000000000..06c322a57 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java @@ -0,0 +1,127 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +import static net.brutex.ai.dnn.api.dnn.*; +import static org.junit.jupiter.api.Assertions.*; + +import java.util.Iterator; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.commons.lang3.RandomUtils; +import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator; +import org.deeplearning4j.nn.conf.CacheMode; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.Updater; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.weights.WeightInitXavier; +import org.junit.jupiter.api.Test; +import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.activations.impl.ActivationIdentity; +import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.learning.config.Adam; + + +class dnnTest { + + @Test + void testFFLayer() { + int numFeatures = 128; + int batchSize = 10; + int numRows = 1000; + AtomicInteger cnt = new AtomicInteger(0); + FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize); + + assertTrue(iterator.hasNext()); + + /** + * MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder() + * .seed(42) + * .updater(UPDATER) + * .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + * .gradientNormalizationThreshold(GRADIENT_THRESHOLD) + * .weightInit(WeightInit.XAVIER) + * .activation(Activation.IDENTITY) + * .list(genLayers()) + * .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) + * // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS)) + * .build(); + */ + + /** + * new DenseLayer.Builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), + * new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + * new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), + * new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + * new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), + * new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), + * new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) + */ + dnn.conf() + .seed(42) + .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) + .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold( 100 ) + .weightInit( new WeightInitXavier() ) + .activation( new ActivationIdentity() ) + .inputType( InputType.convolutional( 28, 28, 1)) + .layer( dnn.DenseLayer(10,30).build() ) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build() ) + + ; + + + } + + protected static Iterable> floatIterable(final int totalRows, final int numColumns) { + return new Iterable>() { + @Override + public Iterator> iterator() { + return new Iterator>() { + private final AtomicInteger cnt = new AtomicInteger(0); + + @Override + public boolean hasNext() { + return cnt.incrementAndGet() <= totalRows; + } + + @Override + public Pair next() { + float[] features = new float[numColumns]; + float[] labels = new float[numColumns]; + for (int i = 0; i < numColumns; i++) { + features[i] = (float) i; + labels[i] = RandomUtils.nextFloat(0, 5); + } + return Pair.makePair(features, labels); + } + + @Override + public void remove() { + // no-op + } + }; + } + }; + } + +} \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/conf/layer/FFLayerTest.java b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/conf/layer/FFLayerTest.java new file mode 100644 index 000000000..2fa944000 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/conf/layer/FFLayerTest.java @@ -0,0 +1,47 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.conf.layer; + +import net.brutex.ai.dnn.api.IModel; +import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; +import net.brutex.ai.dnn.api.ILayerConfiguration; +import org.junit.jupiter.api.Test; + +class FFLayerTest { + + @Test + void instantiate() { + ILayerConfiguration ff_conf = FeedForwardLayerConfiguration.builder().build(); + INeuralNetworkConfiguration net_conf = net.brutex.ai.dnn.conf.NeuralNetworkConfiguration.builder() + .layerConfiguration(ff_conf) + .build(); + IModel network = net.brutex.ai.dnn.impl.network.NeuralNetwork.builder().name("Test Network") + .configuration(net_conf) + .build(); + ff_conf.instantiate(network); + + } + + @Test + void getOutputType() { + } +} \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/test/java/org/deeplearning4j/nn/layers/HelperUtilsTest.java b/cavis-dnn/cavis-dnn-nn/src/test/java/org/deeplearning4j/nn/layers/HelperUtilsTest.java index bd05f187f..a3d21fb0c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/test/java/org/deeplearning4j/nn/layers/HelperUtilsTest.java +++ b/cavis-dnn/cavis-dnn-nn/src/test/java/org/deeplearning4j/nn/layers/HelperUtilsTest.java @@ -34,7 +34,7 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; /** */ -@DisplayName("Activation Layer Test") +@DisplayName("Activation ILayer Test") public class HelperUtilsTest extends BaseDL4JTest { @Override diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java index 20dcd51d9..9f32446ae 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java @@ -29,7 +29,6 @@ import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.parallelism.inference.InferenceMode; import org.deeplearning4j.parallelism.inference.LoadBalanceMode; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -195,7 +194,7 @@ public class InplaceParallelInference extends ParallelInference { for (int e = 0; e < workers; e++) { if (sourceModel instanceof ComputationGraph) { // building configuration with shared parameters - val model = new ComputationGraph(ComputationGraphConfiguration.fromJson(((ComputationGraph) sourceModel).getConfiguration().toJson())); + val model = new ComputationGraph(ComputationGraphConfiguration.fromJson(((ComputationGraph) sourceModel).getComputationGraphConfiguration().toJson())); model.init(params, false); Nd4j.getExecutioner().commit(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java index 52a28606e..8547e7b9f 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java @@ -458,7 +458,7 @@ public class ParallelInference { if (protoModel instanceof ComputationGraph) { if (!rootDevice) { this.replicatedModel = new ComputationGraph(ComputationGraphConfiguration - .fromJson(((ComputationGraph) protoModel).getConfiguration().toJson())); + .fromJson(((ComputationGraph) protoModel).getComputationGraphConfiguration().toJson())); this.replicatedModel.init(); synchronized (locker) { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java index a1909795a..be706234f 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java @@ -329,7 +329,7 @@ public class DefaultTrainer extends Thread implements Trainer { } else if (originalModel instanceof ComputationGraph) { if (!onRootModel) { ComputationGraphConfiguration conf = ComputationGraphConfiguration - .fromJson(((ComputationGraph) originalModel).getConfiguration().toJson()); + .fromJson(((ComputationGraph) originalModel).getComputationGraphConfiguration().toJson()); conf.setTrainingWorkspaceMode(workspaceMode); this.replicatedModel = new ComputationGraph(conf); @@ -354,7 +354,7 @@ public class DefaultTrainer extends Thread implements Trainer { } else { this.replicatedModel = originalModel; this.replicatedModel.init(); - ((ComputationGraph) replicatedModel).getConfiguration().setTrainingWorkspaceMode(workspaceMode); + ((ComputationGraph) replicatedModel).getComputationGraphConfiguration().setTrainingWorkspaceMode(workspaceMode); } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java index 67b120ddf..e460ddc2f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java @@ -102,7 +102,7 @@ public class SparkComputationGraph extends SparkListenable { TrainingMaster trainingMaster) { sc = javaSparkContext; this.trainingMaster = trainingMaster; - this.conf = network.getConfiguration().clone(); + this.conf = network.getComputationGraphConfiguration().clone(); this.network = network; this.network.init(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java index 3fa3312d7..b7da3d143 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java @@ -56,7 +56,7 @@ public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWith if (!(l instanceof VariationalAutoencoder)) { throw new RuntimeException( "Cannot use CGVaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE " - + "layer as layer 0. Layer type: " + l.getClass()); + + "layer as layer 0. ILayer type: " + l.getClass()); } return (VariationalAutoencoder) l; } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java index a71912367..43defe37f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java @@ -58,7 +58,7 @@ public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstruc if (!(l instanceof VariationalAutoencoder)) { throw new RuntimeException( "Cannot use CGVaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " - + "layer as layer 0. Layer type: " + l.getClass()); + + "layer as layer 0. ILayer type: " + l.getClass()); } return (VariationalAutoencoder) l; } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java index e1c2f760d..a0bcca02b 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java @@ -59,7 +59,7 @@ public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKe if (!(l instanceof VariationalAutoencoder)) { throw new RuntimeException( "Cannot use VaeReconstructionErrorWithKeyFunction on network that doesn't have a VAE " - + "layer as layer 0. Layer type: " + l.getClass()); + + "layer as layer 0. ILayer type: " + l.getClass()); } return (VariationalAutoencoder) l; } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java index 12fbbbeb6..d65084dc5 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java @@ -59,7 +59,7 @@ public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructi if (!(l instanceof VariationalAutoencoder)) { throw new RuntimeException( "Cannot use VaeReconstructionProbWithKeyFunction on network that doesn't have a VAE " - + "layer as layer 0. Layer type: " + l.getClass()); + + "layer as layer 0. ILayer type: " + l.getClass()); } return (VariationalAutoencoder) l; } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java index 3a2170bc3..4a0252b28 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java @@ -292,7 +292,7 @@ public class ParameterAveragingTrainingMaster @Override public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph graph) { - NetBroadcastTuple tuple = new NetBroadcastTuple(graph.getNetwork().getConfiguration(), + NetBroadcastTuple tuple = new NetBroadcastTuple(graph.getNetwork().getComputationGraphConfiguration(), graph.getNetwork().params(), graph.getNetwork().getUpdater().getStateViewArray()); if (collectTrainingStats) @@ -731,7 +731,7 @@ public class ParameterAveragingTrainingMaster int numUpdates = averagingFrequency; conf.setIterationCount(conf.getIterationCount() + numUpdates); } else { - ComputationGraphConfiguration conf = graph.getNetwork().getConfiguration(); + ComputationGraphConfiguration conf = graph.getNetwork().getComputationGraphConfiguration(); int numUpdates = averagingFrequency; conf.setIterationCount(conf.getIterationCount() + numUpdates); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java index 887696af3..c899fae04 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java @@ -118,7 +118,7 @@ public class TestFrozenLayers extends BaseSparkTest { boolean isFrozen = entry.getKey().startsWith("0_") || entry.getKey().startsWith("1_"); if (isFrozen) { - //Layer should be frozen -> no change + //ILayer should be frozen -> no change assertEquals(orig, now, entry.getKey()); } else { //Not frozen -> should be different @@ -195,7 +195,7 @@ public class TestFrozenLayers extends BaseSparkTest { boolean isFrozen = entry.getKey().startsWith("0_") || entry.getKey().startsWith("1_"); if (isFrozen) { - //Layer should be frozen -> no change + //ILayer should be frozen -> no change assertEquals(orig, now, entry.getKey()); } else { //Not frozen -> should be different diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index 48a30034a..c2c24a617 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -835,12 +835,12 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { JavaRDD rdd = sc.parallelize(list); - assertEquals(0, sparkNet.getNetwork().getConfiguration().getIterationCount()); + assertEquals(0, sparkNet.getNetwork().getComputationGraphConfiguration().getIterationCount()); sparkNet.fit(rdd); - assertEquals(minibatchesPerWorkerPerEpoch, sparkNet.getNetwork().getConfiguration().getIterationCount()); + assertEquals(minibatchesPerWorkerPerEpoch, sparkNet.getNetwork().getComputationGraphConfiguration().getIterationCount()); sparkNet.fit(rdd); assertEquals(2 * minibatchesPerWorkerPerEpoch, - sparkNet.getNetwork().getConfiguration().getIterationCount()); + sparkNet.getNetwork().getComputationGraphConfiguration().getIterationCount()); sparkNet.getTrainingMaster().deleteTempFiles(sc); } @@ -1076,11 +1076,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { for(int i=0; i<3; i++ ){ assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount()); - assertEquals(i, sn2.getNetwork().getConfiguration().getEpochCount()); + assertEquals(i, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount()); sn1.fit(rdd); sn2.fit(rdd); assertEquals(i+1, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount()); - assertEquals(i+1, sn2.getNetwork().getConfiguration().getEpochCount()); + assertEquals(i+1, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount()); } } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java index a9e2a213b..7e521f0c1 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java @@ -375,8 +375,8 @@ public class SharedTrainingWrapper { ((MultiLayerNetwork) model).setIterationCount(ModelParameterServer.getInstance().getStartPosition().getFirst()); ((MultiLayerNetwork) model).setEpochCount(ModelParameterServer.getInstance().getStartPosition().getSecond()); } else if (originalModel instanceof ComputationGraph) { - ((ComputationGraph) model).getConfiguration().setIterationCount(ModelParameterServer.getInstance().getStartPosition().getFirst()); - ((ComputationGraph) model).getConfiguration().setEpochCount(ModelParameterServer.getInstance().getStartPosition().getSecond()); + ((ComputationGraph) model).getComputationGraphConfiguration().setIterationCount(ModelParameterServer.getInstance().getStartPosition().getFirst()); + ((ComputationGraph) model).getComputationGraphConfiguration().setEpochCount(ModelParameterServer.getInstance().getStartPosition().getSecond()); } // if we're going to extend iteratation for debugging purposes - let's do that here @@ -421,7 +421,7 @@ public class SharedTrainingWrapper { // ok. attaching accumulator to model if (model instanceof ComputationGraph) { - ((ComputationGraph) originalModel).getConfiguration() + ((ComputationGraph) originalModel).getComputationGraphConfiguration() .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); ((ComputationGraph) originalModel).setGradientsAccumulator(accumulator); } else if (model instanceof MultiLayerNetwork) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java index f0b6bc151..1a11d70a5 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java @@ -295,7 +295,7 @@ public class SharedTrainingMaster extends BaseTrainingMaster layerNames = getlayerNames(); for (String s : layerNames) { @@ -728,7 +728,7 @@ public class SbeStatsReport implements StatsReport, AgronaPersistable { pne.next().paramName(s); } - //Layer names + //ILayer names List layerNames = getlayerNames(); UpdateEncoder.LayerNamesEncoder lne = ue.layerNamesCount(layerNames.size()); for (String s : layerNames) { diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java index ff6f00901..274e670f6 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java @@ -182,7 +182,7 @@ public class TrainModuleUtils { long inputSize = (i == 0 ? va.getNIn() : encLayerSizes[i - 1]); long outputSize = encLayerSizes[i]; encoderInfo.put("Input Size", String.valueOf(inputSize)); - encoderInfo.put("Layer Size", String.valueOf(outputSize)); + encoderInfo.put("ILayer Size", String.valueOf(outputSize)); encoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize)); encoderInfo.put("Activation Function", va.getActivationFn().toString()); layerInfo.add(encoderInfo); @@ -197,7 +197,7 @@ public class TrainModuleUtils { long inputSize = encLayerSizes[encLayerSizes.length - 1]; long outputSize = va.getNOut(); latentInfo.put("Input Size", String.valueOf(inputSize)); - latentInfo.put("Layer Size", String.valueOf(outputSize)); + latentInfo.put("ILayer Size", String.valueOf(outputSize)); latentInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize * 2)); latentInfo.put("Activation Function", va.getPzxActivationFn().toString()); layerInfo.add(latentInfo); @@ -216,7 +216,7 @@ public class TrainModuleUtils { inputSize = (i == 0 ? va.getNOut() : decLayerSizes[i - 1]); outputSize = decLayerSizes[i]; decoderInfo.put("Input Size", String.valueOf(inputSize)); - decoderInfo.put("Layer Size", String.valueOf(outputSize)); + decoderInfo.put("ILayer Size", String.valueOf(outputSize)); decoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize)); decoderInfo.put("Activation Function", va.getActivationFn().toString()); layerInfo.add(decoderInfo); @@ -231,7 +231,7 @@ public class TrainModuleUtils { inputSize = decLayerSizes[decLayerSizes.length - 1]; outputSize = va.getNIn(); reconstructionInfo.put("Input Size", String.valueOf(inputSize)); - reconstructionInfo.put("Layer Size", String.valueOf(outputSize)); + reconstructionInfo.put("ILayer Size", String.valueOf(outputSize)); reconstructionInfo.put("Num Parameters", String .valueOf((inputSize + 1) * va.getOutputDistribution().distributionInputSize((int) va.getNIn()))); reconstructionInfo.put("Distribution", va.getOutputDistribution().toString()); diff --git a/cavis-ui/cavis-ui-vertx/src/main/resources/templates/TrainingModel.html.ftl b/cavis-ui/cavis-ui-vertx/src/main/resources/templates/TrainingModel.html.ftl index 859aae287..51d63af6b 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/resources/templates/TrainingModel.html.ftl +++ b/cavis-ui/cavis-ui-vertx/src/main/resources/templates/TrainingModel.html.ftl @@ -103,7 +103,7 @@

- +
@@ -179,7 +179,7 @@
- +
@@ -244,7 +244,7 @@ - + diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java index a61ae386d..44d9dff3c 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java @@ -65,7 +65,7 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true); - assertEquals(net.getConfiguration(), restored.getConfiguration()); + assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); assertEquals(net.params(), restored.params()); return restored; diff --git a/settings.gradle b/settings.gradle index 80b29bef8..d7875c751 100644 --- a/settings.gradle +++ b/settings.gradle @@ -100,7 +100,7 @@ include ':cavis-dnn:cavis-dnn-data:cavis-dnn-data-utility-iterators' include ':cavis-dnn:cavis-dnn-modelimport' include ':cavis-dnn:cavis-dnn-nlp' include ':cavis-dnn:cavis-dnn-nn' -include ':cavis-dnn:cavis-dnn-nn-api' +//include ':cavis-dnn:cavis-dnn-nn-api' include ':cavis-dnn:cavis-dnn-nn-parent' include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-server' include ':cavis-dnn:cavis-dnn-nn-parent:cavis-dnn-nn-client' From 3edb90dbd11545b22b51f6eba1a0f67a4f7fee92 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 7 Apr 2023 14:28:47 +0200 Subject: [PATCH 121/126] Playing with some new code 2 Signed-off-by: brian --- .../TupleStreamDataSetIteratorTest.java | 2 +- .../ModelTupleStreamIntegrationTest.java | 4 +- .../solr/handler/ModelTupleStreamTest.java | 6 +- .../solr/ltr/model/ScoringModelTest.java | 6 +- .../remote/JsonModelServerTest.java | 10 +- .../pw/SharedTrainingWrapper.java | 2 +- .../training/SharedTrainingMaster.java | 2 +- .../training/SharedTrainingWorker.java | 4 +- .../spark/parameterserver/BaseSparkTest.java | 6 +- .../train/GradientSharingTrainingTest.java | 12 +- .../spark/api/worker/NetBroadcastTuple.java | 10 +- ...eVaeReconstructionProbWithKeyFunction.java | 2 +- .../score/BaseVaeScoreWithKeyFunction.java | 2 +- .../impl/evaluation/EvaluationRunner.java | 4 +- ...VaeReconstructionErrorWithKeyFunction.java | 2 +- ...GVaeReconstructionProbWithKeyFunction.java | 2 +- .../impl/multilayer/SparkDl4jMultiLayer.java | 16 +- .../scoring/FeedForwardWithKeyFunction.java | 6 +- .../scoring/ScoreExamplesFunction.java | 4 +- .../scoring/ScoreExamplesWithKeyFunction.java | 6 +- .../scoring/ScoreFlatMapFunction.java | 4 +- ...VaeReconstructionErrorWithKeyFunction.java | 6 +- .../VaeReconstructionProbWithKeyFunction.java | 6 +- .../ParameterAveragingTrainingMaster.java | 6 +- .../deeplearning4j/spark/BaseSparkTest.java | 6 +- .../spark/TestEarlyStoppingSpark.java | 12 +- .../TestEarlyStoppingSparkCompGraph.java | 10 +- .../org/deeplearning4j/spark/TestKryo.java | 6 +- .../spark/datavec/TestPreProcessedData.java | 8 +- .../spark/impl/TestKryoWarning.java | 6 +- .../impl/customlayer/TestCustomLayer.java | 6 +- .../impl/graph/TestSparkComputationGraph.java | 10 +- .../spark/impl/misc/TestFrozenLayers.java | 4 +- .../impl/multilayer/TestMiscFunctions.java | 12 +- .../multilayer/TestSparkDl4jMultiLayer.java | 4 +- ...arameterAveragingSparkVsSingleMachine.java | 16 +- ...TestSparkMultiLayerParameterAveraging.java | 52 +- .../stats/TestTrainingStatsCollection.java | 4 +- .../spark/ui/TestListeners.java | 4 +- .../network/MultiLayerNetworkHandler.java | 6 +- .../ActorCriticFactoryCompGraphStdConv.java | 2 +- .../ActorCriticFactoryCompGraphStdDense.java | 2 +- .../ActorCriticFactorySeparateStdDense.java | 10 +- .../rl4j/network/ac/ActorCriticSeparate.java | 10 +- .../deeplearning4j/rl4j/network/dqn/DQN.java | 6 +- .../rl4j/network/dqn/DQNFactoryStdConv.java | 8 +- .../rl4j/network/dqn/DQNFactoryStdDense.java | 6 +- .../org/deeplearning4j/rl4j/NStepRnn.java | 2 +- .../deeplearning4j/rl4j/RobotLakeExample.java | 2 +- .../org/deeplearning4j/rl4j/TMazeExample.java | 2 +- .../network/MultiLayerNetworkHandlerTest.java | 8 +- .../rl4j/policy/PolicyTest.java | 4 +- README.md | 6 +- .../src/test/java/net/brutex/gan/App.java | 76 +- .../src/test/java/net/brutex/gan/GAN.java | 15 +- .../net/brutex/gan/MnistDCGANExample.java | 10 +- .../java/net/brutex/gan/MnistSimpleGAN.java | 9 +- .../test/java/net/brutex/spark/BrianTest.java | 7 +- .../java/net/brutex/spark/BrianTest2.java | 5 +- .../java/net/brutex/spark/TestServer.java | 18 +- .../java/net/brutex/spark/TestServer2.java | 8 +- .../IntegrationTestBaselineGenerator.java | 10 +- .../integration/IntegrationTestRunner.java | 40 +- .../deeplearning4j/integration/TestCase.java | 4 +- .../deeplearning4j/integration/TestUtils.java | 8 +- .../testcases/dl4j/CNN1DTestCases.java | 4 +- .../testcases/dl4j/CNN2DTestCases.java | 20 +- .../testcases/dl4j/CNN3DTestCases.java | 7 +- .../testcases/dl4j/MLPTestCases.java | 11 +- .../testcases/dl4j/RNNTestCases.java | 18 +- .../testcases/dl4j/UnsupervisedTestCases.java | 4 +- build.gradle | 5 +- .../net/brutex/ai/dnn/core/util/ANSI.java | 52 + .../listener/SystemInfoFilePrintListener.java | 16 +- .../listener/SystemInfoPrintListener.java | 16 +- .../core/util/ModelGuesser.java | 14 +- .../LayerHelperValidationUtil.java | 15 +- .../java/org/deeplearning4j/RandomTests.java | 8 +- .../java/org/deeplearning4j/TestUtils.java | 12 +- .../iterator/DataSetIteratorTest.java | 15 +- .../earlystopping/TestEarlyStopping.java | 81 +- .../TestEarlyStoppingCompGraph.java | 22 +- .../org/deeplearning4j/eval/EvalTest.java | 28 +- .../eval/EvaluationToolsTests.java | 5 +- .../java/org/deeplearning4j/eval/ROCTest.java | 6 +- .../eval/RegressionEvalTest.java | 5 +- .../exceptions/TestInvalidConfigurations.java | 41 +- .../exceptions/TestInvalidInput.java | 29 +- .../gradientcheck/AttentionLayerTest.java | 25 +- .../gradientcheck/BNGradientCheckTest.java | 76 +- .../gradientcheck/CNN1DGradientCheckTest.java | 37 +- .../gradientcheck/CNN3DGradientCheckTest.java | 38 +- .../gradientcheck/CNNGradientCheckTest.java | 110 +- .../CapsnetGradientCheckTest.java | 10 +- .../gradientcheck/DropoutGradientCheck.java | 13 +- .../GlobalPoolingGradientCheckTests.java | 13 +- .../gradientcheck/GradientCheckTests.java | 27 +- .../GradientCheckTestsComputationGraph.java | 58 +- .../GradientCheckTestsMasking.java | 19 +- .../gradientcheck/LRNGradientCheckTests.java | 6 +- .../gradientcheck/LSTMGradientCheckTests.java | 46 +- .../LossFunctionGradientCheck.java | 9 +- .../NoBiasGradientCheckTests.java | 13 +- .../OutputLayerGradientChecks.java | 13 +- .../gradientcheck/RnnGradientChecks.java | 13 +- .../UtilLayerGradientChecks.java | 15 +- .../gradientcheck/VaeGradientCheckTests.java | 11 +- .../gradientcheck/YoloGradientCheckTests.java | 8 +- .../ComputationGraphConfigurationTest.java | 30 +- .../org/deeplearning4j/nn/conf/JsonTest.java | 6 +- .../MultiLayerNeuralNetConfigurationTest.java | 728 ++--- .../MultiNeuralNetConfLayerBuilderTest.java | 16 +- .../nn/conf/NeuralNetConfigurationTest.java | 60 +- .../nn/conf/constraints/TestConstraints.java | 41 +- .../nn/conf/dropout/TestDropout.java | 39 +- .../nn/conf/graph/ElementWiseVertexTest.java | 12 +- .../nn/conf/graph/ShiftVertexTest.java | 4 +- .../nn/conf/layers/LayerBuilderTest.java | 14 +- .../nn/conf/layers/LayerConfigTest.java | 61 +- .../layers/LayerConfigValidationTest.java | 49 +- .../conf/preprocessor/CNNProcessorTest.java | 43 +- .../preprocessor/CustomPreprocessorTest.java | 15 +- .../conf/preprocessor/TestPreProcessors.java | 47 +- .../nn/conf/weightnoise/TestWeightNoise.java | 29 +- .../deeplearning4j/nn/dtypes/DTypeTests.java | 86 +- .../nn/graph/ComputationGraphTestRNN.java | 32 +- .../nn/graph/TestCompGraphCNN.java | 4 +- .../nn/graph/TestCompGraphUnsupervised.java | 7 +- .../nn/graph/TestComputationGraphNetwork.java | 162 +- .../nn/graph/TestSetGetParameters.java | 2 +- .../nn/graph/TestVariableLengthTSCG.java | 10 +- .../nn/graph/graphnodes/TestGraphNodes.java | 10 +- .../nn/layers/ActivationLayerTest.java | 59 +- .../nn/layers/AutoEncoderTest.java | 2 +- .../nn/layers/BaseLayerTest.java | 9 +- .../nn/layers/CacheModeTest.java | 20 +- .../nn/layers/CenterLossOutputLayerTest.java | 4 +- .../nn/layers/DropoutLayerTest.java | 27 +- .../nn/layers/FrozenLayerTest.java | 84 +- .../layers/FrozenLayerWithBackpropTest.java | 28 +- .../nn/layers/OutputLayerTest.java | 49 +- .../nn/layers/RepeatVectorTest.java | 6 +- .../deeplearning4j/nn/layers/SeedTest.java | 6 +- .../deeplearning4j/nn/layers/TestDropout.java | 8 +- .../nn/layers/capsule/CapsNetMNISTTest.java | 5 +- .../nn/layers/capsule/CapsuleLayerTest.java | 5 +- .../capsule/CapsuleStrengthLayerTest.java | 5 +- .../layers/capsule/PrimaryCapsulesTest.java | 5 +- .../convolution/ConvDataFormatTests.java | 22 +- .../layers/convolution/Convolution3DTest.java | 8 +- .../ConvolutionLayerSetupTest.java | 190 +- .../convolution/ConvolutionLayerTest.java | 130 +- .../LocallyConnectedLayerTest.java | 16 +- .../layers/convolution/SpaceToDepthTest.java | 6 +- .../convolution/SubsamplingLayerTest.java | 38 +- .../convolution/TestConvolutionModes.java | 39 +- .../layers/convolution/Upsampling1DTest.java | 6 +- .../layers/convolution/Upsampling2DTest.java | 6 +- .../layers/custom/TestCustomActivation.java | 13 +- .../nn/layers/custom/TestCustomLayers.java | 37 +- .../custom/testclasses/CustomLayer.java | 4 +- .../custom/testclasses/CustomOutputLayer.java | 5 +- .../layers/feedforward/dense/DenseTest.java | 9 +- .../embedding/EmbeddingLayerTest.java | 72 +- .../normalization/BatchNormalizationTest.java | 55 +- .../normalization/LocalResponseTest.java | 11 +- .../objdetect/TestYolo2OutputLayer.java | 11 +- .../nn/layers/ocnn/OCNNOutputLayerTest.java | 19 +- .../pooling/GlobalPoolingMaskingTests.java | 17 +- .../layers/recurrent/BidirectionalTest.java | 22 +- .../GravesBidirectionalLSTMTest.java | 59 +- .../nn/layers/recurrent/GravesLSTMTest.java | 27 +- .../layers/recurrent/MaskZeroLayerTest.java | 6 +- .../layers/recurrent/RnnDataFormatTests.java | 6 +- .../recurrent/TestLastTimeStepLayer.java | 4 +- .../recurrent/TestRecurrentWeightInit.java | 2 +- .../nn/layers/recurrent/TestRnnLayers.java | 24 +- .../nn/layers/recurrent/TestSimpleRnn.java | 5 +- .../layers/recurrent/TestTimeDistributed.java | 17 +- .../samediff/SameDiffCustomLayerTests.java | 7 +- .../nn/layers/samediff/TestSameDiffConv.java | 15 +- .../nn/layers/samediff/TestSameDiffDense.java | 39 +- .../samediff/TestSameDiffDenseVertex.java | 4 +- .../layers/samediff/TestSameDiffLambda.java | 8 +- .../layers/samediff/TestSameDiffOutput.java | 11 +- .../testlayers/MinimalSameDiffDense.java | 2 +- .../samediff/testlayers/SameDiffConv.java | 7 +- .../samediff/testlayers/SameDiffDense.java | 5 +- .../testlayers/SameDiffMSEOutputLayer.java | 2 +- .../nn/layers/variational/TestVAE.java | 46 +- .../nn/misc/CloseNetworkTests.java | 7 +- .../deeplearning4j/nn/misc/LargeNetTest.java | 8 +- .../deeplearning4j/nn/misc/TestLrChanges.java | 36 +- .../nn/misc/TestMemoryReports.java | 21 +- .../nn/misc/TestNetConversion.java | 13 +- .../nn/misc/WorkspaceTests.java | 58 +- .../nn/mkldnn/ValidateMKLDNN.java | 15 +- .../nn/multilayer/BackPropMLPTest.java | 11 +- .../nn/multilayer/MultiLayerTest.java | 2745 +++++++++-------- .../nn/multilayer/MultiLayerTestRNN.java | 62 +- .../nn/multilayer/TestMasking.java | 15 +- .../nn/multilayer/TestSetGetParameters.java | 21 +- .../nn/multilayer/TestVariableLengthTS.java | 24 +- .../rl/TestMultiModelGradientApplication.java | 13 +- .../nn/transferlearning/TestFrozenLayers.java | 11 +- .../TestTransferLearningModelSerializer.java | 15 +- .../TransferLearningCompGraphTest.java | 35 +- .../TransferLearningComplex.java | 15 +- .../TransferLearningHelperTest.java | 46 +- .../TransferLearningMLNTest.java | 134 +- .../nn/updater/TestGradientNormalization.java | 30 +- .../nn/updater/TestUpdaters.java | 83 +- .../nn/updater/custom/TestCustomUpdater.java | 7 +- .../nn/weights/WeightInitIdentityTest.java | 6 +- .../solver/BackTrackLineSearchTest.java | 15 +- .../optimize/solver/TestOptimizers.java | 51 +- .../listener/TestCheckpointListener.java | 3 +- .../listener/TestFailureListener.java | 7 +- .../optimizer/listener/TestListeners.java | 27 +- .../parallelism/RandomTests.java | 13 +- .../listener/TestSystemInfoPrintListener.java | 3 +- .../regressiontest/MiscRegressionTests.java | 9 +- .../regressiontest/RegressionTest050.java | 14 +- .../regressiontest/RegressionTest060.java | 24 +- .../regressiontest/RegressionTest071.java | 24 +- .../regressiontest/RegressionTest080.java | 24 +- .../regressiontest/RegressionTest100a.java | 20 +- .../regressiontest/RegressionTest100b3.java | 22 +- .../regressiontest/RegressionTest100b4.java | 50 +- .../regressiontest/RegressionTest100b6.java | 50 +- .../customlayer100a/CustomLayer.java | 4 +- .../customlayer100a/CustomLayerImpl.java | 6 +- .../CompareTrainingImplementations.java | 7 +- .../util/CrashReportingUtilTest.java | 7 +- .../deeplearning4j/util/ModelGuesserTest.java | 27 +- .../util/ModelSerializerTest.java | 29 +- .../util/ModelValidatorTests.java | 5 +- .../nn/modelimport/keras/KerasLayer.java | 21 +- .../nn/modelimport/keras/KerasModel.java | 14 +- .../modelimport/keras/KerasModelImport.java | 10 +- .../keras/KerasSequentialModel.java | 40 +- .../modelimport/keras/layers/TFOpLayer.java | 8 +- .../keras/layers/TFOpLayerImpl.java | 3 +- .../keras/layers/recurrent/KerasLSTM.java | 2 +- .../layers/recurrent/KerasSimpleRnn.java | 4 +- .../layers/wrappers/KerasBidirectional.java | 11 +- .../keras/utils/KerasLayerUtils.java | 8 +- .../keras/utils/KerasModelUtils.java | 6 +- .../configurations/FullModelComparisons.java | 5 +- .../Keras1ModelConfigurationTest.java | 6 +- .../Keras2ModelConfigurationTest.java | 11 +- .../configurations/KerasModelImportTest.java | 12 +- .../keras/e2e/KerasCustomLayerTest.java | 3 +- .../keras/e2e/KerasModelEndToEndTest.java | 10 +- .../models/word2vec/Word2VecTestsSmall.java | 6 +- cavis-dnn/cavis-dnn-nn/build.gradle | 3 +- .../java/net/brutex/ai/dnn/api/Animal.java | 68 + .../ai/dnn/api/IActivationFunction.java | 57 + .../java/net/brutex/ai/dnn/api/IModel.java | 244 +- .../net/brutex/ai/dnn/api/INeuralNetwork.java | 122 - .../dnn/api/INeuralNetworkConfiguration.java | 6 +- .../java/net/brutex/ai/dnn/api/IUnit.java | 47 + .../java/net/brutex/ai/dnn/api/LayerType.java | 52 + .../main/java/net/brutex/ai/dnn/api/NN.java | 42 + .../dnn/conf/NeuralNetworkConfiguration.java | 705 ----- .../conf/layer/DenseLayerConfiguration.java | 62 - .../layer/FeedForwardLayerConfiguration.java | 99 - .../dnn/networks/ArtificialNeuralNetwork.java | 52 +- .../EarlyStoppingConfiguration.java | 6 +- .../EarlyStoppingModelSaver.java | 4 +- .../earlystopping/EarlyStoppingResult.java | 4 +- .../listener/EarlyStoppingListener.java | 4 +- .../saver/InMemoryModelSaver.java | 4 +- .../scorecalc/AutoencoderScoreCalculator.java | 12 +- .../ClassificationScoreCalculator.java | 4 +- .../scorecalc/DataSetLossCalculator.java | 10 +- .../scorecalc/ROCScoreCalculator.java | 4 +- .../scorecalc/RegressionScoreCalculator.java | 4 +- .../scorecalc/ScoreCalculator.java | 4 +- .../VAEReconErrorScoreCalculator.java | 12 +- .../VAEReconProbScoreCalculator.java | 12 +- .../base/BaseIEvaluationScoreCalculator.java | 4 +- .../scorecalc/base/BaseScoreCalculator.java | 4 +- .../trainer/BaseEarlyStoppingTrainer.java | 6 +- .../trainer/EarlyStoppingTrainer.java | 5 +- .../trainer/IEarlyStoppingTrainer.java | 4 +- .../gradientcheck/GradientCheckUtil.java | 41 +- .../nn/adapters/YoloModelAdapter.java | 4 +- .../nn/api/AbstractParamInitializer.java} | 22 +- .../org/deeplearning4j/nn/api/Classifier.java | 3 +- .../java/org/deeplearning4j/nn/api/Layer.java | 41 +- .../java/org/deeplearning4j/nn/api/Model.java | 237 -- .../deeplearning4j/nn/api/ModelAdapter.java | 3 +- .../nn/api/ParamInitializer.java | 24 +- .../org/deeplearning4j/nn/api/Trainable.java | 25 +- .../nn/api/layers/RecurrentLayer.java | 2 - .../conf/ComputationGraphConfiguration.java | 85 +- .../nn/conf/MultiLayerConfiguration.java | 841 ----- .../NeuralNetBaseBuilderConfiguration.java | 1021 ++++++ .../nn/conf/NeuralNetConfiguration.java | 2163 +++++++------ .../nn/conf/constraint/BaseConstraint.java | 4 +- .../nn/conf/graph/LayerVertex.java | 38 +- .../nn/conf/layers/ActivationLayer.java | 13 +- .../nn/conf/layers/AutoEncoder.java | 9 +- .../nn/conf/layers/BaseLayer.java | 5 +- .../nn/conf/layers/BaseUpsamplingLayer.java | 4 +- .../nn/conf/layers/BatchNormalization.java | 12 +- .../nn/conf/layers/CapsuleLayer.java | 2 +- .../nn/conf/layers/CapsuleStrengthLayer.java | 2 +- .../nn/conf/layers/CenterLossOutputLayer.java | 10 +- .../nn/conf/layers/Cnn3DLossLayer.java | 8 +- .../nn/conf/layers/CnnLossLayer.java | 8 +- .../nn/conf/layers/Convolution1DLayer.java | 9 +- .../nn/conf/layers/Convolution3D.java | 8 +- .../nn/conf/layers/ConvolutionLayer.java | 11 +- .../nn/conf/layers/Deconvolution2D.java | 9 +- .../nn/conf/layers/Deconvolution3D.java | 10 +- .../nn/conf/layers/DenseLayer.java | 15 +- .../conf/layers/DepthwiseConvolution2D.java | 8 +- .../nn/conf/layers/DropoutLayer.java | 13 +- .../nn/conf/layers/EmbeddingLayer.java | 9 +- .../conf/layers/EmbeddingSequenceLayer.java | 8 +- .../nn/conf/layers/FeedForwardLayer.java | 3 + .../nn/conf/layers/GlobalPoolingLayer.java | 10 +- .../conf/layers/GravesBidirectionalLSTM.java | 10 +- .../nn/conf/layers/GravesLSTM.java | 10 +- .../deeplearning4j/nn/conf/layers/LSTM.java | 9 +- .../{Layer.java => LayerConfiguration.java} | 64 +- .../nn/conf/layers/LayerValidation.java | 6 +- .../layers/LocalResponseNormalization.java | 14 +- .../nn/conf/layers/LocallyConnected1D.java | 7 +- .../nn/conf/layers/LocallyConnected2D.java | 7 +- .../nn/conf/layers/LossLayer.java | 8 +- .../nn/conf/layers/NoParamLayer.java | 5 +- .../nn/conf/layers/OutputLayer.java | 7 +- .../nn/conf/layers/PReLULayer.java | 7 +- .../nn/conf/layers/PrimaryCapsules.java | 2 +- .../conf/layers/RecurrentAttentionLayer.java | 4 +- .../nn/conf/layers/RnnLossLayer.java | 9 +- .../nn/conf/layers/RnnOutputLayer.java | 7 +- .../conf/layers/SeparableConvolution2D.java | 9 +- .../nn/conf/layers/SpaceToBatchLayer.java | 10 +- .../nn/conf/layers/SpaceToDepthLayer.java | 10 +- .../nn/conf/layers/Subsampling1DLayer.java | 8 +- .../nn/conf/layers/Subsampling3DLayer.java | 10 +- .../nn/conf/layers/SubsamplingLayer.java | 10 +- .../nn/conf/layers/Upsampling1D.java | 9 +- .../nn/conf/layers/Upsampling2D.java | 8 +- .../nn/conf/layers/Upsampling3D.java | 10 +- .../nn/conf/layers/ZeroPadding1DLayer.java | 10 +- .../nn/conf/layers/ZeroPadding3DLayer.java | 10 +- .../nn/conf/layers/ZeroPaddingLayer.java | 10 +- .../conf/layers/convolutional/Cropping1D.java | 12 +- .../conf/layers/convolutional/Cropping2D.java | 12 +- .../conf/layers/convolutional/Cropping3D.java | 12 +- .../misc/ElementWiseMultiplicationLayer.java | 10 +- .../nn/conf/layers/misc/FrozenLayer.java | 66 +- .../layers/misc/FrozenLayerWithBackprop.java | 25 +- .../nn/conf/layers/misc/RepeatVector.java | 9 +- .../layers/objdetect/Yolo2OutputLayer.java | 13 +- .../conf/layers/recurrent/Bidirectional.java | 28 +- .../conf/layers/recurrent/LastTimeStep.java | 9 +- .../nn/conf/layers/recurrent/SimpleRnn.java | 8 +- .../layers/recurrent/TimeDistributed.java | 10 +- .../samediff/AbstractSameDiffLayer.java | 23 +- .../conf/layers/samediff/SameDiffLayer.java | 10 +- .../layers/samediff/SameDiffOutputLayer.java | 9 +- .../conf/layers/samediff/SameDiffVertex.java | 19 +- .../nn/conf/layers/util/MaskLayer.java | 8 +- .../nn/conf/layers/util/MaskZeroLayer.java | 14 +- .../variational/VariationalAutoencoder.java | 8 +- .../conf/layers/wrapper/BaseWrapperLayer.java | 26 +- .../nn/conf/ocnn/OCNNOutputLayer.java | 8 +- .../conf/serde/BaseNetConfigDeserializer.java | 28 +- ...utationGraphConfigurationDeserializer.java | 14 +- .../nn/conf/serde/JsonMappers.java | 6 +- ...> NeuralNetConfigurationDeserializer.java} | 25 +- .../conf/serde/legacy/LegacyJsonFormat.java | 2 +- .../nn/conf/weightnoise/DropConnect.java | 6 +- .../nn/conf/weightnoise/WeightNoise.java | 6 +- .../nn/graph/ComputationGraph.java | 124 +- .../nn/graph/vertex/BaseGraphVertex.java | 12 +- .../nn/graph/vertex/BaseWrapperVertex.java | 4 +- .../nn/graph/vertex/GraphVertex.java | 2 +- .../nn/graph/vertex/impl/FrozenVertex.java | 21 + .../nn/graph/vertex/impl/LayerVertex.java | 6 +- .../nn/layers/AbstractLayer.java | 1133 ++++--- .../nn/layers/ActivationLayer.java | 3 +- .../deeplearning4j/nn/layers/BaseLayer.java | 968 +++--- .../nn/layers/BaseOutputLayer.java | 9 +- .../nn/layers/BasePretrainNetwork.java | 30 +- .../nn/layers/DropoutLayer.java | 3 +- .../deeplearning4j/nn/layers/FrozenLayer.java | 6 +- .../nn/layers/FrozenLayerWithBackprop.java | 2 +- .../deeplearning4j/nn/layers/LossLayer.java | 3 +- .../deeplearning4j/nn/layers/OutputLayer.java | 3 +- .../nn/layers/RepeatVector.java | 3 +- .../nn/layers/convolution/Cnn3DLossLayer.java | 15 +- .../nn/layers/convolution/CnnLossLayer.java | 3 +- .../convolution/Convolution1DLayer.java | 9 +- .../convolution/Convolution3DLayer.java | 7 +- .../layers/convolution/ConvolutionLayer.java | 15 +- .../layers/convolution/Cropping1DLayer.java | 8 +- .../layers/convolution/Cropping2DLayer.java | 8 +- .../layers/convolution/Cropping3DLayer.java | 7 +- .../convolution/Deconvolution2DLayer.java | 7 +- .../convolution/Deconvolution3DLayer.java | 5 +- .../DepthwiseConvolution2DLayer.java | 7 +- .../SeparableConvolution2DLayer.java | 7 +- .../nn/layers/convolution/SpaceToBatch.java | 3 +- .../nn/layers/convolution/SpaceToDepth.java | 3 +- .../convolution/ZeroPadding1DLayer.java | 7 +- .../convolution/ZeroPadding3DLayer.java | 7 +- .../layers/convolution/ZeroPaddingLayer.java | 5 +- .../subsampling/Subsampling1DLayer.java | 3 +- .../subsampling/Subsampling3DLayer.java | 5 +- .../subsampling/SubsamplingLayer.java | 5 +- .../convolution/upsampling/Upsampling1D.java | 5 +- .../convolution/upsampling/Upsampling2D.java | 3 +- .../convolution/upsampling/Upsampling3D.java | 3 +- .../nn/layers/feedforward/PReLU.java | 3 +- .../feedforward/autoencoder/AutoEncoder.java | 3 +- .../layers/feedforward/dense/DenseLayer.java | 3 +- .../ElementWiseMultiplicationLayer.java | 5 +- .../feedforward/embedding/EmbeddingLayer.java | 3 +- .../embedding/EmbeddingSequenceLayer.java | 3 +- .../nn/layers/mkldnn/MKLDNNLSTMHelper.java | 2 +- .../normalization/BatchNormalization.java | 14 +- .../LocalResponseNormalization.java | 6 +- .../nn/layers/objdetect/Yolo2OutputLayer.java | 3 +- .../nn/layers/ocnn/OCNNOutputLayer.java | 13 +- .../nn/layers/ocnn/OCNNParamInitializer.java | 32 +- .../nn/layers/pooling/GlobalPoolingLayer.java | 18 +- .../layers/recurrent/BaseRecurrentLayer.java | 3 +- .../layers/recurrent/BidirectionalLayer.java | 137 +- .../recurrent/GravesBidirectionalLSTM.java | 15 +- .../nn/layers/recurrent/GravesLSTM.java | 7 +- .../nn/layers/recurrent/LSTM.java | 8 +- .../nn/layers/recurrent/LSTMHelpers.java | 2 +- .../layers/recurrent/LastTimeStepLayer.java | 4 +- .../nn/layers/recurrent/RnnLossLayer.java | 3 +- .../nn/layers/recurrent/RnnOutputLayer.java | 3 +- .../nn/layers/recurrent/SimpleRnn.java | 3 +- .../layers/samediff/SameDiffGraphVertex.java | 2 +- .../nn/layers/samediff/SameDiffLayer.java | 13 +- .../layers/samediff/SameDiffOutputLayer.java | 13 +- .../training/CenterLossOutputLayer.java | 11 +- .../nn/layers/util/MaskLayer.java | 3 +- .../variational/VariationalAutoencoder.java | 169 +- .../nn/layers/wrapper/BaseWrapperLayer.java | 489 +-- .../nn/multilayer/MultiLayerNetwork.java | 563 ++-- .../BatchNormalizationParamInitializer.java | 50 +- .../params/BidirectionalParamInitializer.java | 48 +- .../nn/params/CenterLossParamInitializer.java | 18 +- .../params/Convolution3DParamInitializer.java | 22 +- .../params/ConvolutionParamInitializer.java | 50 +- .../Deconvolution3DParamInitializer.java | 23 +- .../params/DeconvolutionParamInitializer.java | 9 +- .../nn/params/DefaultParamInitializer.java | 74 +- .../DepthwiseConvolutionParamInitializer.java | 40 +- .../params/ElementWiseParamInitializer.java | 16 +- .../nn/params/EmptyParamInitializer.java | 26 +- .../params/FrozenLayerParamInitializer.java | 120 +- ...ozenLayerWithBackpropParamInitializer.java | 42 +- ...avesBidirectionalLSTMParamInitializer.java | 30 +- .../nn/params/GravesLSTMParamInitializer.java | 30 +- .../nn/params/LSTMParamInitializer.java | 38 +- .../nn/params/PReLUParamInitializer.java | 35 +- .../nn/params/PretrainParamInitializer.java | 17 +- .../nn/params/SameDiffParamInitializer.java | 30 +- .../SeparableConvolutionParamInitializer.java | 44 +- .../nn/params/SimpleRnnParamInitializer.java | 33 +- ...ariationalAutoencoderParamInitializer.java | 24 +- .../params/WrapperLayerParamInitializer.java | 56 +- .../FineTuneConfiguration.java | 1479 ++++----- .../nn/transferlearning/TransferLearning.java | 151 +- .../TransferLearningHelper.java | 15 +- .../nn/updater/BaseMultiLayerUpdater.java | 10 +- .../nn/updater/LayerUpdater.java | 4 +- .../nn/updater/MultiLayerUpdater.java | 2 +- .../nn/updater/UpdaterCreator.java | 4 +- .../graph/ComputationGraphUpdater.java | 5 +- .../org/deeplearning4j/optimize/Solver.java | 8 +- .../optimize/api/BaseTrainingListener.java | 16 +- .../optimize/api/ConvexOptimizer.java | 4 +- .../optimize/api/IterationListener.java | 4 +- .../optimize/api/TrainingListener.java | 24 +- .../listeners/CheckpointListener.java | 24 +- .../CollectScoresIterationListener.java | 4 +- .../listeners/CollectScoresListener.java | 4 +- .../ComposableIterationListener.java | 4 +- .../listeners/EvaluativeListener.java | 12 +- .../listeners/FailureTestingListener.java | 34 +- .../listeners/PerformanceListener.java | 6 +- .../listeners/ScoreIterationListener.java | 6 +- .../listeners/ScoreToChartListener.java | 4 +- .../listeners/SleepyTrainingListener.java | 16 +- .../listeners/TimeIterationListener.java | 6 +- .../callbacks/EvaluationCallback.java | 4 +- .../callbacks/ModelSavingCallback.java | 6 +- .../optimize/solvers/BackTrackLineSearch.java | 11 +- .../optimize/solvers/BaseOptimizer.java | 31 +- .../optimize/solvers/ConjugateGradient.java | 4 +- .../optimize/solvers/LBFGS.java | 4 +- .../optimize/solvers/LineGradientDescent.java | 4 +- .../solvers/StochasticGradientDescent.java | 4 +- .../EncodedGradientsAccumulator.java | 4 +- .../util/Convolution1DUtils.java | 5 +- .../deeplearning4j/util/ConvolutionUtils.java | 18 +- .../util/CrashReportingUtil.java | 24 +- .../util/DL4JModelValidator.java | 12 +- .../deeplearning4j/util/ModelSerializer.java | 30 +- .../org/deeplearning4j/util/NetworkUtils.java | 35 +- .../deeplearning4j/util/OutputLayerUtil.java | 8 +- .../deeplearning4j/util/TimeSeriesUtils.java | 4 +- .../main/resources/simplelogger.properties | 22 + .../java/net/brutex/ai/dnn/api/dnnTest.java | 23 +- .../brutex/ai/dnn/conf/layer/FFLayerTest.java | 11 - .../ParameterServerTrainer.java | 10 +- .../ParameterServerTrainerContext.java | 10 +- .../ParameterServerParallelWrapperTest.java | 8 +- .../EarlyStoppingParallelTrainer.java | 8 +- .../parallelism/InplaceParallelInference.java | 26 +- .../parallelism/ParallelInference.java | 30 +- .../parallelism/ParallelWrapper.java | 9 +- .../factory/DefaultTrainerContext.java | 10 +- .../factory/SymmetricTrainerContext.java | 11 +- .../parallelism/factory/TrainerContext.java | 10 +- .../parallelism/main/ParallelWrapperMain.java | 4 +- .../parallelism/trainer/DefaultTrainer.java | 18 +- .../parallelism/trainer/SymmetricTrainer.java | 4 +- .../parallelism/trainer/Trainer.java | 10 +- .../InplaceParallelInferenceTest.java | 8 +- .../parallelism/ParallelInferenceTest.java | 25 +- .../parallelism/ParallelWrapperTest.java | 10 +- .../parallelism/TestListeners.java | 31 +- .../TestParallelEarlyStopping.java | 7 +- .../TestParallelEarlyStoppingUI.java | 3 +- .../factory/DefaultTrainerContextTest.java | 11 +- .../factory/SymmetricTrainerContextTest.java | 8 +- .../main/ParallelWrapperMainTest.java | 8 +- .../spark/api/TrainingHook.java | 10 +- .../spark/api/worker/NetBroadcastTuple.java | 10 +- .../BaseSparkEarlyStoppingTrainer.java | 4 +- ...eVaeReconstructionProbWithKeyFunction.java | 2 +- .../score/BaseVaeScoreWithKeyFunction.java | 2 +- .../impl/evaluation/EvaluationRunner.java | 10 +- .../impl/graph/SparkComputationGraph.java | 2 +- ...VaeReconstructionErrorWithKeyFunction.java | 2 +- ...GVaeReconstructionProbWithKeyFunction.java | 2 +- .../impl/multilayer/SparkDl4jMultiLayer.java | 18 +- .../scoring/FeedForwardWithKeyFunction.java | 6 +- .../scoring/ScoreExamplesFunction.java | 5 +- .../scoring/ScoreExamplesWithKeyFunction.java | 6 +- .../scoring/ScoreFlatMapFunction.java | 4 +- ...VaeReconstructionErrorWithKeyFunction.java | 9 +- .../VaeReconstructionProbWithKeyFunction.java | 6 +- .../ParameterAveragingTrainingMaster.java | 6 +- .../ParameterAveragingTrainingWorker.java | 4 +- .../deeplearning4j/spark/BaseSparkTest.java | 7 +- .../spark/TestEarlyStoppingSpark.java | 11 +- .../TestEarlyStoppingSparkCompGraph.java | 10 +- .../org/deeplearning4j/spark/TestKryo.java | 13 +- .../spark/datavec/TestPreProcessedData.java | 7 +- .../spark/impl/TestKryoWarning.java | 5 +- .../impl/customlayer/TestCustomLayer.java | 5 +- .../impl/customlayer/layer/CustomLayer.java | 4 +- .../impl/graph/TestSparkComputationGraph.java | 10 +- .../spark/impl/misc/TestFrozenLayers.java | 23 +- .../impl/multilayer/TestMiscFunctions.java | 11 +- .../multilayer/TestSparkDl4jMultiLayer.java | 4 +- ...arameterAveragingSparkVsSingleMachine.java | 16 +- ...TestSparkMultiLayerParameterAveraging.java | 63 +- .../stats/TestTrainingStatsCollection.java | 4 +- .../spark/ui/TestListeners.java | 3 +- .../ParameterServerTrainingHook.java | 10 +- .../pw/SharedTrainingWrapper.java | 8 +- .../training/SharedTrainingMaster.java | 2 +- .../training/SharedTrainingWorker.java | 4 +- .../spark/parameterserver/BaseSparkTest.java | 7 +- .../train/GradientSharingTrainingTest.java | 18 +- .../deeplearning4j/plot/BarnesHutTsne.java | 8 +- .../ConvolutionalIterationListener.java | 14 +- .../org/deeplearning4j/ui/ManualTests.java | 24 +- .../ui/weights/TestConvolutionalListener.java | 7 +- .../ui/model/stats/BaseStatsListener.java | 52 +- .../ui/stats/TestStatsListener.java | 7 +- .../ui/stats/TestTransferStatsCollection.java | 3 +- .../ui/module/train/TrainModule.java | 27 +- .../ui/module/train/TrainModuleUtils.java | 30 +- .../deeplearning4j/ui/TestRemoteReceiver.java | 5 +- .../org/deeplearning4j/ui/TestVertxUI.java | 14 +- .../deeplearning4j/ui/TestVertxUIManual.java | 9 +- .../ui/TestVertxUIMultiSession.java | 9 +- .../deeplearning4j/zoo/InstantiableModel.java | 6 +- .../java/org/deeplearning4j/zoo/ZooModel.java | 6 +- .../org/deeplearning4j/zoo/model/AlexNet.java | 21 +- .../deeplearning4j/zoo/model/Darknet19.java | 13 +- .../zoo/model/FaceNetNN4Small2.java | 9 +- .../zoo/model/InceptionResNetV1.java | 12 +- .../org/deeplearning4j/zoo/model/LeNet.java | 15 +- .../org/deeplearning4j/zoo/model/NASNet.java | 8 +- .../deeplearning4j/zoo/model/ResNet50.java | 13 +- .../deeplearning4j/zoo/model/SimpleCNN.java | 17 +- .../deeplearning4j/zoo/model/SqueezeNet.java | 12 +- .../zoo/model/TextGenerationLSTM.java | 15 +- .../deeplearning4j/zoo/model/TinyYOLO.java | 9 +- .../org/deeplearning4j/zoo/model/UNet.java | 9 +- .../org/deeplearning4j/zoo/model/VGG16.java | 8 +- .../org/deeplearning4j/zoo/model/VGG19.java | 10 +- .../deeplearning4j/zoo/model/Xception.java | 11 +- .../org/deeplearning4j/zoo/model/YOLO2.java | 9 +- .../deeplearning4j/zoo/TestInstantiation.java | 5 +- .../org/deeplearning4j/zoo/TestUtils.java | 2 +- 614 files changed, 12080 insertions(+), 11594 deletions(-) create mode 100644 cavis-dnn/cavis-dnn-core/src/main/java/net/brutex/ai/dnn/core/util/ANSI.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/Animal.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IActivationFunction.java delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetwork.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IUnit.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/LayerType.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/DenseLayerConfiguration.java delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FeedForwardLayerConfiguration.java rename cavis-dnn/cavis-dnn-nn/src/main/java/{net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java => org/deeplearning4j/nn/api/AbstractParamInitializer.java} (67%) delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Model.java delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java rename cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/{Layer.java => LayerConfiguration.java} (90%) rename cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/{MultiLayerConfigurationDeserializer.java => NeuralNetConfigurationDeserializer.java} (89%) create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/resources/simplelogger.properties diff --git a/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java b/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java index 5c21d354a..67ad09bd1 100644 --- a/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java +++ b/.old/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java @@ -205,7 +205,7 @@ public class TupleStreamDataSetIteratorTest extends SolrCloudTestCase { public void modelFitTest() throws Exception { final MultiLayerNetwork model = new MultiLayerNetwork( - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .list( new OutputLayer.Builder(LossFunction.MSE) .nIn(3) diff --git a/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java index 7c0505605..c2c260fdd 100644 --- a/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java +++ b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java @@ -35,7 +35,7 @@ import org.apache.solr.client.solrj.request.UpdateRequest; import org.apache.solr.cloud.SolrCloudTestCase; import org.apache.solr.common.params.ModifiableSolrParams; import org.deeplearning4j.nn.api.Model; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -153,7 +153,7 @@ public class ModelTupleStreamIntegrationTest extends SolrCloudTestCase { final int numInputs = 3; final int numOutputs = 2; - final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + final NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list( new OutputLayer.Builder() .nIn(numInputs) diff --git a/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java index a638fa14a..c6a05607b 100644 --- a/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java +++ b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java @@ -43,7 +43,7 @@ import org.apache.solr.core.SolrResourceLoader; import org.apache.solr.handler.SolrDefaultStreamFactory; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -242,7 +242,7 @@ public class ModelTupleStreamTest { protected Model buildMultiLayerNetworkModel(int numInputs, int numOutputs) throws Exception { - final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + final NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list( new OutputLayer.Builder() .nIn(numInputs) @@ -274,7 +274,7 @@ public class ModelTupleStreamTest { protected Model buildComputationGraphModel(int numInputs, int numOutputs) throws Exception { - final ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + final ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("inputLayer") .addLayer("outputLayer", diff --git a/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java index 7f77c6c0c..1986511bb 100644 --- a/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java +++ b/.old/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/ltr/model/ScoringModelTest.java @@ -42,7 +42,7 @@ import org.apache.solr.ltr.norm.Normalizer; import org.apache.solr.request.SolrQueryRequest; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -192,7 +192,7 @@ public class ScoringModelTest { protected Model buildMultiLayerNetworkModel(int numFeatures) throws Exception { - final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + final NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list( new OutputLayer.Builder().nIn(numFeatures).nOut(1).lossFunction(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).build() ) @@ -217,7 +217,7 @@ public class ScoringModelTest { protected Model buildComputationGraphModel(int numFeatures) throws Exception { - final ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + final ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("inputLayer") .addLayer("outputLayer", diff --git a/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java index 1de161c2e..dd75472c6 100644 --- a/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java +++ b/.old/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/JsonModelServerTest.java @@ -23,7 +23,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.layers.*; @@ -70,7 +70,7 @@ public class JsonModelServerTest extends BaseDL4JTest { private static final MultiLayerNetwork model; static { - val conf = new NeuralNetConfiguration.Builder() + val conf = NeuralNetConfiguration.builder() .seed(119) .updater(new Adam(0.119f)) .weightInit(WeightInit.XAVIER) @@ -541,7 +541,7 @@ public class JsonModelServerTest extends BaseDL4JTest { @Test public void testMlnMnist() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new DenseLayer.Builder().nIn(784).nOut(10).build()) .layer(new LossLayer.Builder().activation(Activation.SOFTMAX).build()) @@ -597,7 +597,7 @@ public class JsonModelServerTest extends BaseDL4JTest { @Test public void testCompGraph() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("input1", "input2") .addLayer("L1", new DenseLayer.Builder().nIn(3).nOut(4).build(), "input1") @@ -652,7 +652,7 @@ public class JsonModelServerTest extends BaseDL4JTest { @Test public void testCompGraph_1() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .updater(new Sgd(0.01)) .graphBuilder() .addInputs("input") diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java index f3f2cee80..23bc7566d 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java @@ -425,7 +425,7 @@ public class SharedTrainingWrapper { .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); ((ComputationGraph) originalModel).setGradientsAccumulator(accumulator); } else if (model instanceof MultiLayerNetwork) { - ((MultiLayerNetwork) originalModel).getLayerWiseConfigurations() + ((MultiLayerNetwork) originalModel).getConfiguration() .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); ((MultiLayerNetwork) originalModel).setGradientsAccumulator(accumulator); } diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java index 2a17ab3e1..d02cb4234 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java @@ -261,7 +261,7 @@ public class SharedTrainingMaster extends BaseTrainingMaster extends BaseVa /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param useLogProbability If true: use log probability. False: use raw probability. * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java index 4140b8a53..cfcc93b78 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java @@ -45,7 +45,7 @@ public abstract class BaseVaeScoreWithKeyFunction implements PairFlatMapFunct /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param batchSize Batch size to use when scoring */ public BaseVaeScoreWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java index 8550c6e3c..426682d69 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java @@ -27,7 +27,7 @@ import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.common.base.Preconditions; @@ -131,7 +131,7 @@ public class EvaluationRunner { cg.init(deviceLocalParams.get(), false); m = cg; } else { - MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(json.getValue()); + NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson(json.getValue()); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(deviceLocalParams.get(), false); m = net; diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java index d8aadc3f1..e13b5f9b6 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java @@ -33,7 +33,7 @@ public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWith /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param batchSize Batch size to use when scoring */ public CGVaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java index 57c568239..e9455092c 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java @@ -33,7 +33,7 @@ public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstruc /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param useLogProbability If true: use log probability. False: use raw probability. * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java index be7780f2f..054520c70 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java @@ -35,7 +35,7 @@ import org.datavec.spark.util.BroadcastHadoopConfigHolder; import org.deeplearning4j.core.loader.DataSetLoader; import org.deeplearning4j.core.loader.MultiDataSetLoader; import org.deeplearning4j.core.loader.impl.SerializedDataSetLoader; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.spark.api.TrainingMaster; @@ -80,7 +80,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { public static final int DEFAULT_ROC_THRESHOLD_STEPS = 32; public static final int DEFAULT_EVAL_WORKERS = 4; private transient JavaSparkContext sc; - private MultiLayerConfiguration conf; + private NeuralNetConfiguration conf; private MultiLayerNetwork network; private double lastScore; private int defaultEvaluationWorkers = DEFAULT_EVAL_WORKERS; @@ -104,7 +104,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @param sparkContext the spark context to use * @param conf the configuration of the network */ - public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration conf, + public SparkDl4jMultiLayer(SparkContext sparkContext, NeuralNetConfiguration conf, TrainingMaster trainingMaster) { this(new JavaSparkContext(sparkContext), initNetwork(conf), trainingMaster); } @@ -115,14 +115,14 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @param sc the spark context to use * @param conf the configuration of the network */ - public SparkDl4jMultiLayer(JavaSparkContext sc, MultiLayerConfiguration conf, TrainingMaster trainingMaster) { + public SparkDl4jMultiLayer(JavaSparkContext sc, NeuralNetConfiguration conf, TrainingMaster trainingMaster) { this(sc.sc(), conf, trainingMaster); } public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork network, TrainingMaster trainingMaster) { sc = javaSparkContext; - this.conf = network.getLayerWiseConfigurations().clone(); + this.conf = network.getConfiguration().clone(); this.network = network; if (!network.isInitCalled()) network.init(); @@ -132,7 +132,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { SparkUtils.checkKryoConfiguration(javaSparkContext, log); } - private static MultiLayerNetwork initNetwork(MultiLayerConfiguration conf) { + private static MultiLayerNetwork initNetwork(NeuralNetConfiguration conf) { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); return net; @@ -315,8 +315,8 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @return the multi layer network that was fitDataSet */ public MultiLayerNetwork fitLabeledPoint(JavaRDD rdd) { - int nLayers = network.getLayerWiseConfigurations().getConfs().size(); - FeedForwardLayer ffl = (FeedForwardLayer) network.getLayerWiseConfigurations().getConf(nLayers - 1).getLayer(); + int nLayers = network.getConfiguration().getConfs().size(); + FeedForwardLayer ffl = (FeedForwardLayer) network.getConfiguration().getConf(nLayers - 1).getLayer(); JavaRDD ds = MLLibUtil.fromLabeledPoint(sc, rdd, ffl.getNOut()); return fit(ds); } diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java index 510f2e4d4..c064c81d0 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSetUtil; @@ -49,7 +49,7 @@ public class FeedForwardWithKeyFunction /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param batchSize Batch size to use for forward pass (use > 1 for efficiency) */ public FeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { @@ -65,7 +65,7 @@ public class FeedForwardWithKeyFunction return Collections.emptyIterator(); } - MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); + MultiLayerNetwork network = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(jsonConfig.getValue())); network.init(); INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java index 6c3878da5..a8990125d 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java @@ -23,7 +23,7 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; import org.apache.spark.api.java.function.DoubleFlatMapFunction; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -60,7 +60,7 @@ public class ScoreExamplesFunction implements DoubleFlatMapFunction implements PairFlatMapFunction implements PairFlatMapFunction, DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate - MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json)); + MultiLayerNetwork network = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(json)); network.init(); INDArray val = params.value().unsafeDuplication(); //.value() object will be shared by all executors on each machine -> OK, as params are not modified by score function if (val.length() != network.numParams(false)) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java index 3f7c5ba6c..95c0c721e 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction; @@ -36,7 +36,7 @@ public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKe /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param batchSize Batch size to use when scoring */ public VaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, @@ -47,7 +47,7 @@ public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKe @Override public VariationalAutoencoder getVaeLayer() { MultiLayerNetwork network = - new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue())); + new MultiLayerNetwork(NeuralNetConfiguration.fromJson((String) jsonConfig.getValue())); network.init(); INDArray val = ((INDArray) params.value()).unsafeDuplication(); if (val.length() != network.numParams(false)) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java index d9dd8a155..18890d020 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; @@ -34,7 +34,7 @@ public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructi /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param useLogProbability If true: use log probability. False: use raw probability. * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} @@ -47,7 +47,7 @@ public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructi @Override public VariationalAutoencoder getVaeLayer() { MultiLayerNetwork network = - new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String) jsonConfig.getValue())); + new MultiLayerNetwork(NeuralNetConfiguration.fromJson((String) jsonConfig.getValue())); network.init(); INDArray val = ((INDArray) params.value()).unsafeDuplication(); if (val.length() != network.numParams(false)) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java index 8d8532e0b..411422884 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java @@ -41,7 +41,7 @@ import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.StatsStorageRouterProvider; import org.deeplearning4j.core.storage.StorageMetaData; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.TrainingListener; @@ -274,7 +274,7 @@ public class ParameterAveragingTrainingMaster @Override public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) { - NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getLayerWiseConfigurations(), + NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getConfiguration(), network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray()); if (collectTrainingStats) @@ -726,7 +726,7 @@ public class ParameterAveragingTrainingMaster if (params != null) { //Params may be null for edge case (empty RDD) if (network != null) { - MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations(); + NeuralNetConfiguration conf = network.getNetwork().getConfiguration(); int numUpdates = averagingFrequency; conf.setIterationCount(conf.getIterationCount() + numUpdates); } else { diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index e00f8d6d3..686560ffc 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; @@ -129,8 +129,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable return 4; } - protected MultiLayerConfiguration getBasicConf() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + protected NeuralNetConfiguration getBasicConf() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .updater(new Nesterovs(0.1, 0.9)).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) .activation(Activation.TANH).build()) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java index ed8de3623..7154808f6 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java @@ -35,7 +35,7 @@ import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationC import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition; import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -68,7 +68,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) @@ -123,7 +123,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(10.0)) //Intentionally huge LR .weightInit(WeightInit.XAVIER).list() @@ -163,7 +163,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) @@ -209,7 +209,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) @@ -246,7 +246,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java index 3de17a742..76fa0e65b 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java @@ -71,7 +71,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { //Spark tests don't run on windows return; } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3) @@ -124,7 +124,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(2.0)) //Intentionally huge LR .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") @@ -165,7 +165,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).graphBuilder() .addInputs("in") @@ -213,7 +213,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).graphBuilder() .addInputs("in") @@ -253,7 +253,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { //Spark tests don't run on windows return; } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java index 33023d605..47f1807d0 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/TestKryo.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark; import org.apache.spark.serializer.SerializerInstance; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.graph.*; @@ -68,14 +68,14 @@ public class TestKryo extends BaseSparkKryoTest { Map m = new HashMap<>(); m.put(0, 0.5); m.put(10, 0.1); - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder() .updater(new Nadam(new MapSchedule(ScheduleType.ITERATION,m))).list().layer(0, new OutputLayer.Builder().nIn(10).nOut(10).build()) .build(); testSerialization(mlc, si); - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder() .dist(new UniformDistribution(-1, 1)) .updater(new Adam(new MapSchedule(ScheduleType.ITERATION,m))) .graphBuilder() diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java index 714c3ffb6..946f8816f 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java @@ -30,7 +30,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.spark.BaseSparkTest; @@ -84,7 +84,7 @@ public class TestPreProcessedData extends BaseSparkTest { iter.next().save(f2); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(Updater.RMSPROP) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3) .activation(Activation.TANH).build()) @@ -134,7 +134,7 @@ public class TestPreProcessedData extends BaseSparkTest { iter.next().save(f2); } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(Updater.RMSPROP) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3) @@ -188,7 +188,7 @@ public class TestPreProcessedData extends BaseSparkTest { mds.save(f2); } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(Updater.RMSPROP) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java index ec2195081..6aa102fb4 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java @@ -23,7 +23,7 @@ package org.deeplearning4j.spark.impl; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.spark.api.TrainingMaster; @@ -40,7 +40,7 @@ public class TestKryoWarning { try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new OutputLayer.Builder().nIn(10).nOut(10).build()) .build(); @@ -57,7 +57,7 @@ public class TestKryoWarning { try { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("0", new OutputLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("0") .build(); diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java index b3c96333d..1b7bf1052 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.impl.customlayer; import com.sun.jna.Platform; import org.apache.spark.api.java.JavaRDD; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -51,8 +51,8 @@ public class TestCustomLayer extends BaseSparkTest { } //Basic test - checks whether exceptions etc are thrown with custom layers + spark //Custom layers are tested more extensively in dl4j core - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new Sgd(0.1)).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new CustomLayer(3.14159)).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java index cc6e5f9ec..7a28146fb 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java @@ -77,7 +77,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { public static ComputationGraph getBasicNetIris2Class() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .graphBuilder().addInputs("in") .addLayer("l0", new DenseLayer.Builder().nIn(4).nOut(10).build(), "in") .addLayer("l1", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) @@ -104,7 +104,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { while (iter.hasNext()) list.add(iter.next()); - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration config = NeuralNetConfiguration.builder() .updater(new Sgd(0.1)) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", @@ -138,7 +138,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { @Test public void testDistributedScoring() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.1) .seed(123).updater(new Nesterovs(0.1, 0.9)).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) @@ -217,7 +217,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { //@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue") public void testSeedRepeatability() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.RMSPROP) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(Updater.RMSPROP) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(4) @@ -414,7 +414,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { JavaRDD rdd = sc.parallelize(l); // simple model - val modelConf = new NeuralNetConfiguration.Builder() + val modelConf = NeuralNetConfiguration.builder() .updater(new Adam(0.01)) .weightInit(WeightInit.XAVIER_UNIFORM) .biasInit(0) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java index 887696af3..f0d15745d 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java @@ -53,7 +53,7 @@ public class TestFrozenLayers extends BaseSparkTest { @Test public void testSparkFrozenLayers() { - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.TANH); FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); @@ -136,7 +136,7 @@ public class TestFrozenLayers extends BaseSparkTest { int nIn = 6; int nOut = 3; - ComputationGraph origModel = new ComputationGraph(new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + ComputationGraph origModel = new ComputationGraph(NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.TANH).graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(6).nOut(5).build(), "in") .addLayer("1", new DenseLayer.Builder().nIn(5).nOut(4).build(), "0") diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java index 550ccc9b2..adc3d5508 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java @@ -23,7 +23,7 @@ package org.deeplearning4j.spark.impl.multilayer; import org.apache.spark.api.java.JavaPairRDD; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; @@ -57,7 +57,7 @@ public class TestMiscFunctions extends BaseSparkTest { @Test public void testFeedForwardWithKey() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(3).nOut(3) .activation(Activation.SOFTMAX).build()) @@ -107,7 +107,7 @@ public class TestMiscFunctions extends BaseSparkTest { @Test public void testFeedForwardWithKeyInputMask() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .list() .layer( new LSTM.Builder().nIn(4).nOut(3).build()) .layer(new GlobalPoolingLayer(PoolingType.AVG)) @@ -162,7 +162,7 @@ public class TestMiscFunctions extends BaseSparkTest { @Test public void testFeedForwardWithKeyGraph() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .graphBuilder().addInputs("in1", "in2") .addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "in1") .addLayer("1", new DenseLayer.Builder().nIn(4).nOut(3).build(), "in2").addLayer("2", @@ -220,7 +220,7 @@ public class TestMiscFunctions extends BaseSparkTest { int nIn = 10; - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list() .layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .reconstructionDistribution( new GaussianReconstructionDistribution(Activation.IDENTITY)) @@ -259,7 +259,7 @@ public class TestMiscFunctions extends BaseSparkTest { int nIn = 10; - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .reconstructionDistribution(new LossFunctionWrapper( diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java index c64618557..e66e8bb9d 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java @@ -25,7 +25,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.spark.api.java.JavaRDD; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -102,7 +102,7 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest { //---------------------------------- //Create network configuration and conduct network training - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java index cbe7247bd..e5faa2884 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -63,9 +63,9 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { } - private static MultiLayerConfiguration getConf(int seed, IUpdater updater) { + private static NeuralNetConfiguration getConf(int seed, IUpdater updater) { Nd4j.getRandom().setSeed(seed); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder() @@ -74,9 +74,9 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { return conf; } - private static MultiLayerConfiguration getConfCNN(int seed, IUpdater updater) { + private static NeuralNetConfiguration getConfCNN(int seed, IUpdater updater) { Nd4j.getRandom().setSeed(seed); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list() .layer(0, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0) @@ -85,13 +85,13 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { .activation(Activation.TANH).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10) .build()) - .setInputType(InputType.convolutional(10, 10, 3)).build(); + .inputType(InputType.convolutional(10, 10, 3)).build(); return conf; } private static ComputationGraphConfiguration getGraphConf(int seed, IUpdater updater) { Nd4j.getRandom().setSeed(seed); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder() .addInputs("in") @@ -105,7 +105,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { private static ComputationGraphConfiguration getGraphConfCNN(int seed, IUpdater updater) { Nd4j.getRandom().setSeed(seed); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder() .addInputs("in") diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index bc1ced484..8907c2165 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -37,7 +37,7 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; @@ -127,7 +127,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { .toJavaRDD().map(new TestFn()); DataSet d = new IrisDataSetIterator(150, 150).next(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER) .activation(Activation.RELU).build()) @@ -162,8 +162,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { .getAbsolutePath()) .toJavaRDD().map(new TestFn()); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(123) .updater(new Adam(1e-6)) .weightInit(WeightInit.XAVIER) .list() @@ -275,7 +275,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) .activation(Activation.TANH).build()) @@ -300,7 +300,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { @Test public void testDistributedScoring() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.1) .seed(123).updater(new Nesterovs(0.1, 0.9)).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) .activation(Activation.TANH).build()) @@ -389,7 +389,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { list.add(iter.next()); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) .activation(Activation.TANH).build()) @@ -453,7 +453,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) .activation(Activation.TANH).build()) @@ -523,7 +523,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) .activation(Activation.TANH).build()) @@ -611,7 +611,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) @@ -684,7 +684,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(4) @@ -769,7 +769,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { list.add(iter.next()); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) .activation(Activation.TANH).build()) @@ -791,13 +791,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { JavaRDD rdd = sc.parallelize(list); - assertEquals(0, sparkNet.getNetwork().getLayerWiseConfigurations().getIterationCount()); + assertEquals(0, sparkNet.getNetwork().getConfiguration().getIterationCount()); sparkNet.fit(rdd); assertEquals(minibatchesPerWorkerPerEpoch, - sparkNet.getNetwork().getLayerWiseConfigurations().getIterationCount()); + sparkNet.getNetwork().getConfiguration().getIterationCount()); sparkNet.fit(rdd); assertEquals(2 * minibatchesPerWorkerPerEpoch, - sparkNet.getNetwork().getLayerWiseConfigurations().getIterationCount()); + sparkNet.getNetwork().getConfiguration().getIterationCount()); sparkNet.getTrainingMaster().deleteTempFiles(sc); } @@ -819,7 +819,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { list.add(iter.next()); } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) @@ -860,7 +860,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int nIn = 8; Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new RmsProp()) .weightInit(WeightInit.XAVIER).list() .layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(10).encoderLayerSizes(12) .decoderLayerSizes(13).reconstructionDistribution( @@ -896,7 +896,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int nIn = 8; Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp()) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new RmsProp()) .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new VariationalAutoencoder.Builder().nIn(8).nOut(10).encoderLayerSizes(12) .decoderLayerSizes(13).reconstructionDistribution( @@ -936,8 +936,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int nOut = 2; int layerSize = 10; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).build()) .layer(1, new OutputLayer.Builder().nIn(layerSize).nOut(nOut) .activation(Activation.SOFTMAX).lossFunction( @@ -991,8 +991,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int nOut = 3; int layerSize = 10; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).build()) .layer(1, new OutputLayer.Builder().nIn(layerSize).nOut(nOut) .activation(Activation.SOFTMAX).lossFunction( @@ -1045,12 +1045,12 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new OutputLayer.Builder().nIn(4).nOut(3).build()) .build(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .addLayer("out", new OutputLayer.Builder().nIn(4).nOut(3).build(), "in") @@ -1081,11 +1081,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { for(int i=0; i<3; i++ ){ - assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount()); + assertEquals(i, sn1.getNetwork().getConfiguration().getEpochCount()); assertEquals(i, sn2.getNetwork().getConfiguration().getEpochCount()); sn1.fit(rdd); sn2.fit(rdd); - assertEquals(i+1, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount()); + assertEquals(i+1, sn1.getNetwork().getConfiguration().getEpochCount()); assertEquals(i+1, sn2.getNetwork().getConfiguration().getEpochCount()); } } diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java index f4939e369..fc446048f 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java @@ -26,7 +26,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -67,7 +67,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest { try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder().nIn(10).nOut(10).build()) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java index 6f79d7595..6d8a9e9bd 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java @@ -27,7 +27,7 @@ import org.deeplearning4j.core.storage.Persistable; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -60,7 +60,7 @@ public class TestListeners extends BaseSparkTest { JavaSparkContext sc = getContext(); int nExecutors = numExecutors(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER) .activation(Activation.RELU).build()) diff --git a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java index 58389d74e..37bd5c2a9 100644 --- a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java +++ b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandler.java @@ -20,7 +20,7 @@ package org.deeplearning4j.rl4j.network; import lombok.Getter; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -41,7 +41,7 @@ public class MultiLayerNetworkHandler implements INetworkHandler { @Getter private final boolean recurrent; - private final MultiLayerConfiguration configuration; + private final NeuralNetConfiguration configuration; private final String labelName; private final String gradientName; private final int inputFeatureIdx; @@ -59,7 +59,7 @@ public class MultiLayerNetworkHandler implements INetworkHandler { int inputFeatureIdx) { this.model = model; recurrent = model.getOutputLayer() instanceof RnnOutputLayer; - configuration = model.getLayerWiseConfigurations(); + configuration = model.getConfiguration(); this.labelName = labelName; this.gradientName = gradientName; this.inputFeatureIdx = inputFeatureIdx; diff --git a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java index cda26645f..ed8ceacda 100644 --- a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java +++ b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdConv.java @@ -59,7 +59,7 @@ public class ActorCriticFactoryCompGraphStdConv implements ActorCriticFactoryCom int w = (((shapeInputs[2] - 8) / 4 + 1) - 4) / 2 + 1; ComputationGraphConfiguration.GraphBuilder confB = - new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) .weightInit(WeightInit.XAVIER) diff --git a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java index 65e409b83..f05d43f3b 100644 --- a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java +++ b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactoryCompGraphStdDense.java @@ -49,7 +49,7 @@ public class ActorCriticFactoryCompGraphStdDense implements ActorCriticFactoryCo nIn *= i; } ComputationGraphConfiguration.GraphBuilder confB = - new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) .weightInit(WeightInit.XAVIER) diff --git a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java index 8f8b739d8..80cb6384b 100644 --- a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java +++ b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticFactorySeparateStdDense.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Value; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -56,7 +56,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep for (int i : numInputs) { nIn *= i; } - NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + NeuralNetConfiguration.ListBuilder confB = NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) .weightInit(WeightInit.XAVIER) @@ -81,7 +81,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep } confB.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn)); - MultiLayerConfiguration mlnconf2 = confB.build(); + NeuralNetConfiguration mlnconf2 = confB.build(); MultiLayerNetwork model = new MultiLayerNetwork(mlnconf2); model.init(); if (conf.getListeners() != null) { @@ -90,7 +90,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep model.setListeners(new ScoreIterationListener(Constants.NEURAL_NET_ITERATION_LISTENER)); } - NeuralNetConfiguration.ListBuilder confB2 = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + NeuralNetConfiguration.ListBuilder confB2 = NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) .weightInit(WeightInit.XAVIER) @@ -116,7 +116,7 @@ public class ActorCriticFactorySeparateStdDense implements ActorCriticFactorySep } confB2.setInputType(conf.isUseLSTM() ? InputType.recurrent(nIn) : InputType.feedForward(nIn)); - MultiLayerConfiguration mlnconf = confB2.build(); + NeuralNetConfiguration mlnconf = confB2.build(); MultiLayerNetwork model2 = new MultiLayerNetwork(mlnconf); model2.init(); if (conf.getListeners() != null) { diff --git a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java index 9daeb1af8..8ae8f1944 100644 --- a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java +++ b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/ac/ActorCriticSeparate.java @@ -23,7 +23,7 @@ package org.deeplearning4j.rl4j.network.ac; import lombok.Getter; import org.apache.commons.lang3.NotImplementedException; import org.deeplearning4j.nn.api.NeuralNetwork; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -131,7 +131,7 @@ public class ActorCriticSeparate implements IAct @Override public void applyGradients(Gradients gradients) { int batchSize = (int)gradients.getBatchSize(); - MultiLayerConfiguration valueConf = valueNet.getLayerWiseConfigurations(); + NeuralNetConfiguration valueConf = valueNet.getConfiguration(); int valueIterationCount = valueConf.getIterationCount(); int valueEpochCount = valueConf.getEpochCount(); Gradient valueGradient = gradients.getGradient(CommonGradientNames.ActorCritic.Value); @@ -145,7 +145,7 @@ public class ActorCriticSeparate implements IAct } valueConf.setIterationCount(valueIterationCount + 1); - MultiLayerConfiguration policyConf = policyNet.getLayerWiseConfigurations(); + NeuralNetConfiguration policyConf = policyNet.getConfiguration(); int policyIterationCount = policyConf.getIterationCount(); int policyEpochCount = policyConf.getEpochCount(); Gradient policyGradient = gradients.getGradient(CommonGradientNames.ActorCritic.Policy); @@ -191,7 +191,7 @@ public class ActorCriticSeparate implements IAct @Deprecated public void applyGradient(Gradient[] gradient, int batchSize) { - MultiLayerConfiguration valueConf = valueNet.getLayerWiseConfigurations(); + NeuralNetConfiguration valueConf = valueNet.getConfiguration(); int valueIterationCount = valueConf.getIterationCount(); int valueEpochCount = valueConf.getEpochCount(); valueNet.getUpdater().update(valueNet, gradient[0], valueIterationCount, valueEpochCount, batchSize, LayerWorkspaceMgr.noWorkspaces()); @@ -204,7 +204,7 @@ public class ActorCriticSeparate implements IAct } valueConf.setIterationCount(valueIterationCount + 1); - MultiLayerConfiguration policyConf = policyNet.getLayerWiseConfigurations(); + NeuralNetConfiguration policyConf = policyNet.getConfiguration(); int policyIterationCount = policyConf.getIterationCount(); int policyEpochCount = policyConf.getEpochCount(); policyNet.getUpdater().update(policyNet, gradient[1], policyIterationCount, policyEpochCount, batchSize, LayerWorkspaceMgr.noWorkspaces()); diff --git a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java index 8338884a2..c292432b2 100644 --- a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java +++ b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java @@ -22,7 +22,7 @@ package org.deeplearning4j.rl4j.network.dqn; import org.apache.commons.lang3.NotImplementedException; import org.deeplearning4j.nn.api.NeuralNetwork; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -157,7 +157,7 @@ public class DQN implements IDQN { public void applyGradients(Gradients gradients) { Gradient qValues = gradients.getGradient(CommonGradientNames.QValues); - MultiLayerConfiguration mlnConf = mln.getLayerWiseConfigurations(); + NeuralNetConfiguration mlnConf = mln.getConfiguration(); int iterationCount = mlnConf.getIterationCount(); int epochCount = mlnConf.getEpochCount(); mln.getUpdater().update(mln, qValues, iterationCount, epochCount, (int)gradients.getBatchSize(), LayerWorkspaceMgr.noWorkspaces()); @@ -172,7 +172,7 @@ public class DQN implements IDQN { } public void applyGradient(Gradient[] gradient, int batchSize) { - MultiLayerConfiguration mlnConf = mln.getLayerWiseConfigurations(); + NeuralNetConfiguration mlnConf = mln.getConfiguration(); int iterationCount = mlnConf.getIterationCount(); int epochCount = mlnConf.getEpochCount(); mln.getUpdater().update(mln, gradient[0], iterationCount, epochCount, batchSize, LayerWorkspaceMgr.noWorkspaces()); diff --git a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java index cf683aa35..bb64200bd 100644 --- a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java +++ b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdConv.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Value; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -55,7 +55,7 @@ public class DQNFactoryStdConv implements DQNFactory { throw new AssertionError("Impossible to apply convolutional layer on a shape == 1"); - NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + NeuralNetConfiguration.ListBuilder confB = NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .l2(conf.getL2()) .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) @@ -71,8 +71,8 @@ public class DQNFactoryStdConv implements DQNFactory { confB.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nOut(numOutputs) .build()); - confB.setInputType(InputType.convolutional(shapeInputs[1], shapeInputs[2], shapeInputs[0])); - MultiLayerConfiguration mlnconf = confB.build(); + confB.inputType(InputType.convolutional(shapeInputs[1], shapeInputs[2], shapeInputs[0])); + NeuralNetConfiguration mlnconf = confB.build(); MultiLayerNetwork model = new MultiLayerNetwork(mlnconf); model.init(); if (conf.getListeners() != null) { diff --git a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java index d35a5f064..15b33170a 100644 --- a/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java +++ b/.old/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQNFactoryStdDense.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Value; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -54,7 +54,7 @@ public class DQNFactoryStdDense implements DQNFactory { nIn *= i; } - NeuralNetConfiguration.ListBuilder confB = new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + NeuralNetConfiguration.ListBuilder confB = NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(conf.getUpdater() != null ? conf.getUpdater() : new Adam()) .weightInit(WeightInit.XAVIER) @@ -82,7 +82,7 @@ public class DQNFactoryStdDense implements DQNFactory { ); - MultiLayerConfiguration mlnconf = confB.build(); + NeuralNetConfiguration mlnconf = confB.build(); MultiLayerNetwork model = new MultiLayerNetwork(mlnconf); model.init(); if (conf.getListeners() != null) { diff --git a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java index 5cd403cee..dc23edd6e 100644 --- a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java +++ b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/NStepRnn.java @@ -134,7 +134,7 @@ public class NStepRnn { } private static ComputationGraphConfiguration.GraphBuilder buildBaseNetworkConfiguration(int lstmLayerSize, int dl1Size, int dl2Size) { - return new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + return NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Adam()) .weightInit(WeightInit.XAVIER) diff --git a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java index 4f95632a0..adbd6a3c5 100644 --- a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java +++ b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/RobotLakeExample.java @@ -188,7 +188,7 @@ public class RobotLakeExample { } private static ComputationGraphConfiguration.GraphBuilder buildBaseNetworkConfiguration() { - return new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + return NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Adam()) .weightInit(WeightInit.XAVIER) diff --git a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java index 617e436df..64c971e00 100644 --- a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java +++ b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/TMazeExample.java @@ -174,7 +174,7 @@ public class TMazeExample { } private static ComputationGraphConfiguration.GraphBuilder buildBaseNetworkConfiguration() { - return new NeuralNetConfiguration.Builder().seed(Constants.NEURAL_NET_SEED) + return NeuralNetConfiguration.builder().seed(Constants.NEURAL_NET_SEED) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Adam()) .weightInit(WeightInit.XAVIER) diff --git a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java index 0f5b51407..69d305b31 100644 --- a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java +++ b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/network/MultiLayerNetworkHandlerTest.java @@ -21,7 +21,7 @@ package org.deeplearning4j.rl4j.network; import org.deeplearning4j.nn.api.Updater; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -52,7 +52,7 @@ public class MultiLayerNetworkHandlerTest { private MultiLayerNetwork modelMock; private TrainingListener trainingListenerMock; - private MultiLayerConfiguration configurationMock; + private NeuralNetConfiguration configurationMock; private MultiLayerNetworkHandler sut; @@ -60,10 +60,10 @@ public class MultiLayerNetworkHandlerTest { modelMock = mock(MultiLayerNetwork.class); trainingListenerMock = mock(TrainingListener.class); - configurationMock = mock(MultiLayerConfiguration.class); + configurationMock = mock(NeuralNetConfiguration.class); when(configurationMock.getIterationCount()).thenReturn(123); when(configurationMock.getEpochCount()).thenReturn(234); - when(modelMock.getLayerWiseConfigurations()).thenReturn(configurationMock); + when(modelMock.getConfiguration()).thenReturn(configurationMock); if(setupRecurrent) { when(modelMock.getOutputLayer()).thenReturn(new RnnOutputLayer(null, null)); diff --git a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java index f74713466..f0ff3f641 100644 --- a/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java +++ b/.old/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/policy/PolicyTest.java @@ -166,9 +166,9 @@ public class PolicyTest { @Test public void testACPolicy() throws Exception { - ComputationGraph cg = new ComputationGraph(new NeuralNetConfiguration.Builder().seed(444).graphBuilder().addInputs("input") + ComputationGraph cg = new ComputationGraph(NeuralNetConfiguration.builder().seed(444).graphBuilder().addInputs("input") .addLayer("output", new OutputLayer.Builder().nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build(), "input").setOutputs("output").build()); - MultiLayerNetwork mln = new MultiLayerNetwork(new NeuralNetConfiguration.Builder().seed(555).list() + MultiLayerNetwork mln = new MultiLayerNetwork(NeuralNetConfiguration.builder().seed(555).list() .layer(0, new OutputLayer.Builder().nOut(1).lossFunction(LossFunctions.LossFunction.XENT).activation(Activation.SIGMOID).build()).build()); ACPolicy policy = new ACPolicy(new DummyAC(mln), true, Nd4j.getRandom()); diff --git a/README.md b/README.md index e3eb6ba84..d1e64a639 100644 --- a/README.md +++ b/README.md @@ -48,12 +48,12 @@ Deeplearning4J offers a very high level API for defining even complex neural net you how LeNet, a convolutional neural network, is defined in DL4J. ```java -MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() +NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Adam(1e-3)) - .list() + .layer(new ConvolutionLayer.Builder(5, 5) .stride(1,1) .nOut(20) @@ -78,7 +78,7 @@ MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .nOut(outputNum) .activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutionalFlat(28,28,1)) + .inputType(InputType.convolutionalFlat(28,28,1)) .build(); ``` diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index fca68610a..c03d9f5c2 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -46,13 +46,12 @@ import org.datavec.image.transform.ResizeImageTransform; import org.datavec.image.transform.ShowImageTransform; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.DropoutLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -61,7 +60,6 @@ import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -92,8 +90,8 @@ public class App { private static JPanel panel; private static JPanel panel2; - private static Layer[] genLayers() { - return new Layer[] { + private static LayerConfiguration[] genLayers() { + return new LayerConfiguration[] { new DenseLayer.Builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), @@ -103,33 +101,33 @@ public class App { new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) .build() }; - - - } + } /** * Returns a network config that takes in a 10x10 random number and produces a 28x28 grayscale image. * * @return config */ - private static MultiLayerConfiguration generator() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + private static NeuralNetConfiguration generator() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(42) .updater(UPDATER) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold(GRADIENT_THRESHOLD) - .weightInit(WeightInit.XAVIER) + //.weightInit(WeightInit.XAVIER) + .weightInitFn(new WeightInitXavier()) .activation(Activation.IDENTITY) - .list(genLayers()) - .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) + .layersFromArray(genLayers()) + .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS)) .build(); + ((NeuralNetConfiguration) conf).init(); return conf; } - private static Layer[] disLayers() { - return new Layer[]{ + private static LayerConfiguration[] disLayers() { + return new LayerConfiguration[]{ new DenseLayer.Builder().nOut(X_DIM*Y_DIM*CHANNELS*2).build(), //input is set by setInputType on the network new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), new DropoutLayer.Builder(1 - 0.5).build(), @@ -146,45 +144,50 @@ public class App { }; } - private static MultiLayerConfiguration discriminator() { + private static NeuralNetConfiguration discriminator() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(42) .updater(UPDATER) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold(GRADIENT_THRESHOLD) .weightInit(WeightInit.XAVIER) + //.weightInitFn(new WeightInitXavier()) + //.activationFn(new ActivationIdentity()) .activation(Activation.IDENTITY) - .list(disLayers()) - .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) + .layersFromArray(disLayers()) + .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .build(); + ((NeuralNetConfiguration) conf).init(); return conf; } - private static MultiLayerConfiguration gan() { - Layer[] genLayers = genLayers(); - Layer[] disLayers = Arrays.stream(disLayers()) + private static NeuralNetConfiguration gan() { + LayerConfiguration[] genLayers = genLayers(); + LayerConfiguration[] disLayers = Arrays.stream(disLayers()) .map((layer) -> { - if (layer instanceof DenseLayer || layer instanceof OutputLayer) { - return new FrozenLayerWithBackprop(layer); + if (layer instanceof DenseLayer || layer instanceof OutputLayer) { + return new FrozenLayerWithBackprop(layer); } else { return layer; } - }).toArray(Layer[]::new); - Layer[] layers = ArrayUtils.addAll(genLayers, disLayers); + }).toArray(LayerConfiguration[]::new); + LayerConfiguration[] layers = ArrayUtils.addAll(genLayers, disLayers); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(42) .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold( 100 ) - .weightInit( new WeightInitXavier() ) - .activation( new ActivationIdentity()) - .list( layers ) - .setInputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) + //.weightInitFn( new WeightInitXavier() ) //this is internal + .weightInit( WeightInit.XAVIER) + //.activationFn( new ActivationIdentity()) //this is internal + .activation( Activation.IDENTITY ) + .layersFromArray( layers ) + .inputType( InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) .build(); - +((NeuralNetConfiguration) conf).init(); return conf; } @@ -195,6 +198,8 @@ public class App { } public static void main(String... args) throws Exception { + + log.info("\u001B[32m Some \u001B[1m green \u001B[22m text \u001B[0m \u001B[7m Inverted\u001B[0m "); Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); // MnistDataSetIterator trainData = new MnistDataSetIterator(128, true, 45); @@ -220,9 +225,10 @@ public class App { MultiLayerNetwork gen = new MultiLayerNetwork(generator()); MultiLayerNetwork dis = new MultiLayerNetwork(discriminator()); MultiLayerNetwork gan = new MultiLayerNetwork(gan()); - gen.init(); - dis.init(); - gan.init(); + gen.init(); log.debug("Generator network: {}", gen); + dis.init(); log.debug("Discriminator network: {}", dis); + gan.init(); log.debug("Complete GAN network: {}", gan); + copyParams(gen, dis, gan); diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java index 25473fc9e..659c6ab32 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java @@ -25,6 +25,7 @@ import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.nd4j.evaluation.classification.Evaluation; @@ -199,13 +200,13 @@ public class GAN { Layer[] disLayers = ganDiscriminator.getLayers(); Layer[] layers = ArrayUtils.addAll(genLayers, disLayers); - MultiLayerConfiguration genConf = generator.getLayerWiseConfigurations(); - MultiLayerConfiguration disConf = ganDiscriminator.getLayerWiseConfigurations(); - org.deeplearning4j.nn.conf.layers.Layer[] confLayers = new org.deeplearning4j.nn.conf.layers.Layer[layers.length]; + NeuralNetConfiguration genConf = generator.getConfiguration(); + NeuralNetConfiguration disConf = ganDiscriminator.getConfiguration(); + LayerConfiguration[] confLayers = new LayerConfiguration[layers.length]; Map preProcessors = new HashMap<>(); for (int i = 0; i < layers.length; i++) { - confLayers[i] = layers[i].conf().getLayer(); + confLayers[i] = layers[i].getLayerConfiguration(); if (i < numGenLayers) { preProcessors.put(i, genConf.getInputPreProcess(i)); } else { @@ -213,7 +214,7 @@ public class GAN { } } - MultiLayerConfiguration ganConf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration ganConf = NeuralNetConfiguration.builder() .seed(seed) .updater(updater) .biasUpdater(biasUpdater) @@ -224,7 +225,7 @@ public class GAN { .trainingWorkspaceMode(trainingWorkSpaceMode) .inferenceWorkspaceMode(inferenceWorkspaceMode) .cacheMode(cacheMode) - .list(confLayers) + .layersFromArray(confLayers) .inputPreProcessors(preProcessors) .build(); gan = new MultiLayerNetwork(ganConf); @@ -267,7 +268,7 @@ public class GAN { } /** - * GAN builder, used as a starting point for creating a MultiLayerConfiguration or + * GAN builder, used as a starting point for creating a NeuralNetConfiguration or * ComputationGraphConfiguration.
*/ public static class Builder implements Cloneable { diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java index d0e5bb73d..07e6a148a 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java @@ -100,7 +100,7 @@ public class MnistDCGANExample { public static void main(String[] args) throws Exception { Supplier genSupplier = () -> { - return new MultiLayerNetwork(new NeuralNetConfiguration.Builder().list() + return new MultiLayerNetwork(NeuralNetConfiguration.builder() .layer(0, new DenseLayer.Builder().nIn(latentDim).nOut(width / 2 * height / 2 * 128) .activation(Activation.LEAKYRELU).weightInit(WeightInit.NORMAL).build()) .layer(1, new Convolution2D.Builder().nIn(128).nOut(128).kernelSize(5, 5) @@ -119,16 +119,16 @@ public class MnistDCGANExample { .inputPreProcessor(1, new FeedForwardToCnnPreProcessor(height / 2, width / 2, 128)) .inputPreProcessor(6, new CnnToFeedForwardPreProcessor(height, width, channels)) - .setInputType(InputType.feedForward(latentDim)) + .inputType(InputType.feedForward(latentDim)) .build()); }; GAN.DiscriminatorProvider discriminatorProvider = (updater) -> { - return new MultiLayerNetwork(new NeuralNetConfiguration.Builder() + return new MultiLayerNetwork(NeuralNetConfiguration.builder() .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build()) //.gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) //.gradientNormalizationThreshold(100.0) - .list() + .layer(0, new Convolution2D.Builder().nIn(channels).nOut(64).kernelSize(3, 3) .activation(Activation.LEAKYRELU).build()) .layer(1, new Convolution2D.Builder().nIn(64).nOut(64).kernelSize(3, 3).stride(2, 2) @@ -142,7 +142,7 @@ public class MnistDCGANExample { .layer(6, new LossLayer.Builder().lossFunction(LossFunctions.LossFunction.XENT).build()) .inputPreProcessor(0, new FeedForwardToCnnPreProcessor(height, width, channels)) .inputPreProcessor(4, new CnnToFeedForwardPreProcessor(2, 2, 64)) - .setInputType(InputType.convolutionalFlat(height, width, channels)) + .inputType(InputType.convolutionalFlat(height, width, channels)) .build()); }; diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java index 037a0be9d..be3014f3c 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistSimpleGAN.java @@ -23,7 +23,6 @@ package net.brutex.gan; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -57,12 +56,12 @@ public class MnistSimpleGAN { public static MultiLayerNetwork getGenerator() { - MultiLayerConfiguration genConf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration genConf = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold(100) - .list() + .layer(new DenseLayer.Builder().nIn(100).nOut(256).weightInit(WeightInit.NORMAL).build()) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) .layer(new DenseLayer.Builder().nIn(256).nOut(512).build()) @@ -76,14 +75,14 @@ public class MnistSimpleGAN { public static MultiLayerNetwork getDiscriminator(IUpdater updater) { - MultiLayerConfiguration discConf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration discConf = NeuralNetConfiguration.builder() .seed(42) .updater(updater) .weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold(100) - .list() + .layer(new DenseLayer.Builder().nIn(784).nOut(1024).updater(updater).build()) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) .layer(new DropoutLayer.Builder(1 - 0.5).build()) diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java index bc0aafa13..75965d7b5 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest.java @@ -35,7 +35,6 @@ import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StringType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; @@ -43,12 +42,10 @@ import org.datavec.api.transform.TransformProcess; import org.datavec.api.transform.filter.FilterInvalidValues; import org.datavec.api.transform.schema.Schema; import org.datavec.api.Writable; -import org.datavec.spark.transform.Normalization; import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.misc.StringToWritablesFunction; import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.EmnistDataSetIterator.Set; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -287,10 +284,10 @@ public class BrianTest extends BaseSparkSessionTest { //Define Network - MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration multiLayerConfiguration = NeuralNetConfiguration.builder() .seed(123) .updater(new Nesterovs(0.1, 0.9)) - .list() + .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER) .activation(Activation.RELU).l2(0.001).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER) diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java index f32c3c4de..9195933ff 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/BrianTest2.java @@ -37,7 +37,6 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.Writable; import org.datavec.spark.transform.SparkTransformExecutor; import org.datavec.spark.transform.misc.StringToWritablesFunction; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -295,10 +294,10 @@ public class BrianTest2 /*extends BaseDL4JTest*/ { */ //Define Network - MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration multiLayerConfiguration = NeuralNetConfiguration.builder() .seed(123) .updater(new Nesterovs(0.1, 0.9)) - .list() + .layer(0, new DenseLayer.Builder().nIn(5).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).l2(0.001).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(20).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) //.layer(2, new DenseLayerConfiguration.Builder().nIn(9).nOut(9).weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer.java b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer.java index b81f70fc8..0cf2e5676 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer.java @@ -21,7 +21,6 @@ package net.brutex.spark; -import lombok.extern.log4j.Log4j2; //import net.brutex.ai.performance.storage.PostgresStatsStorage; import lombok.extern.slf4j.Slf4j; import org.datavec.api.records.reader.RecordReader; @@ -29,22 +28,17 @@ import org.datavec.api.records.reader.impl.collection.ListStringRecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.split.ListStringSplit; -import org.deeplearning4j.core.storage.StatsStorage; -import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; + import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.DenseLayer; + import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.listeners.ScoreIterationListener; -import org.deeplearning4j.ui.api.UIServer; -import org.deeplearning4j.ui.model.stats.StatsListener; -import org.deeplearning4j.ui.model.storage.FileStatsStorage; -import org.junit.jupiter.api.AfterAll; + import org.deeplearning4j.ui.api.UIServer; + import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; import org.nd4j.evaluation.classification.Evaluation; @@ -86,13 +80,13 @@ public class TestServer { int i = 2000; int numClasses = 10; int numBatchSize = 100; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(1234) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs.Builder().learningRate(0.15).build()) .activation(Activation.RELU) .l2(0) - .list() + //.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build()) //.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build()) // .layer(1, new DenseLayerConfiguration.Builder().nIn(10).nOut(64).activation(Activation.RELU).build()) diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java index ac625f2b6..c2d6f739c 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java @@ -21,7 +21,6 @@ package net.brutex.spark; -import lombok.extern.log4j.Log4j2; //import net.brutex.ai.performance.storage.PostgresStatsStorage; import lombok.extern.slf4j.Slf4j; import org.datavec.api.records.reader.RecordReader; @@ -32,9 +31,8 @@ import org.datavec.api.split.ListStringSplit; import org.datavec.image.recordreader.ImageRecordReader; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.DenseLayer; + import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; @@ -122,13 +120,13 @@ public class TestServer2 { int i = 2000; int numClasses = 10; int numBatchSize = 100; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(1234) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs.Builder().learningRate(0.15).build()) .activation(Activation.RELU) .l2(0) - .list() + //.layer(0, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 5).stride(1,1).padding(0,2).nOut(1).name("1st Filter").updater(new Adam.Builder().learningRate(0.2).build()).build()) //.layer(1, new ConvolutionLayer.Builder().nIn(1).kernelSize(1, 2).stride(1,2).padding(0,0).nOut(1).name("2nd Filter").updater(new Adam.Builder().learningRate(0.1).build()).build()) // .layer(1, new DenseLayerConfiguration.Builder().nIn(10).nOut(64).activation(Activation.RELU).build()) diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java index 8111d2b7d..0842ebfd4 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java @@ -21,14 +21,14 @@ package org.deeplearning4j.integration; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.FileUtils; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.integration.testcases.samediff.SameDiffCNNCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffMLPTestCases; import org.deeplearning4j.integration.testcases.samediff.SameDiffRNNTestCases; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.listeners.CollectScoresListener; @@ -135,12 +135,12 @@ public class IntegrationTestBaselineGenerator { MultiLayerNetwork mln = null; ComputationGraph cg = null; SameDiff sd = null; - Model m = null; + IModel m = null; if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) { Object config = tc.getConfiguration(); String json = null; - if (config instanceof MultiLayerConfiguration) { - MultiLayerConfiguration mlc = (MultiLayerConfiguration) config; + if (config instanceof NeuralNetConfiguration) { + NeuralNetConfiguration mlc = (NeuralNetConfiguration) config; json = mlc.toJson(); mln = new MultiLayerNetwork(mlc); mln.init(); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index 489c8021d..870f4022a 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -25,18 +25,18 @@ import com.google.common.collect.ImmutableSet; import com.google.common.reflect.ClassPath; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.FileUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.integration.util.CountingMultiDataSetIterator; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.LayerVertex; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.deeplearning4j.nn.layers.BaseOutputLayer; @@ -177,22 +177,22 @@ public class IntegrationTestRunner { MultiLayerNetwork mln = null; ComputationGraph cg = null; SameDiff sd = null; - Model m = null; + IModel m = null; if (tc.getTestType() == TestCase.TestType.RANDOM_INIT) { log.info("Checking RANDOM_INIT test case: saved model vs. initialized model"); //Checking randomly initialized model: File savedModel = new File(testBaseDir, IntegrationTestRunner.RANDOM_INIT_UNTRAINED_MODEL_FILENAME); Object config = tc.getConfiguration(); - if (config instanceof MultiLayerConfiguration) { - MultiLayerConfiguration mlc = (MultiLayerConfiguration) config; + if (config instanceof NeuralNetConfiguration) { + NeuralNetConfiguration mlc = (NeuralNetConfiguration) config; mln = new MultiLayerNetwork(mlc); mln.init(); m = mln; MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true); - assertEquals(loaded.getLayerWiseConfigurations(), mln.getLayerWiseConfigurations(), "Configs not equal"); + assertEquals(loaded.getConfiguration(), mln.getConfiguration(), "Configs not equal"); assertEquals( loaded.params(), mln.params(), "Params not equal"); - assertEquals( loaded.paramTable(), mln.paramTable(), "Param table not equal"); + assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal"); } else if(config instanceof ComputationGraphConfiguration ){ ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; cg = new ComputationGraph(cgc); @@ -426,8 +426,8 @@ public class IntegrationTestRunner { boolean isTbptt; int tbpttLength; if(modelType == ModelType.MLN){ - isTbptt = mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT; - tbpttLength = mln.getLayerWiseConfigurations().getTbpttFwdLength(); + isTbptt = mln.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; + tbpttLength = mln.getConfiguration().getTbpttFwdLength(); } else if(modelType == ModelType.CG) { isTbptt = cg.getComputationGraphConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; tbpttLength = cg.getComputationGraphConfiguration().getTbpttFwdLength(); @@ -606,7 +606,7 @@ public class IntegrationTestRunner { if (modelType == ModelType.MLN) { ModelSerializer.writeModel(m, f, true); MultiLayerNetwork restored = MultiLayerNetwork.load(f, true); - assertEquals(mln.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); + assertEquals(mln.getConfiguration(), restored.getConfiguration()); assertEquals(mln.params(), restored.params()); } else if(modelType == ModelType.CG){ ModelSerializer.writeModel(m, f, true); @@ -722,7 +722,7 @@ public class IntegrationTestRunner { } //Work out which layers, vertices etc we have seen - so we can (at the end of all tests) log our integration test coverage - private static void collectCoverageInformation(Model m){ + private static void collectCoverageInformation(IModel m){ boolean isMLN = (m instanceof MultiLayerNetwork); MultiLayerNetwork mln = (isMLN ? (MultiLayerNetwork)m : null); ComputationGraph cg = (!isMLN ? (ComputationGraph)m : null); @@ -735,14 +735,14 @@ public class IntegrationTestRunner { layers = cg.getLayers(); } for (org.deeplearning4j.nn.api.Layer l : layers) { - Layer lConf = l.conf().getLayer(); + LayerConfiguration lConf = l.getLayerConfiguration(); layerConfClassesSeen.put(lConf.getClass(), layerConfClassesSeen.getOrDefault(lConf.getClass(), 0) + 1); } //Collect preprocessor coverage information: Collection preProcessors; if (isMLN) { - preProcessors = mln.getLayerWiseConfigurations().getInputPreProcessors().values(); + preProcessors = mln.getConfiguration().getInputPreProcessors().values(); } else { preProcessors = new ArrayList<>(); for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getComputationGraphConfiguration().getVertices().values()) { @@ -767,7 +767,7 @@ public class IntegrationTestRunner { } - private static void checkLayerClearance(Model m) { + private static void checkLayerClearance(IModel m) { //Check that the input fields for all layers have been cleared org.deeplearning4j.nn.api.Layer[] layers; if (m instanceof MultiLayerNetwork) { @@ -801,7 +801,7 @@ public class IntegrationTestRunner { } } - private static void validateLayerIterCounts(Model m, int expEpoch, int expIter){ + private static void validateLayerIterCounts(IModel m, int expEpoch, int expIter){ //Check that the iteration and epoch counts - on the layers - are synced org.deeplearning4j.nn.api.Layer[] layers; if (m instanceof MultiLayerNetwork) { @@ -817,7 +817,7 @@ public class IntegrationTestRunner { } - private static Map getFrozenLayerParamCopies(Model m){ + private static Map getFrozenLayerParamCopies(IModel m){ Map out = new LinkedHashMap<>(); org.deeplearning4j.nn.api.Layer[] layers; if (m instanceof MultiLayerNetwork) { @@ -832,7 +832,7 @@ public class IntegrationTestRunner { if(m instanceof MultiLayerNetwork){ paramPrefix = l.getIndex() + "_"; } else { - paramPrefix = l.conf().getLayer().getLayerName() + "_"; + paramPrefix = l.getLayerConfiguration().getLayerName() + "_"; } Map paramTable = l.paramTable(); for(Map.Entry e : paramTable.entrySet()){ @@ -854,7 +854,7 @@ public class IntegrationTestRunner { return out; } - public static void checkFrozenParams(Map copiesBeforeTraining, Model m){ + public static void checkFrozenParams(Map copiesBeforeTraining, IModel m){ for(Map.Entry e : copiesBeforeTraining.entrySet()){ INDArray actual = m.getParam(e.getKey()); assertEquals(e.getValue(), actual, e.getKey()); @@ -939,7 +939,7 @@ public class IntegrationTestRunner { } private static boolean isLayerConfig(Class c) { - return Layer.class.isAssignableFrom(c); + return LayerConfiguration.class.isAssignableFrom(c); } private static boolean isPreprocessorConfig(Class c) { diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestCase.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestCase.java index b2d76f04a..41afafa4e 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestCase.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestCase.java @@ -21,7 +21,7 @@ package org.deeplearning4j.integration; import lombok.Data; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -89,7 +89,7 @@ public abstract class TestCase { /** * Required for pretrained models (testType == TestType.PRETRAINED) */ - public Model getPretrainedModel() throws Exception { + public IModel getPretrainedModel() throws Exception { throw new RuntimeException("Implementations must override this method if used"); } diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java index e03f2a523..bbe38a662 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java @@ -22,7 +22,7 @@ package org.deeplearning4j.integration; import org.apache.commons.compress.utils.IOUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; @@ -48,15 +48,15 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); + assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } - //Also check the MultiLayerConfiguration is serializable (required by Spark etc) - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); + //Also check the NeuralNetConfiguration is serializable (required by Spark etc) + NeuralNetConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java index d65a0a9cc..ec116ca31 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN1DTestCases.java @@ -80,12 +80,12 @@ public class CNN1DTestCases { CharacterIterator iter = CharacterIterator.getShakespeareIterator(miniBatchSize,exampleLength); int nOut = iter.totalOutcomes(); - return new NeuralNetConfiguration.Builder() + return ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new Adam(0.01)) - .convolutionMode(ConvolutionMode.Same) + .convolutionMode(ConvolutionMode.Same)) .graphBuilder() .addInputs("in") .layer("0", new Convolution1DLayer.Builder().nOut(32).activation(Activation.TANH).kernelSize(3).stride(1).build(), "in") diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java index 3b351e277..4b7b3f7a3 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN2DTestCases.java @@ -32,7 +32,7 @@ import org.deeplearning4j.datasets.fetchers.DataSetType; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.TinyImageNetDataSetIterator; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -98,13 +98,13 @@ public class CNN2DTestCases { int outputNum = 10; // The number of possible outcomes int seed = 123; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)) - .list() + .layer(0, new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(nChannels) @@ -132,7 +132,7 @@ public class CNN2DTestCases { .nOut(outputNum) .activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below + .inputType(InputType.convolutionalFlat(28, 28, 1)) //See note below .build(); return conf; @@ -207,7 +207,7 @@ public class CNN2DTestCases { } @Override - public Model getPretrainedModel() throws Exception { + public IModel getPretrainedModel() throws Exception { VGG16 vgg16 = VGG16.builder() .seed(12345) .build(); @@ -294,7 +294,7 @@ public class CNN2DTestCases { } @Override - public Model getPretrainedModel() throws Exception { + public IModel getPretrainedModel() throws Exception { int nClasses = 10; int nBoxes = 5; double lambdaNoObj = 0.5; @@ -403,20 +403,20 @@ public class CNN2DTestCases { } @Override - public Model getPretrainedModel() throws Exception { + public IModel getPretrainedModel() throws Exception { Map lrSchedule = new HashMap<>(); lrSchedule.put(0, 0.01); lrSchedule.put(1000, 0.005); lrSchedule.put(3000, 0.001); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)) - .list() + .layer(0, new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(1) @@ -446,7 +446,7 @@ public class CNN2DTestCases { .nOut(10) .activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below + .inputType(InputType.convolutionalFlat(28, 28, 1)) //See note below .build(); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java index f856d5159..157116ba9 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/CNN3DTestCases.java @@ -24,7 +24,6 @@ import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; import org.deeplearning4j.integration.ModelType; import org.deeplearning4j.integration.TestCase; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution3D; @@ -76,13 +75,13 @@ public class CNN3DTestCases { int outputNum = 10; // The number of possible outcomes int seed = 123; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)) .convolutionMode(ConvolutionMode.Same) - .list() + .layer(new Convolution3D.Builder(3,3,3) .dataFormat(Convolution3D.DataFormat.NCDHW) .nIn(nChannels) @@ -98,7 +97,7 @@ public class CNN3DTestCases { .nOut(outputNum) .activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutional3D(8,8,8,nChannels)) + .inputType(InputType.convolutional3D(8,8,8,nChannels)) .build(); return conf; diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java index 9a58e5138..69e9fa4cd 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/MLPTestCases.java @@ -28,7 +28,6 @@ import org.datavec.api.split.FileSplit; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -93,7 +92,7 @@ public class MLPTestCases { @Override public Object getConfiguration() { - return new NeuralNetConfiguration.Builder() + return NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .updater(new Adam(new MapSchedule.Builder(ScheduleType.ITERATION) @@ -104,13 +103,13 @@ public class MLPTestCases { .add(14, 1e-2) .build())) .l1(1e-3).l2(1e-3) - .list() + .layer(new DenseLayer.Builder().activation(Activation.TANH).nOut(64).build()) .layer(new OutputLayer.Builder().nOut(10) .lossFunction(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutionalFlat(28,28,1)) + .inputType(InputType.convolutionalFlat(28,28,1)) .build(); } @@ -198,11 +197,11 @@ public class MLPTestCases { int numHiddenNodes = 20; //log.info("Build model...."); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(seed) .updater(new Nesterovs(learningRate, 0.9)) - .list() + .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes) .weightInit(WeightInit.XAVIER) .activation(Activation.RELU) diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java index a2cf437fe..edb312c0f 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/RNNTestCases.java @@ -112,20 +112,20 @@ public class RNNTestCases { int lstmLayerSize = 200; //Number of units in each GravesLSTM layer int tbpttLength = 50; //Length for truncated backpropagation through time. i.e., do parameter updates ever 50 characters - return new NeuralNetConfiguration.Builder() + return NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .l2(0.001) .weightInit(WeightInit.XAVIER) .updater(new Adam(1e-3)) - .list() + .layer(0, new LSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize) .activation(Activation.TANH).build()) .layer(1, new LSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize) .activation(Activation.TANH).build()) .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) //MCXENT + softmax for classification .nIn(lstmLayerSize).nOut(nOut).build()) - .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(tbpttLength).tBPTTBackwardLength(tbpttLength) + .backpropType(BackpropType.TruncatedBPTT).tbpttFwdLength(tbpttLength).tbpttBackLength(tbpttLength) .build(); } @@ -195,19 +195,19 @@ public class RNNTestCases { @Override public Object getConfiguration() throws Exception { - return new NeuralNetConfiguration.Builder() + return NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .updater(new Adam(5e-2)) .l1(1e-3).l2(1e-3) - .list() + .layer(0, new LSTM.Builder().activation(Activation.TANH).nOut(10).build()) .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()) .layer(new OutputLayer.Builder().nOut(6) .lossFunction(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.recurrent(1)) + .inputType(InputType.recurrent(1)) .build(); } @@ -316,19 +316,19 @@ public class RNNTestCases { @Override public Object getConfiguration() throws Exception { - return new NeuralNetConfiguration.Builder() + return NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .updater(new Adam(5e-2)) .l1(1e-3).l2(1e-3) - .list() + .layer(0, new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build())) .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()) .layer(new OutputLayer.Builder().nOut(6) .lossFunction(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.recurrent(1)) + .inputType(InputType.recurrent(1)) .build(); } diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java index 574e3be2d..84b60ffd6 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/testcases/dl4j/UnsupervisedTestCases.java @@ -72,13 +72,13 @@ public class UnsupervisedTestCases { @Override public Object getConfiguration() { - return new NeuralNetConfiguration.Builder() + return NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .updater(new Adam(1e-3)) .weightInit(WeightInit.XAVIER) .l2(1e-4) - .list() + .layer(0, new VariationalAutoencoder.Builder() .activation(Activation.TANH) .encoderLayerSizes(256, 256) //2 encoder layers, each of size 256 diff --git a/build.gradle b/build.gradle index 20e45b528..a3a070c2b 100644 --- a/build.gradle +++ b/build.gradle @@ -66,9 +66,9 @@ allprojects { Project proj -> plugins.withType(JavaPlugin) { sourceCompatibility = JavaVersion.VERSION_11 - targetCompatibility = JavaVersion.VERSION_1_8 + targetCompatibility = JavaVersion.VERSION_11 tasks.withType(JavaCompile) { - options.release = 8 + options.release = 11 } dependencies { @@ -86,7 +86,6 @@ allprojects { Project proj -> testImplementation 'org.junit.jupiter:junit-jupiter-engine' testImplementation 'org.junit.jupiter:junit-jupiter-api' testImplementation 'org.junit.jupiter:junit-jupiter-params' - implementation "org.slf4j:slf4j-api" implementation "org.slf4j:slf4j-simple" diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/net/brutex/ai/dnn/core/util/ANSI.java b/cavis-dnn/cavis-dnn-core/src/main/java/net/brutex/ai/dnn/core/util/ANSI.java new file mode 100644 index 000000000..bd2247445 --- /dev/null +++ b/cavis-dnn/cavis-dnn-core/src/main/java/net/brutex/ai/dnn/core/util/ANSI.java @@ -0,0 +1,52 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.core.util; + +/** + * ANSI colour codes + */ +public enum ANSI { + BLACK("\u001B[30m"), + RED("\u001B[31m"), + GREEN("\u001B[32m"), + YELLOW("\u001B[33m"), + BLUE("\u001B[34m"), + PURPLE("\u001B[35m"), + CYAN("\u001B[36m"), + WHITE("\u001B[37m"), + + ANSI_RESET("\u001B[0m"), + + BLACK_BACKGROUND("\u001B[40m"), + RED_BACKGROUND("\u001B[41m"), + GREEN_BACKGROUND("\u001B[42m"), + YELLOW_BACKGROUND("\u001B[43m"), + BLUE_BACKGROUND("\u001B[44m"), + PURPLE_BACKGROUND("\u001B[45m"), + CYAN_BACKGROUND("\u001B[46m"), + WHITE_BACKGROUND("\u001B[47m"); + + String code; + ANSI(String code) { + this.code = code; + } +} diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemInfoFilePrintListener.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemInfoFilePrintListener.java index 88f8a2bd8..d9e3d7b6f 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemInfoFilePrintListener.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemInfoFilePrintListener.java @@ -23,8 +23,8 @@ package org.deeplearning4j.core.listener; import lombok.NonNull; import lombok.Builder; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.FileUtils; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.ndarray.INDArray; import oshi.json.SystemInfo; @@ -56,12 +56,12 @@ public class SystemInfoFilePrintListener implements TrainingListener { } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { } @Override - public void onEpochStart(Model model) { + public void onEpochStart(IModel model) { if(!printOnEpochStart || printFileTarget == null) return; @@ -70,7 +70,7 @@ public class SystemInfoFilePrintListener implements TrainingListener { } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { if(!printOnEpochEnd || printFileTarget == null) return; @@ -79,7 +79,7 @@ public class SystemInfoFilePrintListener implements TrainingListener { } @Override - public void onForwardPass(Model model, List activations) { + public void onForwardPass(IModel model, List activations) { if(!printOnBackwardPass || printFileTarget == null) return; @@ -88,7 +88,7 @@ public class SystemInfoFilePrintListener implements TrainingListener { } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { if(!printOnForwardPass || printFileTarget == null) return; @@ -97,7 +97,7 @@ public class SystemInfoFilePrintListener implements TrainingListener { } @Override - public void onGradientCalculation(Model model) { + public void onGradientCalculation(IModel model) { if(!printOnGradientCalculation || printFileTarget == null) return; @@ -107,7 +107,7 @@ public class SystemInfoFilePrintListener implements TrainingListener { } @Override - public void onBackwardPass(Model model) { + public void onBackwardPass(IModel model) { if(!printOnBackwardPass || printFileTarget == null) return; diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemInfoPrintListener.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemInfoPrintListener.java index 5b115d542..e4bdfcda6 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemInfoPrintListener.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/listener/SystemInfoPrintListener.java @@ -22,7 +22,7 @@ package org.deeplearning4j.core.listener; import lombok.Builder; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.ndarray.INDArray; import oshi.json.SystemInfo; @@ -49,12 +49,12 @@ public class SystemInfoPrintListener implements TrainingListener { private static final String SYSTEM_INFO = "System info on epoch end: "; @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { } @Override - public void onEpochStart(Model model) { + public void onEpochStart(IModel model) { if(!printOnEpochStart) return; @@ -64,7 +64,7 @@ public class SystemInfoPrintListener implements TrainingListener { } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { if(!printOnEpochEnd) return; @@ -74,7 +74,7 @@ public class SystemInfoPrintListener implements TrainingListener { } @Override - public void onForwardPass(Model model, List activations) { + public void onForwardPass(IModel model, List activations) { if(!printOnBackwardPass) return; @@ -84,7 +84,7 @@ public class SystemInfoPrintListener implements TrainingListener { } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { if(!printOnForwardPass) return; @@ -94,7 +94,7 @@ public class SystemInfoPrintListener implements TrainingListener { } @Override - public void onGradientCalculation(Model model) { + public void onGradientCalculation(IModel model) { if(!printOnGradientCalculation) return; @@ -104,7 +104,7 @@ public class SystemInfoPrintListener implements TrainingListener { } @Override - public void onBackwardPass(Model model) { + public void onBackwardPass(IModel model) { if(!printOnBackwardPass) return; SystemInfo systemInfo = new SystemInfo(); diff --git a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/ModelGuesser.java b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/ModelGuesser.java index 70b250978..3ab6eec8f 100644 --- a/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/ModelGuesser.java +++ b/cavis-dnn/cavis-dnn-core/src/main/java/org/deeplearning4j/core/util/ModelGuesser.java @@ -21,13 +21,13 @@ package org.deeplearning4j.core.util; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.deeplearning4j.common.util.DL4JFileUtils; import org.deeplearning4j.common.config.DL4JSystemProperties; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.linalg.dataset.api.preprocessor.Normalizer; @@ -80,7 +80,7 @@ public class ModelGuesser { //note here that we load json BEFORE YAML. YAML //turns out to load just fine *accidentally* try { - return MultiLayerConfiguration.fromJson(input); + return NeuralNetConfiguration.fromJson(input); } catch (Exception e) { log.warn("Tried multi layer config from json", e); try { @@ -96,7 +96,7 @@ public class ModelGuesser { } catch (Exception e3) { log.warn("Tried computation graph from json"); try { - return MultiLayerConfiguration.fromYaml(input); + return NeuralNetConfiguration.fromYaml(input); } catch (Exception e4) { log.warn("Tried multi layer configuration from yaml"); try { @@ -142,7 +142,7 @@ public class ModelGuesser { * @return the loaded model * @throws Exception */ - public static Model loadModelGuess(String path) throws Exception { + public static IModel loadModelGuess(String path) throws Exception { try { return ModelSerializer.restoreMultiLayerNetwork(new File(path), true); } catch (Exception e) { @@ -185,7 +185,7 @@ public class ModelGuesser { * @return the loaded model * @throws Exception */ - public static Model loadModelGuess(InputStream stream) throws Exception { + public static IModel loadModelGuess(InputStream stream) throws Exception { return loadModelGuess(stream, null); } @@ -194,7 +194,7 @@ public class ModelGuesser { * @param stream Stream of the model file * @param tempDirectory Temporary/working directory. May be null. */ - public static Model loadModelGuess(InputStream stream, File tempDirectory) throws Exception { + public static IModel loadModelGuess(InputStream stream, File tempDirectory) throws Exception { //Currently (Nov 2017): KerasModelImport doesn't support loading from input streams //Simplest solution here: write to a temporary file File f; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java index cc1220762..8da3ff4e5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java @@ -26,6 +26,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.listeners.CollectScoresListener; @@ -99,7 +100,7 @@ public class LayerHelperValidationUtil { //Don't allow fallback: for(Layer l : netOrig.getLayers()){ - org.deeplearning4j.nn.conf.layers.Layer lConf = l.conf().getLayer(); + LayerConfiguration lConf = l.getLayerConfiguration(); if(lConf instanceof ConvolutionLayer){ ((ConvolutionLayer) lConf).setCudnnAllowFallback(false); } else if(lConf instanceof SubsamplingLayer){ @@ -108,12 +109,12 @@ public class LayerHelperValidationUtil { } - MultiLayerNetwork net1NoHelper = new MultiLayerNetwork(netOrig.getLayerWiseConfigurations().clone()); + MultiLayerNetwork net1NoHelper = new MultiLayerNetwork(netOrig.getConfiguration().clone()); net1NoHelper.init(); log.info("Removing all layer helpers from network copy 1"); removeHelpers(net1NoHelper.getLayers(), null); - MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getLayerWiseConfigurations().clone()); + MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); net2With.init(); net2With.params().assign(netOrig.params()); log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses()); @@ -133,7 +134,7 @@ public class LayerHelperValidationUtil { enableCppHelpers(); } List ff2 = net2With.feedForward(t.getFeatures(), train); - List paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet()); + List paramKeys = new ArrayList<>(net1NoHelper.getParamTable().keySet()); Collections.sort(paramKeys); for (String p : paramKeys) { INDArray p1 = net1NoHelper.getParam(p); @@ -224,7 +225,7 @@ public class LayerHelperValidationUtil { } net2With.computeGradientAndScore(); - List paramKeys = new ArrayList<>(net1NoHelper.paramTable().keySet()); + List paramKeys = new ArrayList<>(net1NoHelper.getParamTable().keySet()); Collections.sort(paramKeys); for(String p : paramKeys){ INDArray g1 = net1NoHelper.gradient().gradientForVariable().get(p); @@ -252,7 +253,7 @@ public class LayerHelperValidationUtil { Preconditions.checkNotNull(t.getData(), "DataSetIterator is not set (null)"); log.info("Testing run-to-run consistency of training with layer helper"); - net2With = new MultiLayerNetwork(netOrig.getLayerWiseConfigurations().clone()); + net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); net2With.init(); net2With.params().assign(netOrig.params()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); @@ -264,7 +265,7 @@ public class LayerHelperValidationUtil { for( int i=0; i<2; i++ ) { - net2With = new MultiLayerNetwork(netOrig.getLayerWiseConfigurations().clone()); + net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); net2With.init(); net2With.params().assign(netOrig.params()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/RandomTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/RandomTests.java index 63b13e660..d939dab81 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/RandomTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/RandomTests.java @@ -23,19 +23,15 @@ package org.deeplearning4j; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; -import org.nd4j.common.resources.Resources; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.RmsProp; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.nio.file.Files; import java.util.concurrent.CountDownLatch; //@Ignore @@ -44,8 +40,8 @@ public class RandomTests extends BaseDL4JTest { @Test public void testReproduce() throws Exception { - final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + final NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(10) .activation(Activation.TANH).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java index cecc969ac..6e4456ef2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -23,7 +23,7 @@ package org.deeplearning4j; import org.apache.commons.compress.utils.IOUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; @@ -66,15 +66,15 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); + assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); } - //Also check the MultiLayerConfiguration is serializable (required by Spark etc) - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); + //Also check the NeuralNetConfiguration is serializable (required by Spark etc) + NeuralNetConfiguration conf = net.getConfiguration(); serializeDeserializeJava(conf); return restored; @@ -317,14 +317,14 @@ public class TestUtils { for(Layer l : layers){ //Don't use instanceof here - there are sub conv subclasses if(l.getClass() == ConvolutionLayer.class || l instanceof SubsamplingLayer || l instanceof BatchNormalization || l instanceof LSTM){ - Preconditions.checkNotNull(l.getHelper(), l.conf().getLayer().getLayerName()); + Preconditions.checkNotNull(l.getHelper(), l.getLayerConfiguration().getLayerName()); } } } public static void assertHelpersAbsent(Layer[] layers) throws Exception { for(Layer l : layers){ - Preconditions.checkState(l.getHelper() == null, l.conf().getLayer().getLayerName()); + Preconditions.checkState(l.getHelper() == null, l.getLayerConfiguration().getLayerName()); } } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java index dc9b3ffcf..f391f35f9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java @@ -31,8 +31,8 @@ import org.deeplearning4j.datasets.iterator.impl.*; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -168,9 +168,9 @@ public class DataSetIteratorTest extends BaseDL4JTest { LFWDataSetIterator lfw = new LFWDataSetIterator(batchSize, numSamples, new int[] {numRows, numColumns, numChannels}, outputNum, false, true, 1.0, new Random(seed)); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(numChannels).nOut(6) .weightInit(WeightInit.XAVIER).activation(Activation.RELU).build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) @@ -178,8 +178,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutionalFlat(numRows, numColumns, numChannels)) - ; + .inputType(InputType.convolutionalFlat(numRows, numColumns, numChannels)); MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); model.init(); @@ -229,9 +228,9 @@ public class DataSetIteratorTest extends BaseDL4JTest { Cifar10DataSetIterator cifar = new Cifar10DataSetIterator(batchSize); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new ConvolutionLayer.Builder(5, 5).nIn(channels).nOut(6).weightInit(WeightInit.XAVIER) .activation(Activation.RELU).build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) @@ -240,7 +239,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutionalFlat(height, width, channels)); + .inputType(InputType.convolutionalFlat(height, width, channels)); MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); model.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java index 13ae46efb..12e17fa3a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java @@ -38,10 +38,9 @@ import org.deeplearning4j.earlystopping.scorecalc.*; import org.deeplearning4j.earlystopping.termination.*; import org.deeplearning4j.earlystopping.trainer.EarlyStoppingTrainer; import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution; @@ -133,9 +132,9 @@ public class TestEarlyStopping extends BaseDL4JTest { String msg = i + " - " + sc.getClass().getSimpleName(); log.info("Starting test - {}", msg); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) - .updater(new Sgd(0.5)).weightInit(WeightInit.XAVIER).list() + .updater(new Sgd(0.5)).weightInit(WeightInit.XAVIER) .layer(new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()) .layer(new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) @@ -219,9 +218,9 @@ public class TestEarlyStopping extends BaseDL4JTest { @Test public void testEarlyStoppingEveryNEpoch() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.01)).weightInit(WeightInit.XAVIER).list() + .updater(new Sgd(0.01)).weightInit(WeightInit.XAVIER) .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -248,9 +247,9 @@ public class TestEarlyStopping extends BaseDL4JTest { @Test public void testEarlyStoppingIrisMultiEpoch() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list() + .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER) .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -297,10 +296,10 @@ public class TestEarlyStopping extends BaseDL4JTest { //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(5.0)) //Intentionally huge LR - .weightInit(WeightInit.XAVIER).list() + .weightInit(WeightInit.XAVIER) .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); @@ -336,9 +335,9 @@ public class TestEarlyStopping extends BaseDL4JTest { //test termination after max time Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).list() + .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER) .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -379,9 +378,9 @@ public class TestEarlyStopping extends BaseDL4JTest { //Simulate this by setting LR = 0.0 Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).list() + .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER) .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -419,9 +418,9 @@ public class TestEarlyStopping extends BaseDL4JTest { //Simulate this by setting LR = 0.0 Random rng = new Random(123); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Nesterovs(0.0,0.9)).list() + .updater(new Nesterovs(0.0,0.9)) .layer(0, new DenseLayer.Builder().nIn(1).nOut(20) .weightInit(WeightInit.XAVIER).activation( Activation.TANH) @@ -466,9 +465,9 @@ public class TestEarlyStopping extends BaseDL4JTest { @Test public void testEarlyStoppingGetBestModel() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list() + .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER) .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -496,17 +495,17 @@ public class TestEarlyStopping extends BaseDL4JTest { MultiLayerNetwork mln = result.getBestModel(); assertEquals(net.getnLayers(), mln.getnLayers()); - assertEquals(net.conf().getOptimizationAlgo(), mln.conf().getOptimizationAlgo()); - BaseLayer bl = (BaseLayer) net.conf().getLayer(); - assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.conf().getLayer()).getActivationFn().toString()); - assertEquals(bl.getIUpdater(), ((BaseLayer) mln.conf().getLayer()).getIUpdater()); + assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo()); + BaseLayer bl = (BaseLayer) net.getLayerConfiguration(); + assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.getLayerConfiguration()).getActivationFn().toString()); + assertEquals(bl.getIUpdater(), ((BaseLayer) mln.getLayerConfiguration()).getIUpdater()); } @Test public void testListeners() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list() + .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER) .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -569,8 +568,8 @@ public class TestEarlyStopping extends BaseDL4JTest { Metric.MAE}) { log.info("Metric: " + metric); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new DenseLayer.Builder().nIn(784).nOut(32).build()) .layer(new OutputLayer.Builder().nIn(32).nOut(784).activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build()) .build(); @@ -612,8 +611,8 @@ public class TestEarlyStopping extends BaseDL4JTest { Metric.MAE}) { log.info("Metric: " + metric); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new AutoEncoder.Builder().nIn(784).nOut(32).build()) .build(); @@ -655,8 +654,8 @@ public class TestEarlyStopping extends BaseDL4JTest { Metric.MAE}) { log.info("Metric: " + metric); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new VariationalAutoencoder.Builder() .nIn(784).nOut(32) .encoderLayerSizes(64) @@ -700,8 +699,8 @@ public class TestEarlyStopping extends BaseDL4JTest { for(boolean logProb : new boolean[]{false, true}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new VariationalAutoencoder.Builder() .nIn(784).nOut(32) .encoderLayerSizes(64) @@ -747,8 +746,8 @@ public class TestEarlyStopping extends BaseDL4JTest { for(Evaluation.Metric metric : Evaluation.Metric.values()) { log.info("Metric: " + metric); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new DenseLayer.Builder().nIn(784).nOut(32).build()) .layer(new OutputLayer.Builder().nIn(32).nOut(10).activation(Activation.SOFTMAX).build()) .build(); @@ -784,8 +783,8 @@ public class TestEarlyStopping extends BaseDL4JTest { @Test public void testEarlyStoppingListeners() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER) .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -827,19 +826,19 @@ public class TestEarlyStopping extends BaseDL4JTest { private int maxEpochEnd = -1; @Override - public void onEpochStart(Model model){ + public void onEpochStart(IModel model){ countEpochStart++; maxEpochStart = Math.max(maxEpochStart, BaseOptimizer.getEpochCount(model)); } @Override - public void onEpochEnd(Model model){ + public void onEpochEnd(IModel model){ countEpochEnd++; maxEpochEnd = Math.max(maxEpochEnd, BaseOptimizer.getEpochCount(model)); } @Override - public void iterationDone(Model model, int iteration, int epoch){ + public void iterationDone(IModel model, int iteration, int epoch){ iterCount++; } @@ -859,7 +858,7 @@ public class TestEarlyStopping extends BaseDL4JTest { DataSetIterator test = new SingletonDataSetIterator(ds); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(123) .weightInit(WeightInit.XAVIER) .updater(new Adam(0.1)) @@ -868,7 +867,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .gradientNormalization(GradientNormalization .ClipElementWiseAbsoluteValue) .gradientNormalizationThreshold(1.0) - .list() + .layer(0, new LSTM.Builder() .nIn(10) .nOut(10) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java index 4209f8dd3..fb55e2957 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java @@ -76,7 +76,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { @Test public void testEarlyStoppingIris() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3) @@ -120,7 +120,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(5.0)) //Intentionally huge LR .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") @@ -156,7 +156,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { //test termination after max time Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).graphBuilder() .addInputs("in") @@ -198,7 +198,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { //Simulate this by setting LR = 0.0 Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).graphBuilder() .addInputs("in") @@ -233,7 +233,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { @Test public void testListeners() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3) @@ -297,7 +297,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { Metric.MAE}) { log.info("Metric: " + metric); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new DenseLayer.Builder().nIn(784).nOut(32).build(), "in") @@ -343,7 +343,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { Metric.MAE}) { log.info("Metric: " + metric); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new AutoEncoder.Builder().nIn(784).nOut(32).build(), "in") @@ -388,7 +388,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { Metric.MAE}) { log.info("Metric: " + metric); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new VariationalAutoencoder.Builder() @@ -435,7 +435,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { for(boolean logProb : new boolean[]{false, true}) { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(1e-5)) .graphBuilder() .addInputs("in") @@ -486,7 +486,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { for(Evaluation.Metric metric : Evaluation.Metric.values()) { log.info("Metric: " + metric); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new DenseLayer.Builder().nIn(784).nOut(32).build(), "in") @@ -526,7 +526,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { @Test public void testEarlyStoppingListenersCG() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .updater(new Sgd(0.001)).weightInit(WeightInit.XAVIER) .graphBuilder() .addInputs("in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 7b44d26c9..8f69cf1d9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -69,7 +69,7 @@ public class EvalTest extends BaseDL4JTest { public void testIris() { // Network config - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42) .updater(new Sgd(1e-6)).list() @@ -177,7 +177,7 @@ public class EvalTest extends BaseDL4JTest { rrdsi.reset(); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)) .list() .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) @@ -295,7 +295,7 @@ public class EvalTest extends BaseDL4JTest { int tbpttLength = 10; int tsLength = 5 * tbpttLength + tbpttLength / 2; - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder() .seed(12345) .trainingWorkspaceMode(ws) .inferenceWorkspaceMode(ws) @@ -306,7 +306,7 @@ public class EvalTest extends BaseDL4JTest { .build()) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .seed(12345) .trainingWorkspaceMode(ws) .inferenceWorkspaceMode(ws) @@ -314,7 +314,7 @@ public class EvalTest extends BaseDL4JTest { .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).build()) .layer(new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut) .activation(Activation.SOFTMAX).build()) - .tBPTTLength(10) + .tbpttFwdLength(10).tbpttBackLength(10) .backpropType(BackpropType.TruncatedBPTT) .build(); @@ -371,7 +371,7 @@ public class EvalTest extends BaseDL4JTest { int tbpttLength = 10; int tsLength = 5 * tbpttLength + tbpttLength / 2; - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf1 = NeuralNetConfiguration.builder() .seed(12345) .trainingWorkspaceMode(ws) .inferenceWorkspaceMode(ws) @@ -384,7 +384,7 @@ public class EvalTest extends BaseDL4JTest { .setOutputs("1") .build(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .seed(12345) .trainingWorkspaceMode(ws) .inferenceWorkspaceMode(ws) @@ -455,12 +455,12 @@ public class EvalTest extends BaseDL4JTest { DataSetIterator testData = new SequenceRecordReaderDataSetIterator(fsr, lsr, 1, -1, true, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .list() .layer(0, new LSTM.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()) .layer(1, new RnnOutputLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.XENT) .nIn(3).nOut(1).build()) - .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(10).tBPTTBackwardLength(10) + .backpropType(BackpropType.TruncatedBPTT).tbpttFwdLength(10).tbpttBackLength(10) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -473,7 +473,7 @@ public class EvalTest extends BaseDL4JTest { //Sanity check: https://github.com/eclipse/deeplearning4j/issues/5351 // Network config - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(42) .updater(new Sgd(1e-6)).list() @@ -503,7 +503,7 @@ public class EvalTest extends BaseDL4JTest { public void testMultiOutputEvalSimple(){ Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .graphBuilder() .addInputs("in") @@ -538,7 +538,7 @@ public class EvalTest extends BaseDL4JTest { public void testMultiOutputEvalCG(){ //Simple sanity check on evaluation - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build(), "in") @@ -566,7 +566,7 @@ public class EvalTest extends BaseDL4JTest { @Test public void testInvalidEvaluation(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new DenseLayer.Builder().nIn(4).nOut(10).build()) @@ -622,7 +622,7 @@ public class EvalTest extends BaseDL4JTest { //Disable validation, and check same thing: - net.getLayerWiseConfigurations().setValidateOutputLayerConfig(false); + net.getConfiguration().setValidateOutputLayerConfig(false); net.evaluate(iter); net.evaluateROCMultiClass(iter, 0); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java index 70271cd95..aa9b2686f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java @@ -23,7 +23,6 @@ package org.deeplearning4j.eval; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.core.evaluation.EvaluationTools; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -48,7 +47,7 @@ public class EvaluationToolsTests extends BaseDL4JTest { DataSetIterator iter = new IrisDataSetIterator(150, 150); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -90,7 +89,7 @@ public class EvaluationToolsTests extends BaseDL4JTest { public void testRocMultiToHtml() throws Exception { DataSetIterator iter = new IrisDataSetIterator(150, 150); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/ROCTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/ROCTest.java index 5684a76d6..ca3ad1b54 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/ROCTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/ROCTest.java @@ -22,23 +22,19 @@ package org.deeplearning4j.eval; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.Test; -import org.nd4j.evaluation.curves.PrecisionRecallCurve; import org.nd4j.evaluation.curves.RocCurve; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.*; @@ -86,7 +82,7 @@ public class ROCTest extends BaseDL4JTest { DataSetIterator iter = new IrisDataSetIterator(150, 150); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).seed(12345) .list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java index b5e2b994e..92991d1cc 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/RegressionEvalTest.java @@ -23,7 +23,6 @@ package org.deeplearning4j.eval; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -49,7 +48,7 @@ public class RegressionEvalTest extends BaseDL4JTest { public void testRegressionEvalMethods() { //Basic sanity check - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.ZERO).list() .layer(0, new OutputLayer.Builder().activation(Activation.TANH) .lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(5).build()) .build(); @@ -71,7 +70,7 @@ public class RegressionEvalTest extends BaseDL4JTest { ComputationGraphConfiguration graphConf = - new NeuralNetConfiguration.Builder().weightInit(WeightInit.ZERO).graphBuilder() + NeuralNetConfiguration.builder().weightInit(WeightInit.ZERO).graphBuilder() .addInputs("in").addLayer("0", new OutputLayer.Builder() .lossFunction(LossFunctions.LossFunction.MSE) .activation(Activation.TANH).nIn(10).nOut(5).build(), "in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java index 0a09599bb..be9568f89 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidConfigurations.java @@ -24,7 +24,6 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -41,7 +40,7 @@ import static org.junit.jupiter.api.Assertions.fail; public class TestInvalidConfigurations extends BaseDL4JTest { public static MultiLayerNetwork getDensePlusOutput(int nIn, int nOut) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(10).build()) .layer(1, new OutputLayer.Builder().nIn(10).nOut(nOut).build()).build(); @@ -52,7 +51,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } public static MultiLayerNetwork getLSTMPlusRnnOutput(int nIn, int nOut) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(10).build()) .layer(1, new RnnOutputLayer.Builder().nIn(10).nOut(nOut).build()).build(); @@ -63,10 +62,10 @@ public class TestInvalidConfigurations extends BaseDL4JTest { } public static MultiLayerNetwork getCnnPlusOutputLayer(int depthIn, int inH, int inW, int nOut) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new ConvolutionLayer.Builder().nIn(depthIn).nOut(5).build()) .layer(1, new OutputLayer.Builder().nOut(nOut).build()) - .setInputType(InputType.convolutional(inH, inW, depthIn)).build(); + .inputType(InputType.convolutional(inH, inW, depthIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -90,7 +89,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { @Test public void testDenseNout0() { try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(0).build()) .layer(1, new OutputLayer.Builder().nIn(10).nOut(10).build()).build(); @@ -147,7 +146,7 @@ public class TestInvalidConfigurations extends BaseDL4JTest { @Test public void testLSTMNOut0() { try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new GravesLSTM.Builder().nIn(10).nOut(0).build()) .layer(1, new RnnOutputLayer.Builder().nIn(10).nOut(10).build()).build(); @@ -178,10 +177,10 @@ public class TestInvalidConfigurations extends BaseDL4JTest { @Test public void testConvolutionalNOut0() { try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new ConvolutionLayer.Builder().nIn(5).nOut(0).build()) .layer(1, new OutputLayer.Builder().nOut(10).build()) - .setInputType(InputType.convolutional(10, 10, 5)).build(); + .inputType(InputType.convolutional(10, 10, 5)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -207,12 +206,12 @@ public class TestInvalidConfigurations extends BaseDL4JTest { //(10-3+2*0)/2+1 = 7/2 + 1 try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().convolutionMode(ConvolutionMode.Strict) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict) .list() .layer(0, new ConvolutionLayer.Builder().kernelSize(3, 2).stride(2, 2).padding(0, 0).nOut(5) .build()) .layer(1, new OutputLayer.Builder().nOut(10).build()) - .setInputType(InputType.convolutional(hIn, wIn, depthIn)).build(); + .inputType(InputType.convolutional(hIn, wIn, depthIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -234,11 +233,11 @@ public class TestInvalidConfigurations extends BaseDL4JTest { int hIn = 10; int wIn = 10; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new ConvolutionLayer.Builder().kernelSize(7, 7).stride(1, 1).padding(0, 0).nOut(5) .build()) .layer(1, new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(hIn, wIn, depthIn)).build(); + .inputType(InputType.convolutional(hIn, wIn, depthIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -265,8 +264,8 @@ public class TestInvalidConfigurations extends BaseDL4JTest { //Invalid: (10-3+0)/2+1 = 4.5 - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().convolutionMode(ConvolutionMode.Strict).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict).list() .layer(0, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(2, 2) .padding(0, 0).nIn(depthIn).nOut(5).build()) .layer(1, new OutputLayer.Builder().nIn(5 * 4 * 4).nOut(10).activation(Activation.SOFTMAX).build()) @@ -299,22 +298,22 @@ public class TestInvalidConfigurations extends BaseDL4JTest { //(10-3+2*0)/2+1 = 7/2 + 1 try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5) .build()) .layer(1, new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(hIn, wIn, depthIn)).build(); + .inputType(InputType.convolutional(hIn, wIn, depthIn)).build(); } catch (Exception e) { fail("Did not expect exception with default (truncate)"); } try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().convolutionMode(ConvolutionMode.Strict) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict) .list() .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 3).stride(2, 2).padding(0, 0).nOut(5) .build()) .layer(1, new OutputLayer.Builder().nOut(10).build()) - .setInputType(InputType.convolutional(hIn, wIn, depthIn)).build(); + .inputType(InputType.convolutional(hIn, wIn, depthIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -338,12 +337,12 @@ public class TestInvalidConfigurations extends BaseDL4JTest { //(10-3+2*0)/2+1 = 7/2 + 1 try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().convolutionMode(ConvolutionMode.Strict) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Strict) .list() .layer(0, new SubsamplingLayer.Builder().kernelSize(2, 3).stride(2, 2).padding(0, 0) .build()) .layer(1, new OutputLayer.Builder().nOut(10).build()) - .setInputType(InputType.convolutional(hIn, wIn, depthIn)).build(); + .inputType(InputType.convolutional(hIn, wIn, depthIn)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java index 7d958355a..4e35f44eb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/exceptions/TestInvalidInput.java @@ -23,7 +23,6 @@ package org.deeplearning4j.exceptions; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JException; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -43,7 +42,7 @@ public class TestInvalidInput extends BaseDL4JTest { @Test public void testInputNinMismatchDense() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()).build(); @@ -64,7 +63,7 @@ public class TestInvalidInput extends BaseDL4JTest { @Test public void testLabelsNOutMismatchOutputLayer() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()).build(); @@ -85,7 +84,7 @@ public class TestInvalidInput extends BaseDL4JTest { @Test public void testLabelsNOutMismatchRnnOutputLayer() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new LSTM.Builder().nIn(5).nOut(5).build()) .layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); @@ -112,10 +111,10 @@ public class TestInvalidInput extends BaseDL4JTest { int w = 16; int d = 3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new ConvolutionLayer.Builder().nIn(d).nOut(5).build()) .layer(1, new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(h, w, d)).build(); + .inputType(InputType.convolutional(h, w, d)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -139,10 +138,10 @@ public class TestInvalidInput extends BaseDL4JTest { int w = 16; int d = 3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new ConvolutionLayer.Builder().nIn(d).nOut(5).build()) .layer(1, new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(h, w, d)).build(); + .inputType(InputType.convolutional(h, w, d)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -165,10 +164,10 @@ public class TestInvalidInput extends BaseDL4JTest { int w = 16; int d = 3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new SubsamplingLayer.Builder().kernelSize(2, 2).build()) .layer(1, new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(h, w, d)).build(); + .inputType(InputType.convolutional(h, w, d)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -188,7 +187,7 @@ public class TestInvalidInput extends BaseDL4JTest { @Test public void testInputNinMismatchLSTM() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new GravesLSTM.Builder().nIn(5).nOut(5).build()) .layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); @@ -209,7 +208,7 @@ public class TestInvalidInput extends BaseDL4JTest { @Test public void testInputNinMismatchBidirectionalLSTM() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new GravesBidirectionalLSTM.Builder().nIn(5).nOut(5).build()) .layer(1, new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); @@ -231,7 +230,7 @@ public class TestInvalidInput extends BaseDL4JTest { @Test public void testInputNinMismatchEmbeddingLayer() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new EmbeddingLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()).build(); @@ -257,7 +256,7 @@ public class TestInvalidInput extends BaseDL4JTest { for(String layerType : new String[]{"simple", "lstm", "graves"}) { - Layer l; + LayerConfiguration l; switch (layerType){ case "simple": l = new SimpleRnn.Builder().nIn(5).nOut(5).build(); @@ -272,7 +271,7 @@ public class TestInvalidInput extends BaseDL4JTest { throw new RuntimeException(); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(l) .layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX).build()).build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java index e375aa180..b83cc07c4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/AttentionLayerTest.java @@ -23,7 +23,6 @@ package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.AttentionVertex; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -84,7 +83,7 @@ public class AttentionLayerTest extends BaseDL4JTest { System.out.println("Starting test: " + name); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) @@ -98,7 +97,7 @@ public class AttentionLayerTest extends BaseDL4JTest { .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) + .inputType(InputType.recurrent(nIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -145,7 +144,7 @@ public class AttentionLayerTest extends BaseDL4JTest { System.out.println("Starting test: " + name); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) @@ -159,7 +158,7 @@ public class AttentionLayerTest extends BaseDL4JTest { .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) + .inputType(InputType.recurrent(nIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -185,7 +184,7 @@ public class AttentionLayerTest extends BaseDL4JTest { for (boolean inputMask : new boolean[]{false, true}) { for (boolean projectInput : new boolean[]{false, true}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) @@ -199,7 +198,7 @@ public class AttentionLayerTest extends BaseDL4JTest { .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) + .inputType(InputType.recurrent(nIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -240,7 +239,7 @@ public class AttentionLayerTest extends BaseDL4JTest { int nOut = 5; int layerSize = 8; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.IDENTITY) .updater(new NoOp()) @@ -251,7 +250,7 @@ public class AttentionLayerTest extends BaseDL4JTest { .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()) .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) + .inputType(InputType.recurrent(nIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -303,7 +302,7 @@ public class AttentionLayerTest extends BaseDL4JTest { System.out.println("Starting test: " + name); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.IDENTITY) .updater(new NoOp()) @@ -314,7 +313,7 @@ public class AttentionLayerTest extends BaseDL4JTest { .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()) .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) + .inputType(InputType.recurrent(nIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -361,7 +360,7 @@ public class AttentionLayerTest extends BaseDL4JTest { System.out.println("Starting test: " + name); - ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration graph = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) @@ -425,7 +424,7 @@ public class AttentionLayerTest extends BaseDL4JTest { System.out.println("Starting test: " + name); - ComputationGraphConfiguration graph = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration graph = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index f45861f57..5e6ed72bd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -25,8 +25,8 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -74,17 +74,17 @@ public class BNGradientCheckTest extends BaseDL4JTest { for (boolean useLogStd : new boolean[]{true, false}) { - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dataType(DataType.DOUBLE) - .seed(12345L) - .dist(new NormalDistribution(0, 1)).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3) - .activation(Activation.IDENTITY).build()) - .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).nOut(3).build()) - .layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3).build()); + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder().updater(new NoOp()) + .dataType(DataType.DOUBLE) + .seed(12345L) + .dist(new NormalDistribution(0, 1)).list() + .layer(0, new DenseLayer.Builder().nIn(4).nOut(3) + .activation(Activation.IDENTITY).build()) + .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).nOut(3).build()) + .layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()) + .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(3).nOut(3).build()); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); @@ -119,7 +119,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { } for (boolean useLogStd : new boolean[]{true, false}) { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()).seed(12345L) .dist(new NormalDistribution(0, 2)).list() @@ -129,7 +129,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { .layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(hw, hw, depth)); + .inputType(InputType.convolutional(hw, hw, depth)); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); @@ -188,7 +188,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .l2(l2vals[j]) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) @@ -203,9 +203,9 @@ public class BNGradientCheckTest extends BaseDL4JTest { .layer(4, new ActivationLayer.Builder().activation(afn).build()) .layer(5, new OutputLayer.Builder(lf).activation(outputActivation).nOut(nOut) .build()) - .setInputType(InputType.convolutional(hw, hw, depth)); + .inputType(InputType.convolutional(hw, hw, depth)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); @@ -294,23 +294,23 @@ public class BNGradientCheckTest extends BaseDL4JTest { LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .l2(l2vals[j]) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) - .updater(new NoOp()) - .dist(new UniformDistribution(-2, 2)).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(4) - .activation(afn).build()) - .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()) - .layer(2, new DenseLayer.Builder().nIn(4).nOut(4).build()) - .layer(3, new BatchNormalization.Builder().useLogStd(useLogStd).build()) - .layer(4, new OutputLayer.Builder(lf) - .activation(outputActivation).nOut(nOut) - .build()); + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder() + .dataType(DataType.DOUBLE) + .l2(l2vals[j]) + .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) + .updater(new NoOp()) + .dist(new UniformDistribution(-2, 2)).seed(12345L).list() + .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(4) + .activation(afn).build()) + .layer(1, new BatchNormalization.Builder().useLogStd(useLogStd).build()) + .layer(2, new DenseLayer.Builder().nIn(4).nOut(4).build()) + .layer(3, new BatchNormalization.Builder().useLogStd(useLogStd).build()) + .layer(4, new OutputLayer.Builder(lf) + .activation(outputActivation).nOut(nOut) + .build()); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); @@ -370,7 +370,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { INDArray labels = ds.getLabels(); for (boolean useLogStd : new boolean[]{true, false}) { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 1)).list() @@ -414,7 +414,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { } for (boolean useLogStd : new boolean[]{true, false}) { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 2)).list() @@ -424,7 +424,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { .layer(2, new ActivationLayer.Builder().activation(Activation.TANH).build()) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(hw, hw, depth)); + .inputType(InputType.convolutional(hw, hw, depth)); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); @@ -457,7 +457,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { for (boolean useLogStd : new boolean[]{true, false}) { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed).updater(new NoOp()) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(seed).updater(new NoOp()) .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .setInputTypes(InputType.convolutional(height, width, channels)) @@ -526,7 +526,7 @@ public class BNGradientCheckTest extends BaseDL4JTest { LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .updater(new NoOp()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java index b9f461775..0f474bb16 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN1DGradientCheckTest.java @@ -24,17 +24,14 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; -import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.Convolution1DUtils; -import org.deeplearning4j.util.ConvolutionUtils; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; @@ -45,8 +42,6 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.io.File; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -90,7 +85,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() @@ -103,10 +98,10 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { .build()) .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length)).build(); + .inputType(InputType.recurrent(convNIn, length)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -170,7 +165,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() @@ -183,10 +178,10 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { .build()) .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); + .inputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -251,7 +246,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() @@ -267,10 +262,10 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { .stride(stride).padding(padding).pnorm(pnorm).build()) .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); + .inputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -330,7 +325,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { } } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)).convolutionMode(ConvolutionMode.Same).list() @@ -344,10 +339,10 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { .stride(stride).padding(padding).pnorm(pnorm).build()) .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); + .inputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -393,7 +388,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { log.info("Starting test: " + s); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) @@ -413,7 +408,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { .layer(new GlobalPoolingLayer(PoolingType.AVG)) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length)).build(); + .inputType(InputType.recurrent(convNIn, length)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -481,7 +476,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { log.info("Starting test: " + s); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) @@ -501,7 +496,7 @@ public class CNN1DGradientCheckTest extends BaseDL4JTest { .build()) .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); + .inputType(InputType.recurrent(convNIn, length,RNNFormat.NCW)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java index 1f4a1ceec..ba60ca557 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNN3DGradientCheckTest.java @@ -24,7 +24,6 @@ import lombok.extern.java.Log; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -35,7 +34,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -112,7 +110,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % finalNOut}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) @@ -131,10 +129,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { .inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, df == Convolution3D.DataFormat.NCDHW)) - .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); + .inputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -215,7 +213,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % finalNOut}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) @@ -235,10 +233,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { .inputPreProcessor(3, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, true)) - .setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); + .inputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -310,7 +308,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % finalNOut}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) @@ -327,10 +325,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { .activation(Activation.SOFTMAX).nOut(finalNOut).build()) .inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth,convNOut, df)) - .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); + .inputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -397,7 +395,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % finalNOut}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) @@ -414,10 +412,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { .inputPreProcessor(2, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut, true)) - .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); + .inputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -493,7 +491,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % finalNOut}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(WeightInit.LECUN_NORMAL) .dist(new NormalDistribution(0, 1)) @@ -513,10 +511,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { .inputPreProcessor(3, new Cnn3DToFeedForwardPreProcessor(outDepth, outHeight, outWidth, convNOut2, true)) - .setInputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); + .inputType(InputType.convolutional3D(depth, height, width, convNIn)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -592,7 +590,7 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{j, j % finalNOut}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .weightInit(new NormalDistribution(0, 0.1)) @@ -607,10 +605,10 @@ public class CNN3DGradientCheckTest extends BaseDL4JTest { .build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(finalNOut).build()) - .setInputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); + .inputType(InputType.convolutional3D(df, depth, height, width, convNIn)).build(); String json = conf.toJson(); - MultiLayerConfiguration c2 = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration c2 = NeuralNetConfiguration.fromJson(json); assertEquals(conf, c2); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index b737fcf79..bee788e55 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -26,8 +26,8 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -100,15 +100,15 @@ public class CNNGradientCheckTest extends BaseDL4JTest { LossFunctions.LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .weightInit(WeightInit.XAVIER).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder(1, 1).nOut(6).activation(afn).build()) .layer(1, new OutputLayer.Builder(lf).activation(outputActivation).nOut(3).build()) - .setInputType(InputType.convolutionalFlat(1, 4, 1)); + .inputType(InputType.convolutionalFlat(1, 4, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); @@ -186,7 +186,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { double l2 = l2vals[i]; double l1 = l1vals[i]; - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .l2(l2).l1(l1).l2Bias(biasL2[i]).l1Bias(biasL1[i]) .optimizationAlgo( @@ -198,9 +198,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(1, new OutputLayer.Builder(lf).activation(outputActivation).nOut(3) .weightInit(WeightInit.XAVIER).updater(new NoOp()).build()) - .setInputType(InputType.convolutionalFlat(1, 4, 1)); + .inputType(InputType.convolutionalFlat(1, 4, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); @@ -269,8 +269,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)) @@ -281,7 +281,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(2 * 2 * 4) .nOut(nOut).build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)) + .inputType(InputType.convolutionalFlat(height, width, inputDepth)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -334,8 +334,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()).weightInit(new NormalDistribution(0, 1)) .list() @@ -349,7 +349,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nOut(nOut).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) + .inputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -403,8 +403,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1)) @@ -416,7 +416,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(8 * 8 * 3) .nOut(4).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) + .inputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -472,8 +472,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .list().layer(0, @@ -488,7 +488,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(3 * 3 * 3) .nOut(4).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) + .inputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -546,8 +546,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .list().layer(0, @@ -562,7 +562,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(2 * 2 * 2) .nOut(4).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) + .inputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -606,7 +606,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new NoOp()) .dataType(DataType.DOUBLE) .activation(afn) .list() @@ -623,10 +623,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut) .build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)).build(); + .inputType(InputType.convolutional(height, width, inputDepth, format)).build(); - assertEquals(ConvolutionMode.Truncate, - ((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode()); + assertEquals(ConvolutionMode.Truncate, conf.getConvolutionMode()); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -673,7 +672,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new NoOp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new NoOp()) .dataType(DataType.DOUBLE) .activation(afn) .list() @@ -689,10 +688,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(2 * 2 * 2).nOut(nOut) .build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)).build(); + .inputType(InputType.convolutional(height, width, inputDepth, format)).build(); - assertEquals(ConvolutionMode.Truncate, - ((ConvolutionLayer) conf.getConf(0).getLayer()).getConvolutionMode()); + assertEquals(ConvolutionMode.Truncate,conf.getConvolutionMode()); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -744,7 +742,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.SIGMOID).convolutionMode(Same).list() @@ -760,7 +758,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .stride(1, 1).padding(0, 0).build()) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)).build(); + .inputType(InputType.convolutional(height, width, inputDepth, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -810,14 +808,14 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - Layer convLayer = new ConvolutionLayer.Builder().name("layer 0").kernelSize(k, k).dataFormat(format) + LayerConfiguration convLayer = new ConvolutionLayer.Builder().name("layer 0").kernelSize(k, k).dataFormat(format) .stride(stride, stride).padding(0, 0).nIn(inputDepth).nOut(2).build(); - Layer poolLayer = new SubsamplingLayer.Builder() + LayerConfiguration poolLayer = new SubsamplingLayer.Builder() .poolingType(SubsamplingLayer.PoolingType.MAX).kernelSize(k, k).dataFormat(format) .stride(stride, stride).padding(0, 0).build(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(Same).list() @@ -825,7 +823,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .layer(1, convFirst ? poolLayer : convLayer) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) + .inputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -883,8 +881,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)).list() .layer(0, new ConvolutionLayer.Builder(kernel, stride, padding).dataFormat(format) @@ -894,7 +892,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { padding).nIn(3).nOut(3).dataFormat(format).build())//output: (6-2+0)/1+1 = 5 .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(4).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) + .inputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -971,7 +969,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{j, j % nOut}, 1.0); } - NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(act) @@ -981,11 +979,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .stride(s, s).dataFormat(format) .dilation(d, d) .convolutionMode(cm) - .nIn(inputDepth).nOut(nOut).build()); + .nIn(inputDepth).nOut(nOut).build()) - MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(h, w, inputDepth, format)).build(); + .inputType(InputType.convolutional(h, w, inputDepth, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -1043,7 +1041,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) @@ -1054,11 +1052,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .stride(s, s) .dilation(d, d) .depthMultiplier(3).dataFormat(format) - .nIn(inputDepth).nOut(2).build()); + .nIn(inputDepth).nOut(2).build()) - MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(h, w, inputDepth, format)).build(); + .inputType(InputType.convolutional(h, w, inputDepth, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -1116,7 +1114,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration.NeuralNetConfigurationBuilder b = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH).convolutionMode(cm).list() @@ -1140,9 +1138,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .build()); } - MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + NeuralNetConfiguration conf = (NeuralNetConfiguration) b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(h, w, inputDepth, format)).build(); + .inputType(InputType.convolutional(h, w, inputDepth, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -1190,8 +1188,8 @@ public class CNNGradientCheckTest extends BaseDL4JTest { INDArray input = Nd4j.rand(DataType.DOUBLE, inShape); INDArray labels = TestUtils.randomOneHot(minibatchSize, nOut); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) @@ -1208,7 +1206,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .build()) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(height, width, inputDepth, format)) + .inputType(InputType.convolutional(height, width, inputDepth, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -1277,7 +1275,7 @@ public class CNNGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % nOut}, 1.0); } - NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration.NeuralNetConfigurationBuilder b = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(Activation.TANH) @@ -1293,9 +1291,9 @@ public class CNNGradientCheckTest extends BaseDL4JTest { .depthMultiplier(depthMultiplier) .nIn(nIn).build()); // nOut = nIn * depthMultiplier - MultiLayerConfiguration conf = b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + NeuralNetConfiguration conf = (NeuralNetConfiguration) b.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(height, width, nIn, format)).build(); + .inputType(InputType.convolutional(height, width, nIn, format)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java index c0a6cad8e..8d9caef52 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CapsnetGradientCheckTest.java @@ -24,7 +24,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -44,8 +43,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; -import java.util.Random; - ////@Ignore public class CapsnetGradientCheckTest extends BaseDL4JTest { @@ -80,12 +77,11 @@ public class CapsnetGradientCheckTest extends BaseDL4JTest { labels.putScalar(new int[]{i, i % capsule}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .seed(123) .updater(new NoOp()) - .weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6))) - .list() + .dist(new UniformDistribution(-6, 6)) .layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel) .kernelSize(3, 3) .stride(2, 2) @@ -94,7 +90,7 @@ public class CapsnetGradientCheckTest extends BaseDL4JTest { .layer(new CapsuleStrengthLayer.Builder().build()) .layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()) .layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()) - .setInputType(InputType.convolutional(height, width, inputDepth)) + .inputType(InputType.convolutional(height, width, inputDepth)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java index 9aafd297c..5c124dfa0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/DropoutGradientCheck.java @@ -25,7 +25,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.dropout.*; @@ -92,7 +91,7 @@ public class DropoutGradientCheck extends BaseDL4JTest { continue; } - NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0,1)) .convolutionMode(ConvolutionMode.Same) @@ -104,18 +103,18 @@ public class DropoutGradientCheck extends BaseDL4JTest { if(cnn){ builder.layer(new ConvolutionLayer.Builder().kernelSize(3,3).stride(2,2).nOut(2).build()); builder.layer(new ConvolutionLayer.Builder().kernelSize(3,3).stride(2,2).nOut(2).build()); - builder.setInputType(InputType.convolutional(6,6,2)); + builder.inputType(InputType.convolutional(6,6,2)); } else { builder.layer(new DenseLayer.Builder().nOut(3).build()); builder.layer(new DenseLayer.Builder().nOut(3).build()); - builder.setInputType(InputType.feedForward(6)); + builder.inputType(InputType.feedForward(6)); } builder.layer(new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunction.MCXENT).build()); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); //Remove spatial dropout from output layer - can't be used for 2d input if(i == 4){ - conf.getConf(2).getLayer().setIDropout(null); + conf.getFlattenedLayerConfigurations().get(2).setIDropout(null); } MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -149,7 +148,7 @@ public class DropoutGradientCheck extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); int mb = 3; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0,1)) .convolutionMode(ConvolutionMode.Same) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java index 36574096d..18d430044 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GlobalPoolingGradientCheckTests.java @@ -24,7 +24,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -72,7 +71,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { for (int miniBatchSize : minibatchSizes) { for (PoolingType pt : poolingTypes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() @@ -127,7 +126,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { for (int miniBatchSize : minibatchSizes) { for (PoolingType pt : poolingTypes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() @@ -138,7 +137,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { .layer(1, new GlobalPoolingLayer.Builder().poolingType(pt).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(inputH, inputW, inputDepth, nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)).build(); + .inputType(InputType.convolutional(inputH, inputW, inputDepth, nchw ? CNN2DFormat.NCHW : CNN2DFormat.NHWC)).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); @@ -185,7 +184,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { for (PoolingType pt : poolingTypes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() @@ -259,7 +258,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { stride = new int[] {inputH, 1}; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).convolutionMode(ConvolutionMode.Same) @@ -270,7 +269,7 @@ public class GlobalPoolingGradientCheckTests extends BaseDL4JTest { .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(inputH, inputW, inputDepth)).build(); + .inputType(InputType.convolutional(inputH, inputW, inputDepth)).build(); MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index 553477bd5..90f927d66 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -26,7 +26,6 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; @@ -74,7 +73,7 @@ public class GradientCheckTests extends BaseDL4JTest { public void testMinibatchApplication() { IrisDataSetIterator iter = new IrisDataSetIterator(30, 150); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().miniBatch(false) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().miniBatch(false) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new NoOp()) .list() @@ -164,7 +163,7 @@ public class GradientCheckTests extends BaseDL4JTest { LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .seed(12345L) @@ -253,8 +252,8 @@ public class GradientCheckTests extends BaseDL4JTest { double l2 = l2vals[k]; double l1 = l1vals[k]; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().l2(l2).l1(l1) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().l2(l2).l1(l1) .dataType(DataType.DOUBLE) .l2Bias(biasL2[k]).l1Bias(biasL1[k]) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) @@ -325,7 +324,7 @@ public class GradientCheckTests extends BaseDL4JTest { labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.1) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l2(0.2).l1(0.1) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345L) .list().layer(new EmbeddingLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) @@ -363,7 +362,7 @@ public class GradientCheckTests extends BaseDL4JTest { labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.1) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l2(0.2).l1(0.1) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345L) .list().layer(0, @@ -429,8 +428,8 @@ public class GradientCheckTests extends BaseDL4JTest { double l1 = l1vals[k]; Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .l2(l2).l1(l1) @@ -491,7 +490,7 @@ public class GradientCheckTests extends BaseDL4JTest { for(Activation a : new Activation[]{Activation.IDENTITY, Activation.TANH}) { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .seed(12345L) @@ -561,7 +560,7 @@ public class GradientCheckTests extends BaseDL4JTest { for (boolean maskArray : new boolean[]{false, true}) { for (int inputRank : new int[]{2, 3}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .seed(12345) .updater(new NoOp()) @@ -672,8 +671,8 @@ public class GradientCheckTests extends BaseDL4JTest { double l2 = l2vals[k]; double l1 = l1vals[k]; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().l2(l2).l1(l1) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().l2(l2).l1(l1) .dataType(DataType.DOUBLE) .l2Bias(biasL2[k]).l1Bias(biasL1[k]) .weightDecay(wdVals[k]).weightDecayBias(wdBias[k]) @@ -736,7 +735,7 @@ public class GradientCheckTests extends BaseDL4JTest { LossFunction lf = lossFunctions[i]; Activation outputActivation = outputActivations[i]; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).updater(new NoOp()) .seed(12345L) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java index 7718078a6..c121e8b14 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsComputationGraph.java @@ -36,9 +36,6 @@ import org.deeplearning4j.nn.conf.graph.rnn.ReverseTimeSeriesVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; -import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; -import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; @@ -52,7 +49,6 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.util.Arrays; import java.util.Map; import java.util.Random; @@ -74,7 +70,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { @Test public void testBasicIris() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)).updater(new NoOp()) @@ -120,7 +116,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { @Test public void testBasicIrisWithMerging() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)).updater(new NoOp()) @@ -177,7 +173,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { for (ElementWiseVertex.Op op : ops) { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -235,7 +231,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { for (ElementWiseVertex.Op op : ops) { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -295,7 +291,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { for(boolean firstSmaller : new boolean[]{false, true}) { for (ElementWiseVertex.Op op : ops) { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .updater(new NoOp()) .dataType(DataType.DOUBLE) .activation(Activation.TANH) @@ -343,7 +339,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { String msg = "testCnnDepthMerge - " + format; Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 0.1)) @@ -398,7 +394,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int outSize = 3; Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new UniformDistribution(0.2, 0.6)) @@ -457,7 +453,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int batchSize = 2; int timeSeriesLength = 4; int inLength = 3; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(1234) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(1234) .dataType(DataType.DOUBLE) .weightInit(new NormalDistribution(0, 1)) .updater(new NoOp()).graphBuilder().addInputs("input").setOutputs("out") @@ -493,7 +489,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { public void testLSTMWithLastTimeStepVertex() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -545,7 +541,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int timeSeriesLength = 4; Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -595,7 +591,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int timeSeriesLength = 4; Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -654,7 +650,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { public void testMultipleInputsLayer() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -697,7 +693,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { @Test public void testMultipleOutputsLayer() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -737,7 +733,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { @Test public void testMultipleOutputsMergeVertex() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -786,7 +782,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int inW = 7; Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -836,7 +832,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { public void testBasicIrisTripletStackingL2Loss() { Nd4j.getRandom().setSeed(12345); ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -910,7 +906,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { for (boolean train : trainFirst) { for (double lambda : new double[] {0.0, 0.5, 2.0}) { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new GaussianDistribution(0, 1)) @@ -975,7 +971,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { for (boolean train : trainFirst) { for (double lambda : new double[] {0.0, 0.5, 2.0}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() @@ -986,7 +982,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { .alpha(1.0).lambda(lambda).gradientCheck(true) .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(inputH, inputW, inputDepth)).build(); + .inputType(InputType.convolutional(inputH, inputW, inputDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -1029,7 +1025,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { @Test public void testBasicL2() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -1081,7 +1077,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int layerSizes = 2; Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -1136,7 +1132,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { public void testBasicStackUnstackDebug() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -1196,7 +1192,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int layerSizes = 2; Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -1259,7 +1255,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { public void testBasicTwoOutputs() { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -1320,7 +1316,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int[][] definitions = {null,new int[]{1}}; for(int[] definition : definitions) { log.info("Testing definition {}",definition); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .activation(Activation.TANH).updater(new NoOp()).graphBuilder() @@ -1368,7 +1364,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { int w = 4; int dIn = 2; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)) @@ -1420,7 +1416,7 @@ public class GradientCheckTestsComputationGraph extends BaseDL4JTest { labels.putScalar(new int[] {i, r.nextInt(3)}, 1.0); } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.1) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().l2(0.2).l1(0.1) .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345L) .updater(new NoOp()).graphBuilder().addInputs("in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java index 4efd20ee7..689720529 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTestsMasking.java @@ -24,7 +24,6 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; @@ -117,7 +116,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { maskArr.putScalar(new int[] {0, j}, mask[i][j] ? 1.0 : 0.0); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345L) .dataType(DataType.DOUBLE) .updater(new NoOp()) .list() @@ -158,7 +157,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { int testNum = 0; for (INDArray mask : masks) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new NoOp()) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() @@ -238,7 +237,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { Activation a = act[i]; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)).seed(12345) .list() @@ -332,7 +331,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { Activation a = act[i]; Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)).seed(12345) .list() @@ -341,7 +340,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { .layer(1, new RnnOutputLayer.Builder().nIn(layerSize).nOut(nOut).lossFunction(lf) .activation(a).build()) .validateOutputLayerConfig(false) - .setInputType(InputType.recurrent(nIn,tsLength, RNNFormat.NCW)) + .inputType(InputType.recurrent(nIn,tsLength, RNNFormat.NCW)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -365,7 +364,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { //Check the equivalent compgraph: Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration cg = new NeuralNetConfiguration.Builder().updater(new NoOp()) + ComputationGraphConfiguration cg = NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 2)).seed(12345) .graphBuilder().addInputs("in") @@ -397,7 +396,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { int mb = 4; int tsLength = 5; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .weightInit(new NormalDistribution(0,2)) .updater(new NoOp()) @@ -405,7 +404,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { .layer(new LSTM.Builder().nIn(3).nOut(3).build()) .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build()) .layer(new OutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(3)) + .inputType(InputType.recurrent(3)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -452,7 +451,7 @@ public class GradientCheckTestsMasking extends BaseDL4JTest { int mb = 10; int tsLength = 5; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .weightInit(new NormalDistribution(0,2)) .updater(new NoOp()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java index 87ea20cf5..18769905c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LRNGradientCheckTests.java @@ -22,8 +22,8 @@ package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -68,7 +68,7 @@ public class LRNGradientCheckTests extends BaseDL4JTest { labels.putScalar(i, r.nextInt(nOut), 1.0); } - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .seed(12345L) .dist(new NormalDistribution(0, 2)).list() @@ -77,7 +77,7 @@ public class LRNGradientCheckTests extends BaseDL4JTest { .layer(1, new LocalResponseNormalization.Builder().build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(hw, hw, depth)); + .inputType(InputType.convolutional(hw, hw, depth)); MultiLayerNetwork mln = new MultiLayerNetwork(builder.build()); mln.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java index a2c7d7039..421b6a63d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LSTMGradientCheckTests.java @@ -22,8 +22,8 @@ package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; @@ -70,8 +70,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { for (boolean graves : gravesLSTM) { - Layer l0; - Layer l1; + LayerConfiguration l0; + LayerConfiguration l1; if (graves) { l0 = new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.SIGMOID) .dist(new NormalDistribution(0, 1.0)) @@ -88,8 +88,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { .updater(new NoOp()).build(); } - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345L) .dataType(DataType.DOUBLE) .list() .layer(0, l0).layer(1, @@ -179,11 +179,11 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { double l1 = l1vals[i]; Activation afn = activFns[i]; - NeuralNetConfiguration.Builder conf = - new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .seed(12345L) - .dist(new NormalDistribution(0, 1)).updater(new NoOp()); + NeuralNetConfiguration.NeuralNetConfigurationBuilder conf = + NeuralNetConfiguration.builder() + .dataType(DataType.DOUBLE) + .seed(12345L) + .dist(new NormalDistribution(0, 1)).updater(new NoOp()); if (l1 > 0.0) conf.l1(l1); @@ -194,17 +194,17 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { if (biasL1[i] > 0) conf.l1Bias(biasL1[i]); - Layer layer; + LayerConfiguration layer; if (graves) { layer = new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation(afn).build(); } else { layer = new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(afn).build(); } - NeuralNetConfiguration.ListBuilder conf2 = conf.list().layer(0, layer) + NeuralNetConfiguration.NeuralNetConfigurationBuilder conf2 = (NeuralNetConfigurationBuilder) conf + .layer(0, layer) .layer(1, new RnnOutputLayer.Builder(lf).activation(outputActivation) - .nIn(layerSize).nOut(nOut).build()) - ; + .nIn(layerSize).nOut(nOut).build()); MultiLayerNetwork mln = new MultiLayerNetwork(conf2.build()); mln.init(); @@ -249,14 +249,14 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { INDArray labels = TestUtils.randomOneHotTimeSeries(miniBatchSize[i], nOut, timeSeriesLength[i]); - Layer layer; + LayerConfiguration layer; if (graves) { layer = new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH).build(); } else { layer = new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH).build(); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345L) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .updater(new NoOp()).list().layer(0, layer) @@ -309,8 +309,8 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { double l2 = l2vals[k]; double l1 = l1vals[k]; - NeuralNetConfiguration.Builder conf = - new NeuralNetConfiguration.Builder(); + NeuralNetConfiguration.NeuralNetConfigurationBuilder conf = + NeuralNetConfiguration.builder(); if (l1 > 0.0) conf.l1(l1); if (l2 > 0.0) @@ -320,10 +320,10 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { if (biasL1[k] > 0) conf.l1Bias(biasL1[k]); - MultiLayerConfiguration mlc = conf.seed(12345L) + NeuralNetConfiguration mlc = (NeuralNetConfiguration) conf.seed(12345L) .dataType(DataType.DOUBLE) .updater(new NoOp()) - .list().layer(0, + .layer(0, new GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize) .weightInit(new NormalDistribution(0, 1)) .activation(afn) @@ -380,7 +380,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { } } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345L) .dataType(DataType.DOUBLE) .list() .layer(0, new GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize) @@ -429,7 +429,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()).seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new NoOp()).seed(12345) .dataType(DataType.DOUBLE) .dist(new UniformDistribution(-2, 2)).list() .layer(0, new ConvolutionLayer.Builder(3, 3).nIn(2).nOut(3).stride(1, 1) @@ -440,7 +440,7 @@ public class LSTMGradientCheckTests extends BaseDL4JTest { .layer(3, new GravesLSTM.Builder().nIn(4).nOut(3).activation(Activation.TANH).build()) .layer(4, new RnnOutputLayer.Builder().lossFunction(LossFunction.MCXENT).nIn(3).nOut(nClasses) .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(6, 6, 2)).build(); + .inputType(InputType.convolutional(6, 6, 2)).build(); //Here: ConvolutionLayerSetup in config builder doesn't know that we are expecting time series input, not standard FF input -> override it here conf.getInputPreProcessors().put(0, new RnnToCnnPreProcessor(6, 6, 2)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index 74b142845..0cf7ebd1b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -26,7 +26,6 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMAE; import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMSE; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; @@ -183,7 +182,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { + minibatchSizes[j]; Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345) .updater(new NoOp()) @@ -347,7 +346,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { } Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345) .updater(new NoOp()) @@ -362,7 +361,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertSame(((LossLayer) net.getLayer(1).conf().getLayer()).getLossFn().getClass(), lossFunctions[i] + assertSame(((LossLayer) net.getLayer(1).getLayerConfiguration()).getLossFn().getClass(), lossFunctions[i] .getClass()); INDArray[] inOut = getFeaturesAndLabels(lossFunctions[i], minibatchSizes[j], 4, nOut[i], 12345); @@ -649,7 +648,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { + minibatchSizes[j] + "; weights = " + w; Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(12345) .updater(new NoOp()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java index 477199be0..f47a4ee0e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/NoBiasGradientCheckTests.java @@ -22,7 +22,6 @@ package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -69,7 +68,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { for (boolean denseHasBias : new boolean[]{true, false}) { for (boolean outHasBias : new boolean[]{true, false}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L) @@ -140,7 +139,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { for (boolean rnnOutHasBias : new boolean[]{true, false}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L) @@ -201,7 +200,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { for (boolean embeddingHasBias : new boolean[]{true, false}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L) @@ -267,8 +266,8 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { for(boolean cnnHasBias : new boolean[]{true, false}) { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new NoOp()) .dataType(DataType.DOUBLE) .dist(new NormalDistribution(0, 1)) .list() @@ -285,7 +284,7 @@ public class NoBiasGradientCheckTests extends BaseDL4JTest { .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(height, width, inputDepth)) + .inputType(InputType.convolutionalFlat(height, width, inputDepth)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java index 0928b52de..7556178b9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/OutputLayerGradientChecks.java @@ -23,7 +23,6 @@ package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.*; @@ -117,8 +116,8 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { Activation oa = maskType == 2 ? Activation.SIGMOID : Activation.SOFTMAX; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345L) .dataType(DataType.DOUBLE) .updater(new NoOp()) .list() @@ -223,8 +222,8 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { Activation oa = maskType == 3 ? Activation.SIGMOID : Activation.SOFTMAX; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345L) .dataType(DataType.DOUBLE) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) @@ -370,8 +369,8 @@ public class OutputLayerGradientChecks extends BaseDL4JTest { Activation oa = maskType == 1 ? Activation.SOFTMAX : Activation.SIGMOID; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345L) .dataType(DataType.DOUBLE) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java index 87a42e4e0..44e904d7e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/RnnGradientChecks.java @@ -22,7 +22,6 @@ package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -108,7 +107,7 @@ public class RnnGradientChecks extends BaseDL4JTest { System.out.println("Starting test: " + name); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) @@ -187,7 +186,7 @@ public class RnnGradientChecks extends BaseDL4JTest { System.out.println("Starting test: " + name); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .updater(new NoOp()) .weightInit(WeightInit.XAVIER) @@ -263,7 +262,7 @@ public class RnnGradientChecks extends BaseDL4JTest { System.out.println("Starting test: " + name); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) @@ -275,7 +274,7 @@ public class RnnGradientChecks extends BaseDL4JTest { new LSTM.Builder().nOut(layerSize).build())) .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) + .inputType(InputType.recurrent(nIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -329,7 +328,7 @@ public class RnnGradientChecks extends BaseDL4JTest { System.out.println("Starting test: " + name); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new NoOp()) @@ -339,7 +338,7 @@ public class RnnGradientChecks extends BaseDL4JTest { .layer(new TimeDistributed(new DenseLayer.Builder().nOut(layerSize).activation(Activation.SOFTMAX).build())) .layer(new RnnOutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) + .inputType(InputType.recurrent(nIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java index 670987c78..212bd29da 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/UtilLayerGradientChecks.java @@ -23,7 +23,6 @@ package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; @@ -127,9 +126,9 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { String name = "mb=" + minibatch + ", maskType=" + maskType + ", inputRank=" + inputRank; System.out.println("*** Starting test: " + name); - Layer l1; - Layer l2; - Layer l3; + LayerConfiguration l1; + LayerConfiguration l2; + LayerConfiguration l3; InputType it; switch (inputRank){ case 2: @@ -163,7 +162,7 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new NoOp()) .activation(Activation.TANH) .dataType(DataType.DOUBLE) @@ -173,7 +172,7 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { .layer(new MaskLayer()) .layer(l2) .layer(l3) - .setInputType(it) + .inputType(it) .build(); @@ -197,10 +196,10 @@ public class UtilLayerGradientChecks extends BaseDL4JTest { for( int minibatch : new int[]{1,5}) { - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .seed(12345) - .updater(Updater.NONE) + .updater(Updater.NONE.getIUpdaterWithDefaultConfig()) .list() .layer(new DenseLayer.Builder().nIn(10).nOut(10) .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java index 40041885e..233836066 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/VaeGradientCheckTests.java @@ -22,7 +22,6 @@ package org.deeplearning4j.gradientcheck; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -94,8 +93,8 @@ public class VaeGradientCheckTests extends BaseDL4JTest { } Activation afn = activFns[i]; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().l2(l2).l1(l1) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().l2(l2).l1(l1) .dataType(DataType.DOUBLE) .updater(new NoOp()) .l2Bias(biasL2[i]).l1Bias(biasL1[i]) @@ -170,7 +169,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { Activation pzxAfn = pzxAfns[i]; Activation pxzAfn = pxzAfns[i]; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(l2) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l2(l2) .dataType(DataType.DOUBLE) .l1(l1).l2Bias(biasL2[i]).l1Bias(biasL1[i]).updater(new NoOp()) .seed(12345L).weightInit(WeightInit.XAVIER).list() @@ -259,7 +258,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { throw new RuntimeException(); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.3) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l2(0.2).l1(0.3) .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L).dist(new NormalDistribution(0, 1)) @@ -303,7 +302,7 @@ public class VaeGradientCheckTests extends BaseDL4JTest { for (int numSamples : new int[]{1, 2}) { INDArray features = Nd4j.rand(DataType.DOUBLE, minibatch, 4); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.3) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l2(0.2).l1(0.3) .dataType(DataType.DOUBLE) .updater(new NoOp()) .seed(12345L).weightInit(WeightInit.XAVIER).list() diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java index 9ae3e598a..1eb72b1bd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/YoloGradientCheckTests.java @@ -109,7 +109,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest { labels = yoloLabels(mb, c, h, w).permute(0,2,3,1); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .dataType(DataType.DOUBLE) .updater(new NoOp()) .activation(a) @@ -122,7 +122,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest { .layer(new Yolo2OutputLayer.Builder() .boundingBoxPriors(bbPrior) .build()) - .setInputType(InputType.convolutional(h, w, depthIn, format)) + .inputType(InputType.convolutional(h, w, depthIn, format)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -227,7 +227,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest { DataSetIterator iter = new RecordReaderDataSetIterator(rr,2,1,1,true); iter.setPreProcessor(new ImagePreProcessingScaler()); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .convolutionMode(ConvolutionMode.Same) .updater(new NoOp()) @@ -240,7 +240,7 @@ public class YoloGradientCheckTests extends BaseDL4JTest { .layer(new Yolo2OutputLayer.Builder() .boundingBoxPriors(bbPriors) .build()) - .setInputType(InputType.convolutional(h,w,c)) + .inputType(InputType.convolutional(h,w,c)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java index 7862cb95f..6e0cbd770 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/ComputationGraphConfigurationTest.java @@ -57,7 +57,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { @Test public void testJSONBasic() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)).updater(new NoOp()) .graphBuilder().addInputs("input") @@ -79,7 +79,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { @Test public void testJSONBasic2() { ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("input") .addLayer("cnn1", @@ -115,7 +115,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { public void testJSONWithGraphNodes() { ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("input1", "input2") .addLayer("cnn1", @@ -149,7 +149,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { //Test no inputs for a layer: try { - new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1") + NeuralNetConfiguration.builder().graphBuilder().addInputs("input1") .addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1") .addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out") .build(); @@ -161,7 +161,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { // Use appendLayer on first layer try { - new NeuralNetConfiguration.Builder().graphBuilder() + NeuralNetConfiguration.builder().graphBuilder() .appendLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build()) .addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build()).setOutputs("out") .build(); @@ -173,7 +173,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { //Test no network inputs try { - new NeuralNetConfiguration.Builder().graphBuilder() + NeuralNetConfiguration.builder().graphBuilder() .addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1") .addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1") .setOutputs("out").build(); @@ -185,7 +185,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { //Test no network outputs try { - new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1") + NeuralNetConfiguration.builder().graphBuilder().addInputs("input1") .addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1") .addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "dense1").build(); fail("No exception thrown for invalid configuration"); @@ -196,7 +196,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { //Test: invalid input try { - new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1") + NeuralNetConfiguration.builder().graphBuilder().addInputs("input1") .addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1") .addLayer("out", new OutputLayer.Builder().nIn(2).nOut(2).build(), "thisDoesntExist") .setOutputs("out").build(); @@ -208,7 +208,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { //Test: graph with cycles try { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("input1") .addLayer("dense1", new DenseLayer.Builder().nIn(2).nOut(2).build(), "input1", "dense3") .addLayer("dense2", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense1") .addLayer("dense3", new DenseLayer.Builder().nIn(2).nOut(2).build(), "dense2") @@ -226,7 +226,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { //Test: input != inputType count mismatch try { - new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input1", "input2") + NeuralNetConfiguration.builder().graphBuilder().addInputs("input1", "input2") .setInputTypes(new InputType.InputTypeRecurrent(10, 12)) .addLayer("cnn1", new ConvolutionLayer.Builder(2, 2).stride(2, 2).nIn(1).nOut(5) @@ -259,7 +259,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { // using runtime/reflection subtype mechanism in ComputationGraphConfiguration.fromJson() //Check a standard GraphVertex implementation, plus a static inner graph vertex - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addVertex("test", new TestGraphVertex(3, 7), "in") .addVertex("test2", new StaticInnerGraphVertex(4, 5), "in").setOutputs("test", "test2").build(); @@ -282,7 +282,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { @Test public void testOutputOrderDoesntChangeWhenCloning() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("out1", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in") .addLayer("out2", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in") .addLayer("out3", new OutputLayer.Builder().nIn(1).nOut(1).build(), "in") @@ -299,7 +299,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { @Test public void testAllowDisconnectedLayers() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("bidirectional", new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()), "in") @@ -321,7 +321,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { @Test public void testBidirectionalGraphSummary() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("bidirectional", new Bidirectional(new LSTM.Builder().activation(Activation.TANH).nOut(10).build()), "in") @@ -408,7 +408,7 @@ public class ComputationGraphConfigurationTest extends BaseDL4JTest { if(nOut[i] == 1 && lossLayer) continue; //nOuts are not availabel in loss layer, can't expect it to detect this case try { - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/JsonTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/JsonTest.java index 190a89746..3e43bfdbe 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/JsonTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/JsonTest.java @@ -98,7 +98,7 @@ public class JsonTest extends BaseDL4JTest { for (int i = 0; i < lossFunctions.length; i++) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.ADAM).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(Updater.ADAM.getIUpdaterWithDefaultConfig()) .layer(0, new DenseLayer.Builder().nIn(4).nOut(nOut[i]).activation(Activation.TANH).build()) .layer(1, new LossLayer.Builder().lossFunction(lossFunctions[i]) .activation(outputActivationFn[i]).build()) @@ -107,8 +107,8 @@ public class JsonTest extends BaseDL4JTest { String json = conf.toJson(); String yaml = conf.toYaml(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); - MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(yaml); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json); + NeuralNetConfiguration fromYaml = NeuralNetConfiguration.fromYaml(yaml); assertEquals(conf, fromJson); assertEquals(conf, fromYaml); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java index a10a9a3c7..700b70a6b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java @@ -20,14 +20,35 @@ package org.deeplearning4j.nn.conf; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.util.Arrays; +import java.util.Properties; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; +import org.deeplearning4j.nn.conf.layers.LossLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.PoolingType; +import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; +import org.deeplearning4j.nn.conf.layers.Upsampling2D; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.weightnoise.DropConnect; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -41,349 +62,349 @@ import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.io.*; -import java.util.Arrays; -import java.util.Properties; - -import static org.junit.jupiter.api.Assertions.*; - @Slf4j public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { - @TempDir - public File testDir; + @TempDir + public File testDir; - @Test - public void testJson() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()) - .inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build(); + private static NeuralNetConfiguration getConf() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345L) + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2) + .dist(new NormalDistribution(0, 1)).build()) + .layer(1, new OutputLayer.Builder().nIn(2).nOut(1) + .activation(Activation.TANH) + .dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE) + .build()) + .build(); + return conf; + } - String json = conf.toJson(); - MultiLayerConfiguration from = MultiLayerConfiguration.fromJson(json); - assertEquals(conf.getConf(0), from.getConf(0)); + @Test + public void testJson() throws Exception { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()) + .inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build(); - Properties props = new Properties(); - props.put("json", json); - String key = props.getProperty("json"); - assertEquals(json, key); - File f = new File(testDir, "props"); - f.deleteOnExit(); - BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); - props.store(bos, ""); - bos.flush(); - bos.close(); - BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f)); - Properties props2 = new Properties(); - props2.load(bis); - bis.close(); - assertEquals(props2.getProperty("json"), props.getProperty("json")); - String json2 = props2.getProperty("json"); - MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromJson(json2); - assertEquals(conf.getConf(0), conf3.getConf(0)); + String json = conf.toJson(); + NeuralNetConfiguration from = NeuralNetConfiguration.fromJson(json); + assertEquals(conf.getConf(0), from.getConf(0)); + Properties props = new Properties(); + props.put("json", json); + String key = props.getProperty("json"); + assertEquals(json, key); + File f = new File(testDir, "props"); + f.deleteOnExit(); + BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); + props.store(bos, ""); + bos.flush(); + bos.close(); + BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f)); + Properties props2 = new Properties(); + props2.load(bis); + bis.close(); + assertEquals(props2.getProperty("json"), props.getProperty("json")); + String json2 = props2.getProperty("json"); + NeuralNetConfiguration conf3 = NeuralNetConfiguration.fromJson(json2); + assertEquals(conf.getConf(0), conf3.getConf(0)); + + } + + @Test + public void testConvnetJson() { + final int numRows = 76; + final int numColumns = 76; + int nChannels = 3; + int outputNum = 6; + int seed = 123; + + //setup the network + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) + .l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true) + .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) + .layer(0, + new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER) + .activation(Activation.RELU).build()) + .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[]{2, 2}) + .build()) + .layer(2, + new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER) + .activation(Activation.RELU).build()) + .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[]{2, 2}) + .build()) + .layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) + .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) + .build()) + + .inputType(InputType.convolutional(numRows, numColumns, nChannels)); + + NeuralNetConfiguration conf = builder.build(); + String json = conf.toJson(); + NeuralNetConfiguration conf2 = NeuralNetConfiguration.fromJson(json); + assertEquals(conf, conf2); + } + + @Test + public void testUpsamplingConvnetJson() { + final int numRows = 76; + final int numColumns = 76; + int nChannels = 3; + int outputNum = 6; + int seed = 123; + + //setup the network + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) + .l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true) + .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) + .layer(new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER) + .activation(Activation.RELU).build()) + .layer(new Upsampling2D.Builder().size(2).build()) + .layer(2, + new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER) + .activation(Activation.RELU).build()) + .layer(new Upsampling2D.Builder().size(2).build()) + .layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) + .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) + .build()) + + .inputType(InputType.convolutional(numRows, numColumns, nChannels)); + + NeuralNetConfiguration conf = builder.build(); + String json = conf.toJson(); + NeuralNetConfiguration conf2 = NeuralNetConfiguration.fromJson(json); + assertEquals(conf, conf2); + } + + @Test + public void testGlobalPoolingJson() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new NoOp()) + .dist(new NormalDistribution(0, 1.0)).seed(12345L) + .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(5).build()) + .layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.PNORM).pnorm(3).build()) + .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nOut(3).build()) + .inputType(InputType.convolutional(32, 32, 1)).build(); + + String str = conf.toJson(); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(str); + + assertEquals(conf, fromJson); + } + + @Test + public void testYaml() throws Exception { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()) + .inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build(); + String json = conf.toYaml(); + NeuralNetConfiguration from = NeuralNetConfiguration.fromYaml(json); + assertEquals(conf.getConf(0), from.getConf(0)); + + Properties props = new Properties(); + props.put("json", json); + String key = props.getProperty("json"); + assertEquals(json, key); + File f = new File(testDir, "props"); + f.deleteOnExit(); + BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); + props.store(bos, ""); + bos.flush(); + bos.close(); + BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f)); + Properties props2 = new Properties(); + props2.load(bis); + bis.close(); + assertEquals(props2.getProperty("json"), props.getProperty("json")); + String yaml = props2.getProperty("json"); + NeuralNetConfiguration conf3 = NeuralNetConfiguration.fromYaml(yaml); + assertEquals(conf.getConf(0), conf3.getConf(0)); + + } + + @Test + public void testClone() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(0, new DenseLayer.Builder().build()) + .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).build()) + .inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build(); + + NeuralNetConfiguration conf2 = conf.clone(); + + assertEquals(conf, conf2); + assertNotSame(conf, conf2); + assertNotSame(conf.getNetConfigurations(), conf2.getNetConfigurations()); + for (int i = 0; i < conf.getNetConfigurations().size(); i++) { + assertNotSame(conf.getConf(i), conf2.getConf(i)); + } + assertNotSame(conf.getInputPreProcessors(), conf2.getInputPreProcessors()); + for (Integer layer : conf.getInputPreProcessors().keySet()) { + assertNotSame(conf.getInputPreProcess(layer), conf2.getInputPreProcess(layer)); + } + } + + @Test + public void testRandomWeightInit() { + MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); + model1.init(); + + Nd4j.getRandom().setSeed(12345L); + MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); + model2.init(); + + float[] p1 = model1.params().data().asFloat(); + float[] p2 = model2.params().data().asFloat(); + System.out.println(Arrays.toString(p1)); + System.out.println(Arrays.toString(p2)); + + org.junit.jupiter.api.Assertions.assertArrayEquals(p1, p2, 0.0f); + } + + @Test + public void testTrainingListener() { + MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); + model1.init(); + model1.addListeners(new ScoreIterationListener(1)); + + MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); + model2.addListeners(new ScoreIterationListener(1)); + model2.init(); + + Layer[] l1 = model1.getLayers(); + for (int i = 0; i < l1.length; i++) { + assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1); + } + + Layer[] l2 = model2.getLayers(); + for (int i = 0; i < l2.length; i++) { + assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1); + } + } + + @Test + public void testInvalidConfig() { + + try { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + fail("No exception thrown for invalid configuration"); + } catch (IllegalStateException e) { + //OK + log.error("", e); + } catch (Throwable e) { + log.error("", e); + fail("Unexpected exception thrown for invalid config"); } - @Test - public void testConvnetJson() { - final int numRows = 76; - final int numColumns = 76; - int nChannels = 3; - int outputNum = 6; - int seed = 123; - - //setup the network - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .l1(1e-1).l2(2e-4).weightNoise(new DropConnect(0.5)).miniBatch(true) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) - .build()) - .layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - - .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); - - MultiLayerConfiguration conf = builder.build(); - String json = conf.toJson(); - MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json); - assertEquals(conf, conf2); + try { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(4).build()) + .layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + fail("No exception thrown for invalid configuration"); + } catch (IllegalStateException e) { + //OK + log.info(e.toString()); + } catch (Throwable e) { + log.error("", e); + fail("Unexpected exception thrown for invalid config"); } - @Test - public void testUpsamplingConvnetJson() { - final int numRows = 76; - final int numColumns = 76; - int nChannels = 3; - int outputNum = 6; - int seed = 123; - - //setup the network - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) - .l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(new Upsampling2D.Builder().size(2).build()) - .layer(2, new ConvolutionLayer.Builder(3, 3).nOut(10).dropOut(0.5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(new Upsampling2D.Builder().size(2).build()) - .layer(4, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) - .build()) - - .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); - - MultiLayerConfiguration conf = builder.build(); - String json = conf.toJson(); - MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json); - assertEquals(conf, conf2); + try { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) + .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) + .layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + fail("No exception thrown for invalid configuration"); + } catch (IllegalStateException e) { + //OK + log.info(e.toString()); + } catch (Throwable e) { + log.error("", e); + fail("Unexpected exception thrown for invalid config"); } + } - @Test - public void testGlobalPoolingJson() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) - .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nOut(5).build()) - .layer(1, new GlobalPoolingLayer.Builder().poolingType(PoolingType.PNORM).pnorm(3).build()) - .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(3).build()) - .setInputType(InputType.convolutional(32, 32, 1)).build(); + @Test + public void testListOverloads() { - String str = conf.toJson(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(str); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) + .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) + .layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); - assertEquals(conf, fromJson); - } + DenseLayer dl = (DenseLayer) conf.getConf(0).getLayer(); + assertEquals(3, dl.getNIn()); + assertEquals(4, dl.getNOut()); + OutputLayer ol = (OutputLayer) conf.getConf(1).getLayer(); + assertEquals(4, ol.getNIn()); + assertEquals(5, ol.getNOut()); + + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345) + .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) + .layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) + .build(); + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net2.init(); + + NeuralNetConfiguration conf3 = NeuralNetConfiguration.builder().seed(12345) + .layer(new DenseLayer.Builder().nIn(3).nOut(4).build()) + .layer( + new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) + .build(); + MultiLayerNetwork net3 = new MultiLayerNetwork(conf3); + net3.init(); + + assertEquals(conf, conf2); + assertEquals(conf, conf3); + } - @Test - public void testYaml() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().dist(new NormalDistribution(1, 1e-1)).build()) - .inputPreProcessor(0, new CnnToFeedForwardPreProcessor()).build(); - String json = conf.toYaml(); - MultiLayerConfiguration from = MultiLayerConfiguration.fromYaml(json); - assertEquals(conf.getConf(0), from.getConf(0)); + @Test + public void testBiasLr() { + //setup the network + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) + .updater(new Adam(1e-2)) + .biasUpdater(new Adam(0.5)) + .layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER) + .activation(Activation.RELU).build()) + .layer(1, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) + .layer(2, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) + .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10) + .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) + .inputType(InputType.convolutional(28, 28, 1)).build(); - Properties props = new Properties(); - props.put("json", json); - String key = props.getProperty("json"); - assertEquals(json, key); - File f = new File(testDir, "props"); - f.deleteOnExit(); - BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(f)); - props.store(bos, ""); - bos.flush(); - bos.close(); - BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f)); - Properties props2 = new Properties(); - props2.load(bis); - bis.close(); - assertEquals(props2.getProperty("json"), props.getProperty("json")); - String yaml = props2.getProperty("json"); - MultiLayerConfiguration conf3 = MultiLayerConfiguration.fromYaml(yaml); - assertEquals(conf.getConf(0), conf3.getConf(0)); + org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer(); + org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer(); + org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer(); + org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer(); - } + assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6); + assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6); - @Test - public void testClone() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, new DenseLayer.Builder().build()) - .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).build()) - .inputPreProcessor(1, new CnnToFeedForwardPreProcessor()).build(); + assertEquals(0.5, ((Adam) l1.getUpdaterByParam("b")).getLearningRate(), 1e-6); + assertEquals(1e-2, ((Adam) l1.getUpdaterByParam("W")).getLearningRate(), 1e-6); - MultiLayerConfiguration conf2 = conf.clone(); + assertEquals(0.5, ((Adam) l2.getUpdaterByParam("b")).getLearningRate(), 1e-6); + assertEquals(1e-2, ((Adam) l2.getUpdaterByParam("W")).getLearningRate(), 1e-6); - assertEquals(conf, conf2); - assertNotSame(conf, conf2); - assertNotSame(conf.getConfs(), conf2.getConfs()); - for (int i = 0; i < conf.getConfs().size(); i++) { - assertNotSame(conf.getConf(i), conf2.getConf(i)); - } - assertNotSame(conf.getInputPreProcessors(), conf2.getInputPreProcessors()); - for (Integer layer : conf.getInputPreProcessors().keySet()) { - assertNotSame(conf.getInputPreProcess(layer), conf2.getInputPreProcess(layer)); - } - } - - @Test - public void testRandomWeightInit() { - MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); - model1.init(); - - Nd4j.getRandom().setSeed(12345L); - MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); - model2.init(); - - float[] p1 = model1.params().data().asFloat(); - float[] p2 = model2.params().data().asFloat(); - System.out.println(Arrays.toString(p1)); - System.out.println(Arrays.toString(p2)); - - org.junit.jupiter.api.Assertions.assertArrayEquals(p1, p2, 0.0f); - } - - @Test - public void testTrainingListener() { - MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); - model1.init(); - model1.addListeners( new ScoreIterationListener(1)); - - MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); - model2.addListeners( new ScoreIterationListener(1)); - model2.init(); - - Layer[] l1 = model1.getLayers(); - for (int i = 0; i < l1.length; i++) - assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1); - - Layer[] l2 = model2.getLayers(); - for (int i = 0; i < l2.length; i++) - assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1); - } + assertEquals(0.5, ((Adam) l3.getUpdaterByParam("b")).getLearningRate(), 1e-6); + assertEquals(1e-2, ((Adam) l3.getUpdaterByParam("W")).getLearningRate(), 1e-6); + } - private static MultiLayerConfiguration getConf() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(2) - .dist(new NormalDistribution(0, 1)).build()) - .layer(1, new OutputLayer.Builder().nIn(2).nOut(1) - .activation(Activation.TANH) - .dist(new NormalDistribution(0, 1)).lossFunction(LossFunctions.LossFunction.MSE).build()) - .build(); - return conf; - } - - @Test - public void testInvalidConfig() { - - try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - fail("No exception thrown for invalid configuration"); - } catch (IllegalStateException e) { - //OK - log.error("",e); - } catch (Throwable e) { - log.error("",e); - fail("Unexpected exception thrown for invalid config"); - } - - try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(1, new DenseLayer.Builder().nIn(3).nOut(4).build()) - .layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - fail("No exception thrown for invalid configuration"); - } catch (IllegalStateException e) { - //OK - log.info(e.toString()); - } catch (Throwable e) { - log.error("",e); - fail("Unexpected exception thrown for invalid config"); - } - - try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) - .layer(2, new OutputLayer.Builder().nIn(4).nOut(5).build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - fail("No exception thrown for invalid configuration"); - } catch (IllegalStateException e) { - //OK - log.info(e.toString()); - } catch (Throwable e) { - log.error("",e); - fail("Unexpected exception thrown for invalid config"); - } - } - - @Test - public void testListOverloads() { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) - .layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - DenseLayer dl = (DenseLayer) conf.getConf(0).getLayer(); - assertEquals(3, dl.getNIn()); - assertEquals(4, dl.getNOut()); - OutputLayer ol = (OutputLayer) conf.getConf(1).getLayer(); - assertEquals(4, ol.getNIn()); - assertEquals(5, ol.getNOut()); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list() - .layer(0, new DenseLayer.Builder().nIn(3).nOut(4).build()) - .layer(1, new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) - .build(); - MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); - net2.init(); - - MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().seed(12345) - .list(new DenseLayer.Builder().nIn(3).nOut(4).build(), - new OutputLayer.Builder().nIn(4).nOut(5).activation(Activation.SOFTMAX).build()) - .build(); - MultiLayerNetwork net3 = new MultiLayerNetwork(conf3); - net3.init(); - - - assertEquals(conf, conf2); - assertEquals(conf, conf3); - } - - - @Test - public void testBiasLr() { - //setup the network - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(1e-2)) - .biasUpdater(new Adam(0.5)).list() - .layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).weightInit(WeightInit.XAVIER) - .activation(Activation.RELU).build()) - .layer(1, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(2, new DenseLayer.Builder().nOut(100).activation(Activation.RELU).build()) - .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)).build(); - - org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer(); - org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer(); - org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer(); - org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer(); - - assertEquals(0.5, ((Adam)l0.getUpdaterByParam("b")).getLearningRate(), 1e-6); - assertEquals(1e-2, ((Adam)l0.getUpdaterByParam("W")).getLearningRate(), 1e-6); - - assertEquals(0.5, ((Adam)l1.getUpdaterByParam("b")).getLearningRate(), 1e-6); - assertEquals(1e-2, ((Adam)l1.getUpdaterByParam("W")).getLearningRate(), 1e-6); - - assertEquals(0.5, ((Adam)l2.getUpdaterByParam("b")).getLearningRate(), 1e-6); - assertEquals(1e-2, ((Adam)l2.getUpdaterByParam("W")).getLearningRate(), 1e-6); - - assertEquals(0.5, ((Adam)l3.getUpdaterByParam("b")).getLearningRate(), 1e-6); - assertEquals(1e-2, ((Adam)l3.getUpdaterByParam("W")).getLearningRate(), 1e-6); - } - - - @Test - public void testInvalidOutputLayer(){ + @Test + public void testInvalidOutputLayer() { /* Test case (invalid configs) 1. nOut=1 + softmax @@ -393,37 +414,44 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { 5. mcxent + sigmoid */ - LossFunctions.LossFunction[] lf = new LossFunctions.LossFunction[]{ - LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.MCXENT, LossFunctions.LossFunction.XENT, - LossFunctions.LossFunction.XENT, LossFunctions.LossFunction.MCXENT}; - int[] nOut = new int[]{1, 3, 3, 3, 3}; - Activation[] activations = new Activation[]{Activation.SOFTMAX, Activation.TANH, Activation.SOFTMAX, Activation.RELU, Activation.SIGMOID}; - for( int i=0; i r = net.getLayer(0).conf().getLayer().getRegularizationByParam("b"); + assertEquals(l1, TestUtils.getL1(net.getLayer(0).getLayerConfiguration().getRegularizationByParam("W")), 1e-4); + List r = net.getLayer(0).getLayerConfiguration().getRegularizationByParam("b"); assertEquals(0, r.size()); - r = net.getLayer(1).conf().getLayer().getRegularizationByParam("beta"); + r = net.getLayer(1).getLayerConfiguration().getRegularizationByParam("beta"); assertTrue(r == null || r.isEmpty()); - r = net.getLayer(1).conf().getLayer().getRegularizationByParam("gamma"); + r = net.getLayer(1).getLayerConfiguration().getRegularizationByParam("gamma"); assertTrue(r == null || r.isEmpty()); - r = net.getLayer(1).conf().getLayer().getRegularizationByParam("mean"); + r = net.getLayer(1).getLayerConfiguration().getRegularizationByParam("mean"); assertTrue(r == null || r.isEmpty()); - r = net.getLayer(1).conf().getLayer().getRegularizationByParam("var"); + r = net.getLayer(1).getLayerConfiguration().getRegularizationByParam("var"); assertTrue(r == null || r.isEmpty()); - assertEquals(l2, TestUtils.getL2(net.getLayer(2).conf().getLayer().getRegularizationByParam("W")), 1e-4); - r = net.getLayer(2).conf().getLayer().getRegularizationByParam("b"); + assertEquals(l2, TestUtils.getL2(net.getLayer(2).getLayerConfiguration().getRegularizationByParam("W")), 1e-4); + r = net.getLayer(2).getLayerConfiguration().getRegularizationByParam("b"); assertTrue(r == null || r.isEmpty()); } @@ -322,7 +322,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { .nIn(10).nOut(5).updater(new Sgd(1e-1)) .lossFunction(LossFunctions.LossFunction.KL_DIVERGENCE).build(); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(42).layer(layer).build(); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(42).layer(layer).build(); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java index 37260087d..afbb64726 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java @@ -26,7 +26,6 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.constraint.MaxNormConstraint; import org.deeplearning4j.nn.conf.constraint.MinMaxNormConstraint; @@ -68,10 +67,10 @@ public class TestConstraints extends BaseDL4JTest { for (LayerConstraint lc : constraints) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Sgd(0.0)) .dist(new NormalDistribution(0, 5)) - .list() + .layer(new LSTM.Builder().nIn(12).nOut(10) .constrainRecurrent(lc).build()) .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build()) @@ -81,7 +80,7 @@ public class TestConstraints extends BaseDL4JTest { net.init(); LayerConstraint exp = lc.clone(); - assertEquals(exp.toString(), net.getLayer(0).conf().getLayer().getConstraints().get(0).toString()); + assertEquals(exp.toString(), net.getLayer(0).getLayerConfiguration().getConstraints().get(0).toString()); INDArray input = Nd4j.rand(3, 12); INDArray labels = Nd4j.rand(3, 8); @@ -120,11 +119,11 @@ public class TestConstraints extends BaseDL4JTest { for (LayerConstraint lc : constraints) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Sgd(0.0)) .dist(new NormalDistribution(0, 5)) .biasInit(10.0) - .list() + .layer(new DenseLayer.Builder().nIn(12).nOut(10) .constrainBias(lc).build()) .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build()) @@ -134,7 +133,7 @@ public class TestConstraints extends BaseDL4JTest { net.init(); LayerConstraint exp = lc.clone(); - assertEquals(exp.toString(), net.getLayer(0).conf().getLayer().getConstraints().get(0).toString()); + assertEquals(exp.toString(), net.getLayer(0).getLayerConfiguration().getConstraints().get(0).toString()); INDArray input = Nd4j.rand(3, 12); INDArray labels = Nd4j.rand(3, 8); @@ -173,10 +172,10 @@ public class TestConstraints extends BaseDL4JTest { for (LayerConstraint lc : constraints) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Sgd(0.0)) .dist(new NormalDistribution(0, 5)) - .list() + .layer(new DenseLayer.Builder().nIn(12).nOut(10) .constrainWeights(lc).build()) .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build()) @@ -186,7 +185,7 @@ public class TestConstraints extends BaseDL4JTest { net.init(); LayerConstraint exp = lc.clone(); - assertEquals(exp.toString(), net.getLayer(0).conf().getLayer().getConstraints().get(0).toString()); + assertEquals(exp.toString(), net.getLayer(0).getLayerConfiguration().getConstraints().get(0).toString()); INDArray input = Nd4j.rand(3, 12); INDArray labels = Nd4j.rand(3, 8); @@ -225,11 +224,11 @@ public class TestConstraints extends BaseDL4JTest { for (LayerConstraint lc : constraints) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Sgd(0.0)) .dist(new NormalDistribution(0, 5)) .biasInit(0.2) - .list() + .layer(new DenseLayer.Builder().nIn(12).nOut(10) .constrainAllParameters(lc).build()) .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build()) @@ -239,7 +238,7 @@ public class TestConstraints extends BaseDL4JTest { net.init(); LayerConstraint exp = lc.clone(); - assertEquals(exp.toString(), net.getLayer(0).conf().getLayer().getConstraints().get(0).toString()); + assertEquals(exp.toString(), net.getLayer(0).getLayerConfiguration().getConstraints().get(0).toString()); INDArray input = Nd4j.rand(3, 12); INDArray labels = Nd4j.rand(3, 8); @@ -286,11 +285,11 @@ public class TestConstraints extends BaseDL4JTest { for (LayerConstraint lc : constraints) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Sgd(0.0)) .dist(new NormalDistribution(0, 5)) .biasInit(0.2) - .list() + .layer(new DenseLayer.Builder().nIn(12).nOut(10) .constrainWeights(lc).constrainBias(lc).build()) .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build()) @@ -300,7 +299,7 @@ public class TestConstraints extends BaseDL4JTest { net.init(); LayerConstraint exp = lc.clone(); - assertEquals(exp.toString(), net.getLayer(0).conf().getLayer().getConstraints().get(0).toString()); + assertEquals(exp.toString(), net.getLayer(0).getLayerConfiguration().getConstraints().get(0).toString()); INDArray input = Nd4j.rand(3, 12); INDArray labels = Nd4j.rand(3, 8); @@ -346,12 +345,12 @@ public class TestConstraints extends BaseDL4JTest { for(LayerConstraint lc : constraints){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .constrainWeights(lc) .updater(new Sgd(0.0)) .dist(new NormalDistribution(0,5)) .biasInit(1) - .list() + .layer(new DenseLayer.Builder().nIn(12).nOut(10).build()) .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(8).build()) .build(); @@ -360,8 +359,8 @@ public class TestConstraints extends BaseDL4JTest { net.init(); LayerConstraint exp = lc.clone(); - assertEquals(exp.toString(), net.getLayer(0).conf().getLayer().getConstraints().get(0).toString()); - assertEquals(exp.toString(), net.getLayer(1).conf().getLayer().getConstraints().get(0).toString()); + assertEquals(exp.toString(), net.getLayer(0).getLayerConfiguration().getConstraints().get(0).toString()); + assertEquals(exp.toString(), net.getLayer(1).getLayerConfiguration().getConstraints().get(0).toString()); INDArray input = Nd4j.rand(3, 12); INDArray labels = Nd4j.rand(3, 8); @@ -400,7 +399,7 @@ public class TestConstraints extends BaseDL4JTest { int nIn = 10; int lstmLayerSize = 32; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.RELU_UNIFORM) .updater(new RmsProp(learningRate)) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java index 5c06f2adc..26c266dc7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java @@ -25,7 +25,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -60,21 +59,21 @@ public class TestDropout extends BaseDL4JTest { @Test public void testBasicConfig(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dropOut(0.6) - .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).dropOut(0.7).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).dropOut(new AlphaDropout(0.5)).build()) .build(); - assertEquals(new Dropout(0.6), conf.getConf(0).getLayer().getIDropout()); - assertEquals(new Dropout(0.7), conf.getConf(1).getLayer().getIDropout()); - assertEquals(new AlphaDropout(0.5), conf.getConf(2).getLayer().getIDropout()); + assertEquals(new Dropout(0.6), conf.getFlattenedLayerConfigurations().get(0).getIDropout()); + assertEquals(new Dropout(0.7), conf.getFlattenedLayerConfigurations().get(1).getIDropout()); + assertEquals(new AlphaDropout(0.5), conf.getFlattenedLayerConfigurations().get(2).getIDropout()); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() - .dropOut(0.6) + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() + .dropOut( new Dropout(0.6)) .graphBuilder() .addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") @@ -83,9 +82,9 @@ public class TestDropout extends BaseDL4JTest { .setOutputs("2") .build(); - assertEquals(new Dropout(0.6), ((LayerVertex)conf2.getVertices().get("0")).getLayerConf().getLayer().getIDropout()); - assertEquals(new Dropout(0.7), ((LayerVertex)conf2.getVertices().get("1")).getLayerConf().getLayer().getIDropout()); - assertEquals(new AlphaDropout(0.5), ((LayerVertex)conf2.getVertices().get("2")).getLayerConf().getLayer().getIDropout()); + assertEquals(new Dropout(0.6), ((LayerVertex)conf2.getVertices().get("0")).getNetConfiguration().getFirstLayer().getIDropout()); + assertEquals(new Dropout(0.7), ((LayerVertex)conf2.getVertices().get("1")).getNetConfiguration().getFirstLayer().getIDropout()); + assertEquals(new AlphaDropout(0.5), ((LayerVertex)conf2.getVertices().get("2")).getNetConfiguration().getFirstLayer().getIDropout()); } @Test @@ -94,8 +93,8 @@ public class TestDropout extends BaseDL4JTest { CustomDropout d1 = new CustomDropout(); CustomDropout d2 = new CustomDropout(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new DenseLayer.Builder().nIn(4).nOut(3).dropOut(d1).build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).dropOut(d2).nIn(3).nOut(3).build()) .build(); @@ -129,7 +128,7 @@ public class TestDropout extends BaseDL4JTest { d1 = new CustomDropout(); d2 = new CustomDropout(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).dropOut(d1).build(), "in") @@ -186,9 +185,9 @@ public class TestDropout extends BaseDL4JTest { for(IDropout id : dropouts) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dropOut(id) - .list() + .layer(new DenseLayer.Builder().nIn(4).nOut(3).build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(3).nOut(3).build()) .build(); @@ -197,7 +196,7 @@ public class TestDropout extends BaseDL4JTest { TestUtils.testModelSerialization(net); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .dropOut(id) .graphBuilder() .addInputs("in") @@ -601,13 +600,13 @@ public class TestDropout extends BaseDL4JTest { @Test public void testSpatialDropoutJSON(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new DropoutLayer.Builder(new SpatialDropout(0.5)).build()) .build(); String asJson = conf.toJson(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(asJson); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(asJson); assertEquals(conf, fromJson); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java index 046cf0f63..02babc8bc 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java @@ -70,7 +70,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { public void testElementWiseVertexForwardAdd() { int batchsz = 24; int featuresz = 17; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder().graphBuilder() .addInputs("input1", "input2", "input3") .addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY) @@ -111,7 +111,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { public void testElementWiseVertexForwardProduct() { int batchsz = 24; int featuresz = 17; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder().graphBuilder() .addInputs("input1", "input2", "input3") .addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY) @@ -152,7 +152,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { public void testElementWiseVertexForwardSubtract() { int batchsz = 24; int featuresz = 17; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder() + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder().graphBuilder() .addInputs("input1", "input2") .addLayer("denselayer", new DenseLayer.Builder().nIn(featuresz).nOut(1).activation(Activation.IDENTITY) @@ -194,7 +194,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { int featuresz = 17; int midsz = 13; int outputsz = 11; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .dataType(DataType.DOUBLE) .biasInit(0.0).updater(new Sgd()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() @@ -370,7 +370,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { int featuresz = 17; int midsz = 13; int outputsz = 11; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .dataType(DataType.DOUBLE) .biasInit(0.0).updater(new Sgd()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() @@ -545,7 +545,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { int featuresz = 17; int midsz = 13; int outputsz = 11; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .dataType(DataType.DOUBLE) .biasInit(0.0).updater(new Sgd()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java index acab33814..cf0e743e6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java @@ -84,7 +84,7 @@ public class ShiftVertexTest extends BaseDL4JTest { INDArray input = Nd4j .create(new double[][] {{0.2, 0.3, 0.5}, {0.7, 1.1, 1.3}, {1.7, 1.9, 2.3}, {2.9, 3.1, 3.7}}); double sf = 4.1; - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("input") + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder().graphBuilder().addInputs("input") .addLayer("denselayer", new DenseLayer.Builder().nIn(input.columns()).nOut(1) .activation(Activation.IDENTITY).build(), @@ -138,7 +138,7 @@ public class ShiftVertexTest extends BaseDL4JTest { INDArray target = Nd4j.create(new double[][] {{0.05, 0.10, 0.15, 0.20, 0.25}, {0.30, 0.35, 0.40, 0.45, 0.50}, {0.55, 0.60, 0.65, 0.70, 0.75}, {0.80, 0.85, 0.90, 0.95, 0.99}}); - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration cgc = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .dataType(DataType.DOUBLE) .updater(new Sgd(0.01)) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java index 484da1ff9..e4e7ce73c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java @@ -199,8 +199,8 @@ public class LayerBuilderTest extends BaseDL4JTest { assertEquals(act, activationLayer.activationFn); } - private void checkSerialization(Layer layer) throws Exception { - NeuralNetConfiguration confExpected = new NeuralNetConfiguration.Builder().layer(layer).build(); + private void checkSerialization(LayerConfiguration layer) throws Exception { + NeuralNetConfiguration confExpected = NeuralNetConfiguration.builder().layer(layer).build(); NeuralNetConfiguration confActual; // check Java serialization @@ -212,21 +212,21 @@ public class LayerBuilderTest extends BaseDL4JTest { try (ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInput in = new ObjectInputStream(bis)) { confActual = (NeuralNetConfiguration) in.readObject(); } - assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal Java serialization"); + assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal Java serialization"); // check JSON String json = confExpected.toJson(); confActual = NeuralNetConfiguration.fromJson(json); - assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal JSON serialization"); + assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal JSON serialization"); // check YAML String yaml = confExpected.toYaml(); confActual = NeuralNetConfiguration.fromYaml(yaml); - assertEquals(confExpected.getLayer(), confActual.getLayer(), "unequal YAML serialization"); + assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal YAML serialization"); // check the layer's use of callSuper on equals method - confActual.getLayer().setIDropout(new Dropout(new java.util.Random().nextDouble())); - assertNotEquals( confExpected.getLayer(), confActual.getLayer(), "broken equals method (missing callSuper?)"); + confActual.getFirstLayer().setIDropout(new Dropout(new java.util.Random().nextDouble())); + assertNotEquals( confExpected.getFirstLayer(), confActual.getFirstLayer(), "broken equals method (missing callSuper?)"); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java index be25a0ccd..db3731f6d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.conf.layers; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; @@ -53,7 +52,7 @@ public class LayerConfigTest extends BaseDL4JTest { String name1 = "genisys"; String name2 = "bill"; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).name(name1).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).name(name2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -67,7 +66,7 @@ public class LayerConfigTest extends BaseDL4JTest { @Test public void testActivationLayerwiseOverride() { //Without layerwise override: - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.RELU) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -77,7 +76,7 @@ public class LayerConfigTest extends BaseDL4JTest { assertEquals("relu", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString()); //With - conf = new NeuralNetConfiguration.Builder().activation(Activation.RELU).list() + conf = NeuralNetConfiguration.builder().activation(Activation.RELU) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).activation(Activation.TANH).build()).build(); @@ -93,8 +92,8 @@ public class LayerConfigTest extends BaseDL4JTest { public void testWeightBiasInitLayerwiseOverride() { //Without layerwise override: final Distribution defaultDistribution = new NormalDistribution(0, 1.0); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dist(defaultDistribution).biasInit(1).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .dist(defaultDistribution).biasInit(1) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -108,8 +107,8 @@ public class LayerConfigTest extends BaseDL4JTest { //With: final Distribution overriddenDistribution = new UniformDistribution(0, 1); - conf = new NeuralNetConfiguration.Builder() - .dist(defaultDistribution).biasInit(1).list() + conf = NeuralNetConfiguration.builder() + .dist(defaultDistribution).biasInit(1) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder().nIn(2).nOut(2) .dist(overriddenDistribution).biasInit(0).build()) @@ -132,7 +131,7 @@ public class LayerConfigTest extends BaseDL4JTest { // the global config, and check they actually work. //Learning rate without layerwise override: - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(0.3).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().learningRate(0.3) .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -142,7 +141,7 @@ public class LayerConfigTest extends BaseDL4JTest { assertEquals(0.3, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0); //With: - conf = new NeuralNetConfiguration.Builder().learningRate(0.3).list() + conf = NeuralNetConfiguration.builder().learningRate(0.3) .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).learningRate(0.2).build()).build(); @@ -153,7 +152,7 @@ public class LayerConfigTest extends BaseDL4JTest { assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0); //L1 and L2 without layerwise override: - conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.2).list() + conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2) .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); net = new MultiLayerNetwork(conf); @@ -165,7 +164,7 @@ public class LayerConfigTest extends BaseDL4JTest { assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0); //L1 and L2 with layerwise override: - conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.2).list() + conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2) .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).l1(0.9).build()) .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).l2(0.8).build()).build(); net = new MultiLayerNetwork(conf); @@ -181,7 +180,7 @@ public class LayerConfigTest extends BaseDL4JTest { @Test public void testDropoutLayerwiseOverride() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().dropOut(1.0) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -190,7 +189,7 @@ public class LayerConfigTest extends BaseDL4JTest { assertEquals(new Dropout(1.0), conf.getConf(0).getLayer().getIDropout()); assertEquals(new Dropout(1.0), conf.getConf(1).getLayer().getIDropout()); - conf = new NeuralNetConfiguration.Builder().dropOut(1.0).list() + conf = NeuralNetConfiguration.builder().dropOut(1.0) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).dropOut(2.0).build()).build(); @@ -206,9 +205,9 @@ public class LayerConfigTest extends BaseDL4JTest { Map testMomentumAfter = new HashMap<>(); testMomentumAfter.put(0, 0.1); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))) - .list() + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -220,8 +219,8 @@ public class LayerConfigTest extends BaseDL4JTest { Map testMomentumAfter2 = new HashMap<>(); testMomentumAfter2.put(0, 0.2); - conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter) )) - .list() + conf = NeuralNetConfiguration.builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter) )) + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()).layer(1, new DenseLayer.Builder() .nIn(2).nOut(2).updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter2))).build()) .build(); @@ -234,7 +233,7 @@ public class LayerConfigTest extends BaseDL4JTest { @Test public void testUpdaterRhoRmsDecayLayerwiseOverride() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new AdaDelta(0.5, 0.9)).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new AdaDelta(0.5, 0.9)) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.01,0.9)).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -245,7 +244,7 @@ public class LayerConfigTest extends BaseDL4JTest { assertEquals(0.5, ((AdaDelta)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0); assertEquals(0.01, ((AdaDelta)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); - conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).list() + conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new AdaDelta(0.5,AdaDelta.DEFAULT_ADADELTA_EPSILON)).build()) .build(); @@ -262,9 +261,9 @@ public class LayerConfigTest extends BaseDL4JTest { @Test public void testUpdaterAdamParamsLayerwiseOverride() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(1.0, 0.5, 0.5, 1e-8)) - .list() + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Adam(1.0, 0.6, 0.7, 1e-8)).build()) .build(); @@ -281,9 +280,9 @@ public class LayerConfigTest extends BaseDL4JTest { public void testGradientNormalizationLayerwiseOverride() { //Learning rate without layerwise override: - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).list() + .gradientNormalizationThreshold(10) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -297,9 +296,9 @@ public class LayerConfigTest extends BaseDL4JTest { assertEquals(10, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0); //With: - conf = new NeuralNetConfiguration.Builder() + conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).list() + .gradientNormalizationThreshold(10) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2) .gradientNormalization(GradientNormalization.None) @@ -323,7 +322,7 @@ public class LayerConfigTest extends BaseDL4JTest { double lr = 2; double lrDecayRate = 5; int iterations = 1; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().learningRate(lr) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().learningRate(lr) .updater(Updater.SGD) .learningRateDecayPolicy(LearningRatePolicy.Exponential).lrPolicyDecayRate(lrDecayRate).list() .layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) @@ -343,7 +342,7 @@ public class LayerConfigTest extends BaseDL4JTest { double lrDecayRate = 5; double power = 3; int iterations = 1; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().iterations(iterations).learningRate(lr) .learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(lrDecayRate) .lrPolicyPower(power).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); @@ -365,7 +364,7 @@ public class LayerConfigTest extends BaseDL4JTest { double lrDecayRate = 5; double steps = 4; int iterations = 1; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().iterations(iterations).learningRate(lr) .learningRateDecayPolicy(LearningRatePolicy.Step).lrPolicyDecayRate(lrDecayRate) .lrPolicySteps(steps).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); @@ -386,7 +385,7 @@ public class LayerConfigTest extends BaseDL4JTest { double lrDecayRate = 5; double power = 3; int iterations = 1; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().iterations(iterations).learningRate(lr) .learningRateDecayPolicy(LearningRatePolicy.Poly).lrPolicyDecayRate(lrDecayRate) .lrPolicyPower(power).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); @@ -407,7 +406,7 @@ public class LayerConfigTest extends BaseDL4JTest { double lrDecayRate = 5; double steps = 4; int iterations = 1; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().iterations(iterations).learningRate(lr) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().iterations(iterations).learningRate(lr) .learningRateDecayPolicy(LearningRatePolicy.Sigmoid).lrPolicyDecayRate(lrDecayRate) .lrPolicySteps(steps).list().layer(0, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayerConfiguration.Builder().nIn(2).nOut(2).build()).build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java index 4b60f98c4..65532a0bc 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java @@ -24,7 +24,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.Distribution; @@ -56,8 +55,8 @@ public class LayerConfigValidationTest extends BaseDL4JTest { @Test public void testDropConnect() { // Warning thrown only since some layers may not have l1 or l2 - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).weightNoise(new DropConnect(0.5)) - .list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)).weightNoise(new DropConnect(0.5)) + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -67,8 +66,8 @@ public class LayerConfigValidationTest extends BaseDL4JTest { @Test public void testL1L2NotSet() { // Warning thrown only since some layers may not have l1 or l2 - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)) - .list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(0.3)) + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -78,7 +77,7 @@ public class LayerConfigValidationTest extends BaseDL4JTest { //@Ignore //Old assumption: throw exception on l1 but no regularization. Current design: warn, not exception public void testRegNotSetL1Global() { assertThrows(IllegalStateException.class, () -> { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).l1(0.5).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(0.3)).l1(0.5) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -90,7 +89,7 @@ public class LayerConfigValidationTest extends BaseDL4JTest { //@Ignore //Old assumption: throw exception on l1 but no regularization. Current design: warn, not exception public void testRegNotSetL2Local() { assertThrows(IllegalStateException.class, () -> { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(0.3)) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -101,9 +100,9 @@ public class LayerConfigValidationTest extends BaseDL4JTest { @Test public void testWeightInitDistNotSet() { // Warning thrown only since global dist can be set with a different weight init locally - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Sgd(0.3)).dist(new GaussianDistribution(1e-3, 2)) - .list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new Sgd(0.3)).dist(new GaussianDistribution(1e-3, 2)) + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -115,8 +114,8 @@ public class LayerConfigValidationTest extends BaseDL4JTest { Map testMomentumAfter = new HashMap<>(); testMomentumAfter.put(0, 0.1); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new Nesterovs(1.0, new MapSchedule(ScheduleType.ITERATION, testMomentumAfter))) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -125,11 +124,11 @@ public class LayerConfigValidationTest extends BaseDL4JTest { @Test public void testCompGraphNullLayer() { - ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder gb = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.01)) .seed(42).miniBatch(false).l1(0.2).l2(0.2) /* Graph Builder */ - .updater(Updater.RMSPROP).graphBuilder().addInputs("in") + .updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()).graphBuilder().addInputs("in") .addLayer("L" + 1, new GravesLSTM.Builder().nIn(20).updater(Updater.RMSPROP).nOut(10) .weightInit(WeightInit.XAVIER) @@ -157,33 +156,33 @@ public class LayerConfigValidationTest extends BaseDL4JTest { double expectedL2 = 0.0; // Nesterovs Updater - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(0.9)) - .list().layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Nesterovs(0.9)) + .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new Nesterovs(0.3, 0.4)).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - BaseLayer layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); + BaseLayer layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3); assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); - BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); + BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3); // Adam Updater - conf = new NeuralNetConfiguration.Builder().updater(new Adam(0.3)) - .weightInit(new WeightInitDistribution(expectedDist)).list() + conf = NeuralNetConfiguration.builder().updater(new Adam(0.3)) + .weightInit(expectedDist) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).l2(0.5).l1(0.3).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); + layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); - layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); + layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3); assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3); assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn()); @@ -191,18 +190,18 @@ public class LayerConfigValidationTest extends BaseDL4JTest { assertNull(TestUtils.getL2Reg(layerConf1.getRegularization())); //RMSProp Updater - conf = new NeuralNetConfiguration.Builder().updater(new RmsProp(0.3)).list() + conf = NeuralNetConfiguration.builder().updater(new RmsProp(0.3)) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).build()) .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(0.3, 0.4, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()).build(); net = new MultiLayerNetwork(conf); net.init(); - layerConf = (BaseLayer) net.getLayer(0).conf().getLayer(); + layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3); assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf.getRegularization())); - layerConf1 = (BaseLayer) net.getLayer(1).conf().getLayer(); + layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java index 48112c682..d530e416d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CNNProcessorTest.java @@ -23,6 +23,7 @@ package org.deeplearning4j.nn.conf.preprocessor; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -233,64 +234,60 @@ public class CNNProcessorTest extends BaseDL4JTest { @Test public void testInvalidInputShape(){ - NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder() - .seed(123) - .miniBatch(true) - .cacheMode(CacheMode.DEVICE) - .updater(new Nesterovs(0.9)) - .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); int[] kernelArray = new int[]{3,3}; int[] strideArray = new int[]{1,1}; int[] zeroPaddingArray = new int[]{0,0}; int processWidth = 4; - NeuralNetConfiguration.ListBuilder listBuilder = builder.list(); // Building the DL4J network - listBuilder = listBuilder.layer(0, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) + NeuralNetConfiguration conf =NeuralNetConfiguration.builder() + .seed(123) + .miniBatch(true) + .cacheMode(CacheMode.DEVICE) + .updater(new Nesterovs(0.9)) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + // Building the DL4J network + .layer(0, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) .name("cnn1") .convolutionMode(ConvolutionMode.Strict) .nIn(2) // 2 input channels .nOut(processWidth) .weightInit(WeightInit.XAVIER_UNIFORM) .activation(Activation.RELU) - .biasInit(1e-2).build()); + .biasInit(1e-2).build()) - listBuilder = listBuilder.layer(1, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) + .layer(1, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) .name("cnn2") .convolutionMode(ConvolutionMode.Strict) .nOut(processWidth) .weightInit(WeightInit.XAVIER_UNIFORM) .activation(Activation.RELU) .biasInit(1e-2) - .build()); + .build()) - listBuilder = listBuilder.layer(2, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) + .layer(2, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) .name("cnn3") .convolutionMode(ConvolutionMode.Strict) .nOut(processWidth) .weightInit(WeightInit.XAVIER_UNIFORM) - .activation(Activation.RELU).build()); + .activation(Activation.RELU).build()) - listBuilder = listBuilder.layer(3, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) + .layer(3, new ConvolutionLayer.Builder(kernelArray, strideArray, zeroPaddingArray) .name("cnn4") .convolutionMode(ConvolutionMode.Strict) .nOut(processWidth) .weightInit(WeightInit.XAVIER_UNIFORM) - .activation(Activation.RELU).build()); + .activation(Activation.RELU).build()) - listBuilder = listBuilder - .layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) + .layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.MSE) .name("output") .nOut(1) .activation(Activation.TANH) - .build()); + .build()) - MultiLayerConfiguration conf = listBuilder - - - .setInputType(InputType.convolutional(20, 10, 2)) + .inputType(InputType.convolutional(20, 10, 2)) .build(); // For some reason, this model works diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java index 36bfbc95f..c5755753a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/CustomPreprocessorTest.java @@ -21,8 +21,6 @@ package org.deeplearning4j.nn.conf.preprocessor; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -30,11 +28,6 @@ import org.deeplearning4j.nn.conf.preprocessor.custom.MyCustomPreprocessor; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.lossfunctions.LossFunctions; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.introspect.AnnotatedClass; -import com.fasterxml.jackson.databind.jsontype.NamedType; - -import java.util.Collection; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -44,8 +37,8 @@ public class CustomPreprocessorTest extends BaseDL4JTest { @Test public void testCustomPreprocessor() { //Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works... - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10) .activation(Activation.SOFTMAX).nOut(10).build()) @@ -57,10 +50,10 @@ public class CustomPreprocessorTest extends BaseDL4JTest { // System.out.println(json); - MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json); assertEquals(conf, confFromJson); - MultiLayerConfiguration confFromYaml = MultiLayerConfiguration.fromYaml(yaml); + NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml); assertEquals(conf, confFromYaml); assertTrue(confFromJson.getInputPreProcess(0) instanceof MyCustomPreprocessor); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java index 56c6cfb1d..1f279a762 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.conf.preprocessor; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; @@ -58,14 +57,14 @@ public class TestPreProcessors extends BaseDL4JTest { int timeSeriesLength = timeSeriesLengths[x]; RnnToFeedForwardPreProcessor proc = new RnnToFeedForwardPreProcessor(); - NeuralNetConfiguration nnc = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration nnc = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(layerSize) .nOut(layerSize).build()) .build(); - long numParams = nnc.getLayer().initializer().numParams(nnc); + long numParams = nnc.getFirstLayer().initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); - DenseLayer layer = (DenseLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); + DenseLayer layer = (DenseLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray activations3dc = Nd4j.create(new int[] {miniBatchSize, layerSize, timeSeriesLength}, 'c'); @@ -143,14 +142,14 @@ public class TestPreProcessors extends BaseDL4JTest { FeedForwardToRnnPreProcessor proc = new FeedForwardToRnnPreProcessor(); - NeuralNetConfiguration nnc = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration nnc = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(layerSize) .nOut(layerSize).build()) .build(); - val numParams = nnc.getLayer().initializer().numParams(nnc); + val numParams = nnc.getFirstLayer().initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); - DenseLayer layer = (DenseLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); + DenseLayer layer = (DenseLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray rand = Nd4j.rand(miniBatchSize * timeSeriesLength, layerSize); @@ -227,16 +226,16 @@ public class TestPreProcessors extends BaseDL4JTest { InputPreProcessor proc = new CnnToRnnPreProcessor(inputHeight, inputWidth, nChannels); NeuralNetConfiguration nnc = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( inputWidth, inputHeight).nIn(cnnNChannelsIn) .nOut(nChannels).build()) .build(); - val numParams = nnc.getLayer().initializer().numParams(nnc); + val numParams = nnc.getFirstLayer().initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); ConvolutionLayer layer = - (ConvolutionLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); + (ConvolutionLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray activationsCnn = Nd4j.rand(miniBatchSize * timeSeriesLength, nChannels, @@ -309,16 +308,16 @@ public class TestPreProcessors extends BaseDL4JTest { InputPreProcessor proc = new RnnToCnnPreProcessor(inputHeight, inputWidth, nChannels); NeuralNetConfiguration nnc = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( inputWidth, inputHeight).nIn(cnnNChannelsIn) .nOut(nChannels).build()) .build(); - val numParams = nnc.getLayer().initializer().numParams(nnc); + val numParams = nnc.getFirstLayer().initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); ConvolutionLayer layer = - (ConvolutionLayer) nnc.getLayer().instantiate(nnc, null, 0, params, true, params.dataType()); + (ConvolutionLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); val shape_rnn = new long[] {miniBatchSize, nChannels * inputHeight * inputWidth, @@ -396,8 +395,8 @@ public class TestPreProcessors extends BaseDL4JTest { @Test public void testAutoAdditionOfPreprocessors() { //FF->RNN and RNN->FF - MultiLayerConfiguration conf1 = - new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf1 = + NeuralNetConfiguration.builder() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(5) .nOut(6).build()) .layer(1, new GravesLSTM.Builder().nIn(6).nOut(7).build()) @@ -412,12 +411,12 @@ public class TestPreProcessors extends BaseDL4JTest { //FF-> CNN, CNN-> FF, FF->RNN - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder().nOut(10) .kernelSize(5, 5).stride(1, 1).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nOut(6).build()) .layer(2, new RnnOutputLayer.Builder().nIn(6).nOut(5).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); //Expect preprocessors: 0: FF->CNN; 1: CNN->FF; 2: FF->RNN assertEquals(3, conf2.getInputPreProcessors().size()); assertTrue(conf2.getInputPreProcess(0) instanceof FeedForwardToCnnPreProcessor); @@ -425,12 +424,12 @@ public class TestPreProcessors extends BaseDL4JTest { assertTrue(conf2.getInputPreProcess(2) instanceof FeedForwardToRnnPreProcessor); //CNN-> FF, FF->RNN - InputType.convolutional instead of convolutionalFlat - MultiLayerConfiguration conf2a = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf2a = NeuralNetConfiguration.builder() .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder().nOut(10) .kernelSize(5, 5).stride(1, 1).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nOut(6).build()) .layer(2, new RnnOutputLayer.Builder().nIn(6).nOut(5).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)).build(); + .inputType(InputType.convolutional(28, 28, 1)).build(); //Expect preprocessors: 1: CNN->FF; 2: FF->RNN assertEquals(2, conf2a.getInputPreProcessors().size()); assertTrue(conf2a.getInputPreProcess(1) instanceof CnnToFeedForwardPreProcessor); @@ -438,12 +437,12 @@ public class TestPreProcessors extends BaseDL4JTest { //FF->CNN and CNN->RNN: - MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf3 = NeuralNetConfiguration.builder().list() .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder().nOut(10) .kernelSize(5, 5).stride(1, 1).build()) .layer(1, new GravesLSTM.Builder().nOut(6).build()) .layer(2, new RnnOutputLayer.Builder().nIn(6).nOut(5).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); //Expect preprocessors: 0: FF->CNN, 1: CNN->RNN; assertEquals(2, conf3.getInputPreProcessors().size()); assertTrue(conf3.getInputPreProcess(0) instanceof FeedForwardToCnnPreProcessor); @@ -452,8 +451,8 @@ public class TestPreProcessors extends BaseDL4JTest { @Test public void testCnnToDense() { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( 4, 4) // 28*28*1 => 15*15*10 @@ -467,7 +466,7 @@ public class TestPreProcessors extends BaseDL4JTest { .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(200) .nOut(5).weightInit(WeightInit.RELU) .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .inputType(InputType.convolutionalFlat(28, 28, 1)) .build(); assertNotNull(conf.getInputPreProcess(0)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java index d4bae91a6..4d4b36013 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java @@ -27,7 +27,6 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.BaseLayer; @@ -65,9 +64,9 @@ public class TestWeightNoise extends BaseDL4JTest { }; for (IWeightNoise wn : weightNoises) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .weightNoise(wn) - .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).weightNoise(new DropConnect(0.25)).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()) @@ -76,14 +75,14 @@ public class TestWeightNoise extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(wn, ((BaseLayer) net.getLayer(0).conf().getLayer()).getWeightNoise()); - assertEquals(new DropConnect(0.25), ((BaseLayer) net.getLayer(1).conf().getLayer()).getWeightNoise()); - assertEquals(wn, ((BaseLayer) net.getLayer(2).conf().getLayer()).getWeightNoise()); + assertEquals(wn, ((BaseLayer) net.getLayer(0).getLayerConfiguration()).getWeightNoise()); + assertEquals(new DropConnect(0.25), ((BaseLayer) net.getLayer(1).getLayerConfiguration()).getWeightNoise()); + assertEquals(wn, ((BaseLayer) net.getLayer(2).getLayerConfiguration()).getWeightNoise()); TestUtils.testModelSerialization(net); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .weightNoise(wn) .graphBuilder() .addInputs("in") @@ -96,9 +95,9 @@ public class TestWeightNoise extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf2); graph.init(); - assertEquals(wn, ((BaseLayer) graph.getLayer(0).conf().getLayer()).getWeightNoise()); - assertEquals(new DropConnect(0.25), ((BaseLayer) graph.getLayer(1).conf().getLayer()).getWeightNoise()); - assertEquals(wn, ((BaseLayer) graph.getLayer(2).conf().getLayer()).getWeightNoise()); + assertEquals(wn, ((BaseLayer) graph.getLayer(0).getLayerConfiguration()).getWeightNoise()); + assertEquals(new DropConnect(0.25), ((BaseLayer) graph.getLayer(1).getLayerConfiguration()).getWeightNoise()); + assertEquals(wn, ((BaseLayer) graph.getLayer(2).getLayerConfiguration()).getWeightNoise()); TestUtils.testModelSerialization(graph); @@ -144,8 +143,8 @@ public class TestWeightNoise extends BaseDL4JTest { List list = Arrays.asList(wn1, wn2, wn3); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).weightNoise(wn1).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).weightNoise(wn2).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).weightNoise(wn3).build()) @@ -168,7 +167,7 @@ public class TestWeightNoise extends BaseDL4JTest { wn3 = new CustomWeightNoise(); list = Arrays.asList(wn1, wn2, wn3); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new DenseLayer.Builder().nIn(10).nOut(10).weightNoise(wn1).build(), "in") @@ -247,9 +246,9 @@ public class TestWeightNoise extends BaseDL4JTest { public void testDropConnectValues() { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .weightInit(WeightInit.ONES) - .list() + .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index edad9fb7d..9002ba2af 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -30,7 +30,6 @@ import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.WorkspaceMode; @@ -82,7 +81,7 @@ import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; import org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM; import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.LSTM; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LearnedSelfAttentionLayer; import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization; import org.deeplearning4j.nn.conf.layers.LocallyConnected1D; @@ -141,7 +140,6 @@ import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor import org.deeplearning4j.nn.modelimport.keras.preprocessors.TensorFlowCnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Disabled; @@ -216,7 +214,7 @@ public class DTypeTests extends BaseDL4JTest { continue; } - if (Layer.class.isAssignableFrom(clazz)) { + if (LayerConfiguration.class.isAssignableFrom(clazz)) { layerClasses.add(clazz); } else if (InputPreProcessor.class.isAssignableFrom(clazz)) { preprocClasses.add(clazz); @@ -258,9 +256,9 @@ public class DTypeTests extends BaseDL4JTest { } public static void logUsedClasses(MultiLayerNetwork net) { - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - for (NeuralNetConfiguration nnc : conf.getConfs()) { - Layer l = nnc.getLayer(); + NeuralNetConfiguration conf = net.getConfiguration(); + for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) { + LayerConfiguration l = nnc.getFirstLayer(); seenLayers.add(l.getClass()); if (l instanceof BaseWrapperLayer) { BaseWrapperLayer bwl = (BaseWrapperLayer) l; @@ -283,7 +281,7 @@ public class DTypeTests extends BaseDL4JTest { for (GraphVertex gv : conf.getVertices().values()) { seenVertices.add(gv.getClass()); if (gv instanceof LayerVertex) { - seenLayers.add(((LayerVertex) gv).getLayerConf().getLayer().getClass()); + seenLayers.add(((LayerVertex) gv).getNetConfiguration().getFirstLayer().getClass()); InputPreProcessor ipp = ((LayerVertex) gv).getPreProcessor(); if (ipp != null) { seenPreprocs.add(ipp.getClass()); @@ -301,7 +299,7 @@ public class DTypeTests extends BaseDL4JTest { for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { Nd4j.setDefaultDataTypes(dt, dt); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new Adam(0.01)) @@ -384,7 +382,7 @@ public class DTypeTests extends BaseDL4JTest { for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}) { Nd4j.setDefaultDataTypes(dt, dt); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new Adam(0.01)) @@ -475,8 +473,8 @@ public class DTypeTests extends BaseDL4JTest { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", outputLayer=" + outputLayer; - Layer ol; - Layer secondLast; + LayerConfiguration ol; + LayerConfiguration secondLast; switch (outputLayer) { case 0: ol = new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); @@ -503,7 +501,7 @@ public class DTypeTests extends BaseDL4JTest { } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(networkDtype) .convolutionMode(ConvolutionMode.Same) .updater(new Adam(1e-2)) @@ -531,7 +529,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new ActivationLayer(Activation.LEAKYRELU)) .layer(secondLast) .layer(ol) - .setInputType(InputType.convolutionalFlat(8, 8, 1)) + .inputType(InputType.convolutionalFlat(8, 8, 1)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -560,7 +558,7 @@ public class DTypeTests extends BaseDL4JTest { assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { - String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).getLayerConfiguration().getClass().getSimpleName()); assertEquals(networkDtype, ff.get(i).dataType(), msg); } @@ -601,8 +599,8 @@ public class DTypeTests extends BaseDL4JTest { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", outputLayer=" + outputLayer; log.info(msg); - Layer ol; - Layer secondLast; + LayerConfiguration ol; + LayerConfiguration secondLast; switch (outputLayer) { case 0: ol = new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); @@ -621,7 +619,7 @@ public class DTypeTests extends BaseDL4JTest { } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(networkDtype) .convolutionMode(ConvolutionMode.Same) .updater(new Nesterovs(1e-2, 0.9)) @@ -636,7 +634,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new Upsampling3D.Builder().size(2).build()) .layer(secondLast) .layer(ol) - .setInputType(InputType.convolutional3D(Convolution3D.DataFormat.NCDHW, 8, 8, 8, 1)) + .inputType(InputType.convolutional3D(Convolution3D.DataFormat.NCDHW, 8, 8, 8, 1)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -664,7 +662,7 @@ public class DTypeTests extends BaseDL4JTest { assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { - String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).getLayerConfiguration().getClass().getSimpleName()); assertEquals(networkDtype, ff.get(i).dataType(), s); } @@ -712,8 +710,8 @@ public class DTypeTests extends BaseDL4JTest { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", outputLayer=" + outputLayer + " at index " + outputLayer; - Layer ol; - Layer secondLast; + LayerConfiguration ol; + LayerConfiguration secondLast; switch (outputLayer) { case 0: ol = new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); @@ -732,7 +730,7 @@ public class DTypeTests extends BaseDL4JTest { } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .trainingWorkspaceMode(WorkspaceMode.NONE) .inferenceWorkspaceMode(WorkspaceMode.NONE) .dataType(networkDtype) @@ -749,7 +747,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new Upsampling1D.Builder(2).build()) .layer(secondLast) .layer(ol) - .setInputType(InputType.recurrent(5, 10,RNNFormat.NCW)) + .inputType(InputType.recurrent(5, 10,RNNFormat.NCW)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -774,7 +772,7 @@ public class DTypeTests extends BaseDL4JTest { assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { - String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).getLayerConfiguration().getClass().getSimpleName()); assertEquals(networkDtype, ff.get(i).dataType(), s); } @@ -814,7 +812,7 @@ public class DTypeTests extends BaseDL4JTest { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(networkDtype) .convolutionMode(ConvolutionMode.Same) .updater(new Adam(1e-2)) @@ -822,7 +820,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new SpaceToBatchLayer.Builder().blocks(1, 1).build()) .layer(new SpaceToDepthLayer.Builder().blocks(2).build()) .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutional(28, 28, 5)) + .inputType(InputType.convolutional(28, 28, 5)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -840,7 +838,7 @@ public class DTypeTests extends BaseDL4JTest { assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { - String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).getLayerConfiguration().getClass().getSimpleName()); assertEquals(networkDtype, ff.get(i).dataType(), s); } @@ -878,8 +876,8 @@ public class DTypeTests extends BaseDL4JTest { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", outputLayer=" + outputLayer; - Layer ol; - Layer secondLast; + LayerConfiguration ol; + LayerConfiguration secondLast; switch (outputLayer) { case 0: ol = new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(); @@ -897,7 +895,7 @@ public class DTypeTests extends BaseDL4JTest { throw new RuntimeException(); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(networkDtype) .convolutionMode(ConvolutionMode.Same) .updater(new Adam(1e-2)) @@ -982,12 +980,12 @@ public class DTypeTests extends BaseDL4JTest { int width = 6; int inputDepth = 4; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(networkDtype) .seed(123) .updater(new NoOp()) - .weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6))) - .list() + .dist(new UniformDistribution(-6, 6)) + .layer(new PrimaryCapsules.Builder(primaryCapsDim, primarpCapsChannel) .kernelSize(3, 3) .stride(2, 2) @@ -996,7 +994,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new CapsuleStrengthLayer.Builder().build()) .layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()) .layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()) - .setInputType(InputType.convolutional(height, width, inputDepth)) + .inputType(InputType.convolutional(height, width, inputDepth)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -1013,7 +1011,7 @@ public class DTypeTests extends BaseDL4JTest { assertEquals(networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { - String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).getLayerConfiguration().getClass().getSimpleName()); assertEquals(networkDtype, ff.get(i).dataType(), s); } @@ -1052,11 +1050,11 @@ public class DTypeTests extends BaseDL4JTest { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test; - ComputationGraphConfiguration.GraphBuilder conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder conf = NeuralNetConfiguration.builder() .dataType(networkDtype) .seed(123) .updater(new NoOp()) - .weightInit(new WeightInitDistribution(new UniformDistribution(-6, 6))) + .dist(new UniformDistribution(-6, 6)) .graphBuilder() .addInputs("in") .setOutputs("out"); @@ -1144,7 +1142,7 @@ public class DTypeTests extends BaseDL4JTest { for (int test = 0; test < 8; test++) { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test; - ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder b = NeuralNetConfiguration.builder() .dataType(networkDtype) .seed(123) .updater(new NoOp()) @@ -1301,7 +1299,7 @@ public class DTypeTests extends BaseDL4JTest { for (int test = 0; test < 2; test++) { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test; - ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder b = NeuralNetConfiguration.builder() .dataType(networkDtype) .seed(123) .updater(new NoOp()) @@ -1395,7 +1393,7 @@ public class DTypeTests extends BaseDL4JTest { INDArray in = Nd4j.rand(networkDtype, new long[]{mb, nIn, tsLength}); INDArray labels = TestUtils.randomOneHot(mb, nOut); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(networkDtype) .activation(Activation.TANH) .updater(new NoOp()) @@ -1408,7 +1406,7 @@ public class DTypeTests extends BaseDL4JTest { .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) .layer(new OutputLayer.Builder().nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(nIn)) + .inputType(InputType.recurrent(nIn)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -1418,7 +1416,7 @@ public class DTypeTests extends BaseDL4JTest { assertEquals( networkDtype, out.dataType(), msg); List ff = net.feedForward(in); for (int i = 0; i < ff.size(); i++) { - String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).conf().getLayer().getClass().getSimpleName()); + String s = msg + " - layer " + (i - 1) + " - " + (i == 0 ? "input" : net.getLayer(i - 1).getLayerConfiguration().getClass().getSimpleName()); assertEquals(networkDtype, ff.get(i).dataType(), s); } @@ -1482,7 +1480,7 @@ public class DTypeTests extends BaseDL4JTest { System.out.println("Starting test: " + name); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dataType(networkDtype) .activation(Activation.TANH) .updater(new NoOp()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java index 2d2379fdb..de8c16075 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java @@ -66,7 +66,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { int timeSeriesLength = 12; //4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors. - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(7) .activation(Activation.TANH) @@ -156,7 +156,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 6; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build(), "in") @@ -211,7 +211,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { //4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors. //Network architecture: lstm0 -> Dense -> RnnOutputLayer0 // and lstm1 -> Dense -> RnnOutputLayer1 - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).graphBuilder() .addInputs("in0", "in1") .addLayer("lstm0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(5).nOut(6) @@ -340,7 +340,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { int nIn = 5; int nOut = 4; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .trainingWorkspaceMode(WorkspaceMode.NONE).inferenceWorkspaceMode(WorkspaceMode.NONE) .graphBuilder() .addInputs("in") @@ -360,7 +360,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { .setOutputs("out").build(); assertEquals(BackpropType.Standard, conf.getBackpropType()); - ComputationGraphConfiguration confTBPTT = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration confTBPTT = NeuralNetConfiguration.builder().seed(12345) .trainingWorkspaceMode(WorkspaceMode.NONE).inferenceWorkspaceMode(WorkspaceMode.NONE) .graphBuilder() .addInputs("in") @@ -377,7 +377,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { .activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)).build(), "1") .setOutputs("out").backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(timeSeriesLength).tBPTTBackwardLength(timeSeriesLength) + .tbpttFwdLength(timeSeriesLength).tbpttBackLength(timeSeriesLength) .setInputTypes(InputType.recurrent(nIn,timeSeriesLength,RNNFormat.NCW)) .build(); assertEquals(BackpropType.TruncatedBPTT, confTBPTT.getBackpropType()); @@ -456,7 +456,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { int nTimeSlices = 20; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) @@ -473,7 +473,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { .dist(new NormalDistribution(0, 0.5)).build(), "1") .setOutputs("out").backpropType(BackpropType.TruncatedBPTT) .setInputTypes(InputType.recurrent(nIn,timeSeriesLength,RNNFormat.NCW)) - .tBPTTBackwardLength(timeSeriesLength).tBPTTForwardLength(timeSeriesLength).build(); + .tbpttBackLength(timeSeriesLength).tbpttFwdLength(timeSeriesLength).build(); Nd4j.getRandom().setSeed(12345); ComputationGraph graph = new ComputationGraph(conf); @@ -493,7 +493,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { int nIn = 5; int nOut = 4; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) @@ -509,7 +509,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { .activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)).build(), "1") .setOutputs("out").backpropType(BackpropType.TruncatedBPTT) - .tBPTTBackwardLength(tbpttLength).tBPTTForwardLength(tbpttLength) + .tbpttBackLength(tbpttLength).tbpttFwdLength(tbpttLength) .setInputTypes(InputType.recurrent(nIn,timeSeriesLength, RNNFormat.NCW)) .build(); @@ -530,13 +530,13 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { @Test public void testTbpttMasking() { //Simple "does it throw an exception" type test... - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .graphBuilder().addInputs("in") .addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE) .activation(Activation.IDENTITY).nIn(1).nOut(1).build(), "in") - .setOutputs("out").backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(8) + .setOutputs("out").backpropType(BackpropType.TruncatedBPTT).tbpttFwdLength(8) .setInputTypes(InputType.recurrent(1,1,RNNFormat.NCW)) - .tBPTTBackwardLength(8).build(); + .tbpttBackLength(8).build(); ComputationGraph net = new ComputationGraph(conf); net.init(); @@ -553,12 +553,12 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { public void checkMaskArrayClearance() { for (boolean tbptt : new boolean[] {true, false}) { //Simple "does it throw an exception" type test... - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .graphBuilder().addInputs("in") .addLayer("out", new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE) .activation(Activation.IDENTITY).nIn(1).nOut(1).build(), "in") .setOutputs("out").backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard) - .tBPTTForwardLength(8).tBPTTBackwardLength(8).build(); + .tbpttFwdLength(8).tbpttBackLength(8).build(); ComputationGraph net = new ComputationGraph(conf); net.init(); @@ -616,7 +616,7 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { int nHiddenUnits = 17; try { - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new org.deeplearning4j.nn.conf.layers.LSTM.Builder().nIn(nIn).nOut(nHiddenUnits).build(), "in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java index 95691fed6..d83f4ac17 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java @@ -57,7 +57,7 @@ public class TestCompGraphCNN extends BaseDL4JTest { protected static ComputationGraphConfiguration getMultiInputGraphConfig() { ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("input") .setInputTypes(InputType.convolutional(32, 32, 3)) @@ -154,7 +154,7 @@ public class TestCompGraphCNN extends BaseDL4JTest { DataSet trainInput; ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .seed(123).graphBuilder().addInputs("input") .setInputTypes(InputType.convolutional(nChannels, imageWidth, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java index 794538c36..f4da77575 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java @@ -24,7 +24,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.layers.variational.BernoulliReconstructionDistribution; @@ -60,7 +59,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(1e-3)) .weightInit(WeightInit.XAVIER) @@ -136,13 +135,13 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.NONE, WorkspaceMode.ENABLED}) { - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(1e-3)) .weightInit(WeightInit.XAVIER) .inferenceWorkspaceMode(wsm) .trainingWorkspaceMode(wsm) - .list() + .layer(new VariationalAutoencoder.Builder() .nIn(784) .nOut(32) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index a6373c6a9..adf347260 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -98,7 +98,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { public File testDir; private static ComputationGraphConfiguration getIrisGraphConfiguration() { - return new NeuralNetConfiguration.Builder().seed(12345) + return NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("input") .addLayer("firstLayer", new DenseLayer.Builder().nIn(4).nOut(5).build(), "input") @@ -106,9 +106,9 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .setOutputs("outputLayer").build(); } - private static MultiLayerConfiguration getIrisMLNConfiguration() { - return new NeuralNetConfiguration.Builder().seed(12345) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + private static NeuralNetConfiguration getIrisMLNConfiguration() { + return NeuralNetConfiguration.builder().seed(12345) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new DenseLayer.Builder().nIn(4).nOut(5).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(3).activation(Activation.SOFTMAX).build()).build(); } @@ -150,7 +150,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(configuration); graph.init(); - MultiLayerConfiguration mlc = getIrisMLNConfiguration(); + NeuralNetConfiguration mlc = getIrisMLNConfiguration(); MultiLayerNetwork net = new MultiLayerNetwork(mlc); net.init(); @@ -209,7 +209,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(configuration); graph.init(); - MultiLayerConfiguration mlc = getIrisMLNConfiguration(); + NeuralNetConfiguration mlc = getIrisMLNConfiguration(); MultiLayerNetwork net = new MultiLayerNetwork(mlc); net.init(); @@ -244,7 +244,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(configuration); graph.init(); - MultiLayerConfiguration mlc = getIrisMLNConfiguration(); + NeuralNetConfiguration mlc = getIrisMLNConfiguration(); MultiLayerNetwork net = new MultiLayerNetwork(mlc); net.init(); @@ -295,7 +295,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(configuration); graph.init(); - MultiLayerConfiguration mlnConfig = getIrisMLNConfiguration(); + NeuralNetConfiguration mlnConfig = getIrisMLNConfiguration(); MultiLayerNetwork net = new MultiLayerNetwork(mlnConfig); net.init(); @@ -332,7 +332,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { MultiDataSetIterator iter = new RecordReaderMultiDataSetIterator.Builder(10).addReader("iris", rr) .addInput("iris", 0, 3).addOutputOneHot("iris", 4, 3).build(); - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration config = NeuralNetConfiguration.builder() .updater(new Sgd(0.1)) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", @@ -377,7 +377,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(configuration); graph.init(); - MultiLayerConfiguration mlc = getIrisMLNConfiguration(); + NeuralNetConfiguration mlc = getIrisMLNConfiguration(); MultiLayerNetwork net = new MultiLayerNetwork(mlc); net.init(); @@ -401,14 +401,14 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { public void testPreprocessorAddition() { //Also check that nIns are set automatically //First: check FF -> RNN - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf1 = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .setInputTypes(InputType.feedForward(5)) .addLayer("rnn", new GravesLSTM.Builder().nOut(5).build(), "in") .addLayer("out", new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).build(), "rnn").setOutputs("out").build(); - assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("rnn")).getLayerConf().getLayer()) + assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("rnn")).getNetConfiguration().getFirstLayer()) .getNIn()); - assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("out")).getLayerConf().getLayer()) + assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("out")).getNetConfiguration().getFirstLayer()) .getNIn()); LayerVertex lv1 = (LayerVertex) conf1.getVertices().get("rnn"); @@ -417,15 +417,15 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertNull(lv2.getPreProcessor()); //Check RNN -> FF -> RNN - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .setInputTypes(InputType.recurrent(5)) .addLayer("ff", new DenseLayer.Builder().nOut(5).build(), "in") .addLayer("out", new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).build(), "ff") .setOutputs("out").build(); - assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("ff")).getLayerConf().getLayer()) + assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("ff")).getNetConfiguration().getFirstLayer()) .getNIn()); - assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("out")).getLayerConf().getLayer()) + assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("out")).getNetConfiguration().getFirstLayer()) .getNIn()); lv1 = (LayerVertex) conf2.getVertices().get("ff"); @@ -434,7 +434,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertTrue(lv2.getPreProcessor() instanceof FeedForwardToRnnPreProcessor); //CNN -> Dense - ComputationGraphConfiguration conf3 = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf3 = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .setInputTypes(InputType.convolutional(28, 28, 1)) .addLayer("cnn", new ConvolutionLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(2, 2) .nOut(3).build(), "in") //(28-2+0)/2+1 = 14 @@ -460,11 +460,11 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { LayerVertex lv4 = (LayerVertex) conf3.getVertices().get("out"); assertNull(lv4.getPreProcessor()); //Check nIns: - assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getLayerConf().getLayer()).getNIn()); + assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFirstLayer()).getNIn()); //CNN->Dense, RNN->Dense, Dense->RNN ComputationGraphConfiguration conf4 = - new NeuralNetConfiguration.Builder().graphBuilder().addInputs("inCNN", "inRNN") + NeuralNetConfiguration.builder().graphBuilder().addInputs("inCNN", "inRNN") .setInputTypes(InputType.convolutional(28, 28, 1), InputType.recurrent(5)) .addLayer("cnn", new ConvolutionLayer.Builder().kernelSize(2, 2).padding(0, 0) .stride(2, 2).nOut(3).build(), "inCNN") //(28-2+0)/2+1 = 14 @@ -495,14 +495,14 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { LayerVertex lv5 = (LayerVertex) conf4.getVertices().get("out"); assertTrue(lv5.getPreProcessor() instanceof FeedForwardToRnnPreProcessor); //Check nIns: - assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getLayerConf().getLayer()).getNIn()); - assertEquals(5, ((FeedForwardLayer) lv4.getLayerConf().getLayer()).getNIn()); - assertEquals(20, ((FeedForwardLayer) lv5.getLayerConf().getLayer()).getNIn()); //10+10 out of the merge vertex -> 20 in to output layer vertex + assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFirstLayer()).getNIn()); + assertEquals(5, ((FeedForwardLayer) lv4.getNetConfiguration().getFirstLayer()).getNIn()); + assertEquals(20, ((FeedForwardLayer) lv5.getNetConfiguration().getFirstLayer()).getNIn()); //10+10 out of the merge vertex -> 20 in to output layer vertex //Input to 2 CNN layers: ComputationGraphConfiguration conf5 = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("input") .setInputTypes(InputType.convolutional(28, 28, 1)) @@ -575,7 +575,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { public void testCompGraphUnderscores() { //Problem: underscores in names could be problematic for ComputationGraphUpdater, HistogramIterationListener - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("input") .addLayer("first_layer", new DenseLayer.Builder().nIn(4).nOut(5).build(), "input") @@ -594,7 +594,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testPreTraining() { ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(1e-6)) .l2(2e-4).graphBuilder().addInputs("in") @@ -648,7 +648,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int nIn = 5; int nOut = 6; ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01) + NeuralNetConfiguration.builder().seed(12345).l1(0.01).l2(0.01) .updater(new Sgd(0.1)) .activation(Activation.TANH).weightInit(WeightInit.XAVIER) .graphBuilder().addInputs("in") @@ -660,7 +660,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .setOutputs("2").build(); ComputationGraphConfiguration confNoReg = - new NeuralNetConfiguration.Builder().seed(12345).updater(new Sgd(0.1)).activation(Activation.TANH) + NeuralNetConfiguration.builder().seed(12345).updater(new Sgd(0.1)).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(nIn).nOut(20).build(), "in") .addLayer("1", new DenseLayer.Builder().nIn(20).nOut(30).build(), "0") @@ -717,7 +717,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { INDArray outData = Nd4j.rand(3, 10); Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration standard = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + ComputationGraphConfiguration standard = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws) .seed(12345).graphBuilder().addInputs("in") .addLayer("l0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") @@ -729,7 +729,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration external = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + ComputationGraphConfiguration external = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws) .seed(12345).graphBuilder().addInputs("in") .addLayer("l0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("l0") @@ -771,7 +771,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { for(WorkspaceMode ws : WorkspaceMode.values()) { // System.out.println("***** WORKSPACE: " + ws); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(0.01)) .trainingWorkspaceMode(ws) .inferenceWorkspaceMode(ws) @@ -819,7 +819,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int nIn = 2; int nOut = 4; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(nIn).nOut(4).activation(Activation.RELU).build(), "in") @@ -857,7 +857,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { expectedGradient.setGradientFor("output_W", Nd4j.ones(5, 3)); expectedGradient.setGradientFor("output_b", Nd4j.ones(1, 3)); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .addInputs("input").addLayer("first", new DenseLayer.Builder().nIn(4).nOut(5).build(), "input") .addLayer("output", new OutputLayer.Builder().nIn(5).nOut(3).activation(Activation.SOFTMAX).build(), "first") @@ -893,7 +893,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { public void testCnnFlatInputType1() { //First: check conv input type. Expect: no preprocessor, nIn set appropriately - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .setInputTypes(InputType.convolutional(10, 8, 3)) .addLayer("layer", new ConvolutionLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(1, 1) @@ -903,14 +903,14 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .build(); LayerVertex lv = (LayerVertex) conf.getVertices().get("layer"); - FeedForwardLayer l = ((FeedForwardLayer) (lv).getLayerConf().getLayer()); + FeedForwardLayer l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); assertEquals(3, l.getNIn()); assertNull(lv.getPreProcessor()); //Check the equivalent config, but with flat conv data input instead //In this case, the only difference should be the addition of a preprocessor //First: check conv input type. Expect: no preprocessor, nIn set appropriately - conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .setInputTypes(InputType.convolutionalFlat(10, 8, 3)) .addLayer("layer", new ConvolutionLayer.Builder().kernelSize(2, 2).padding(0, 0).stride(1, 1) @@ -920,7 +920,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .build(); lv = (LayerVertex) conf.getVertices().get("layer"); - l = ((FeedForwardLayer) (lv).getLayerConf().getLayer()); + l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); assertEquals(3, l.getNIn()); assertNotNull(lv.getPreProcessor()); InputPreProcessor preProcessor = lv.getPreProcessor(); @@ -932,7 +932,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //Finally, check configuration with a subsampling layer - conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .setInputTypes(InputType.convolutionalFlat(10, 8, 3)) .addLayer("l0", new SubsamplingLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0) .build(), "in") @@ -945,7 +945,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //Check subsampling layer: lv = (LayerVertex) conf.getVertices().get("l0"); - SubsamplingLayer sl = ((SubsamplingLayer) (lv).getLayerConf().getLayer()); + SubsamplingLayer sl = ((SubsamplingLayer) (lv).getNetConfiguration().getFirstLayer()); assertNotNull(lv.getPreProcessor()); preProcessor = lv.getPreProcessor(); assertTrue(preProcessor instanceof FeedForwardToCnnPreProcessor); @@ -955,7 +955,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertEquals(3, preproc.getNumChannels()); //Check dense layer lv = (LayerVertex) conf.getVertices().get("layer"); - l = ((FeedForwardLayer) (lv).getLayerConf().getLayer()); + l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); assertEquals(3, l.getNIn()); assertNull(lv.getPreProcessor()); @@ -970,7 +970,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { graph.init(); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration mlnConfig = getIrisMLNConfiguration(); + NeuralNetConfiguration mlnConfig = getIrisMLNConfiguration(); MultiLayerNetwork net = new MultiLayerNetwork(mlnConfig); net.init(); @@ -999,7 +999,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { for (OptimizationAlgorithm oa : oas) { // System.out.println(oa); ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().optimizationAlgo(oa).graphBuilder() + NeuralNetConfiguration.builder().optimizationAlgo(oa).graphBuilder() .addInputs("input") .addLayer("first", new DenseLayer.Builder().nIn(4).nOut(5).build(), "input") .addLayer("output", new OutputLayer.Builder().nIn(5).nOut(3).activation(Activation.SOFTMAX).build(), @@ -1016,7 +1016,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testIterationCountAndPersistence() throws IOException { Nd4j.getRandom().setSeed(123); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) @@ -1054,7 +1054,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void printSummary() { - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY); ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight") @@ -1095,7 +1095,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testFeedForwardIncludeNonLayerVertices() { - ComputationGraphConfiguration c = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration c = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).build(), "in") .addLayer("1", new DenseLayer.Builder().nIn(5).nOut(5).build(), "in") .addVertex("merge", new MergeVertex(), "0", "1") @@ -1123,7 +1123,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //Users generally shouldn't do this, but multiple setOutputs calls should *replace* not *add* outputs - ComputationGraphConfiguration c = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration c = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("out", new OutputLayer.Builder().nIn(10).nOut(5).activation(Activation.SOFTMAX).build(), "in").setOutputs("out") .setOutputs("out").build(); @@ -1135,7 +1135,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { public void testDropoutValidation() { //At one point: this threw an exception due to incorrect validation for (boolean dropConnect : new boolean[]{false, true}) { - new NeuralNetConfiguration.Builder().weightNoise(new DropConnect(0.5)) + NeuralNetConfiguration.builder().weightNoise(new DropConnect(0.5)) .graphBuilder().setInputTypes(InputType.feedForward(1)).addInputs("input1") .addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(1).nOut(1) @@ -1151,7 +1151,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //Don't care about this being valid ComputationGraphConfiguration c = - new NeuralNetConfiguration.Builder().l1(0.5).l2(0.6).graphBuilder() + NeuralNetConfiguration.builder().l1(0.5).l2(0.6).graphBuilder() .addInputs("in") .addLayer("sub1", new SubsamplingLayer.Builder(2, 2).build(), "in") .addLayer("sub2", new Subsampling1DLayer.Builder(2).build(), "sub1") @@ -1178,7 +1178,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testErrorNoOutputLayer() { - ComputationGraphConfiguration c = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration c = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("dense") .build(); @@ -1202,7 +1202,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //When a vertex supports only one input, and gets multiple inputs - we should automatically add a merge //vertex - NeuralNetConfiguration nnc = new NeuralNetConfiguration(); + NeuralNetConfiguration nnc = NeuralNetConfiguration.builder().build(); nnc.setLayer(new DenseLayer.Builder().build()); GraphVertex[] singleInputVertices = new GraphVertex[]{new L2NormalizeVertex(), new LayerVertex(nnc, null), new PoolHelperVertex(), new PreprocessorVertex(), new ReshapeVertex(1, 1), @@ -1210,7 +1210,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { new DuplicateToTimeSeriesVertex("in1"), new LastTimeStepVertex("in1")}; for (GraphVertex gv : singleInputVertices) { - ComputationGraphConfiguration c = new NeuralNetConfiguration.Builder().graphBuilder() + ComputationGraphConfiguration c = NeuralNetConfiguration.builder().graphBuilder() .addInputs("in1", "in2").addVertex("gv", gv, "in1", "in2").setOutputs("gv").build(); boolean foundMerge = false; @@ -1238,7 +1238,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int depth = 3; INDArray img = Nd4j.ones(minibatch, depth, height, width); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("input") .addLayer("L1", new ConvolutionLayer.Builder(new int[]{1, 1}, new int[]{1, 1}, new int[]{0, 0}).nIn(depth).nOut(depth) @@ -1262,7 +1262,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testEpochCounter() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .addLayer("out", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build(), "in") @@ -1302,7 +1302,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int V_HEIGHT = 130; int V_NFRAMES = 150; ComputationGraphConfiguration confForArchitecture = - new NeuralNetConfiguration.Builder().seed(12345).l2(0.001) //l2 regularization on all layers + NeuralNetConfiguration.builder().seed(12345).l2(0.001) //l2 regularization on all layers .updater(new AdaGrad(0.4)).graphBuilder() .addInputs("in") .addLayer("layer0", new ConvolutionLayer.Builder(10, 10).nIn(3) //3 channels: RGB @@ -1331,7 +1331,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .inputPreProcessor("layer3", new CnnToFeedForwardPreProcessor(7, 7, 10)) .inputPreProcessor("layer4", new FeedForwardToRnnPreProcessor()) .backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(V_NFRAMES / 5).tBPTTBackwardLength(V_NFRAMES / 5).build(); + .tbpttFwdLength(V_NFRAMES / 5).tbpttBackLength(V_NFRAMES / 5).build(); ComputationGraph modelExpectedArch = new ComputationGraph(confForArchitecture); modelExpectedArch.init(); ComputationGraph modelMow = new TransferLearning.GraphBuilder(modelExpectedArch).setFeatureExtractor("layer2").build(); @@ -1347,7 +1347,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { public void testInputClearance() throws Exception { //Activations should be cleared - if not, it's possible for out of (workspace) scope arrays to be around // which can cause a crash - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .convolutionMode(ConvolutionMode.Same) .graphBuilder() .addInputs("in") @@ -1383,7 +1383,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { for(boolean allowDisconnected : new boolean[]{false, true}) { try { - ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder b = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .addLayer("0", new DenseLayer.Builder().activation(Activation.SIGMOID).nOut(8).build(), "in") @@ -1414,7 +1414,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testLayerSize(){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") @@ -1436,7 +1436,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertEquals(3, net.layerInputSize(0)); assertEquals(0, net.layerInputSize(1)); - assertEquals(((FeedForwardLayer)net.getLayer(2).conf().getLayer()).getNIn(), net.layerInputSize(2)); + assertEquals(((FeedForwardLayer)net.getLayer(2).getLayerConfiguration()).getNIn(), net.layerInputSize(2)); assertEquals(30, net.layerInputSize(3)); assertEquals(6, net.layerSize("0")); @@ -1446,14 +1446,14 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertEquals(3, net.layerInputSize("0")); assertEquals(0, net.layerInputSize("1")); - assertEquals(((FeedForwardLayer)net.getLayer(2).conf().getLayer()).getNIn(), net.layerInputSize("2")); + assertEquals(((FeedForwardLayer)net.getLayer(2).getLayerConfiguration()).getNIn(), net.layerInputSize("2")); assertEquals(30, net.layerInputSize("3")); } @Test public void testZeroParamNet() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).build(), "in") @@ -1494,7 +1494,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { final String inputName = "input"; final String outputName = "output"; final String scaleName = "scale"; - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + final ComputationGraph graph = new ComputationGraph(NeuralNetConfiguration.builder() //.inferenceWorkspaceMode(WorkspaceMode.NONE) .graphBuilder() .addInputs(inputName) @@ -1535,7 +1535,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { DataSet all = new IrisDataSetIterator(150,150).next(); DataSetIterator iter = new IrisDataSetIterator(5,150); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .graphBuilder() .addInputs("in") @@ -1558,7 +1558,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //Test for a simple net: - ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder builder = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in1", "in2") .layer("0", new DenseLayer.Builder().nOut(10).build(), "in1") @@ -1595,7 +1595,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testTopoSortSaving(){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in1", "in2") .addLayer("l0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in1") @@ -1694,7 +1694,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //The fit methods should *not* do layerwise pretraining: - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") @@ -1742,7 +1742,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testAllowInputModification(){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in1", "in2") @@ -1781,7 +1781,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testCompGraphDropoutOutputLayers(){ //https://github.com/deeplearning4j/deeplearning4j/issues/6326 - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dropOut(0.8) .graphBuilder() .addInputs("in1", "in2") @@ -1819,7 +1819,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testCompGraphDropoutOutputLayers2() { //https://github.com/deeplearning4j/deeplearning4j/issues/6326 - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dropOut(0.8) .graphBuilder() .addInputs("in1", "in2") @@ -1854,7 +1854,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testAddRemoveVertex() { - new NeuralNetConfiguration.Builder().graphBuilder() + NeuralNetConfiguration.builder().graphBuilder() .addVertex("toRemove", new ScaleVertex(0), "don't care") .addVertex("test", new ScaleVertex(0), "toRemove") .removeVertex("toRemove", true); @@ -1864,7 +1864,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testGetSetParamUnderscores(){ //Test get/set param with underscores in layer nome - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("layer_zero", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") @@ -1890,7 +1890,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testOutputSpecificLayers(){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .graphBuilder() .addInputs("in") @@ -1918,7 +1918,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void singleInputElemVertex() { final InputType inputType = InputType.convolutional(10, 10, 2); - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + final ComputationGraph graph = new ComputationGraph(NeuralNetConfiguration.builder() .graphBuilder() .setInputTypes(inputType) .addInputs("input") @@ -1935,7 +1935,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testCloneDropoutIndependence(){ - val modelConf = new NeuralNetConfiguration.Builder() + val modelConf = NeuralNetConfiguration.builder() .updater(new Adam(0.01)) .weightInit(WeightInit.XAVIER_UNIFORM) .biasInit(0) @@ -1968,8 +1968,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph cg2 = model.clone(); - IDropout d1 = model.getLayer(0).conf().getLayer().getIDropout(); - IDropout d2 = cg2.getLayer(0).conf().getLayer().getIDropout(); + IDropout d1 = model.getLayer(0).getLayerConfiguration().getIDropout(); + IDropout d2 = cg2.getLayer(0).getLayerConfiguration().getIDropout(); assertNotSame(d1, d2); //Should not be same object! assertEquals(d1, d2); //But should be equal @@ -1982,7 +1982,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int hiddenSize = 100; int dataSize = 10; int seqLen = 5; - ComputationGraphConfiguration configuration = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration configuration = NeuralNetConfiguration.builder() .updater(new Adam()) .graphBuilder() .addInputs("x_emb") @@ -2021,7 +2021,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //https://github.com/deeplearning4j/deeplearning4j/issues/6809#issuecomment-463892644 double lr = 1e-3; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new Adam(lr)) @@ -2121,7 +2121,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int outputSize = 6; int layerSize = 3; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .seed(12345) .weightInit(WeightInit.XAVIER) @@ -2152,7 +2152,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testConv3dMergeVertex(){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addLayer("l0", new Convolution3D.Builder().kernelSize(2,2,2).stride(1,1,1).nIn(3).nOut(3).dataFormat(Convolution3D.DataFormat.NCDHW).build(), "in") @@ -2172,7 +2172,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testDualEmbedding(){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .addLayer("e1", new EmbeddingLayer.Builder().nIn(10).nOut(5).build(), "in") @@ -2191,7 +2191,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { @Test public void testMergeNchw() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .convolutionMode(ConvolutionMode.Same) .graphBuilder() .addInputs("in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java index 0c17238db..ce8019133 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java @@ -42,7 +42,7 @@ public class TestSetGetParameters extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); //Create configuration. Doesn't matter if this doesn't actually work for forward/backward pass here - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).graphBuilder() .addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") .addLayer("1", new GravesLSTM.Builder().nIn(10).nOut(10).build(), "in") .addLayer("2", new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10).build(), "in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java index 96e1dcf12..237e7550e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java @@ -68,7 +68,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { for (int nExamples : miniBatchSizes) { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.1)).seed(12345).graphBuilder().addInputs("in") .addLayer("0", new GravesLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).build(), @@ -158,7 +158,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { for (int nExamples : miniBatchSizes) { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(new NormalDistribution(0,2)) .updater(new Sgd(0.1)).seed(12345).graphBuilder().addInputs("in") @@ -300,7 +300,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { INDArray labels = Nd4j.ones(miniBatch, nOut, tsLength); ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration.builder().seed(12345L) .graphBuilder() .addInputs("in").addLayer("0", new GravesLSTM.Builder().nIn(nIn).nOut(5) @@ -370,7 +370,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { INDArray input = Nd4j.rand(miniBatch, nIn, tsLength); ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration.builder().seed(12345L) .graphBuilder() .addInputs("in").addLayer("0", new GravesLSTM.Builder().nIn(nIn).nOut(5) @@ -391,7 +391,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { net.init(); ComputationGraphConfiguration conf2 = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration.builder().seed(12345L) .graphBuilder() .addInputs("in").addLayer("0", new GravesLSTM.Builder().nIn(nIn).nOut(5) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java index ba3eb90bb..3ca1aa8bd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/graphnodes/TestGraphNodes.java @@ -188,7 +188,7 @@ public class TestGraphNodes extends BaseDL4JTest { @Test public void testLastTimeStepVertex() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addVertex("lastTS", new LastTimeStepVertex("in"), "in") .addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "lastTS").setOutputs("out") .build(); @@ -239,7 +239,7 @@ public class TestGraphNodes extends BaseDL4JTest { @Test public void testDuplicateToTimeSeriesVertex() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder() .addInputs("in2d", "in3d") .addVertex("duplicateTS", new DuplicateToTimeSeriesVertex("in3d"), "in2d") .addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "duplicateTS") @@ -313,7 +313,7 @@ public class TestGraphNodes extends BaseDL4JTest { null, null); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in1", "in2") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in1", "in2") .addVertex("stack", new org.deeplearning4j.nn.conf.graph.StackVertex(), "in1", "in2") .addLayer("1", new EmbeddingLayer.Builder().nIn(5).nOut(5).build(), "stack") .addVertex("unstack1", new org.deeplearning4j.nn.conf.graph.UnstackVertex(0, 2), "1") @@ -540,7 +540,7 @@ public class TestGraphNodes extends BaseDL4JTest { public void testJSON() { //The config here is non-sense, but that doesn't matter for config -> json -> config test ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addVertex("v1", new ElementWiseVertex(ElementWiseVertex.Op.Add), "in") .addVertex("v2", new org.deeplearning4j.nn.conf.graph.MergeVertex(), "in", "in") .addVertex("v3", new PreprocessorVertex( @@ -565,7 +565,7 @@ public class TestGraphNodes extends BaseDL4JTest { int numLabelClasses = 10; int numInputs = 5; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .trainingWorkspaceMode(WorkspaceMode.NONE) .inferenceWorkspaceMode(WorkspaceMode.NONE) .seed(123) //Random number generator seed for improved repeatability. Optional. diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java index 14e169767..629fd7069 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ActivationLayerTest.java @@ -24,7 +24,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ActivationLayer; @@ -83,15 +82,17 @@ public class ActivationLayerTest extends BaseDL4JTest { DataSet next = iter.next(); // Run without separate activation layer - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).activation(Activation.RELU) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.LBFGS) + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .seed(123) + .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + + .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX).nIn(10).nOut(10).build()) - .build(); + .build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -99,7 +100,7 @@ public class ActivationLayerTest extends BaseDL4JTest { // Run with separate activation layer - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).activation(Activation.IDENTITY) @@ -152,7 +153,7 @@ public class ActivationLayerTest extends BaseDL4JTest { // Run without separate activation layer Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0) @@ -170,7 +171,7 @@ public class ActivationLayerTest extends BaseDL4JTest { // Run with separate activation layer Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new AutoEncoder.Builder().nIn(nIn).nOut(layerSize).corruptionLevel(0.0) @@ -214,7 +215,7 @@ public class ActivationLayerTest extends BaseDL4JTest { DataSet next = iter.next(); // Run without separate activation layer - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20) @@ -222,7 +223,7 @@ public class ActivationLayerTest extends BaseDL4JTest { .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) .activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -230,8 +231,8 @@ public class ActivationLayerTest extends BaseDL4JTest { // Run with separate activation layer - MultiLayerConfiguration conf2 = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .seed(123).list() .layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20) @@ -243,7 +244,7 @@ public class ActivationLayerTest extends BaseDL4JTest { .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) .nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network2 = new MultiLayerNetwork(conf2); network2.init(); @@ -271,7 +272,7 @@ public class ActivationLayerTest extends BaseDL4JTest { @Test public void testActivationInheritance() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .weightInit(WeightInit.XAVIER) .activation(Activation.RATIONALTANH) @@ -287,19 +288,19 @@ public class ActivationLayerTest extends BaseDL4JTest { MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - assertNotNull(((ActivationLayer)network.getLayer(1).conf().getLayer()).getActivationFn()); + assertNotNull(((ActivationLayer)network.getLayer(1).getLayerConfiguration()).getActivationFn()); - assertTrue(((DenseLayer)network.getLayer(0).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer(1).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer(2).conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer(3).conf().getLayer()).getActivationFn() instanceof ActivationELU); - assertTrue(((OutputLayer)network.getLayer(4).conf().getLayer()).getActivationFn() instanceof ActivationSoftmax); + assertTrue(((DenseLayer)network.getLayer(0).getLayerConfiguration()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer)network.getLayer(1).getLayerConfiguration()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer)network.getLayer(2).getLayerConfiguration()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer)network.getLayer(3).getLayerConfiguration()).getActivationFn() instanceof ActivationELU); + assertTrue(((OutputLayer)network.getLayer(4).getLayerConfiguration()).getActivationFn() instanceof ActivationSoftmax); } @Test public void testActivationInheritanceCG() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .weightInit(WeightInit.XAVIER) .activation(Activation.RATIONALTANH) @@ -317,13 +318,13 @@ public class ActivationLayerTest extends BaseDL4JTest { ComputationGraph network = new ComputationGraph(conf); network.init(); - assertNotNull(((ActivationLayer)network.getLayer("1").conf().getLayer()).getActivationFn()); + assertNotNull(((ActivationLayer)network.getLayer("1").getLayerConfiguration()).getActivationFn()); - assertTrue(((DenseLayer)network.getLayer("0").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer("1").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer("2").conf().getLayer()).getActivationFn() instanceof ActivationRationalTanh); - assertTrue(((ActivationLayer)network.getLayer("3").conf().getLayer()).getActivationFn() instanceof ActivationELU); - assertTrue(((OutputLayer)network.getLayer("4").conf().getLayer()).getActivationFn() instanceof ActivationSoftmax); + assertTrue(((DenseLayer)network.getLayer("0").getLayerConfiguration()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer)network.getLayer("1").getLayerConfiguration()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer)network.getLayer("2").getLayerConfiguration()).getActivationFn() instanceof ActivationRationalTanh); + assertTrue(((ActivationLayer)network.getLayer("3").getLayerConfiguration()).getActivationFn() instanceof ActivationELU); + assertTrue(((OutputLayer)network.getLayer("4").getLayerConfiguration()).getActivationFn() instanceof ActivationSoftmax); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/AutoEncoderTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/AutoEncoderTest.java index f841d1454..8b63b88b4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/AutoEncoderTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/AutoEncoderTest.java @@ -47,7 +47,7 @@ public class AutoEncoderTest extends BaseDL4JTest { int in2Size = 15; int hiddenSize = 10; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .graphBuilder() .addInputs("in1", "in2") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java index bc1b2db87..3162ed209 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -81,12 +80,12 @@ public class BaseLayerTest extends BaseDL4JTest { int nIn = 2; int nOut = 2; - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); } @@ -94,7 +93,7 @@ public class BaseLayerTest extends BaseDL4JTest { int nIn = 2; int nOut = 2; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) .layer(1, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()).build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java index 7b55a4641..002495133 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java @@ -41,8 +41,8 @@ public class CacheModeTest extends BaseDL4JTest { @Test public void testConvCacheModeSimple(){ - MultiLayerConfiguration conf1 = getConf(CacheMode.NONE); - MultiLayerConfiguration conf2 = getConf(CacheMode.DEVICE); + NeuralNetConfiguration conf1 = getConf(CacheMode.NONE); + NeuralNetConfiguration conf2 = getConf(CacheMode.DEVICE); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); @@ -62,8 +62,8 @@ public class CacheModeTest extends BaseDL4JTest { assertEquals(net1.params(), net2.params()); } - private static MultiLayerConfiguration getConf(CacheMode cacheMode){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + private static NeuralNetConfiguration getConf(CacheMode cacheMode){ + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .inferenceWorkspaceMode(WorkspaceMode.ENABLED) .trainingWorkspaceMode(WorkspaceMode.ENABLED) @@ -73,7 +73,7 @@ public class CacheModeTest extends BaseDL4JTest { .layer(new ConvolutionLayer.Builder().nOut(3).build()) .layer(new ConvolutionLayer.Builder().nOut(3).build()) .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .inputType(InputType.convolutionalFlat(28, 28, 1)) .build(); return conf; @@ -84,8 +84,8 @@ public class CacheModeTest extends BaseDL4JTest { for(boolean graves : new boolean[]{true, false}) { - MultiLayerConfiguration conf1 = getConfLSTM(CacheMode.NONE, graves); - MultiLayerConfiguration conf2 = getConfLSTM(CacheMode.DEVICE, graves); + NeuralNetConfiguration conf1 = getConfLSTM(CacheMode.NONE, graves); + NeuralNetConfiguration conf2 = getConfLSTM(CacheMode.DEVICE, graves); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); @@ -106,8 +106,8 @@ public class CacheModeTest extends BaseDL4JTest { } } - private static MultiLayerConfiguration getConfLSTM(CacheMode cacheMode, boolean graves){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + private static NeuralNetConfiguration getConfLSTM(CacheMode cacheMode, boolean graves){ + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .inferenceWorkspaceMode(WorkspaceMode.ENABLED) .trainingWorkspaceMode(WorkspaceMode.ENABLED) @@ -152,7 +152,7 @@ public class CacheModeTest extends BaseDL4JTest { } private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .inferenceWorkspaceMode(WorkspaceMode.ENABLED) .trainingWorkspaceMode(WorkspaceMode.ENABLED) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java index 73bd4c333..9f5597199 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java @@ -52,7 +52,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest { private ComputationGraph getGraph(int numLabels, double lambda) { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .dist(new NormalDistribution(0, 1)).updater(new NoOp()) .graphBuilder().addInputs("input1") @@ -73,7 +73,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest { int nChannels = 1; // Number of input channels int outputNum = 10; // The number of possible outcomes - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) // Training iterations as above + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) // Training iterations as above .l2(0.0005).weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)) .graphBuilder().addInputs("input") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java index 3aa7e37dd..716bbb8a9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java @@ -25,7 +25,6 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -76,7 +75,7 @@ public class DropoutLayerTest extends BaseDL4JTest { @Test public void testDropoutLayerWithoutTraining() throws Exception { - MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(3648) + NeuralNetConfiguration confIntegrated = NeuralNetConfiguration.builder().seed(3648) .list().layer(0, new ConvolutionLayer.Builder(1, 1).stride(1, 1).nIn(1).nOut(1).dropOut(0.25) .activation(Activation.IDENTITY).weightInit(WeightInit.XAVIER) @@ -85,7 +84,7 @@ public class DropoutLayerTest extends BaseDL4JTest { .activation(Activation.SOFTMAX) .weightInit(WeightInit.XAVIER).dropOut(0.25) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(2, 2, 1)).build(); + .inputType(InputType.convolutionalFlat(2, 2, 1)).build(); MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated); netIntegrated.init(); @@ -94,8 +93,8 @@ public class DropoutLayerTest extends BaseDL4JTest { netIntegrated.getLayer(1).setParam("W", Nd4j.eye(4)); netIntegrated.getLayer(1).setParam("b", Nd4j.zeros(4, 1)); - MultiLayerConfiguration confSeparate = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confSeparate = + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .seed(3648) .list().layer(0, @@ -109,7 +108,7 @@ public class DropoutLayerTest extends BaseDL4JTest { .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) .nOut(4).build()) - .setInputType(InputType.convolutionalFlat(2, 2, 1)).build(); + .inputType(InputType.convolutionalFlat(2, 2, 1)).build(); MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); netSeparate.init(); @@ -137,8 +136,8 @@ public class DropoutLayerTest extends BaseDL4JTest { List actTestSeparate = netSeparate.feedForward(in.dup(), false); //Check masks: - INDArray maskIntegrated = ((Dropout)netIntegrated.getLayer(0).conf().getLayer().getIDropout()).getMask(); - INDArray maskSeparate = ((Dropout)netSeparate.getLayer(0).conf().getLayer().getIDropout()).getMask(); + INDArray maskIntegrated = ((Dropout)netIntegrated.getLayer(0).getLayerConfiguration().getIDropout()).getMask(); + INDArray maskSeparate = ((Dropout)netSeparate.getLayer(0).getLayerConfiguration().getIDropout()).getMask(); assertEquals(maskIntegrated, maskSeparate); @@ -156,7 +155,7 @@ public class DropoutLayerTest extends BaseDL4JTest { DataSet next = iter.next(); // Run without separate activation layer - MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confIntegrated = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10) @@ -173,7 +172,7 @@ public class DropoutLayerTest extends BaseDL4JTest { netIntegrated.fit(next); // Run with separate activation layer - MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confSeparate = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).activation(Activation.RELU) @@ -229,7 +228,7 @@ public class DropoutLayerTest extends BaseDL4JTest { // Run without separate activation layer Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration confIntegrated = NeuralNetConfiguration.builder().seed(123) .list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20) .activation(Activation.TANH).weightInit(WeightInit.XAVIER) @@ -237,7 +236,7 @@ public class DropoutLayerTest extends BaseDL4JTest { .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.5) .nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); // Run with separate activation layer Nd4j.getRandom().setSeed(12345); @@ -248,14 +247,14 @@ public class DropoutLayerTest extends BaseDL4JTest { Map preProcessorMap = new HashMap<>(); preProcessorMap.put(1, new CnnToFeedForwardPreProcessor(13, 13, 20)); - MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().seed(123).list() + NeuralNetConfiguration confSeparate = NeuralNetConfiguration.builder().seed(123).list() .layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20) .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()) .layer(1, new DropoutLayer.Builder(0.5).build()) .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()) .inputPreProcessors(preProcessorMap) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); Nd4j.getRandom().setSeed(12345); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java index 0f506dbfe..1e83adaf2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java @@ -23,9 +23,11 @@ package org.deeplearning4j.nn.layers; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder; +import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; @@ -40,6 +42,7 @@ import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.List; +import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -54,19 +57,20 @@ public class FrozenLayerTest extends BaseDL4JTest { public void testFrozen() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY); FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork( + (NeuralNetConfiguration) ((NeuralNetConfigurationBuilder)overallConf).clone().list() + .layer(0, new Builder().nIn(4).nOut(3).build()) + .layer(1, new Builder().nIn(3).nOut(2).build()) + .layer(2, new Builder().nIn(2).nOut(3).build()) + .layer(3, new OutputLayer.Builder( + LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + .build()) + .build()); modelToFineTune.init(); List ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false); @@ -77,12 +81,13 @@ public class FrozenLayerTest extends BaseDL4JTest { INDArray paramsLastTwoLayers = Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); - MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build(), paramsLastTwoLayers); + MultiLayerNetwork notFrozen = new MultiLayerNetwork( + (NeuralNetConfiguration) overallConf.clone() + .layer(0, new Builder().nIn(2).nOut(3).build()) + .layer(1, new OutputLayer.Builder( + LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + .build()) + .build(), paramsLastTwoLayers); // assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf()); //Equal, other than names // assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf()); //Equal, other than names @@ -109,16 +114,17 @@ public class FrozenLayerTest extends BaseDL4JTest { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY); - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork( + (NeuralNetConfiguration) overallConf + .layer(0, new Builder().nIn(4).nOut(3).build()) + .layer(1, new Builder().nIn(3).nOut(2).build()) + .layer(2, new Builder().nIn(2).nOut(3).build()) + .layer(3, new OutputLayer.Builder( + LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + .build()) + .build()); modelToFineTune.init(); INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); @@ -127,18 +133,18 @@ public class FrozenLayerTest extends BaseDL4JTest { MultiLayerNetwork clonedModel = modelNow.clone(); //Check json - assertEquals(modelNow.getLayerWiseConfigurations().toJson(), clonedModel.getLayerWiseConfigurations().toJson()); + assertEquals(modelNow.getConfiguration().toJson(), clonedModel.getConfiguration().toJson()); //Check params assertEquals(modelNow.params(), clonedModel.params()); MultiLayerNetwork notFrozen = new MultiLayerNetwork( - overallConf.list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build(), + (NeuralNetConfiguration) overallConf.layer(0, new Builder().nIn(2).nOut(3).build()) + .layer(1, new OutputLayer.Builder( + LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(3).nOut(3) + .build()) + .build(), Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params())); int i = 0; @@ -161,7 +167,7 @@ public class FrozenLayerTest extends BaseDL4JTest { public void testFrozenCompGraph() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY); ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") @@ -212,7 +218,7 @@ public class FrozenLayerTest extends BaseDL4JTest { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY); ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") @@ -273,7 +279,7 @@ public class FrozenLayerTest extends BaseDL4JTest { public void testFrozenLayerInstantiation() { //We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // they were initialized via the builder - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list() + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder().seed(12345).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).build()) .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) @@ -283,7 +289,7 @@ public class FrozenLayerTest extends BaseDL4JTest { .nOut(10).build()) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10) .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())) .layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer( @@ -303,7 +309,7 @@ public class FrozenLayerTest extends BaseDL4JTest { String json = conf2.toJson(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json); assertEquals(conf2, fromJson); @@ -323,7 +329,7 @@ public class FrozenLayerTest extends BaseDL4JTest { //We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // they were initialized via the builder - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() + ComputationGraphConfiguration conf1 = NeuralNetConfiguration.builder().seed(12345).graphBuilder() .addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).build(), "in") @@ -335,7 +341,7 @@ public class FrozenLayerTest extends BaseDL4JTest { "1") .setOutputs("2").build(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer.Builder() .layer(new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java index dce5daebd..89c359ae7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java @@ -22,17 +22,13 @@ package org.deeplearning4j.nn.layers; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.nn.transferlearning.FineTuneConfiguration; -import org.deeplearning4j.nn.transferlearning.TransferLearning; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; @@ -42,8 +38,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.util.List; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -55,7 +49,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { public void testFrozenWithBackpropLayerInstantiation() { //We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // they were initialized via the builder - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).list() + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder().seed(12345).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).build()) .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) @@ -65,7 +59,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .nOut(10).build()) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(10).nOut(10) .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build())) .layer(1, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( @@ -85,7 +79,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { String json = conf2.toJson(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json); assertEquals(conf2, fromJson); @@ -105,7 +99,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { //We need to be able to instantitate frozen layers from JSON etc, and have them be the same as if // they were initialized via the builder - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() + ComputationGraphConfiguration conf1 = NeuralNetConfiguration.builder().seed(12345).graphBuilder() .addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).build(), "in") @@ -117,7 +111,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { "1") .setOutputs("2").build(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345).graphBuilder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( new DenseLayer.Builder().nIn(10).nOut(10).activation(Activation.TANH) @@ -160,7 +154,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new Sgd(2)) @@ -212,7 +206,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { String unfrozenLayer1 = unfrozenBranchName + "1"; String unfrozenBranch2 = unfrozenBranchName + "Output"; - ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration computationGraphConf = NeuralNetConfiguration.builder() .updater(new Sgd(2.0)) .seed(12345) .graphBuilder() @@ -258,7 +252,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1)); - MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confSgd = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new Sgd(2)) @@ -269,7 +263,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build()) .build(); - MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confFrozen = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .updater(new Sgd(2)) @@ -326,7 +320,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { String unfrozenLayer1 = unfrozenBranchName + "1"; String unfrozenBranch2 = unfrozenBranchName + "Output"; - ComputationGraphConfiguration computationGraphConf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration computationGraphConf = NeuralNetConfiguration.builder() .updater(new Sgd(2.0)) .seed(12345) .graphBuilder() @@ -347,7 +341,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .setOutputs(frozenBranchOutput) .build(); - ComputationGraphConfiguration computationSgdGraphConf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration computationSgdGraphConf = NeuralNetConfiguration.builder() .updater(new Sgd(2.0)) .seed(12345) .graphBuilder() diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java index 232a9a46e..0bdf441ac 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java @@ -34,7 +34,6 @@ import org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; @@ -57,7 +56,7 @@ public class OutputLayerTest extends BaseDL4JTest { @Test public void testSetParams() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) .updater(new Sgd(1e-1)) .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(4).nOut(3) @@ -65,12 +64,12 @@ public class OutputLayerTest extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - OutputLayer l = (OutputLayer) conf.getLayer().instantiate(conf, + OutputLayer l = (OutputLayer) conf.getFirstLayer().instantiate(conf, Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); params = l.params(); - l.setParams(params); + l.setParamsTable(params); assertEquals(params, l.params()); } @@ -94,7 +93,7 @@ public class OutputLayerTest extends BaseDL4JTest { } } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) .dist(new NormalDistribution(0, 1)).activation(Activation.TANH) .updater(new NoOp()).build()) @@ -118,7 +117,7 @@ public class OutputLayerTest extends BaseDL4JTest { //As above, but for RnnOutputLayer. Expect all activations etc. to be 3d - MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list() + NeuralNetConfiguration confRnn = NeuralNetConfiguration.builder().seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) .dist(new NormalDistribution(0, 1)).activation(Activation.TANH) .updater(new NoOp()).build()) @@ -175,7 +174,7 @@ public class OutputLayerTest extends BaseDL4JTest { } INDArray labels2d = proc.backprop(labels3d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345L).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).build()) @@ -192,7 +191,7 @@ public class OutputLayerTest extends BaseDL4JTest { INDArray out2d = mln.feedForward(input).get(2); INDArray out3d = proc.preProcess(out2d, miniBatchSize, LayerWorkspaceMgr.noWorkspaces()); - MultiLayerConfiguration confRnn = new NeuralNetConfiguration.Builder().seed(12345L).list() + NeuralNetConfiguration confRnn = NeuralNetConfiguration.builder().seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize) .dist(new NormalDistribution(0, 1)) .activation(Activation.TANH).updater(new NoOp()).build()) @@ -271,8 +270,8 @@ public class OutputLayerTest extends BaseDL4JTest { int nOut = 6; int miniBatchSize = 3; - MultiLayerConfiguration conf1 = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf1 = + NeuralNetConfiguration.builder().seed(12345L) .updater(new NoOp()) .list() .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -288,8 +287,8 @@ public class OutputLayerTest extends BaseDL4JTest { mln.init(); - MultiLayerConfiguration conf2 = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf2 = + NeuralNetConfiguration.builder().seed(12345L) .updater(new NoOp()) .list() .layer(new LSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -348,8 +347,8 @@ public class OutputLayerTest extends BaseDL4JTest { //Check that (A+identity) is equal to (identity+A), for activation A //i.e., should get same output and weight gradients for both - MultiLayerConfiguration conf1 = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf1 = + NeuralNetConfiguration.builder().seed(12345L) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .inferenceWorkspaceMode(ws) @@ -364,8 +363,8 @@ public class OutputLayerTest extends BaseDL4JTest { .build()) .build(); - MultiLayerConfiguration conf2 = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf2 = + NeuralNetConfiguration.builder().seed(12345L) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .inferenceWorkspaceMode(ws) @@ -438,7 +437,7 @@ public class OutputLayerTest extends BaseDL4JTest { //i.e., should get same output and weight gradients for both ComputationGraphConfiguration conf1 = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration.builder().seed(12345L) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .inferenceWorkspaceMode(ws) @@ -456,7 +455,7 @@ public class OutputLayerTest extends BaseDL4JTest { .build(); ComputationGraphConfiguration conf2 = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration.builder().seed(12345L) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .inferenceWorkspaceMode(ws) @@ -524,8 +523,8 @@ public class OutputLayerTest extends BaseDL4JTest { public void testCnnOutputLayerSoftmax(){ //Check that softmax is applied channels-wise - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345L) .updater(new NoOp()) .convolutionMode(ConvolutionMode.Same) .list() @@ -555,19 +554,19 @@ public class OutputLayerTest extends BaseDL4JTest { @Test public void testOutputLayerDefaults(){ - new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.builder().list() .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder().nIn(10).nOut(10).build()) .build(); - new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.builder().list() .layer(new org.deeplearning4j.nn.conf.layers.LossLayer.Builder().build()) .build(); - new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.builder().list() .layer(new org.deeplearning4j.nn.conf.layers.CnnLossLayer.Builder().build()) .build(); - new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.builder().list() .layer(new org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer.Builder().build()) .build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java index 3e526e774..483e34572 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java @@ -32,8 +32,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; -import java.util.Arrays; - import static org.junit.jupiter.api.Assertions.*; public class RepeatVectorTest extends BaseDL4JTest { @@ -42,10 +40,10 @@ public class RepeatVectorTest extends BaseDL4JTest { private Layer getRepeatVectorLayer() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .dataType(DataType.DOUBLE) .layer(new RepeatVector.Builder(REPEAT).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, + return conf.getFirstLayer().instantiate(conf, null, 0, null, false, DataType.DOUBLE); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java index 4d46d5066..db7d4525c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java @@ -50,11 +50,11 @@ public class SeedTest extends BaseDL4JTest { .activation(Activation.SIGMOID).build(); NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().layer(layerType).seed(123).build(); + NeuralNetConfiguration.builder().layer(layerType).seed(123).build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java index 868f34ba7..e17653219 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/TestDropout.java @@ -20,26 +20,20 @@ package org.deeplearning4j.nn.layers; -import lombok.val; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.distribution.UniformDistribution; -import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.lang.reflect.Field; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -55,7 +49,7 @@ public class TestDropout extends BaseDL4JTest { int nIn = 8; int nOut = 8; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Sgd()) .dropOut(0.5).list() .layer(0, new OutputLayer.Builder().activation(Activation.IDENTITY) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsNetMNISTTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsNetMNISTTest.java index 18c285baf..6b307a68c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsNetMNISTTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsNetMNISTTest.java @@ -25,7 +25,6 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.io.IOException; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ActivationLayer; @@ -55,7 +54,7 @@ public class CapsNetMNISTTest extends BaseDL4JTest { @Test public void testCapsNetOnMNIST(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(123) .updater(new Adam()) .list() @@ -72,7 +71,7 @@ public class CapsNetMNISTTest extends BaseDL4JTest { .layer(new CapsuleStrengthLayer.Builder().build()) .layer(new ActivationLayer.Builder(new ActivationSoftmax()).build()) .layer(new LossLayer.Builder(new LossNegativeLogLikelihood()).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .inputType(InputType.convolutionalFlat(28, 28, 1)) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleLayerTest.java index 70e503c42..4536b915b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleLayerTest.java @@ -26,7 +26,6 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.CapsuleLayer; @@ -81,11 +80,11 @@ public class CapsuleLayerTest extends BaseDL4JTest { @Test public void testLayer(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(123) .list() .layer(new CapsuleLayer.Builder(10, 16, 3).build()) - .setInputType(InputType.recurrent(10, 8)) + .inputType(InputType.recurrent(10, 8)) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleStrengthLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleStrengthLayerTest.java index fac472d68..388d380dc 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleStrengthLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/CapsuleStrengthLayerTest.java @@ -24,7 +24,6 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.CapsuleStrengthLayer; @@ -52,11 +51,11 @@ public class CapsuleStrengthLayerTest extends BaseDL4JTest { @Test public void testLayer(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(123) .list() .layer(new CapsuleStrengthLayer.Builder().build()) - .setInputType(InputType.recurrent(5, 8)) + .inputType(InputType.recurrent(5, 8)) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/PrimaryCapsulesTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/PrimaryCapsulesTest.java index 5840ec85f..12f63e7ec 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/PrimaryCapsulesTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/capsule/PrimaryCapsulesTest.java @@ -26,7 +26,6 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.PrimaryCapsules; @@ -106,7 +105,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest { @Test public void testLayer(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(123) .list() .layer(new PrimaryCapsules.Builder(8, 10) @@ -114,7 +113,7 @@ public class PrimaryCapsulesTest extends BaseDL4JTest { .stride(4, 4) .useLeakyReLU(0.5) .build()) - .setInputType(InputType.convolutional(20, 20, 20)) + .inputType(InputType.convolutional(20, 20, 20)) .build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java index 7c07bfeb2..44ee236c8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java @@ -758,8 +758,8 @@ public class ConvDataFormatTests extends BaseDL4JTest { } } - private MultiLayerNetwork getNetWithLayer(Layer layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { - NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + private MultiLayerNetwork getNetWithLayer(LayerConfiguration layer, CNN2DFormat format, ConvolutionMode cm, InputType inputType) { + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder() .dataType(this.dataType) .seed(12345) .convolutionMode(cm) @@ -774,7 +774,7 @@ public class ConvDataFormatTests extends BaseDL4JTest { .layer(layer) .layer(new OutputLayer.Builder().nOut(10) .activation(Activation.SOFTMAX).build()) - .setInputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format)); + .inputType(inputType != null ? inputType : InputType.convolutional(12, 12, 3, format)); if(format == CNN2DFormat.NHWC && !(layer instanceof GlobalPoolingLayer)){ //Add a preprocessor due to the differences in how NHWC and NCHW activations are flattened @@ -799,7 +799,7 @@ public class ConvDataFormatTests extends BaseDL4JTest { } private MultiLayerNetwork getCnnLossNet(CNN2DFormat format, boolean setOnLayerAlso, ConvolutionMode cm){ - NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder() .seed(12345) .convolutionMode(cm) .list() @@ -819,7 +819,7 @@ public class ConvDataFormatTests extends BaseDL4JTest { .activation(Activation.SOFTMAX).build()); } - builder.setInputType(InputType.convolutional(12, 12, 3, format)); + builder.inputType(InputType.convolutional(12, 12, 3, format)); MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); net.init(); @@ -984,24 +984,24 @@ public class ConvDataFormatTests extends BaseDL4JTest { for(CNN2DFormat df : CNN2DFormat.values()) { for(int i = 0; i < 4; i++) { - NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder() - .list(); + NeuralNetConfiguration.NeuralNetConfigurationBuilder b = NeuralNetConfiguration.builder(); + switch (i){ case 0: b.layer(new ConvolutionLayer.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build()); - b.setInputType(InputType.convolutional(12,12,3,df)); + b.inputType(InputType.convolutional(12,12,3,df)); break; case 1: b.layer(new DepthwiseConvolution2D.Builder().kernelSize(2,2).nIn(3).nOut(3).dataFormat(df).build()); - b.setInputType(InputType.convolutional(12,12,3,df)); + b.inputType(InputType.convolutional(12,12,3,df)); break; case 2: b.layer(new Deconvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build()); - b.setInputType(InputType.convolutional(12,12,3,df)); + b.inputType(InputType.convolutional(12,12,3,df)); break; case 3: b.layer(new SeparableConvolution2D.Builder().dataFormat(df).kernelSize(2,2).nIn(3).nOut(3).build()); - b.setInputType(InputType.convolutional(12,12,3,df)); + b.inputType(InputType.convolutional(12,12,3,df)); break; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java index d282690bb..d4a685a3a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java @@ -34,8 +34,6 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; -import java.util.Arrays; - import static org.junit.jupiter.api.Assertions.*; /** @@ -86,15 +84,15 @@ public class Convolution3DTest extends BaseDL4JTest { } private Layer getConvolution3DLayer(ConvolutionMode mode) { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new Convolution3D.Builder().kernelSize(kernelSize).nIn(nChannelsIn).nOut(nChannelsOut) .dataFormat(Convolution3D.DataFormat.NCDHW).convolutionMode(mode).hasBias(false) .build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.ones(1, numParams); - return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); } public INDArray getData() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java index 246dfee5b..1af476e5e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerSetupTest.java @@ -27,8 +27,8 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -71,10 +71,10 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { @Test public void testConvolutionLayerSetup() { - MultiLayerConfiguration.Builder builder = inComplete(); - builder.setInputType(InputType.convolutionalFlat(28, 28, 1)); - MultiLayerConfiguration completed = complete().build(); - MultiLayerConfiguration test = builder.build(); + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = inComplete(); + builder.inputType(InputType.convolutionalFlat(28, 28, 1)); + NeuralNetConfiguration completed = complete().build(); + NeuralNetConfiguration test = builder.build(); assertEquals(completed, test); } @@ -90,7 +90,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { int seed = 123; //setup the network - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .l1(1e-1).l2(2e-4).dropOut(0.5).miniBatch(true) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() .layer(0, new ConvolutionLayer.Builder(5, 5).nOut(5).dropOut(0.5).weightInit(WeightInit.XAVIER) @@ -106,7 +106,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { .nOut(outputNum).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); + .inputType(InputType.convolutional(numRows, numColumns, nChannels)); DataSet d = new DataSet(Nd4j.rand(10, nChannels, numRows, numColumns), FeatureUtil.toOutcomeMatrix(new int[] {1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, 6)); @@ -119,10 +119,10 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { @Test public void testMnistLenet() throws Exception { - MultiLayerConfiguration.Builder incomplete = incompleteMnistLenet(); - incomplete.setInputType(InputType.convolutionalFlat(28, 28, 1)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder incomplete = incompleteMnistLenet(); + incomplete.inputType(InputType.convolutionalFlat(28, 28, 1)); - MultiLayerConfiguration testConf = incomplete.build(); + NeuralNetConfiguration testConf = incomplete.build(); assertEquals(800, ((FeedForwardLayer) testConf.getConf(4).getLayer()).getNIn()); assertEquals(500, ((FeedForwardLayer) testConf.getConf(5).getLayer()).getNIn()); @@ -141,9 +141,9 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { INDArray labels = Nd4j.rand(10, 2); DataSet next = new DataSet(in, labels); - NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLFW(); - builder.setInputType(InputType.convolutional(28, 28, 3)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = incompleteLFW(); + builder.inputType(InputType.convolutional(28, 28, 3)); + NeuralNetConfiguration conf = builder.build(); ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(2).getLayer(); assertEquals(6, layer2.getNIn()); @@ -163,10 +163,10 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { reader.initialize(new FileSplit(new File(rootDir))); DataSetIterator recordReader = new RecordReaderDataSetIterator(reader, 10, 1, labels.size()); labels.remove("lfwtest"); - NeuralNetConfiguration.ListBuilder builder = (NeuralNetConfiguration.ListBuilder) incompleteLRN(); - builder.setInputType(InputType.convolutional(28, 28, 3)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = incompleteLRN(); + builder.inputType(InputType.convolutional(28, 28, 3)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); ConvolutionLayer layer2 = (ConvolutionLayer) conf.getConf(3).getLayer(); assertEquals(6, layer2.getNIn()); @@ -174,70 +174,70 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { } - public MultiLayerConfiguration.Builder incompleteLRN() { - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().seed(3) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nOut(6).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}).build()) - .layer(2, new LocalResponseNormalization.Builder().build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nOut(6).build()) - .layer(4, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}).build()) - .layer(5, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(2) - .activation(Activation.SOFTMAX).build()); + public NeuralNetConfiguration.NeuralNetConfigurationBuilder incompleteLRN() { + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder().seed(3) + .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() + .layer(0, new ConvolutionLayer.Builder( + new int[] {5, 5}).nOut(6).build()) + .layer(1, new SubsamplingLayer.Builder( + new int[] {2, 2}).build()) + .layer(2, new LocalResponseNormalization.Builder().build()) + .layer(3, new ConvolutionLayer.Builder( + new int[] {5, 5}).nOut(6).build()) + .layer(4, new SubsamplingLayer.Builder( + new int[] {2, 2}).build()) + .layer(5, new OutputLayer.Builder( + LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(2) + .activation(Activation.SOFTMAX).build()); return builder; } - public MultiLayerConfiguration.Builder incompleteLFW() { - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().seed(3) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nOut(6).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nOut(6).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}).build()) - .layer(4, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX) - .nOut(2).build()); + public NeuralNetConfiguration.NeuralNetConfigurationBuilder incompleteLFW() { + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder().seed(3) + .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() + .layer(0, new ConvolutionLayer.Builder( + new int[] {5, 5}).nOut(6).build()) + .layer(1, new SubsamplingLayer.Builder( + new int[] {2, 2}).build()) + .layer(2, new ConvolutionLayer.Builder( + new int[] {5, 5}).nOut(6).build()) + .layer(3, new SubsamplingLayer.Builder( + new int[] {2, 2}).build()) + .layer(4, new OutputLayer.Builder( + LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).activation(Activation.SOFTMAX) + .nOut(2).build()); return builder; } - public MultiLayerConfiguration.Builder incompleteMnistLenet() { - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().seed(3) - .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nIn(1).nOut(20).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}, new int[] {2, 2}).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - new int[] {5, 5}).nIn(20).nOut(50).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.SubsamplingLayer.Builder( - new int[] {2, 2}, new int[] {2, 2}).build()) - .layer(4, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nOut(500) - .build()) - .layer(5, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .activation(Activation.SOFTMAX).nOut(10) - .build()); + public NeuralNetConfiguration.NeuralNetConfigurationBuilder incompleteMnistLenet() { + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder().seed(3) + .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() + .layer(0, new ConvolutionLayer.Builder( + new int[] {5, 5}).nIn(1).nOut(20).build()) + .layer(1, new SubsamplingLayer.Builder( + new int[] {2, 2}, new int[] {2, 2}).build()) + .layer(2, new ConvolutionLayer.Builder( + new int[] {5, 5}).nIn(20).nOut(50).build()) + .layer(3, new SubsamplingLayer.Builder( + new int[] {2, 2}, new int[] {2, 2}).build()) + .layer(4, new DenseLayer.Builder().nOut(500) + .build()) + .layer(5, new OutputLayer.Builder( + LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .activation(Activation.SOFTMAX).nOut(10) + .build()); return builder; } - public MultiLayerConfiguration mnistLenet() { - MultiLayerConfiguration builder = - new NeuralNetConfiguration.Builder().seed(3) + public NeuralNetConfiguration mnistLenet() { + NeuralNetConfiguration builder = + NeuralNetConfiguration.builder().seed(3) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( new int[] {5, 5}).nIn(1).nOut(6).build()) @@ -254,12 +254,12 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { return builder; } - public MultiLayerConfiguration.Builder inComplete() { + public NeuralNetConfiguration.NeuralNetConfigurationBuilder inComplete() { int nChannels = 1; int outputNum = 10; int seed = 123; - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] {10, 10}, new int[] {2, 2}).nIn(nChannels).nOut(6).build()) @@ -274,14 +274,14 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { } - public MultiLayerConfiguration.Builder complete() { + public NeuralNetConfiguration.NeuralNetConfigurationBuilder complete() { final int numRows = 28; final int numColumns = 28; int nChannels = 1; int outputNum = 10; int seed = 123; - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder(new int[] {10, 10}, new int[] {2, 2}).nIn(nChannels).nOut(6).build()) @@ -301,15 +301,15 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { @Test public void testDeconvolution() { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() //out = stride * (in-1) + filter - 2*pad -> 2 * (28-1) + 2 - 0 = 56 -> 56x56x3 .layer(0, new Deconvolution2D.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(56-2+2*1)/2+1 = 29 -> 29x29x3 .layer(1, new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()) .layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); + .inputType(InputType.convolutional(28, 28, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); @@ -324,13 +324,13 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { @Test public void testSubSamplingWithPadding() { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() .layer(0, new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 .layer(1, new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()) //(14-2+2)/2+1 = 8 -> 8x8x3 .layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); + .inputType(InputType.convolutional(28, 28, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); @@ -345,13 +345,13 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { @Test public void testUpsampling() { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() .layer(new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 .layer(new Upsampling2D.Builder().size(3).build()) // 14 * 3 = 42! .layer(new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); + .inputType(InputType.convolutional(28, 28, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); @@ -368,13 +368,13 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { int[] blocks = new int[] {2, 2}; - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() .layer(new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 .layer(new SpaceToBatchLayer.Builder(blocks).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7 .layer(new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); + .inputType(InputType.convolutional(28, 28, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); @@ -389,15 +389,15 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { int blocks = 2; - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() //(28-2+0)/2+1 = 14 -> 14x14x3 out .layer(new ConvolutionLayer.Builder(2, 2).padding(0, 0).stride(2, 2).nIn(1).nOut(3).build()) // Divide space dimensions by blocks, i.e. 14/2 = 7 -> 7x7x12 out (3x2x2 depth) .layer(new SpaceToDepthLayer.Builder(blocks, SpaceToDepthLayer.DataFormat.NCHW).build()) .layer(new OutputLayer.Builder().nIn(3 * 2 * 2).nOut(3).activation(Activation.SOFTMAX).build()) // nIn of the next layer gets multiplied by 2*2. - .setInputType(InputType.convolutional(28, 28, 1)); + .inputType(InputType.convolutional(28, 28, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); @@ -415,7 +415,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { DataSet next = iter.next(); // Run with separate activation layer - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .weightInit(WeightInit.XAVIER).list() .layer(0, new ConvolutionLayer.Builder(new int[] {1, 1}, new int[] {1, 1}).nIn(1).nOut(6) @@ -428,7 +428,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { .layer(5, new ActivationLayer.Builder().activation(Activation.RELU).build()) .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -447,16 +447,16 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { @Test public void testSeparableConv2D() { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() .layer( new SeparableConvolution2D.Builder(2, 2) .depthMultiplier(2) .padding(0, 0) .stride(2, 2).nIn(1).nOut(3).build()) //(28-2+0)/2+1 = 14 .layer( new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()) //(14-2+2)/2+1 = 8 -> 8x8x3 .layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); + .inputType(InputType.convolutional(28, 28, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); @@ -471,7 +471,7 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { @Test public void testDeconv2D() { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list() //out = stride * (in-1) + filter - 2*pad -> 2 * (28-1) + 2 - 0 = 56 -> 56x56x3 .layer( new Deconvolution2D.Builder(2, 2) .padding(0, 0) @@ -479,9 +479,9 @@ public class ConvolutionLayerSetupTest extends BaseDL4JTest { //(56-2+2*1)/2+1 = 29 -> 29x29x3 .layer( new SubsamplingLayer.Builder().kernelSize(2, 2).padding(1, 1).stride(2, 2).build()) .layer(2, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, 1)); + .inputType(InputType.convolutional(28, 28, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); assertNotNull(conf.getInputPreProcess(2)); assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index 6b68d6cea..0c58b8703 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -29,20 +29,18 @@ import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.*; -import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitNormal; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.jupiter.api.Test; -import org.nd4j.enums.RnnDataFormat; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationSoftmax; import org.nd4j.linalg.api.buffer.DataType; @@ -59,8 +57,6 @@ import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; -import java.io.File; -import java.util.Arrays; import java.util.List; import static org.junit.jupiter.api.Assertions.*; @@ -77,7 +73,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest { @Test public void testTwdFirstLayer() throws Exception { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4) .updater(new Nesterovs(0.9)).dropOut(0.5) .list().layer(0, @@ -94,10 +90,10 @@ public class ConvolutionLayerTest extends BaseDL4JTest { .dropOut(0.5).build()) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS) //output layer .nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)); + .inputType(InputType.convolutionalFlat(28, 28, 1)); DataSetIterator iter = new MnistDataSetIterator(10, 10); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); DataSet ds = iter.next(); @@ -118,21 +114,21 @@ public class ConvolutionLayerTest extends BaseDL4JTest { int kernelWidth = 3; DataSet trainInput; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 1) - .nOut(2).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new SubsamplingLayer.Builder() - .poolingType(SubsamplingLayer.PoolingType.MAX) - .kernelSize(imageHeight - kernelHeight, 1).stride(1, 1).build()) - .layer(2, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(imageHeight, imageWidth, nChannels)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder() + .seed(123) + .list() + .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 1) + .nOut(2).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build()) + .layer(1, new SubsamplingLayer.Builder() + .poolingType(SubsamplingLayer.PoolingType.MAX) + .kernelSize(imageHeight - kernelHeight, 1).stride(1, 1).build()) + .layer(2, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) + .activation(Activation.SOFTMAX).build()) + .inputType(InputType.convolutionalFlat(imageHeight, imageWidth, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); @@ -155,9 +151,9 @@ public class ConvolutionLayerTest extends BaseDL4JTest { long batchSize = 1; INDArray arr = Nd4j.randn(batchSize,vectorLength,timeSteps); - MultiLayerConfiguration build = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration build = NeuralNetConfiguration.builder().seed(seed) .activation(Activation.RELU) - .weightInit(new WeightInitNormal()) // better init + .weightInit(WeightInit.NORMAL) // better init .updater(new Adam(learningRate)) .list() // block 1 @@ -172,7 +168,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest { .layer(new RnnLossLayer.Builder().dataFormat(RNNFormat.NCW) .activation(new ActivationSoftmax()) .lossFunction(new LossMCXENT()).build()) - .setInputType(InputType.recurrent(vectorLength,timeSteps,RNNFormat.NCW)) + .inputType(InputType.recurrent(vectorLength,timeSteps,RNNFormat.NCW)) .build(); MultiLayerNetwork network = new MultiLayerNetwork(build); @@ -196,18 +192,18 @@ public class ConvolutionLayerTest extends BaseDL4JTest { int kernelWidth = imageWidth + 1; DataSet trainInput; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth) //(img-kernel+2*padding)/stride + 1: must be >= 1. Therefore: with p=0, kernel <= img size - .stride(1, 1).nOut(2).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(imageHeight, imageWidth, nChannels)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder() + .seed(123) + .list() + .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth) //(img-kernel+2*padding)/stride + 1: must be >= 1. Therefore: with p=0, kernel <= img size + .stride(1, 1).nOut(2).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build()) + .layer(1, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) + .activation(Activation.SOFTMAX).build()) + .inputType(InputType.convolutionalFlat(imageHeight, imageWidth, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); @@ -232,19 +228,19 @@ public class ConvolutionLayerTest extends BaseDL4JTest { int kernelWidth = imageWidth; DataSet trainInput; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder() - .seed(123) - .list() - .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 0) - .nOut(2).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder() + .seed(123) + .list() + .layer(0, new ConvolutionLayer.Builder(kernelHeight, kernelWidth).stride(1, 0) + .nOut(2).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build()) + .layer(1, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) + .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(imageHeight, imageWidth, nChannels)); + .inputType(InputType.convolutional(imageHeight, imageWidth, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); @@ -260,11 +256,11 @@ public class ConvolutionLayerTest extends BaseDL4JTest { public void testCNNBiasInit() { ConvolutionLayer cnn = new ConvolutionLayer.Builder().nIn(1).nOut(3).biasInit(1).build(); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(cnn).build(); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(cnn).build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); assertEquals(1, layer.getParam("b").size(0)); } @@ -321,11 +317,11 @@ public class ConvolutionLayerTest extends BaseDL4JTest { ConvolutionLayer layer = new ConvolutionLayer.Builder(kernelSize, stride, padding).nIn(nIn).nOut(nOut) .activation(Activation.SIGMOID).build(); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(layer).build(); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(layer).build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); } public Layer getMNISTConfig() { @@ -695,17 +691,17 @@ public class ConvolutionLayerTest extends BaseDL4JTest { int outputNum = 10; int seed = 123; - MultiLayerConfiguration.Builder conf = - new NeuralNetConfiguration.Builder().seed(seed) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() - .layer(0, new ConvolutionLayer.Builder(new int[] {10, 10}).nOut(6).build()) - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, - new int[] {2, 2}).stride(1, 1).build()) - .layer(2, new OutputLayer.Builder( - LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) - .nOut(outputNum).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder conf = + NeuralNetConfiguration.builder().seed(seed) + .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() + .layer(0, new ConvolutionLayer.Builder(new int[] {10, 10}).nOut(6).build()) + .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, + new int[] {2, 2}).stride(1, 1).build()) + .layer(2, new OutputLayer.Builder( + LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) + .nOut(outputNum).weightInit(WeightInit.XAVIER) + .activation(Activation.SOFTMAX).build()) + .inputType(InputType.convolutionalFlat(28, 28, 1)); MultiLayerNetwork model = new MultiLayerNetwork(conf.build()); model.init(); @@ -718,14 +714,14 @@ public class ConvolutionLayerTest extends BaseDL4JTest { @Test public void test1dInputType(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .convolutionMode(ConvolutionMode.Same) .list() .layer(new Convolution1DLayer.Builder().nOut(3).kernelSize(2).activation(Activation.TANH).build()) .layer(new Subsampling1DLayer.Builder().kernelSize(2).stride(2).build()) .layer(new Upsampling1D.Builder().size(2).build()) .layer(new RnnOutputLayer.Builder().nOut(7).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(10)) + .inputType(InputType.recurrent(10)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -751,7 +747,7 @@ public class ConvolutionLayerTest extends BaseDL4JTest { @Test public void testDeconvBadInput(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new Deconvolution2D.Builder().nIn(5).nOut(3).build()) .build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java index e4921b555..c39a785c1 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/LocallyConnectedLayerTest.java @@ -25,8 +25,8 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -64,7 +64,7 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest { @Test public void test2dForward(){ - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4) .updater(new Nesterovs(0.9)).dropOut(0.5) .list() @@ -77,9 +77,9 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest { .build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS) //output layer .nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 3)); + .inputType(InputType.convolutionalFlat(28, 28, 3)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -91,7 +91,7 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest { @Test public void test1dForward(){ - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).l2(2e-4) .updater(new Nesterovs(0.9)).dropOut(0.5) .list() @@ -104,9 +104,9 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest { .build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS) //output layer .nOut(10).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(3, 8)); + .inputType(InputType.recurrent(3, 8)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -132,7 +132,7 @@ public class LocallyConnectedLayerTest extends BaseDL4JTest { for (int test = 0; test < 2; test++) { String msg = "Global dtype: " + globalDtype + ", network dtype: " + networkDtype + ", test=" + test; - ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder b = NeuralNetConfiguration.builder() .dataType(networkDtype) .seed(123) .updater(new NoOp()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java index 0ee4e322f..ed8e8c99d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java @@ -31,8 +31,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import java.util.Arrays; - import static org.junit.jupiter.api.Assertions.*; public class SpaceToDepthTest extends BaseDL4JTest { @@ -61,10 +59,10 @@ public class SpaceToDepthTest extends BaseDL4JTest { } private Layer getSpaceToDepthLayer() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); + return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java index 75434a4c3..9fda734eb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java @@ -24,8 +24,8 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; @@ -44,8 +44,6 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; -import java.util.Arrays; - import static org.junit.jupiter.api.Assertions.*; /** @@ -170,11 +168,11 @@ public class SubsamplingLayerTest extends BaseDL4JTest { ////////////////////////////////////////////////////////////////////////////////// private Layer getSubsamplingLayer(SubsamplingLayer.PoolingType pooling) { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new SubsamplingLayer.Builder(pooling, new int[] {2, 2}).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); + return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } public INDArray getData() throws Exception { @@ -214,23 +212,23 @@ public class SubsamplingLayerTest extends BaseDL4JTest { int kernelWidth = 3; DataSet trainInput; - MultiLayerConfiguration.Builder builder = - new NeuralNetConfiguration.Builder().seed(123).list() - .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( - kernelHeight, kernelWidth).stride(1, 1).nOut(2) - .activation(Activation.RELU).weightInit( - WeightInit.XAVIER) - .build()) - .layer(1, new SubsamplingLayer.Builder() - .poolingType(SubsamplingLayer.PoolingType.MAX) - .kernelSize(imageHeight - kernelHeight + 2, 1) //imageHeight-kernelHeight+1 is ok: full height - .stride(1, 1).build()) - .layer(2, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).build()) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = + NeuralNetConfiguration.builder().seed(123).list() + .layer(0, new org.deeplearning4j.nn.conf.layers.ConvolutionLayer.Builder( + kernelHeight, kernelWidth).stride(1, 1).nOut(2) + .activation(Activation.RELU).weightInit( + WeightInit.XAVIER) + .build()) + .layer(1, new SubsamplingLayer.Builder() + .poolingType(SubsamplingLayer.PoolingType.MAX) + .kernelSize(imageHeight - kernelHeight + 2, 1) //imageHeight-kernelHeight+1 is ok: full height + .stride(1, 1).build()) + .layer(2, new OutputLayer.Builder().nOut(classes).weightInit(WeightInit.XAVIER) + .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(imageHeight, imageWidth, nChannels)); + .inputType(InputType.convolutional(imageHeight, imageWidth, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java index 35ba6d924..61f937cec 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java @@ -25,7 +25,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -79,7 +78,7 @@ public class TestConvolutionModes extends BaseDL4JTest { inputData.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 9), NDArrayIndex.interval(0, 9)).assign(origData); - Layer layer; + LayerConfiguration layer; if (isSubsampling) { layer = new SubsamplingLayer.Builder().kernelSize(3, 3).stride(3, 3).padding(0, 0) .build(); @@ -90,15 +89,15 @@ public class TestConvolutionModes extends BaseDL4JTest { MultiLayerNetwork net = null; try { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(cm).list() .layer(0, layer).layer(1, new OutputLayer.Builder() .activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT) .nOut(3).build()) - .setInputType(InputType.convolutional(inSize, inSize, + .inputType(InputType.convolutional(inSize, inSize, inDepth)) .build(); @@ -158,7 +157,7 @@ public class TestConvolutionModes extends BaseDL4JTest { inputData.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 9), NDArrayIndex.interval(0, 9)).assign(origData); - Layer layer; + LayerConfiguration layer; if (isSubsampling) { layer = new SubsamplingLayer.Builder().kernelSize(3, 3).stride(3, 3).padding(0, 0) .build(); @@ -169,7 +168,7 @@ public class TestConvolutionModes extends BaseDL4JTest { ComputationGraph net = null; try { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER).convolutionMode(cm).graphBuilder() .addInputs("in").addLayer("0", layer, "in") .addLayer("1", new OutputLayer.Builder() @@ -210,7 +209,7 @@ public class TestConvolutionModes extends BaseDL4JTest { @Test public void testGlobalLocalConfig() { for (ConvolutionMode cm : new ConvolutionMode[] {ConvolutionMode.Strict, ConvolutionMode.Truncate}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(cm).list() .layer(0, new ConvolutionLayer.Builder().kernelSize(3, 3).stride(3, 3).padding(0, 0) .nIn(3).nOut( @@ -258,7 +257,7 @@ public class TestConvolutionModes extends BaseDL4JTest { public void testGlobalLocalConfigCompGraph() { for (ConvolutionMode cm : new ConvolutionMode[] {ConvolutionMode.Strict, ConvolutionMode.Truncate, ConvolutionMode.Same}) { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(cm).graphBuilder().addInputs("in") .addLayer("0", new ConvolutionLayer.Builder().kernelSize(3, 3).stride(3, 3).padding(0, 0) .nIn(3).nOut( @@ -288,28 +287,28 @@ public class TestConvolutionModes extends BaseDL4JTest { .activation(Activation.SOFTMAX).nOut(3).build(), "7") .setOutputs("8").build(); - assertEquals(cm, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer()) + assertEquals(cm, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("0")).getNetConfiguration().getFirstLayer()) .getConvolutionMode()); assertEquals(ConvolutionMode.Strict, - ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer()) + ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("1")).getNetConfiguration().getFirstLayer()) .getConvolutionMode()); assertEquals(ConvolutionMode.Truncate, - ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer()) + ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("2")).getNetConfiguration().getFirstLayer()) .getConvolutionMode()); assertEquals(ConvolutionMode.Same, - ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("3")).getLayerConf().getLayer()) + ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("3")).getNetConfiguration().getFirstLayer()) .getConvolutionMode()); - assertEquals(cm, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("4")).getLayerConf().getLayer()) + assertEquals(cm, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("4")).getNetConfiguration().getFirstLayer()) .getConvolutionMode()); assertEquals(ConvolutionMode.Strict, - ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("5")).getLayerConf().getLayer()) + ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("5")).getNetConfiguration().getFirstLayer()) .getConvolutionMode()); assertEquals(ConvolutionMode.Truncate, - ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("6")).getLayerConf().getLayer()) + ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("6")).getNetConfiguration().getFirstLayer()) .getConvolutionMode()); assertEquals(ConvolutionMode.Same, - ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("7")).getLayerConf().getLayer()) + ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("7")).getNetConfiguration().getFirstLayer()) .getConvolutionMode()); } } @@ -437,15 +436,15 @@ public class TestConvolutionModes extends BaseDL4JTest { int kH = 3; int kW = 3; - Layer[] l = new Layer[2]; + LayerConfiguration[] l = new LayerConfiguration[2]; l[0] = new ConvolutionLayer.Builder().nOut(4).kernelSize(kH, kW).stride(sH, sW).build(); l[1] = new SubsamplingLayer.Builder().kernelSize(kH, kW).stride(sH, sW).build(); for (int i = 0; i < l.length; i++) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().convolutionMode(ConvolutionMode.Same) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().convolutionMode(ConvolutionMode.Same) .list().layer(0, l[i]).layer(1, new OutputLayer.Builder().nOut(3).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(inH, inW, inDepth)).build(); + .inputType(InputType.convolutional(inH, inW, inDepth)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java index 277b43c31..5d74b94fa 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java @@ -36,8 +36,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import java.util.Arrays; - import static org.junit.jupiter.api.Assertions.*; /** @@ -106,10 +104,10 @@ public class Upsampling1DTest extends BaseDL4JTest { private Layer getUpsampling1DLayer() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new Upsampling1D.Builder(size).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, + return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java index e1d46f911..bfb872ba8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java @@ -36,8 +36,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import java.util.Arrays; - import static org.junit.jupiter.api.Assertions.*; /** @@ -110,10 +108,10 @@ public class Upsampling2DTest extends BaseDL4JTest { private Layer getUpsamplingLayer() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new Upsampling2D.Builder(size).build()).build(); - return conf.getLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); + return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } public INDArray getData() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java index 2f837fc2f..94994ea47 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomActivation.java @@ -21,21 +21,14 @@ package org.deeplearning4j.nn.layers.custom; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.layers.custom.testclasses.CustomActivation; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.introspect.AnnotatedClass; -import com.fasterxml.jackson.databind.jsontype.NamedType; - -import java.util.Collection; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -46,7 +39,7 @@ public class TestCustomActivation extends BaseDL4JTest { public void testCustomActivationFn() { //Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works... - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).activation(new CustomActivation()).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(10).build()) .build(); @@ -56,10 +49,10 @@ public class TestCustomActivation extends BaseDL4JTest { // System.out.println(json); - MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json); assertEquals(conf, confFromJson); - MultiLayerConfiguration confFromYaml = MultiLayerConfiguration.fromYaml(yaml); + NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml); assertEquals(conf, confFromYaml); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java index a0de7f2df..4ef8fab18 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java @@ -22,10 +22,8 @@ package org.deeplearning4j.nn.layers.custom; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.Layer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.custom.testclasses.CustomLayer; @@ -39,13 +37,6 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.introspect.AnnotatedClass; -import com.fasterxml.jackson.databind.jsontype.NamedType; - -import java.util.Collection; -import java.util.HashSet; -import java.util.Set; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -54,8 +45,8 @@ public class TestCustomLayers extends BaseDL4JTest { @Test public void testJsonMultiLayerNetwork() { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new CustomLayer(3.14159)).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) @@ -67,10 +58,10 @@ public class TestCustomLayers extends BaseDL4JTest { // System.out.println(json); - MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json); assertEquals(conf, confFromJson); - MultiLayerConfiguration confFromYaml = MultiLayerConfiguration.fromYaml(yaml); + NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml); assertEquals(conf, confFromYaml); } @@ -78,7 +69,7 @@ public class TestCustomLayers extends BaseDL4JTest { public void testJsonComputationGraph() { //ComputationGraph with a custom layer; check JSON and YAML config actually works... - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder() .addInputs("in").addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in") .addLayer("1", new CustomLayer(3.14159), "0").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX) @@ -103,7 +94,7 @@ public class TestCustomLayers extends BaseDL4JTest { public void checkInitializationFF() { //Actually create a network with a custom layer; check initialization and forward pass - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(9).nOut(10).build()).layer(1, new CustomLayer(3.14159)) //hard-coded nIn/nOut of 10 .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(10).nOut(11).build()) .build(); @@ -125,8 +116,8 @@ public class TestCustomLayers extends BaseDL4JTest { @Test public void testCustomOutputLayerMLN() { //Second: let's create a MultiLayerCofiguration with one, and check JSON and YAML config actually works... - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new CustomOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) @@ -138,10 +129,10 @@ public class TestCustomLayers extends BaseDL4JTest { // System.out.println(json); - MultiLayerConfiguration confFromJson = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration confFromJson = NeuralNetConfiguration.fromJson(json); assertEquals(conf, confFromJson); - MultiLayerConfiguration confFromYaml = MultiLayerConfiguration.fromYaml(yaml); + NeuralNetConfiguration confFromYaml = NeuralNetConfiguration.fromYaml(yaml); assertEquals(conf, confFromYaml); //Third: check initialization @@ -152,8 +143,8 @@ public class TestCustomLayers extends BaseDL4JTest { assertTrue(net.getLayer(1) instanceof CustomOutputLayerImpl); //Fourth: compare to an equivalent standard output layer (should be identical) - MultiLayerConfiguration conf2 = - new NeuralNetConfiguration.Builder().seed(12345).weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf2 = + NeuralNetConfiguration.builder().seed(12345).weightInit(WeightInit.XAVIER) .list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) @@ -178,7 +169,7 @@ public class TestCustomLayers extends BaseDL4JTest { @Test public void testCustomOutputLayerCG() { //Create a ComputationGraphConfiguration with custom output layer, and check JSON and YAML config actually works... - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1", new CustomOutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10) @@ -205,7 +196,7 @@ public class TestCustomLayers extends BaseDL4JTest { assertTrue(net.getLayer(1) instanceof CustomOutputLayerImpl); //Fourth: compare to an equivalent standard output layer (should be identical) - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345) .graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(10).nOut(10).build(), "in").addLayer("1", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(10).nOut(10) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java index 1eacc4d20..f3b201d63 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java @@ -57,9 +57,9 @@ public class CustomLayer extends FeedForwardLayer { ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(conf); return ret; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java index 88972c96a..b64a341d8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java @@ -29,7 +29,6 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; -import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -56,9 +55,9 @@ public class CustomOutputLayer extends BaseOutputLayer { ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(conf); return ret; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java index 25c8074a8..382476fc9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java @@ -24,7 +24,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -52,11 +51,11 @@ public class DenseTest extends BaseDL4JTest { public void testDenseBiasInit() { DenseLayer build = new DenseLayer.Builder().nIn(1).nOut(3).biasInit(1).build(); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(build).build(); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(build).build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, Nd4j.defaultFloatingPointType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, Nd4j.defaultFloatingPointType()); assertEquals(1, layer.getParam("b").size(0)); } @@ -124,7 +123,7 @@ public class DenseTest extends BaseDL4JTest { int outputNum = 3; long seed = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(seed) .updater(new Sgd(1e-3)).l1(0.3).l2(1e-3).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(numInputs).nOut(3) .activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index 55c26b12b..60c4e3b0d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -25,7 +25,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -47,7 +46,6 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Random; @@ -60,7 +58,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { public void testEmbeddingLayerConfig() { for (boolean hasBias : new boolean[]{true, false}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(0, new EmbeddingLayer.Builder().hasBias(hasBias).nIn(10).nOut(5).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .build(); @@ -71,8 +69,8 @@ public class EmbeddingLayerTest extends BaseDL4JTest { Layer l0 = net.getLayer(0); assertEquals(org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer.class, l0.getClass()); - assertEquals(10, ((FeedForwardLayer) l0.conf().getLayer()).getNIn()); - assertEquals(5, ((FeedForwardLayer) l0.conf().getLayer()).getNOut()); + assertEquals(10, ((FeedForwardLayer) l0.getLayerConfiguration()).getNIn()); + assertEquals(5, ((FeedForwardLayer) l0.getLayerConfiguration()).getNOut()); INDArray weights = l0.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray bias = l0.getParam(DefaultParamInitializer.BIAS_KEY); @@ -92,7 +90,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int nout = 4; for (boolean hasBias : new boolean[]{true, false}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(new EmbeddingSequenceLayer.Builder().hasBias(hasBias) .inputLength(inputLength).nIn(nIn).nOut(embeddingDim).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nout).activation(Activation.SOFTMAX).build()) @@ -104,8 +102,8 @@ public class EmbeddingLayerTest extends BaseDL4JTest { Layer l0 = net.getLayer(0); assertEquals(org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer.class, l0.getClass()); - assertEquals(10, ((FeedForwardLayer) l0.conf().getLayer()).getNIn()); - assertEquals(5, ((FeedForwardLayer) l0.conf().getLayer()).getNOut()); + assertEquals(10, ((FeedForwardLayer) l0.getLayerConfiguration()).getNIn()); + assertEquals(5, ((FeedForwardLayer) l0.getLayerConfiguration()).getNOut()); INDArray weights = l0.getParam(DefaultParamInitializer.WEIGHT_KEY); INDArray bias = l0.getParam(DefaultParamInitializer.BIAS_KEY); @@ -124,7 +122,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int embeddingDim = 5; int nOut = 4; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(new EmbeddingSequenceLayer.Builder().inputLength(inputLength) .hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) @@ -155,12 +153,12 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int embeddingDim = 5; int nOut = 4; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(new EmbeddingSequenceLayer.Builder().inputLength(1) .hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) @@ -204,11 +202,11 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int nClassesIn = 10; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(0, new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) .layer(1, new OutputLayer.Builder().nIn(5).nOut(4).activation(Activation.SOFTMAX).build()) .build(); @@ -247,12 +245,12 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int nClassesIn = 10; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4) .activation(Activation.SOFTMAX).build()) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH) + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().activation(Activation.TANH) .weightInit(WeightInit.XAVIER).list() .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(5).activation(Activation.IDENTITY).build()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(4) @@ -308,16 +306,16 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int nOut = 4; int inputLength = 1; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(new EmbeddingSequenceLayer.Builder().inputLength(inputLength) .hasBias(true).nIn(nClassesIn).nOut(embeddingDim).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(nClassesIn,inputLength,RNNFormat.NCW)) + .inputType(InputType.recurrent(nClassesIn,inputLength,RNNFormat.NCW)) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH).list() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().activation(Activation.TANH).list() .layer(new DenseLayer.Builder().nIn(nClassesIn).nOut(embeddingDim).activation(Activation.IDENTITY).build()) .layer(new RnnOutputLayer.Builder().nIn(embeddingDim).nOut(nOut).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(nClassesIn,inputLength,RNNFormat.NCW)) + .inputType(InputType.recurrent(nClassesIn,inputLength,RNNFormat.NCW)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -368,7 +366,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { int batchSize = 3; int timeSeriesLength = 8; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().activation(Activation.TANH) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().activation(Activation.TANH) .dataType(DataType.DOUBLE) .list() .layer(0, new EmbeddingLayer.Builder().hasBias(true).nIn(nClassesIn).nOut(5).build()) @@ -377,9 +375,9 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .activation(Activation.SOFTMAX).build()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(1, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(nClassesIn,timeSeriesLength, RNNFormat.NCW)) + .inputType(InputType.recurrent(nClassesIn,timeSeriesLength, RNNFormat.NCW)) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().activation(Activation.TANH) + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .dataType(DataType.DOUBLE) .list() @@ -389,7 +387,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .activation(Activation.SOFTMAX).build()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(1, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(nClassesIn,timeSeriesLength, RNNFormat.NCW)) + .inputType(InputType.recurrent(nClassesIn,timeSeriesLength, RNNFormat.NCW)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -452,7 +450,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { for (int nExamples : miniBatchSizes) { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.1)).seed(12345).list() .layer(0, new EmbeddingLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses) @@ -463,13 +461,13 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .nOut(4).build()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(numInputClasses,timeSeriesLength, RNNFormat.NCW)) + .inputType(InputType.recurrent(numInputClasses,timeSeriesLength, RNNFormat.NCW)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.1)).seed(12345).list() .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5) @@ -480,7 +478,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .nOut(4).build()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(numInputClasses,timeSeriesLength, RNNFormat.NCW)) + .inputType(InputType.recurrent(numInputClasses,timeSeriesLength, RNNFormat.NCW)) .build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); @@ -553,7 +551,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { el = new EmbeddingLayer.Builder().weightInit(new WordVectorsMockup()).build(); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345).list() .layer(el) .layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(3).nOut(3).build()) @@ -577,7 +575,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { esl = new EmbeddingSequenceLayer.Builder().weightInit(new WordVectorsMockup()).build(); } - conf = new NeuralNetConfiguration.Builder() + conf = NeuralNetConfiguration.builder() .seed(12345).list() .layer(esl) .layer(new GlobalPoolingLayer()) @@ -614,7 +612,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { for (int nExamples : miniBatchSizes) { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.1)).seed(12345).list() .layer(0, new EmbeddingSequenceLayer.Builder().hasBias(true).activation(Activation.TANH).nIn(numInputClasses) @@ -623,12 +621,12 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).build()) .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) .nOut(4).build()) - .setInputType(InputType.recurrent(numInputClasses,timeSeriesLength,RNNFormat.NCW)).build(); + .inputType(InputType.recurrent(numInputClasses,timeSeriesLength,RNNFormat.NCW)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.1)).seed(12345).list() .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(numInputClasses).nOut(5) @@ -637,7 +635,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { .layer(2, new LSTM.Builder().activation(Activation.TANH).nIn(4).nOut(3).dataFormat(RNNFormat.NCW).build()) .layer(3, new RnnOutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(3) .nOut(4).build()) - .setInputType(InputType.recurrent(numInputClasses,1,RNNFormat.NCW)).build(); + .inputType(InputType.recurrent(numInputClasses,1,RNNFormat.NCW)).build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); @@ -722,7 +720,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { @Test public void testEmbeddingDefaultActivation(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new EmbeddingLayer.Builder().nIn(10).nOut(10).build()) .layer(new EmbeddingSequenceLayer.Builder().nIn(10).nOut(10).build()) @@ -747,7 +745,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { for (boolean seq : new boolean[]{false, true}) { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .list() .layer(seq ? @@ -758,7 +756,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.init(); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .seed(12345) .list() .layer(seq ? @@ -769,7 +767,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net2.init(); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf3 = NeuralNetConfiguration.builder() .seed(12345) .list() .layer(seq ? diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java index c4950d3c4..e6f85611a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java @@ -29,7 +29,6 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -129,14 +128,14 @@ public class BatchNormalizationTest extends BaseDL4JTest { b.lockGammaBeta(true).gamma(gamma).beta(beta); } BatchNormalization bN = b.build(); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer(bN).build(); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(bN).build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = null; if (numParams > 0) { params = Nd4j.create(1, numParams); } - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params == null ? Nd4j.defaultFloatingPointType() : params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params == null ? Nd4j.defaultFloatingPointType() : params.dataType()); if (numParams > 0) { layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); } @@ -365,7 +364,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { DataSet next = iter.next(); // Run with separate activation layer - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new DenseLayer.Builder().nIn(28 * 28).nOut(10).weightInit(WeightInit.XAVIER) @@ -397,7 +396,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { DataSetIterator iter = new MnistDataSetIterator(2, 2); DataSet next = iter.next(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) .list() .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER) @@ -406,7 +405,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { .layer(2, new ActivationLayer.Builder().activation(Activation.RELU).build()) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -422,7 +421,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { //Serialize the batch norm network (after training), and make sure we get same activations out as before // i.e., make sure state is properly stored - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .list() .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER) @@ -433,7 +432,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { .layer(4, new BatchNormalization.Builder().build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -461,9 +460,9 @@ public class BatchNormalizationTest extends BaseDL4JTest { public void testGradientAndUpdaters() throws Exception { //Global mean/variance are part of the parameter vector. Expect 0 gradient, and no-op updater for these - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345).list() + .updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()).seed(12345).list() .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY).build()) .layer(1, new BatchNormalization.Builder().build()) @@ -472,7 +471,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { .layer(4, new BatchNormalization.Builder().build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -519,9 +518,9 @@ public class BatchNormalizationTest extends BaseDL4JTest { for(boolean useLogStd : new boolean[]{true, false}) { //First, Mnist data as 2d input (NOT taking into account convolution property) - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345) + .updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()).seed(12345) .list().layer(0, new BatchNormalization.Builder().nIn(10).nOut(10).eps(1e-5).decay(0.95) .useLogStd(useLogStd).build()) @@ -586,13 +585,13 @@ public class BatchNormalizationTest extends BaseDL4JTest { //Check that the internal global mean/variance estimate is approximately correct //First, Mnist data as 2d input (NOT taking into account convolution property) - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345).list() + .updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()).seed(12345).list() .layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(useLogStd).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY).nOut(10).build()) - .setInputType(InputType.convolutional(5, 5, 3)).build(); + .inputType(InputType.convolutional(5, 5, 3)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -649,24 +648,24 @@ public class BatchNormalizationTest extends BaseDL4JTest { //Check that the internal global mean/variance estimate is approximately correct //First, Mnist data as 2d input (NOT taking into account convolution property) - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345).list() + .updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()).seed(12345).list() .layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(false).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY).nOut(10).build()) - .setInputType(InputType.convolutional(5, 5, 3)).build(); + .inputType(InputType.convolutional(5, 5, 3)).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .updater(Updater.RMSPROP).seed(12345).list() + .updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()).seed(12345).list() .layer(0, new BatchNormalization.Builder().nIn(3).nOut(3).eps(1e-5).decay(0.95).useLogStd(true).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY).nOut(10).build()) - .setInputType(InputType.convolutional(5, 5, 3)).build(); + .inputType(InputType.convolutional(5, 5, 3)).build(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); @@ -691,7 +690,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { @Test public void testBatchNorm() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(1e-3)) .activation(Activation.TANH) @@ -700,7 +699,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { .layer(new BatchNormalization()) .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(2, 2).build()) .layer(new OutputLayer.Builder().activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).nOut(10).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) + .inputType(InputType.convolutionalFlat(28, 28, 1)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -728,7 +727,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { for (boolean rnn : new boolean[]{true, false}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .convolutionMode(ConvolutionMode.Same) @@ -737,7 +736,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { new Convolution1DLayer.Builder().kernelSize(3).stride(1).nOut(3).build()) .layer(new BatchNormalization()) .layer(new RnnOutputLayer.Builder().nOut(3).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) - .setInputType(InputType.recurrent(3)) + .inputType(InputType.recurrent(3)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -757,7 +756,7 @@ public class BatchNormalizationTest extends BaseDL4JTest { @Test public void testInputValidation() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new BatchNormalization.Builder().nIn(10).nOut(10).build()) .build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java index e876b736b..c2f8cb3c4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/LocalResponseTest.java @@ -25,7 +25,6 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -110,7 +109,7 @@ public class LocalResponseTest extends BaseDL4JTest { @BeforeEach public void doBefore() { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new LocalResponseNormalization.Builder().k(2).n(5).alpha(1e-4).beta(0.75).build()) .build(); @@ -140,7 +139,7 @@ public class LocalResponseTest extends BaseDL4JTest { public void testRegularization() { // Confirm a structure with regularization true will not throw an error - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).l1(0.2) .l2(0.1).seed(123) .layer(new LocalResponseNormalization.Builder().k(2).n(5).alpha(1e-4).beta(0.75).build()) @@ -149,7 +148,7 @@ public class LocalResponseTest extends BaseDL4JTest { @Test public void testMultiCNNLayer() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list() .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(6).weightInit(WeightInit.XAVIER) .activation(Activation.RELU).build()) @@ -159,7 +158,7 @@ public class LocalResponseTest extends BaseDL4JTest { .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(10) .build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)).build(); + .inputType(InputType.convolutionalFlat(28, 28, 1)).build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -203,7 +202,7 @@ public class LocalResponseTest extends BaseDL4JTest { } LocalResponseNormalization lrn = new LocalResponseNormalization.Builder().build(); - NeuralNetConfiguration nnc = new NeuralNetConfiguration.Builder().layer(lrn).build(); + NeuralNetConfiguration nnc = NeuralNetConfiguration.builder().layer(lrn).build(); org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization layer = (org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization) lrn.instantiate(nnc, null, 0, null, false, Nd4j.defaultFloatingPointType()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java index c732ab366..558041072 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java @@ -34,7 +34,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -88,7 +87,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { INDArray bbPrior = Nd4j.rand(b, 2).muliRowVector(Nd4j.create(new double[]{w, h})); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .l2(0.01) .list() .layer(new ConvolutionLayer.Builder().nIn(depth).nOut(depth).kernelSize(1,1).build()) @@ -177,7 +176,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { INDArray bbPrior = Nd4j.rand(b, 2).muliRowVector(Nd4j.create(new double[]{w, h})); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new ConvolutionLayer.Builder().nIn(1).nOut(1).kernelSize(1,1).build()) .layer(new Yolo2OutputLayer.Builder() @@ -335,7 +334,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { //Check IOU calculation - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new ConvolutionLayer.Builder().kernelSize(3,3).stride(1,1).nIn(3).nOut(3).build()) .layer(new Yolo2OutputLayer.Builder() @@ -495,7 +494,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { DataSetIterator iter = new RecordReaderDataSetIterator(rr,1,1,1,true); iter.setPreProcessor(new ImagePreProcessingScaler()); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .convolutionMode(ConvolutionMode.Same) .updater(new Adam(2e-3)) .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) @@ -510,7 +509,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { .layer(new Yolo2OutputLayer.Builder() .boundingBoxPriors(bbPriors) .build()) - .setInputType(InputType.convolutional(h,w,c)) + .inputType(InputType.convolutional(h,w,c)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java index 0eaa156f1..c989d0bf5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.ocnn; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.gradientcheck.GradientCheckUtil; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -35,7 +34,6 @@ import org.junit.jupiter.api.io.TempDir; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationReLU; import org.nd4j.linalg.activations.impl.ActivationSigmoid; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -43,10 +41,7 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.learning.config.NoOp; -import org.nd4j.linalg.schedule.ScheduleType; -import org.nd4j.linalg.schedule.StepSchedule; import java.io.File; import java.util.UUID; @@ -128,7 +123,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { DataSet filtered = next.filterBy(new int[]{0, 1}); for (int i = 0; i < 10; i++) { network.setEpochCount(i); - network.getLayerWiseConfigurations().setEpochCount(i); + network.getConfiguration().setEpochCount(i); network.fit(filtered); } @@ -170,7 +165,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { private MultiLayerNetwork getSingleLayer() { int numHidden = 2; - MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration configuration = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .miniBatch(true) @@ -182,8 +177,9 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { // 1e-2, // 0.1, // 20)).build()) - .list(new DenseLayer.Builder().activation(new ActivationReLU()) - .nIn(4).nOut(2).build(), + .layer(new DenseLayer.Builder().activation(new ActivationReLU()) + .nIn(4).nOut(2).build()) + .layer( new org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer.Builder() .nIn(2).activation(new ActivationSigmoid()).initialRValue(0.1) .nu(0.1) @@ -197,10 +193,11 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { public MultiLayerNetwork getGradientCheckNetwork(int numHidden) { - MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration configuration = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .seed(42).updater(new NoOp()).miniBatch(false) - .list(new DenseLayer.Builder().activation(new ActivationIdentity()).nIn(4).nOut(4).build(), + .layer(new DenseLayer.Builder().activation(new ActivationIdentity()).nIn(4).nOut(4).build()) + .layer( new org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer.Builder().nIn(4) .nu(0.002).activation(new ActivationSigmoid()) .hiddenLayerSize(numHidden).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java index a7f3d1867..86d695f3d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingMaskingTests.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.layers.pooling; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.*; @@ -59,7 +58,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { for (int miniBatchSize : minibatchSizes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new NoOp()) .dist(new NormalDistribution(0, 1.0)).seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -123,7 +122,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { new PoolingType[] {PoolingType.SUM, PoolingType.AVG, PoolingType.MAX, PoolingType.PNORM}; for (PoolingType pt : poolingTypes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(ConvolutionMode.Same).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder().nIn(depthIn).nOut(depthOut).kernelSize(height, 2) .stride(height, 1).activation(Activation.TANH).build()) @@ -186,7 +185,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { new PoolingType[] {PoolingType.SUM, PoolingType.AVG, PoolingType.MAX, PoolingType.PNORM}; for (PoolingType pt : poolingTypes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(ConvolutionMode.Same).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder().nIn(depthIn).nOut(depthOut).kernelSize(2, width) .stride(1, width).activation(Activation.TANH).build()) @@ -250,7 +249,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { new PoolingType[] {PoolingType.SUM, PoolingType.AVG, PoolingType.MAX, PoolingType.PNORM}; for (PoolingType pt : poolingTypes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(ConvolutionMode.Same).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder().nIn(depthIn).nOut(depthOut).kernelSize(height, 2) .stride(height, 1).activation(Activation.TANH).build()) @@ -309,7 +308,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { new PoolingType[] {PoolingType.SUM, PoolingType.AVG, PoolingType.MAX, PoolingType.PNORM}; for (PoolingType pt : poolingTypes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(ConvolutionMode.Same).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder().nIn(depthIn).nOut(depthOut).kernelSize(2, width) .stride(1, width).activation(Activation.TANH).build()) @@ -368,7 +367,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { new PoolingType[] {PoolingType.SUM, PoolingType.AVG, PoolingType.MAX, PoolingType.PNORM}; for (PoolingType pt : poolingTypes) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(ConvolutionMode.Same).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder().nIn(depthIn).nOut(depthOut).kernelSize(2, 2) .stride(1, 1).activation(Activation.TANH).build()) @@ -434,7 +433,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { for(PoolingType pt : PoolingType.values()) { //System.out.println("Net: " + networkDtype + ", mask: " + dt + ", pt=" + pt); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new GlobalPoolingLayer(pt)) .layer(new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) @@ -447,7 +446,7 @@ public class GlobalPoolingMaskingTests extends BaseDL4JTest { net.output(in, false, mask, null); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .list() .layer(new RnnOutputLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index e785b36e5..8e329077c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -85,7 +85,7 @@ public class BidirectionalTest extends BaseDL4JTest { //Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params //Note that GravesBidirectionalLSTM implements ADD mode only - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .trainingWorkspaceMode(wsm) @@ -98,7 +98,7 @@ public class BidirectionalTest extends BaseDL4JTest { .nIn(10).nOut(10).build()) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .trainingWorkspaceMode(wsm) @@ -189,7 +189,7 @@ public class BidirectionalTest extends BaseDL4JTest { //Bidirectional(GravesLSTM) and GravesBidirectionalLSTM should be equivalent, given equivalent params //Note that GravesBidirectionalLSTM implements ADD mode only - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf1 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .updater(new Adam()) @@ -204,7 +204,7 @@ public class BidirectionalTest extends BaseDL4JTest { .setOutputs("2") .build(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .updater(new Adam()) @@ -288,7 +288,7 @@ public class BidirectionalTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .trainingWorkspaceMode(wsm) @@ -354,7 +354,7 @@ public class BidirectionalTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf1 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .trainingWorkspaceMode(wsm) @@ -422,7 +422,7 @@ public class BidirectionalTest extends BaseDL4JTest { INDArray in = Nd4j.rand(inshape); for (Bidirectional.Mode m : modes) { - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) @@ -436,7 +436,7 @@ public class BidirectionalTest extends BaseDL4JTest { MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); net1.init(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) @@ -548,7 +548,7 @@ public class BidirectionalTest extends BaseDL4JTest { for (Bidirectional.Mode m : modes) { - ComputationGraphConfiguration conf1 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf1 = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) @@ -564,7 +564,7 @@ public class BidirectionalTest extends BaseDL4JTest { ComputationGraph net1 = new ComputationGraph(conf1); net1.init(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) @@ -680,7 +680,7 @@ public class BidirectionalTest extends BaseDL4JTest { int in = 2; int out = 2; - ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder builder = NeuralNetConfiguration.builder() .updater(new Adam(0.01)) .activation(Activation.RELU) .graphBuilder() diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index bd1291216..d51fc5280 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -24,7 +24,6 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; @@ -64,15 +63,15 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { int nIn = 13; int nHiddenUnits = 17; - final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + final NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) .nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build()) .build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); final GravesBidirectionalLSTM layer = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + (GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; @@ -130,17 +129,17 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { INDArray inputData = (rnnDataFormat == RNNFormat.NCW)?Nd4j.ones(miniBatchSize, nIn, timeSeriesLength): Nd4j.ones(miniBatchSize, timeSeriesLength, nIn); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) .nOut(lstmNHiddenUnits).dataFormat(rnnDataFormat) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); GravesBidirectionalLSTM lstm = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getLayer().initializer().numParams(conf))); + (GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFirstLayer().initializer().numParams(conf))); //Set input, do a forward pass: lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(lstm.input()); @@ -202,21 +201,21 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final int miniBatchSize = 4; final int timeSeriesLength = 7; - final NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + final NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) .nOut(layerSize) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); final GravesBidirectionalLSTM lstm = - (GravesBidirectionalLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + (GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); final INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); - final INDArray fwdPassFalse = LSTMHelpers.activateHelper(lstm, lstm.conf(), new ActivationSigmoid(), + final INDArray fwdPassFalse = LSTMHelpers.activateHelper(lstm, lstm.getNetConfiguration(), new ActivationSigmoid(), lstm.input(), lstm.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), lstm.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), @@ -224,7 +223,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { false, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, null, true, null, CacheMode.NONE, LayerWorkspaceMgr.noWorkspaces(), true).fwdPassOutput; - final INDArray[] fwdPassTrue = LSTMHelpers.activateHelper(lstm, lstm.conf(), new ActivationSigmoid(), + final INDArray[] fwdPassTrue = LSTMHelpers.activateHelper(lstm, lstm.getNetConfiguration(), new ActivationSigmoid(), lstm.input(), lstm.getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), lstm.getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), @@ -260,16 +259,16 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); - final NeuralNetConfiguration confBidirectional = new NeuralNetConfiguration.Builder() + final NeuralNetConfiguration confBidirectional = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder().nIn(nIn) .nOut(layerSize).dataFormat(rnnDataFormat) .dist(new UniformDistribution(-0.1, 0.1)).activation(Activation.TANH).build()) .build(); - long numParams = confBidirectional.getLayer().initializer().numParams(confBidirectional); + long numParams = confBidirectional.getFirstLayer().initializer().numParams(confBidirectional); INDArray params = Nd4j.create(1, numParams); - final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getLayer() + final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFirstLayer() .instantiate(confBidirectional, null, 0, params, true, params.dataType()); @@ -280,7 +279,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { params = bidirectionalLSTM.params(); - bidirectionalLSTM.setParams(params); + bidirectionalLSTM.setParamsTable(params); final INDArray act2 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); @@ -300,31 +299,31 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); final NeuralNetConfiguration confBidirectional = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() .nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat) .dist(new UniformDistribution(-0.1, 0.1)) .activation(Activation.TANH).updater(new NoOp()).build()) .build(); - final NeuralNetConfiguration confForwards = new NeuralNetConfiguration.Builder() + final NeuralNetConfiguration confForwards = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize).dataFormat(rnnDataFormat) .weightInit(WeightInit.ZERO).activation(Activation.TANH).build()) .build(); - long numParams = confForwards.getLayer().initializer().numParams(confForwards); + long numParams = confForwards.getFirstLayer().initializer().numParams(confForwards); INDArray params = Nd4j.create(1, numParams); - long numParamsBD = confBidirectional.getLayer().initializer().numParams(confBidirectional); + long numParamsBD = confBidirectional.getFirstLayer().initializer().numParams(confBidirectional); INDArray paramsBD = Nd4j.create(1, numParamsBD); - final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getLayer() + final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFirstLayer() .instantiate(confBidirectional, null, 0, paramsBD, true, params.dataType()); final GravesLSTM forwardsLSTM = - (GravesLSTM) confForwards.getLayer().instantiate(confForwards, null, 0, params, true, params.dataType()); + (GravesLSTM) confForwards.getFirstLayer().instantiate(confForwards, null, 0, params, true, params.dataType()); bidirectionalLSTM.setBackpropGradientsViewArray( - Nd4j.create(1, confBidirectional.getLayer().initializer().numParams(confBidirectional))); + Nd4j.create(1, confBidirectional.getFirstLayer().initializer().numParams(confBidirectional))); forwardsLSTM.setBackpropGradientsViewArray( - Nd4j.create(1, confForwards.getLayer().initializer().numParams(confForwards))); + Nd4j.create(1, confForwards.getFirstLayer().initializer().numParams(confForwards))); final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(miniBatchSize, nIn, timeSeriesLength): @@ -501,7 +500,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { @Test public void testSerialization() { - final MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + final NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new AdaGrad(0.1)) .l2(0.001) @@ -520,7 +519,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final String json1 = conf1.toJson(); - final MultiLayerConfiguration conf2 = MultiLayerConfiguration.fromJson(json1); + final NeuralNetConfiguration conf2 = NeuralNetConfiguration.fromJson(json1); final String json2 = conf1.toJson(); @@ -532,7 +531,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { public void testGateActivationFnsSanityCheck() { for (String gateAfn : new String[] {"sigmoid", "hardsigmoid"}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .seed(12345).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM.Builder() @@ -546,8 +545,8 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) net.getLayer(0).conf() - .getLayer()).getGateActivationFn().toString()); + assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) net.getLayer(0).getNetConfiguration() + .getFirstLayer()).getGateActivationFn().toString()); INDArray in = Nd4j.rand(3, 2, 5); INDArray labels = Nd4j.rand(3, 2, 5); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java index 679066755..2868c08d8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java @@ -24,7 +24,6 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.gradient.Gradient; @@ -59,14 +58,14 @@ public class GravesLSTMTest extends BaseDL4JTest { int nHiddenUnits = 17; NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn) .nOut(nHiddenUnits).activation(Activation.TANH).build()) .build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM layer = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + GravesLSTM layer = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; @@ -104,16 +103,16 @@ public class GravesLSTMTest extends BaseDL4JTest { INDArray inputData = Nd4j.ones(miniBatchSize, nIn, timeSeriesLength); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn) .nOut(lstmNHiddenUnits) .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); - lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getLayer().initializer().numParams(conf))); + GravesLSTM lstm = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFirstLayer().initializer().numParams(conf))); //Set input, do a forward pass: lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(lstm.input()); @@ -155,15 +154,15 @@ public class GravesLSTMTest extends BaseDL4JTest { int miniBatchSize = 4; int timeSeriesLength = 7; - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(layerSize) .dist(new UniformDistribution(0, 1)) .activation(Activation.TANH).build()) .build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM lstm = (GravesLSTM) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + GravesLSTM lstm = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); @@ -197,7 +196,7 @@ public class GravesLSTMTest extends BaseDL4JTest { public void testSingleExample() { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.1)).seed(12345).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().activation(Activation.TANH) @@ -254,7 +253,7 @@ public class GravesLSTMTest extends BaseDL4JTest { public void testGateActivationFnsSanityCheck() { for (String gateAfn : new String[] {"sigmoid", "hardsigmoid"}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .seed(12345).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder() @@ -268,7 +267,7 @@ public class GravesLSTMTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesLSTM) net.getLayer(0).conf().getLayer()) + assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesLSTM) net.getLayer(0).getLayerConfiguration()) .getGateActivationFn().toString()); INDArray in = Nd4j.rand(3, 2, 5); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java index f1fa71ab2..1a3bcbc65 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java @@ -23,13 +23,11 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.deeplearning4j.optimize.api.TrainingListener; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -71,7 +69,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest { .nIn(2) .nOut(1).dataFormat(rnnDataFormat) .build(); - NeuralNetConfiguration conf = new NeuralNetConfiguration(); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().build(); conf.setLayer(underlying); INDArray params = Nd4j.zeros(1, 16); @@ -108,7 +106,7 @@ public class MaskZeroLayerTest extends BaseDL4JTest { @Test public void testSerialization(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer.Builder() .setMaskValue(0.0).setUnderlying(new LSTM.Builder().nIn(4).nOut(5).dataFormat(rnnDataFormat).build()).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java index 2b5280339..c6b315cb5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java @@ -238,14 +238,14 @@ public class RnnDataFormatTests extends BaseDL4JTest { return getNetWithLayer(new SimpleRnn.Builder().nOut(3).build(), format, lastTimeStep, maskZeros); } } - private MultiLayerNetwork getNetWithLayer(Layer layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) { + private MultiLayerNetwork getNetWithLayer(LayerConfiguration layer, RNNFormat format, boolean lastTimeStep, boolean maskZeros) { if (maskZeros){ layer = new MaskZeroLayer.Builder().setMaskValue(0.).setUnderlying(layer).build(); } if(lastTimeStep){ layer = new LastTimeStep(layer); } - NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder() .seed(12345) .list() .layer(new LSTM.Builder() @@ -260,7 +260,7 @@ public class RnnDataFormatTests extends BaseDL4JTest { (lastTimeStep)?new OutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).build(): new RnnOutputLayer.Builder().activation(Activation.SOFTMAX).nOut(10).dataFormat(format).build() ) - .setInputType(InputType.recurrent(3, 12, format)); + .inputType(InputType.recurrent(3, 12, format)); MultiLayerNetwork net = new MultiLayerNetwork(builder.build()); net.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java index 4abcfa768..7755790e4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestLastTimeStepLayer.java @@ -62,7 +62,7 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { @Test public void testLastTimeStepVertex() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("lastTS", new LastTimeStep(new SimpleRnn.Builder() .nIn(5).nOut(6).dataFormat(rnnDataFormat).build()), "in") .setOutputs("lastTS") @@ -124,7 +124,7 @@ public class TestLastTimeStepLayer extends BaseDL4JTest { @Test public void testMaskingAndAllMasked(){ - ComputationGraphConfiguration.GraphBuilder builder = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder builder = NeuralNetConfiguration.builder() .optimizationAlgo(STOCHASTIC_GRADIENT_DESCENT) .weightInit(XAVIER_UNIFORM) .activation(TANH) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java index 951680ca7..3ea9cdbdb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRecurrentWeightInit.java @@ -40,7 +40,7 @@ public class TestRecurrentWeightInit extends BaseDL4JTest { for (boolean rwInit : new boolean[]{false, true}) { for (int i = 0; i < 3; i++) { - NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder b = NeuralNetConfiguration.builder() .weightInit(new UniformDistribution(0, 1)) .list(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index b5fd0ac57..e2b6bc359 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -22,13 +22,13 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.dropout.TestDropout; import org.deeplearning4j.nn.conf.layers.GravesLSTM; import org.deeplearning4j.nn.conf.layers.LSTM; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.RnnLossLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; @@ -67,7 +67,7 @@ public class TestRnnLayers extends BaseDL4JTest { int nIn = 12; int nOut = 3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .list() @@ -119,9 +119,9 @@ public class TestRnnLayers extends BaseDL4JTest { for(String s : layerTypes){ - Layer layer; - Layer layerD; - Layer layerD2; + LayerConfiguration layer; + LayerConfiguration layerD; + LayerConfiguration layerD2; TestDropout.CustomDropout cd = new TestDropout.CustomDropout(); switch (s){ case "graves": @@ -143,21 +143,21 @@ public class TestRnnLayers extends BaseDL4JTest { throw new RuntimeException(s); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .list() .layer(layer) .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); - MultiLayerConfiguration confD = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confD = NeuralNetConfiguration.builder() .seed(12345) .list() .layer(layerD) .layer(new RnnOutputLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).nIn(10).nOut(10).dataFormat(rnnDataFormat).build()) .build(); - MultiLayerConfiguration confD2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confD2 = NeuralNetConfiguration.builder() .seed(12345) .list() .layer(layerD2) @@ -214,9 +214,9 @@ public class TestRnnLayers extends BaseDL4JTest { for( int i=0; i<2; i++ ){ - NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder lb = NeuralNetConfiguration.builder() + - .list() .layer(new SimpleRnn.Builder().nIn(5).nOut(5).dataFormat(rnnDataFormat).build()); switch (i){ @@ -230,7 +230,7 @@ public class TestRnnLayers extends BaseDL4JTest { throw new RuntimeException(); } - MultiLayerConfiguration conf = lb.build(); + NeuralNetConfiguration conf = lb.build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java index 9d77537c8..2abd86487 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestSimpleRnn.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; @@ -68,7 +67,7 @@ public class TestSimpleRnn extends BaseDL4JTest { in = Nd4j.rand(DataType.FLOAT, m, tsLength, nIn); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) @@ -126,7 +125,7 @@ public class TestSimpleRnn extends BaseDL4JTest { int nIn = 5; int layerSize = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new NoOp()) .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java index 90a05de95..5a31cf4df 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.WorkspaceMode; @@ -62,7 +61,7 @@ public class TestTimeDistributed extends BaseDL4JTest { public void testTimeDistributed(){ for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) { - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder() .trainingWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm) .seed(12345) @@ -72,10 +71,10 @@ public class TestTimeDistributed extends BaseDL4JTest { .layer(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build()) .layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(3, rnnDataFormat)) + .inputType(InputType.recurrent(3, rnnDataFormat)) .build(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .trainingWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm) .seed(12345) @@ -85,7 +84,7 @@ public class TestTimeDistributed extends BaseDL4JTest { .layer(new TimeDistributed(new DenseLayer.Builder().nIn(3).nOut(3).activation(Activation.TANH).build(), rnnDataFormat)) .layer(new RnnOutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).dataFormat(rnnDataFormat) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(3, rnnDataFormat)) + .inputType(InputType.recurrent(3, rnnDataFormat)) .build(); MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); @@ -135,7 +134,7 @@ public class TestTimeDistributed extends BaseDL4JTest { for( int rnnType=0; rnnType<3; rnnType++ ) { for( int ffType=0; ffType<3; ffType++ ) { - Layer l0, l2; + LayerConfiguration l0, l2; switch (rnnType) { case 0: l0 = new LSTM.Builder().nOut(5).build(); @@ -153,7 +152,7 @@ public class TestTimeDistributed extends BaseDL4JTest { throw new RuntimeException("Not implemented: " + rnnType); } - Layer l1; + LayerConfiguration l1; switch (ffType){ case 0: l1 = new DenseLayer.Builder().nOut(5).build(); @@ -168,13 +167,13 @@ public class TestTimeDistributed extends BaseDL4JTest { throw new RuntimeException("Not implemented: " + ffType); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .list() .layer(l0) .layer(l1) .layer(l2) - .setInputType(InputType.recurrent(5, 9, rnnDataFormat)) + .inputType(InputType.recurrent(5, 9, rnnDataFormat)) .build(); BaseRecurrentLayer l0a; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java index 7b0f6c2cf..534af7bc2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.samediff; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -75,10 +74,10 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest { @Test public void testInputValidationSameDiffLayer() { - final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().list() + final NeuralNetConfiguration config = NeuralNetConfiguration.builder().list() .layer(new ValidatingSameDiffLayer()) .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(2).build()) - .setInputType(InputType.feedForward(2)) + .inputType(InputType.feedForward(2)) .build(); final MultiLayerNetwork net = new MultiLayerNetwork(config); @@ -95,7 +94,7 @@ public class SameDiffCustomLayerTests extends BaseDL4JTest { @Test public void testInputValidationSameDiffVertex(){ - final ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().graphBuilder() + final ComputationGraphConfiguration config = NeuralNetConfiguration.builder().graphBuilder() .addVertex("a", new ValidatingSameDiffVertex(), "input") .addLayer("output", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(2).build(), "a") .addInputs("input") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java index f0d5d16ce..690c07f37 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java @@ -25,7 +25,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.gradientcheck.GradientCheckUtil; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -66,7 +65,7 @@ public class TestSameDiffConv extends BaseDL4JTest { int kH = 2; int kW = 3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new SameDiffConv.Builder().nIn(nIn).nOut(nOut).kernelSize(kH, kW).build()) .build(); @@ -128,7 +127,7 @@ public class TestSameDiffConv extends BaseDL4JTest { + ", ConvolutionMode=" + cm + ", ActFn=" + a + ", hasBias=" + hasBias; log.info("Starting test: " + msg); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .seed(12345) .list() @@ -159,9 +158,9 @@ public class TestSameDiffConv extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertNotNull(net.paramTable()); + assertNotNull(net.getParamTable()); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER) .seed(12345) @@ -193,8 +192,8 @@ public class TestSameDiffConv extends BaseDL4JTest { //Check params: note that samediff/libnd4j conv params are [kH, kW, iC, oC] //DL4J are [nOut, nIn, kH, kW] - Map params1 = net.paramTable(); - Map params2 = net2.paramTable(); + Map params1 = net.getParamTable(); + Map params2 = net2.getParamTable(); for(Map.Entry e : params1.entrySet()){ if(e.getKey().endsWith("_W")){ INDArray p1 = e.getValue(); @@ -267,7 +266,7 @@ public class TestSameDiffConv extends BaseDL4JTest { int outH = cm == ConvolutionMode.Same ? imgH : (imgH-2); int outW = cm == ConvolutionMode.Same ? imgW : (imgW-2); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .seed(12345) .updater(new NoOp()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java index 5e1949f8a..64d59c84b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java @@ -25,7 +25,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.gradientcheck.GradientCheckUtil; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -64,7 +63,7 @@ public class TestSameDiffDense extends BaseDL4JTest { int nIn = 3; int nOut = 4; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut).build()) .build(); @@ -103,7 +102,7 @@ public class TestSameDiffDense extends BaseDL4JTest { for (Activation a : afns) { log.info("Starting test - " + a + ", workspace = " + wsm); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .inferenceWorkspaceMode(wsm) .trainingWorkspaceMode(wsm) .list() @@ -115,9 +114,9 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertNotNull(net.paramTable()); + assertNotNull(net.getParamTable()); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .list() .layer(new DenseLayer.Builder().activation(a).nIn(nIn).nOut(nOut).build()) .build(); @@ -129,8 +128,8 @@ public class TestSameDiffDense extends BaseDL4JTest { //Check params: assertEquals(net2.params(), net.params()); - Map params1 = net.paramTable(); - Map params2 = net2.paramTable(); + Map params1 = net.getParamTable(); + Map params2 = net2.getParamTable(); assertEquals(params2, params1); INDArray in = Nd4j.rand(minibatch, nIn); @@ -176,7 +175,7 @@ public class TestSameDiffDense extends BaseDL4JTest { for (Activation a : afns) { log.info("Starting test - " + a + " - workspace=" + wsm); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .list() .layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut) @@ -194,9 +193,9 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertNotNull(net.paramTable()); + assertNotNull(net.getParamTable()); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .list() @@ -214,8 +213,8 @@ public class TestSameDiffDense extends BaseDL4JTest { //Check params: assertEquals(net2.params(), net.params()); - Map params1 = net.paramTable(); - Map params2 = net2.paramTable(); + Map params1 = net.getParamTable(); + Map params2 = net2.getParamTable(); assertEquals(params2, params1); INDArray in = Nd4j.rand(minibatch, nIn); @@ -264,7 +263,7 @@ public class TestSameDiffDense extends BaseDL4JTest { for (Activation a : afns) { log.info("Starting test - " + a + " - minibatch " + minibatch + ", workspaces: " + workspaces); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .list() @@ -278,7 +277,7 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork netSD = new MultiLayerNetwork(conf); netSD.init(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .list() .layer(new DenseLayer.Builder().activation(a).nIn(nIn).nOut(nOut).build()) .layer(new OutputLayer.Builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX) @@ -292,7 +291,7 @@ public class TestSameDiffDense extends BaseDL4JTest { //Check params: assertEquals(netStandard.params(), netSD.params()); - assertEquals(netStandard.paramTable(), netSD.paramTable()); + assertEquals(netStandard.getParamTable(), netSD.getParamTable()); INDArray in = Nd4j.rand(minibatch, nIn); INDArray l = TestUtils.randomOneHot(minibatch, nOut, 12345); @@ -352,7 +351,7 @@ public class TestSameDiffDense extends BaseDL4JTest { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .trainingWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm) @@ -367,7 +366,7 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork netSD = new MultiLayerNetwork(conf); netSD.init(); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(0.1)) .list() @@ -384,7 +383,7 @@ public class TestSameDiffDense extends BaseDL4JTest { //Check params: assertEquals(netStandard.params(), netSD.params()); - assertEquals(netStandard.paramTable(), netSD.paramTable()); + assertEquals(netStandard.getParamTable(), netSD.getParamTable()); DataSetIterator iter = new IrisDataSetIterator(150, 150); DataSet ds = iter.next(); @@ -422,7 +421,7 @@ public class TestSameDiffDense extends BaseDL4JTest { String msg = "workspaces: " + workspaces + ", " + a; Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .seed(12345) .updater(new NoOp()) @@ -433,7 +432,7 @@ public class TestSameDiffDense extends BaseDL4JTest { .layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut).activation(a).build()) .layer(new OutputLayer.Builder().nIn(nOut).nOut(nOut).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - //.setInputType(InputType.feedForward(nIn)) //TODO + //.inputType(InputType.feedForward(nIn)) //TODO .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java index 630ec1231..f70c4de92 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java @@ -65,7 +65,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { for (Activation a : afns) { log.info("Starting test - " + a + " - minibatch " + minibatch + ", workspaces: " + workspaces); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) @@ -82,7 +82,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { ComputationGraph netSD = new ComputationGraph(conf); netSD.init(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java index 4afbc7e37..8da331f8e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java @@ -62,7 +62,7 @@ public class TestSameDiffLambda extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .trainingWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm) .seed(12345) @@ -77,7 +77,7 @@ public class TestSameDiffLambda extends BaseDL4JTest { .build(); //Equavalent, not using SameDiff Lambda: - ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration confStd = NeuralNetConfiguration.builder() .trainingWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm) .seed(12345) @@ -143,7 +143,7 @@ public class TestSameDiffLambda extends BaseDL4JTest { log.info("--- Workspace Mode: {} ---", wsm); Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .trainingWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm) .dataType(DataType.DOUBLE) @@ -160,7 +160,7 @@ public class TestSameDiffLambda extends BaseDL4JTest { .build(); //Equavalent, not using SameDiff Lambda: - ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration confStd = NeuralNetConfiguration.builder() .trainingWorkspaceMode(wsm) .inferenceWorkspaceMode(wsm) .dataType(DataType.DOUBLE) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java index 2f0479b67..8ff1d6bc9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.layers.samediff; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.LossLayer; @@ -48,7 +47,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { public void testOutputMSELossLayer(){ Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration confSD = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confSD = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(0.01)) .list() @@ -56,7 +55,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { .layer(new SameDiffMSELossLayer()) .build(); - MultiLayerConfiguration confStd = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confStd = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(0.01)) .list() @@ -110,7 +109,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { for(Activation a : new Activation[]{Activation.IDENTITY, Activation.TANH, Activation.SOFTMAX}) { log.info("Starting test: " + a); - MultiLayerConfiguration confSD = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confSD = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(0.01)) .list() @@ -118,7 +117,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { .layer(new SameDiffMSEOutputLayer(5, 5, a, WeightInit.XAVIER)) .build(); - MultiLayerConfiguration confStd = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration confStd = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(0.01)) .list() @@ -134,7 +133,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { netSD.params().assign(netStd.params()); - assertEquals(netStd.paramTable(), netSD.paramTable()); + assertEquals(netStd.getParamTable(), netSD.getParamTable()); int minibatch = 2; INDArray in = Nd4j.rand(minibatch, 5); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java index 8864448b0..bc0677adb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/MinimalSameDiffDense.java @@ -94,5 +94,5 @@ public class MinimalSameDiffDense extends SameDiffLayer { //OPTIONAL methods: // public void setNIn(InputType inputType, boolean override) // public InputPreProcessor getPreProcessorForInputType(InputType inputType) -// public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) +// public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java index 0049696de..6fe2cf15e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java @@ -154,12 +154,13 @@ public class SameDiffConv extends SameDiffLayer { } @Override - public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) { + public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { + NeuralNetConfiguration clone = globalConfig.clone().build(); if (activation == null) { - activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn()); } if (cm == null) { - cm = globalConfig.getConvolutionMode(); + cm = clone.getConvolutionMode(); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java index 3595282c0..d0a176d63 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java @@ -116,9 +116,10 @@ public class SameDiffDense extends SameDiffLayer { } @Override - public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) { + public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { + NeuralNetConfiguration clone = globalConfig.clone().build(); if(activation == null){ - activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java index 41d149b3b..a93db0e56 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffMSEOutputLayer.java @@ -85,7 +85,7 @@ public class SameDiffMSEOutputLayer extends SameDiffOutputLayer { } @Override - public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig){ + public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig){ } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java index f535c81fa..3da4abed5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java @@ -21,10 +21,10 @@ package org.deeplearning4j.nn.layers.variational; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.layers.variational.*; @@ -56,16 +56,16 @@ public class TestVAE extends BaseDL4JTest { @Test public void testInitialization() { - MultiLayerConfiguration mlc = - new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration mlc = + NeuralNetConfiguration.builder() .layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .nIn(10).nOut(5).encoderLayerSizes(12).decoderLayerSizes(13) .build()) .build(); - NeuralNetConfiguration c = mlc.getConf(0); + LayerConfiguration c = mlc.getFirstLayer(); org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder vae = - (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c.getLayer(); + (VariationalAutoencoder) c; long allParams = vae.initializer().numParams(c); @@ -94,14 +94,14 @@ public class TestVAE extends BaseDL4JTest { int[][] encLayerSizes = new int[][] {{12}, {12, 13}, {12, 13, 14}}; for (int i = 0; i < encLayerSizes.length; i++) { - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().list().layer(0, + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list().layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder().nIn(10) .nOut(5).encoderLayerSizes(encLayerSizes[i]).decoderLayerSizes(13).build()) .build(); - NeuralNetConfiguration c = mlc.getConf(0); + LayerConfiguration c = mlc.getConf(0); org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder vae = - (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c.getLayer(); + (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c; MultiLayerNetwork net = new MultiLayerNetwork(mlc); net.init(); @@ -120,14 +120,14 @@ public class TestVAE extends BaseDL4JTest { int inputSize = 3; - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list() .layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .nIn(inputSize).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build()) .build(); - NeuralNetConfiguration c = mlc.getConf(0); + LayerConfiguration c = mlc.getConf(0); org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder vae = - (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c.getLayer(); + (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c; long allParams = vae.initializer().numParams(c); @@ -158,14 +158,14 @@ public class TestVAE extends BaseDL4JTest { @Test public void testParamGradientOrderAndViews() { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list() .layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build()) .build(); - NeuralNetConfiguration c = mlc.getConf(0); + LayerConfiguration c = mlc.getConf(0); org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder vae = - (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c.getLayer(); + (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c; MultiLayerNetwork net = new MultiLayerNetwork(mlc); net.init(); @@ -216,16 +216,16 @@ public class TestVAE extends BaseDL4JTest { //Idea: pretrain-specific parameters shouldn't change during backprop Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().seed(12345).list() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().seed(12345).list() .layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .nIn(10).nOut(5).encoderLayerSizes(12, 13).decoderLayerSizes(14, 15).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(5).nOut(6) .activation(new ActivationTanH()).build()) .build(); - NeuralNetConfiguration c = mlc.getConf(0); + LayerConfiguration c = mlc.getConf(0); org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder vae = - (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c.getLayer(); + (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) c; MultiLayerNetwork net = new MultiLayerNetwork(mlc); net.init(); @@ -268,7 +268,7 @@ public class TestVAE extends BaseDL4JTest { @Test public void testJsonYaml() { - MultiLayerConfiguration config = new NeuralNetConfiguration.Builder().seed(12345).list() + NeuralNetConfiguration config = NeuralNetConfiguration.builder().seed(12345).list() .layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .reconstructionDistribution(new GaussianReconstructionDistribution(Activation.IDENTITY)) .nIn(3).nOut(4).encoderLayerSizes(5).decoderLayerSizes(6).build()) @@ -299,8 +299,8 @@ public class TestVAE extends BaseDL4JTest { String asJson = config.toJson(); String asYaml = config.toYaml(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(asJson); - MultiLayerConfiguration fromYaml = MultiLayerConfiguration.fromYaml(asYaml); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(asJson); + NeuralNetConfiguration fromYaml = NeuralNetConfiguration.fromYaml(asYaml); assertEquals(config, fromJson); assertEquals(config, fromYaml); @@ -350,7 +350,7 @@ public class TestVAE extends BaseDL4JTest { throw new RuntimeException(); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.3) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l2(0.2).l1(0.3) .updater(new Sgd(1.0)) .seed(12345L).dist(new NormalDistribution(0, 1)) .list().layer(0, @@ -416,7 +416,7 @@ public class TestVAE extends BaseDL4JTest { for (int i = 0; i < reconstructionDistributions.length; i++) { INDArray data = Nd4j.rand(minibatch, inOutSize).muli(2).subi(1); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l2(0.2).l1(0.3) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l2(0.2).l1(0.3) .updater(new Sgd(1.0)) .seed(12345L).dist(new NormalDistribution(0, 1)) .list().layer(0, @@ -456,7 +456,7 @@ public class TestVAE extends BaseDL4JTest { for(boolean ws : new boolean[]{false, true}) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345L) .trainingWorkspaceMode(ws ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) .inferenceWorkspaceMode(ws ? WorkspaceMode.ENABLED : WorkspaceMode.NONE) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java index 175292211..0b8b1877d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.misc; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Updater; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -40,15 +39,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class CloseNetworkTests extends BaseDL4JTest { public static MultiLayerNetwork getTestNet() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(1e-3)) - .list() + .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(3, 3).activation(Activation.TANH).build()) .layer(new BatchNormalization.Builder().nOut(5).build()) .layer(new SubsamplingLayer.Builder().build()) .layer(new DenseLayer.Builder().nOut(10).activation(Activation.RELU).build()) .layer(new OutputLayer.Builder().nOut(10).build()) - .setInputType(InputType.convolutional(28, 28, 1)) + .inputType(InputType.convolutional(28, 28, 1)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java index 44d1a2098..09dfb45ea 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.misc; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -31,7 +30,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -50,8 +48,8 @@ public class LargeNetTest extends BaseDL4JTest { //More than 2.1 billion parameters //10M classes plus 300 vector size -> 3 billion elements - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .layer(new EmbeddingLayer.Builder().nIn(10_000_000).nOut(300).build()) .layer(new OutputLayer.Builder().nIn(300).nOut(10).activation(Activation.SOFTMAX).build()) .build(); @@ -82,7 +80,7 @@ public class LargeNetTest extends BaseDL4JTest { //More than 2.1 billion parameters //10M classes plus 300 vector size -> 3 billion elements - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("0", new EmbeddingLayer.Builder().nIn(10_000_000).nOut(300).build(), "in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java index 77f3a2342..f6ddd312c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java @@ -22,11 +22,9 @@ package org.deeplearning4j.nn.misc; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; -import org.deeplearning4j.nn.conf.weightnoise.DropConnect; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; @@ -49,10 +47,10 @@ public class TestLrChanges extends BaseDL4JTest { @Test public void testChangeLrMLN(){ //First: Set LR for a *single* layer and compare vs. equivalent net config - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) - .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.1)).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.01)).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build()) @@ -66,10 +64,10 @@ public class TestLrChanges extends BaseDL4JTest { } - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) - .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.5)).build()) //0.5 LR .layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.01)).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build()) @@ -116,10 +114,10 @@ public class TestLrChanges extends BaseDL4JTest { //Now: Set *all* LRs to say 0.3... - MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf3 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) - .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new Adam(0.3)).build()) //0.5 LR .layer(new DenseLayer.Builder().nIn(10).nOut(10).updater(new RmsProp(0.3)).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).updater(new NoOp()).lossFunction(LossFunctions.LossFunction.MSE).build()) @@ -148,11 +146,11 @@ public class TestLrChanges extends BaseDL4JTest { @Test public void testChangeLSGD() { //Simple test for no updater nets - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .updater(new Sgd(0.1)) - .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build()) @@ -177,11 +175,11 @@ public class TestLrChanges extends BaseDL4JTest { @Test public void testChangeLrMLNSchedule(){ //First: Set LR for a *single* layer and compare vs. equivalent net config - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .updater(new Adam(0.1)) - .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()) @@ -195,11 +193,11 @@ public class TestLrChanges extends BaseDL4JTest { } - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .updater(new Adam(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 ))) - .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()) @@ -239,7 +237,7 @@ public class TestLrChanges extends BaseDL4JTest { @Test public void testChangeLrCompGraph(){ //First: Set LR for a *single* layer and compare vs. equivalent net config - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .graphBuilder() @@ -258,7 +256,7 @@ public class TestLrChanges extends BaseDL4JTest { } - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .graphBuilder() @@ -310,7 +308,7 @@ public class TestLrChanges extends BaseDL4JTest { //Now: Set *all* LRs to say 0.3... - MultiLayerConfiguration conf3 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf3 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .list() @@ -342,7 +340,7 @@ public class TestLrChanges extends BaseDL4JTest { @Test public void testChangeLrCompGraphSchedule(){ //First: Set LR for a *single* layer and compare vs. equivalent net config - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .updater(new Adam(0.1)) @@ -362,7 +360,7 @@ public class TestLrChanges extends BaseDL4JTest { } - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .updater(new Adam(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 ))) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestMemoryReports.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestMemoryReports.java index a7fcee172..b22bfec2f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestMemoryReports.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestMemoryReports.java @@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.*; import org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex; @@ -53,8 +52,8 @@ import static org.junit.jupiter.api.Assertions.assertEquals; public class TestMemoryReports extends BaseDL4JTest { - public static List> getTestLayers() { - List> l = new ArrayList<>(); + public static List> getTestLayers() { + List> l = new ArrayList<>(); l.add(new Pair<>(new ActivationLayer.Builder().activation(Activation.TANH).build(), InputType.feedForward(20))); l.add(new Pair<>(new DenseLayer.Builder().nIn(20).nOut(20).build(), InputType.feedForward(20))); l.add(new Pair<>(new DropoutLayer.Builder().nIn(20).nOut(20).build(), InputType.feedForward(20))); @@ -100,12 +99,12 @@ public class TestMemoryReports extends BaseDL4JTest { @Test public void testMemoryReportSimple() { - List> l = getTestLayers(); + List> l = getTestLayers(); - for (Pair p : l) { + for (Pair p : l) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list().layer(0, p.getFirst().clone()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(0, p.getFirst().clone()) .layer(1, p.getFirst().clone()).validateOutputLayerConfig(false).build(); MemoryReport mr = conf.getMemoryReport(p.getSecond()); @@ -128,12 +127,12 @@ public class TestMemoryReports extends BaseDL4JTest { @Test public void testMemoryReportSimpleCG() { - List> l = getTestLayers(); + List> l = getTestLayers(); - for (Pair p : l) { + for (Pair p : l) { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("0", p.getFirst().clone(), "in").addLayer("1", p.getFirst().clone(), "0") .setOutputs("1").validateOutputLayerConfig(false).build(); @@ -168,7 +167,7 @@ public class TestMemoryReports extends BaseDL4JTest { layerInputs = new String[] {"1"}; } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs(inputs) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs(inputs) .allowDisconnected(true) .addVertex("gv", p.getFirst(), layerInputs).setOutputs("gv").build(); @@ -216,7 +215,7 @@ public class TestMemoryReports extends BaseDL4JTest { @Test public void validateSimple() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(0, new DenseLayer.Builder().nIn(10).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(27).build()).build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java index fc8312630..fdfb16fcd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.misc; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -91,17 +90,17 @@ public class TestNetConversion extends BaseDL4JTest { private MultiLayerNetwork getNet1(boolean train) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .convolutionMode(ConvolutionMode.Same) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .updater(new Sgd(0.1)) - .list() + .layer(new ConvolutionLayer.Builder().nIn(3).nOut(5).kernelSize(2, 2).stride(1, 1).build()) .layer(new SubsamplingLayer.Builder().kernelSize(2, 2).stride(1, 1).build()) .layer(new DenseLayer.Builder().nOut(32).build()) .layer(new OutputLayer.Builder().nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build()) - .setInputType(InputType.convolutional(10, 10, 3)) + .inputType(InputType.convolutional(10, 10, 3)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -121,16 +120,16 @@ public class TestNetConversion extends BaseDL4JTest { private MultiLayerNetwork getNet2() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .convolutionMode(ConvolutionMode.Same) .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .updater(new Sgd(0.1)) - .list() + .layer(new GravesLSTM.Builder().nOut(8).build()) .layer(new LSTM.Builder().nOut(8).build()) .layer(new RnnOutputLayer.Builder().nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build()) - .setInputType(InputType.recurrent(5)) + .inputType(InputType.recurrent(5)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java index 5b00685af..9649adffd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java @@ -93,24 +93,24 @@ public class WorkspaceTests extends BaseDL4JTest { int depthOut = 2; int nOut = 2; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) - .convolutionMode(ConvolutionMode.Same).seed(12345L).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) + .convolutionMode(ConvolutionMode.Same).seed(12345L) .layer(0, new ConvolutionLayer.Builder().nIn(depthIn).nOut(depthOut).kernelSize(2, 2) .stride(1, 1).activation(Activation.TANH).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(nOut).build()) - .setInputType(InputType.convolutional(5, 5, 2)) + .inputType(InputType.convolutional(5, 5, 2)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf.clone()); net.init(); - net.getLayerWiseConfigurations().setInferenceWorkspaceMode(WorkspaceMode.ENABLED); - net.getLayerWiseConfigurations().setTrainingWorkspaceMode(WorkspaceMode.ENABLED); + net.getConfiguration().setInferenceWorkspaceMode(WorkspaceMode.ENABLED); + net.getConfiguration().setTrainingWorkspaceMode(WorkspaceMode.ENABLED); MultiLayerNetwork net2 = new MultiLayerNetwork(conf.clone()); net2.init(); - net2.getLayerWiseConfigurations().setInferenceWorkspaceMode(WorkspaceMode.NONE); - net2.getLayerWiseConfigurations().setTrainingWorkspaceMode(WorkspaceMode.NONE); + net2.getConfiguration().setInferenceWorkspaceMode(WorkspaceMode.NONE); + net2.getConfiguration().setTrainingWorkspaceMode(WorkspaceMode.NONE); INDArray in = Nd4j.rand(1, 2, 5, 5); @@ -120,7 +120,7 @@ public class WorkspaceTests extends BaseDL4JTest { public static ComputationGraph createNet() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .addLayer("0", new ConvolutionLayer.Builder().nOut(3) @@ -149,7 +149,7 @@ public class WorkspaceTests extends BaseDL4JTest { for (WorkspaceMode wm : WorkspaceMode.values()) { System.out.println(wm); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .trainingWorkspaceMode(wm) .inferenceWorkspaceMode(wm) .graphBuilder() @@ -184,15 +184,15 @@ public class WorkspaceTests extends BaseDL4JTest { public void testWithPreprocessorsMLN() { for (WorkspaceMode wm : WorkspaceMode.values()) { System.out.println(wm); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .trainingWorkspaceMode(wm) .inferenceWorkspaceMode(wm) - .list() + .layer(new GravesLSTM.Builder().nIn(10).nOut(5).build()) .layer(new GravesLSTM.Builder().nIn(5).nOut(8).build()) .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.SIGMOID).nOut(3).build()) .inputPreProcessor(0, new DupPreProcessor()) - .setInputType(InputType.recurrent(10)) + .inputType(InputType.recurrent(10)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -247,14 +247,14 @@ public class WorkspaceTests extends BaseDL4JTest { System.out.println("Starting test: " + ws + " - " + i); - NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder b = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) .inferenceWorkspaceMode(ws) .trainingWorkspaceMode(ws) .list(); - ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder gb = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) .inferenceWorkspaceMode(ws) @@ -292,7 +292,7 @@ public class WorkspaceTests extends BaseDL4JTest { gb.addLayer("out", new RnnOutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1"); gb.setOutputs("out"); - MultiLayerConfiguration conf = b.build(); + NeuralNetConfiguration conf = b.build(); ComputationGraphConfiguration conf2 = gb.build(); @@ -320,14 +320,14 @@ public class WorkspaceTests extends BaseDL4JTest { System.out.println("Starting test: " + ws + " - " + i); - NeuralNetConfiguration.ListBuilder b = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.NeuralNetConfigurationBuilder b = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) .inferenceWorkspaceMode(ws) .trainingWorkspaceMode(ws) .list(); - ComputationGraphConfiguration.GraphBuilder gb = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration.GraphBuilder gb = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) .inferenceWorkspaceMode(ws) @@ -366,14 +366,14 @@ public class WorkspaceTests extends BaseDL4JTest { .nIn(10).nOut(10).build(), "1"); gb.setOutputs("out"); - MultiLayerConfiguration conf = b + NeuralNetConfiguration conf = b .backpropType(BackpropType.TruncatedBPTT) - .tBPTTLength(5) + .tbpttBackLength(5).tbpttFwdLength(5) .build(); ComputationGraphConfiguration conf2 = gb .backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(5).tBPTTBackwardLength(5) + .tbpttFwdLength(5).tbpttBackLength(5) .build(); @@ -400,7 +400,7 @@ public class WorkspaceTests extends BaseDL4JTest { log.info("WorkspaceMode = " + ws); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .seed(12345) .trainingWorkspaceMode(ws).inferenceWorkspaceMode(ws) @@ -429,7 +429,7 @@ public class WorkspaceTests extends BaseDL4JTest { public void testWorkspaceSetting() { for (WorkspaceMode wsm : WorkspaceMode.values()) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .seed(12345) .trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm) @@ -441,7 +441,7 @@ public class WorkspaceTests extends BaseDL4JTest { assertEquals(wsm, conf.getInferenceWorkspaceMode()); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() .weightInit(WeightInit.XAVIER) .seed(12345) .trainingWorkspaceMode(wsm).inferenceWorkspaceMode(wsm) @@ -458,7 +458,7 @@ public class WorkspaceTests extends BaseDL4JTest { @Test public void testClearing() { for(WorkspaceMode wsm : WorkspaceMode.values()) { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration config = NeuralNetConfiguration.builder() .updater(new Adam()) .inferenceWorkspaceMode(wsm) .trainingWorkspaceMode(wsm) @@ -501,7 +501,7 @@ public class WorkspaceTests extends BaseDL4JTest { MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(conf, wsName); - MultiLayerConfiguration netConf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration netConf = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .list() @@ -556,7 +556,7 @@ public class WorkspaceTests extends BaseDL4JTest { final INDArray input = Nd4j.rand(1, 30); - final ComputationGraphConfiguration computationGraphConfiguration = new NeuralNetConfiguration.Builder() + final ComputationGraphConfiguration computationGraphConfiguration = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("state") .addLayer("value_output", new OutputLayer.Builder().nIn(30).nOut(1).activation(Activation.IDENTITY) @@ -578,7 +578,7 @@ public class WorkspaceTests extends BaseDL4JTest { INDArray input = Nd4j.rand(1, 30); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new OutputLayer.Builder().nIn(30).nOut(1).activation(Activation.IDENTITY).lossFunction(LossFunctions.LossFunction.MSE).build()) .build(); @@ -607,13 +607,13 @@ public class WorkspaceTests extends BaseDL4JTest { - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .convolutionMode(ConvolutionMode.Same).seed(12345L).list() .layer(0, new ConvolutionLayer.Builder().nIn(1).nOut(2).kernelSize(2, 2) .stride(1, 1).activation(Activation.TANH).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nOut(10).build()) - .setInputType(InputType.convolutional(5, 5, 1)) + .inputType(InputType.convolutional(5, 5, 1)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(mlc); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java index 695fdb70d..ca9c0f67c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/mkldnn/ValidateMKLDNN.java @@ -25,7 +25,6 @@ import org.deeplearning4j.LayerHelperValidationUtil; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; @@ -74,7 +73,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { INDArray f = Nd4j.rand(DataType.FLOAT, inputSize); INDArray l = TestUtils.randomOneHot(minibatch, 10).castTo(DataType.FLOAT); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(0.01)) .convolutionMode(cm) .seed(12345) @@ -98,7 +97,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { .nOut(3) .build()) .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutional(inputSize[2], inputSize[3], inputSize[1])) + .inputType(InputType.convolutional(inputSize[2], inputSize[3], inputSize[1])) .build(); MultiLayerNetwork netWith = new MultiLayerNetwork(conf.clone()); @@ -149,7 +148,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { INDArray f = Nd4j.rand(Nd4j.defaultFloatingPointType(), inputSize); INDArray l = TestUtils.randomOneHot(minibatch, 10); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .updater(new Adam(0.01)) .convolutionMode(cm) @@ -169,7 +168,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { .nOut(3) .build()) .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutional(inputSize[2], inputSize[3], inputSize[1])) + .inputType(InputType.convolutional(inputSize[2], inputSize[3], inputSize[1])) .build(); MultiLayerNetwork netWith = new MultiLayerNetwork(conf.clone()); @@ -223,7 +222,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { INDArray f = Nd4j.rand(Nd4j.defaultFloatingPointType(), inputSize); INDArray l = TestUtils.randomOneHot(minibatch, 10).castTo(DataType.FLOAT); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(0.01)) .convolutionMode(cm) .weightInit(new NormalDistribution(0,1)) @@ -242,7 +241,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { .k(k[i]) .cudnnAllowFallback(false).build()) .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.convolutional(inputSize[2], inputSize[3], inputSize[1])) + .inputType(InputType.convolutional(inputSize[2], inputSize[3], inputSize[1])) .build(); MultiLayerNetwork netWith = new MultiLayerNetwork(conf.clone()); @@ -292,7 +291,7 @@ public class ValidateMKLDNN extends BaseDL4JTest { INDArray dLdb = beta.ulike(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .inferenceWorkspaceMode(WorkspaceMode.NONE) .trainingWorkspaceMode(WorkspaceMode.NONE) .list() diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java index 94f26b712..27efa9149 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.multilayer; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -33,11 +32,9 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.exception.ND4JArraySizeException; @@ -69,7 +66,7 @@ public class BackPropMLPTest extends BaseDL4JTest { @Test public void testMLP() { //Simple mini-batch test with multiple hidden layers - MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 4, 3}, Activation.SIGMOID); + NeuralNetConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 4, 3}, Activation.SIGMOID); // System.out.println(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -83,7 +80,7 @@ public class BackPropMLPTest extends BaseDL4JTest { @Test public void testMLP2() { //Simple mini-batch test with multiple hidden layers - MultiLayerConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 15, 3}, Activation.TANH); + NeuralNetConfiguration conf = getIrisMLPSimpleConfig(new int[] {5, 15, 3}, Activation.TANH); // System.out.println(conf); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); @@ -322,9 +319,9 @@ public class BackPropMLPTest extends BaseDL4JTest { * Learning Rate = 0.1 * No regularization, no Adagrad, no momentum etc. One iteration. */ - private static MultiLayerConfiguration getIrisMLPSimpleConfig(int[] hiddenLayerSizes, + private static NeuralNetConfiguration getIrisMLPSimpleConfig(int[] hiddenLayerSizes, Activation activationFunction) { - NeuralNetConfiguration.ListBuilder lb = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder lb = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .seed(12345L).list(); for (int i = 0; i < hiddenLayerSizes.length; i++) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index 49d70647c..cad0cfd50 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -20,9 +20,31 @@ package org.deeplearning4j.nn.multilayer; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator; @@ -31,12 +53,29 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.exception.DL4JException; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.Updater; +import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.conf.layers.AutoEncoder; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BatchNormalization; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; +import org.deeplearning4j.nn.conf.layers.GravesLSTM; +import org.deeplearning4j.nn.conf.layers.LSTM; +import org.deeplearning4j.nn.conf.layers.LossLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; +import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; @@ -58,6 +97,7 @@ import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -78,356 +118,349 @@ import org.nd4j.linalg.learning.config.Adam; import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import org.nd4j.common.primitives.Pair; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.util.*; - -import static org.junit.jupiter.api.Assertions.*; @Slf4j public class MultiLayerTest extends BaseDL4JTest { - private static OpExecutioner.ProfilingMode origMode; + private static OpExecutioner.ProfilingMode origMode; - @BeforeAll - public static void beforeClass(){ - origMode = Nd4j.getExecutioner().getProfilingMode(); + @BeforeAll + public static void beforeClass() { + origMode = Nd4j.getExecutioner().getProfilingMode(); + } + + @AfterAll + public static void afterClass() { + Nd4j.getExecutioner().setProfilingMode(origMode); + } + + private static NeuralNetConfiguration getConf() { + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345L) + .list().layer(0, + new DenseLayer.Builder().nIn(4).nOut(3) + + .dist(new NormalDistribution(0, 1)) + .build()) + .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(3).nOut(3) + + .dist(new NormalDistribution(0, 1)).build()) + .build(); + return conf; + } + + public static float[] asFloat(INDArray arr) { + long len = arr.length(); + + float[] f = new float[(int) len]; + for (int i = 0; i < len; i++) { + f[i] = arr.getFloat(i); + } + return f; + } + + @BeforeEach + public void before() { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + } + + @Override + public DataType getDataType() { + return DataType.FLOAT; + } + + @Test + public void testSetParams() { + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() + .list().layer(0, + new DenseLayer.Builder().nIn(4).nOut(3) + .activation(Activation.TANH).build()) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) + .build(); + + MultiLayerNetwork network3 = new MultiLayerNetwork(conf); + network3.init(); + + INDArray params = network3.params(); + INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); + INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup(); + network3.setParameters(params); + assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY)); + assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY)); + INDArray params4 = network3.params(); + assertEquals(params, params4); + } + + @Test + public void testBatchNorm() { + Nd4j.getRandom().setSeed(123); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list() + .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(2, new BatchNormalization.Builder().nOut(2).build()) + .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) + .activation(Activation.SOFTMAX).nIn(2).nOut(3).build()) + .build(); + + MultiLayerNetwork network = new MultiLayerNetwork(conf); + network.init(); + network.setListeners(new ScoreIterationListener(1)); + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + + DataSet next = iter.next(); + next.normalizeZeroMeanZeroUnitVariance(); + SplitTestAndTrain trainTest = next.splitTestAndTrain(110); + network.setLabels(trainTest.getTrain().getLabels()); + network.init(); + for (int i = 0; i < 5; i++) { + network.fit(trainTest.getTrain()); } - @BeforeEach - public void before(){ - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + } + + @Test + public void testBackProp() { + Nd4j.getRandom().setSeed(123); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list() + .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) + .activation(Activation.SOFTMAX).nIn(2).nOut(3).build()) + .build(); + + MultiLayerNetwork network = new MultiLayerNetwork(conf); + network.init(); + network.setListeners(new ScoreIterationListener(1)); + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + + DataSet next = iter.next(); + next.normalizeZeroMeanZeroUnitVariance(); + SplitTestAndTrain trainTest = next.splitTestAndTrain(110); + network.setInput(trainTest.getTrain().getFeatures()); + network.setLabels(trainTest.getTrain().getLabels()); + network.init(); + for (int i = 0; i < 5; i++) { + network.fit(trainTest.getTrain()); } - @AfterAll - public static void afterClass(){ - Nd4j.getExecutioner().setProfilingMode(origMode); + DataSet test = trainTest.getTest(); + Evaluation eval = new Evaluation(); + INDArray output = network.output(test.getFeatures()); + eval.eval(test.getLabels(), output); + log.info("Score " + eval.stats()); + } + + @Test + public void testGradientWithAsList() { + MultiLayerNetwork net1 = new MultiLayerNetwork(getConf()); + MultiLayerNetwork net2 = new MultiLayerNetwork(getConf()); + net1.init(); + net2.init(); + + DataSet x1 = new IrisDataSetIterator(1, 150).next(); + DataSet all = new IrisDataSetIterator(150, 150).next(); + DataSet x2 = all.asList().get(0); + + //x1 and x2 contain identical data + assertArrayEquals(asFloat(x1.getFeatures()), asFloat(x2.getFeatures()), 0.0f); + assertArrayEquals(asFloat(x1.getLabels()), asFloat(x2.getLabels()), 0.0f); + assertEquals(x1, x2); + + //Set inputs/outputs so gradient can be calculated: + net1.feedForward(x1.getFeatures()); + net2.feedForward(x2.getFeatures()); + ((BaseOutputLayer) net1.getLayer(1)).setLabels(x1.getLabels()); + ((BaseOutputLayer) net2.getLayer(1)).setLabels(x2.getLabels()); + + net1.gradient(); + net2.gradient(); + } + + /** + * This test intended only to test activateSelectedLayers method, it does not involves + * fully-working AutoEncoder. + */ + @Test + public void testSelectedActivations() { + // Train DeepAutoEncoder on very limited trainset + final int numRows = 28; + final int numColumns = 28; + int seed = 123; + int numSamples = 3; + int iterations = 1; + int listenerFreq = iterations / 5; + + log.info("Load data...."); + + float[][] trainingData = new float[numSamples][numColumns * numRows]; + Arrays.fill(trainingData[0], 0.95f); + Arrays.fill(trainingData[1], 0.5f); + Arrays.fill(trainingData[2], 0.05f); + + log.info("Build model...."); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(seed) + .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() + .layer(0, new DenseLayer.Builder().nIn(numRows * numColumns).nOut(1000).build()) + .layer(1, new DenseLayer.Builder().nIn(1000).nOut(500).build()) + .layer(2, new DenseLayer.Builder().nIn(500).nOut(250).build()) + .layer(3, new DenseLayer.Builder().nIn(250).nOut(100).build()) + .layer(4, new DenseLayer.Builder().nIn(100).nOut(30).build()) //encoding stops + .layer(5, new DenseLayer.Builder().nIn(30).nOut(100).build()) //decoding starts + .layer(6, new DenseLayer.Builder().nIn(100).nOut(250).build()) + .layer(7, new DenseLayer.Builder().nIn(250).nOut(500).build()) + .layer(8, new DenseLayer.Builder().nIn(500).nOut(1000).build()) + .layer(9, + new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000) + .nOut(numRows * numColumns).activation(Activation.SOFTMAX).build()) + .build(); + + MultiLayerNetwork model = new MultiLayerNetwork(conf); + model.init(); + + model.addListeners(new ScoreIterationListener(listenerFreq)); + + log.info("Train model...."); + int cnt = 0; + while (cnt < numSamples) { + INDArray input = Nd4j.create(trainingData[cnt]).reshape(1, -1); + model.fit(new DataSet(input, input)); + cnt++; } + // Make two separate selective calls - @Override - public DataType getDataType(){ - return DataType.FLOAT; + log.info("Testing full cycle..."); + + List comparableResult = model.feedForward( + Nd4j.create(trainingData[0], 1, trainingData[0].length)); + + INDArray encodeResult = model.activateSelectedLayers(0, 4, + Nd4j.create(trainingData[0], 1, trainingData[0].length)); + + log.info("Compare feedForward results with selectedActivation"); + + assertEquals(comparableResult.get(5), encodeResult); + + INDArray decodeResults = model.activateSelectedLayers(5, 9, encodeResult); + + log.info("Decode results: " + decodeResults.columns() + " " + decodeResults); + log.info( + "Comparable results: " + comparableResult.get(10).columns() + " " + comparableResult.get( + 10)); + + assertEquals(comparableResult.get(10), decodeResults); + } + + @Test + public void testFeedForwardToLayer() { + + int nIn = 30; + int nOut = 25; + + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) + .updater(new Sgd(1e-3)) + .list().layer( + 0, new DenseLayer.Builder().nIn(nIn).nOut(600) + + .dist(new NormalDistribution(0, 1e-5)) + .build()) + .layer(1, new DenseLayer.Builder() + .nIn(600).nOut(250) + .dist(new NormalDistribution(0, 1e-5)) + .build()) + .layer(2, new DenseLayer.Builder() + .nIn(250).nOut(100) + .dist(new NormalDistribution(0, 1e-5)) + .build()) + .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MCXENT).nIn(100).nOut(25) + .activation(Activation.SOFTMAX) + .weightInit(new NormalDistribution(0, 1e-5)).build()) + .build(); + + MultiLayerNetwork network = new MultiLayerNetwork(conf); + network.init(); + + INDArray input = Nd4j.rand(5, nIn); + + List activations = network.feedForward(input); + assertEquals(5, activations.size()); //4 layers + input + + List activationsAll = network.feedForwardToLayer(3, input); + assertEquals(activations, activationsAll); + + for (int i = 3; i >= 0; i--) { + List activationsPartial = network.feedForwardToLayer(i, input); + assertEquals(i + 2, + activationsPartial.size()); //i+2: for layer 3: input + activations of {0,1,2,3} -> 5 total = 3+2 + for (int j = 0; j <= i; j++) { + INDArray exp = activationsAll.get(j); + INDArray act = activationsPartial.get(j); + assertEquals(exp, act); + } } + } - @Test - public void testSetParams() { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() - .list().layer(0, - new DenseLayer.Builder().nIn(4).nOut(3) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .build(); - MultiLayerNetwork network3 = new MultiLayerNetwork(conf); - network3.init(); + @Test + public void testBackpropGradient() { + //Testing: MultiLayerNetwork.backpropGradient() + //i.e., specifically without an output layer - INDArray params = network3.params(); - INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); - INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup(); - network3.setParameters(params); - assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY)); - assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY)); - INDArray params4 = network3.params(); - assertEquals(params, params4); + int nIn = 10; + int nOut = 40; + int miniBatch = 5; + + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .updater(new Sgd(0.1)).list() + .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build()) + .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build()) + .layer(2, new DenseLayer.Builder().nIn(30).nOut(nOut).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + Nd4j.getRandom().setSeed(12345); + INDArray eps = Nd4j.rand(miniBatch, nOut); + INDArray input = Nd4j.rand(miniBatch, nIn); + + net.setInput(input); + net.feedForward(true, false); //Need to feed forward before backprop + + Pair pair = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); + INDArray epsOut = pair.getSecond(); + assertNotNull(epsOut); + assertArrayEquals(new long[]{miniBatch, nIn}, epsOut.shape()); + + Gradient g = pair.getFirst(); + Map gradMap = g.gradientForVariable(); + assertEquals(6, gradMap.size()); //3 layers, weight + bias gradients for each + + String[] expKeys = {"0_" + DefaultParamInitializer.WEIGHT_KEY, + "0_" + DefaultParamInitializer.BIAS_KEY, + "1_" + DefaultParamInitializer.WEIGHT_KEY, "2_" + DefaultParamInitializer.BIAS_KEY, + "2_" + DefaultParamInitializer.WEIGHT_KEY, "2_" + DefaultParamInitializer.BIAS_KEY}; + Set keys = gradMap.keySet(); + for (String s : expKeys) { + assertTrue(keys.contains(s)); } - @Test - public void testBatchNorm() { - Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new BatchNormalization.Builder().nOut(2).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).nIn(2).nOut(3).build()) - .build(); - - - MultiLayerNetwork network = new MultiLayerNetwork(conf); - network.init(); - network.setListeners(new ScoreIterationListener(1)); - - DataSetIterator iter = new IrisDataSetIterator(150, 150); - - DataSet next = iter.next(); - next.normalizeZeroMeanZeroUnitVariance(); - SplitTestAndTrain trainTest = next.splitTestAndTrain(110); - network.setLabels(trainTest.getTrain().getLabels()); - network.init(); - for( int i=0; i<5; i++ ) { - network.fit(trainTest.getTrain()); - } - - } - - @Test - public void testBackProp() { - Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER) - .activation(Activation.SOFTMAX).nIn(2).nOut(3).build()) - .build(); - - - MultiLayerNetwork network = new MultiLayerNetwork(conf); - network.init(); - network.setListeners(new ScoreIterationListener(1)); - - DataSetIterator iter = new IrisDataSetIterator(150, 150); - - DataSet next = iter.next(); - next.normalizeZeroMeanZeroUnitVariance(); - SplitTestAndTrain trainTest = next.splitTestAndTrain(110); - network.setInput(trainTest.getTrain().getFeatures()); - network.setLabels(trainTest.getTrain().getLabels()); - network.init(); - for( int i=0; i<5; i++ ) { - network.fit(trainTest.getTrain()); - } - - DataSet test = trainTest.getTest(); - Evaluation eval = new Evaluation(); - INDArray output = network.output(test.getFeatures()); - eval.eval(test.getLabels(), output); - log.info("Score " + eval.stats()); - } - - - - @Test - public void testGradientWithAsList() { - MultiLayerNetwork net1 = new MultiLayerNetwork(getConf()); - MultiLayerNetwork net2 = new MultiLayerNetwork(getConf()); - net1.init(); - net2.init(); - - DataSet x1 = new IrisDataSetIterator(1, 150).next(); - DataSet all = new IrisDataSetIterator(150, 150).next(); - DataSet x2 = all.asList().get(0); - - //x1 and x2 contain identical data - assertArrayEquals(asFloat(x1.getFeatures()), asFloat(x2.getFeatures()), 0.0f); - assertArrayEquals(asFloat(x1.getLabels()), asFloat(x2.getLabels()), 0.0f); - assertEquals(x1, x2); - - //Set inputs/outputs so gradient can be calculated: - net1.feedForward(x1.getFeatures()); - net2.feedForward(x2.getFeatures()); - ((BaseOutputLayer) net1.getLayer(1)).setLabels(x1.getLabels()); - ((BaseOutputLayer) net2.getLayer(1)).setLabels(x2.getLabels()); - - net1.gradient(); - net2.gradient(); - } - - /** - * This test intended only to test activateSelectedLayers method, it does not involves fully-working AutoEncoder. - */ - @Test - public void testSelectedActivations() { - // Train DeepAutoEncoder on very limited trainset - final int numRows = 28; - final int numColumns = 28; - int seed = 123; - int numSamples = 3; - int iterations = 1; - int listenerFreq = iterations / 5; - - log.info("Load data...."); - - float[][] trainingData = new float[numSamples][numColumns * numRows]; - Arrays.fill(trainingData[0], 0.95f); - Arrays.fill(trainingData[1], 0.5f); - Arrays.fill(trainingData[2], 0.05f); - - - - log.info("Build model...."); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) - .optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).list() - .layer(0, new DenseLayer.Builder().nIn(numRows * numColumns).nOut(1000).build()) - .layer(1, new DenseLayer.Builder().nIn(1000).nOut(500).build()) - .layer(2, new DenseLayer.Builder().nIn(500).nOut(250).build()) - .layer(3, new DenseLayer.Builder().nIn(250).nOut(100).build()) - .layer(4, new DenseLayer.Builder().nIn(100).nOut(30).build()) //encoding stops - .layer(5, new DenseLayer.Builder().nIn(30).nOut(100).build()) //decoding starts - .layer(6, new DenseLayer.Builder().nIn(100).nOut(250).build()) - .layer(7, new DenseLayer.Builder().nIn(250).nOut(500).build()) - .layer(8, new DenseLayer.Builder().nIn(500).nOut(1000).build()) - .layer(9, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1000) - .nOut(numRows * numColumns).activation(Activation.SOFTMAX).build()) - .build(); - - MultiLayerNetwork model = new MultiLayerNetwork(conf); - model.init(); - - model.addListeners(new ScoreIterationListener(listenerFreq)); - - log.info("Train model...."); - int cnt = 0; - while (cnt < numSamples) { - INDArray input = Nd4j.create(trainingData[cnt]).reshape(1, -1); - model.fit(new DataSet(input, input)); - cnt++; - } - // Make two separate selective calls - - log.info("Testing full cycle..."); - - List comparableResult = model.feedForward(Nd4j.create(trainingData[0], 1, trainingData[0].length)); - - INDArray encodeResult = model.activateSelectedLayers(0, 4, Nd4j.create(trainingData[0], 1, trainingData[0].length)); - - log.info("Compare feedForward results with selectedActivation"); - - assertEquals(comparableResult.get(5), encodeResult); - - INDArray decodeResults = model.activateSelectedLayers(5, 9, encodeResult); - - - log.info("Decode results: " + decodeResults.columns() + " " + decodeResults); - log.info("Comparable results: " + comparableResult.get(10).columns() + " " + comparableResult.get(10)); - - assertEquals(comparableResult.get(10), decodeResults); - } - - private static MultiLayerConfiguration getConf() { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L) - .list().layer(0, - new DenseLayer.Builder().nIn(4).nOut(3) - - .dist(new NormalDistribution(0,1)) - .build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3) - - .dist(new NormalDistribution(0, 1)).build()) - .build(); - return conf; - } - - public static float[] asFloat(INDArray arr) { - long len = arr.length(); - - float[] f = new float[(int) len]; - for (int i = 0; i < len; i++) - f[i] = arr.getFloat(i); - return f; - } - - @Test - public void testFeedForwardToLayer() { - - int nIn = 30; - int nOut = 25; - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT) - .updater(new Sgd(1e-3)) - .list().layer( - 0, new DenseLayer.Builder().nIn(nIn).nOut(600) - - .dist(new NormalDistribution(0,1e-5)) - .build()) - .layer(1, new DenseLayer.Builder() - .nIn(600).nOut(250) - .dist(new NormalDistribution(0, 1e-5)) - .build()) - .layer(2, new DenseLayer.Builder() - .nIn(250).nOut(100) - .dist(new NormalDistribution(0, 1e-5)) - .build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).nIn(100).nOut(25) - .activation(Activation.SOFTMAX) - .weightInit(new NormalDistribution(0, 1e-5)).build()) - .build(); - - MultiLayerNetwork network = new MultiLayerNetwork(conf); - network.init(); - - - INDArray input = Nd4j.rand(5, nIn); - - List activations = network.feedForward(input); - assertEquals(5, activations.size()); //4 layers + input - - List activationsAll = network.feedForwardToLayer(3, input); - assertEquals(activations, activationsAll); - - for (int i = 3; i >= 0; i--) { - List activationsPartial = network.feedForwardToLayer(i, input); - assertEquals(i + 2, activationsPartial.size()); //i+2: for layer 3: input + activations of {0,1,2,3} -> 5 total = 3+2 - for (int j = 0; j <= i; j++) { - INDArray exp = activationsAll.get(j); - INDArray act = activationsPartial.get(j); - assertEquals(exp, act); - } - } - } - - - @Test - public void testBackpropGradient() { - //Testing: MultiLayerNetwork.backpropGradient() - //i.e., specifically without an output layer - - int nIn = 10; - int nOut = 40; - int miniBatch = 5; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Sgd(0.1)).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(2, new DenseLayer.Builder().nIn(30).nOut(nOut).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - Nd4j.getRandom().setSeed(12345); - INDArray eps = Nd4j.rand(miniBatch, nOut); - INDArray input = Nd4j.rand(miniBatch, nIn); - - net.setInput(input); - net.feedForward(true, false); //Need to feed forward before backprop - - Pair pair = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); - INDArray epsOut = pair.getSecond(); - assertNotNull(epsOut); - assertArrayEquals(new long[] {miniBatch, nIn}, epsOut.shape()); - - Gradient g = pair.getFirst(); - Map gradMap = g.gradientForVariable(); - assertEquals(6, gradMap.size()); //3 layers, weight + bias gradients for each - - String[] expKeys = {"0_" + DefaultParamInitializer.WEIGHT_KEY, "0_" + DefaultParamInitializer.BIAS_KEY, - "1_" + DefaultParamInitializer.WEIGHT_KEY, "2_" + DefaultParamInitializer.BIAS_KEY, - "2_" + DefaultParamInitializer.WEIGHT_KEY, "2_" + DefaultParamInitializer.BIAS_KEY}; - Set keys = gradMap.keySet(); - for (String s : expKeys) { - assertTrue(keys.contains(s)); - } - /* System.out.println(pair); @@ -443,1092 +476,1114 @@ public class MultiLayerTest extends BaseDL4JTest { net.setParams(params); //params() may not be in-place System.out.println(Arrays.toString(params.get(NDArrayIndex.all(), NDArrayIndex.interval(0, 10)).dup().data().asFloat())); */ + } + + @Test + public void testLayerNames() { + int nIn = 10; + int nOut = 40; + + List layerNameList = new ArrayList<>(); + layerNameList.add("dnn1"); + layerNameList.add("dnn2"); + layerNameList.add("dnn3"); + + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .updater(new Sgd(0.1)).list() + .layer(0, + new DenseLayer.Builder().name("dnn1").nIn(nIn).nOut(20).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build()) + .layer(1, new DenseLayer.Builder().name("dnn2").nIn(20).nOut(30).activation(Activation.RELU) + .weightInit(WeightInit.XAVIER).build()) + .layer(2, new DenseLayer.Builder().name("dnn3").nIn(30).nOut(nOut) + .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + assertEquals(layerNameList.get(0), net.getLayer(0).getLayerConfiguration().getLayerName()); + assertEquals(layerNameList, net.getLayerNames()); + BaseLayer b = (BaseLayer) net.getLayer(layerNameList.get(2)).getLayerConfiguration(); + assertEquals("softmax", b.getActivationFn().toString()); + } + + + @Test + public void testScoreExamples() { + Nd4j.getRandom().setSeed(12345); + int nIn = 5; + int nOut = 6; + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) + .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER) + .list() + .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) + .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()) + .layer(2, new OutputLayer.Builder() + .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) + .build(); + + NeuralNetConfiguration confNoReg = NeuralNetConfiguration.builder().seed(12345) + .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() + .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) + .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()) + .layer(2, new OutputLayer.Builder() + .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg); + netNoReg.init(); + netNoReg.setParameters(net.params().dup()); + + //Score single example, and compare to scoreExamples: + INDArray input = Nd4j.rand(3, nIn); + INDArray output = Nd4j.rand(3, nOut); + DataSet ds = new DataSet(input, output); + + INDArray scoresWithRegularization = net.scoreExamples(ds, true); + INDArray scoresNoRegularization = net.scoreExamples(ds, false); + + assertArrayEquals(new long[]{3, 1}, scoresWithRegularization.shape()); + assertArrayEquals(new long[]{3, 1}, scoresNoRegularization.shape()); + + for (int i = 0; i < 3; i++) { + DataSet singleEx = new DataSet(input.getRow(i, true), output.getRow(i, true)); + double score = net.score(singleEx); + double scoreNoReg = netNoReg.score(singleEx); + + double scoreUsingScoreExamples = scoresWithRegularization.getDouble(i); + double scoreUsingScoreExamplesNoReg = scoresNoRegularization.getDouble(i); + assertEquals(score, scoreUsingScoreExamples, 1e-4); + assertEquals(scoreNoReg, scoreUsingScoreExamplesNoReg, 1e-4); + assertTrue(scoreUsingScoreExamples + > scoreUsingScoreExamplesNoReg); //Regularization term increases score + + // System.out.println(score + "\t" + scoreUsingScoreExamples + "\t|\t" + scoreNoReg + "\t" + scoreUsingScoreExamplesNoReg); + } + } + + @Test + public void testDataSetScore() { + + Nd4j.getRandom().setSeed(12345); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .weightInit(WeightInit.XAVIER).seed(12345L).list() + .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.SIGMOID).build()) + .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(3).nOut(3).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray in = Nd4j.create(new double[]{1.0, 2.0, 3.0, 4.0}, 1, 4); + INDArray out = Nd4j.create(new double[]{1, 0, 0}, 1, 3); + + double score = net.score(new DataSet(in, out)); + } + + @Test + public void testDataSetScoreCNN() { + + int miniBatch = 3; + int depth = 2; + int width = 3; + int height = 3; + int nOut = 2; + + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .seed(12345L).list().layer(0, new ConvolutionLayer.Builder(2, 2).nOut(1).build()) + .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nOut(2).build()) + .inputType(InputType.convolutionalFlat(height, width, depth)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + Nd4j.getRandom().setSeed(12345); + Random r = new Random(12345); + INDArray input = Nd4j.rand(miniBatch, depth * width * height); + INDArray labels = Nd4j.create(miniBatch, nOut); + for (int i = 0; i < miniBatch; i++) { + labels.putScalar(new int[]{i, r.nextInt(nOut)}, 1.0); } - @Test - public void testLayerNames() { - int nIn = 10; - int nOut = 40; + double score = net.score(new DataSet(input, labels)); + } - List layerNameList = new ArrayList<>(); - layerNameList.add("dnn1"); - layerNameList.add("dnn2"); - layerNameList.add("dnn3"); + @Test + public void testPredict() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Sgd(0.1)).list() - .layer(0, new DenseLayer.Builder().name("dnn1").nIn(nIn).nOut(20).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(1, new DenseLayer.Builder().name("dnn2").nIn(20).nOut(30).activation(Activation.RELU) - .weightInit(WeightInit.XAVIER).build()) - .layer(2, new DenseLayer.Builder().name("dnn3").nIn(30).nOut(nOut) - .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER).build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); + Nd4j.getRandom().setSeed(12345); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .weightInit(WeightInit.XAVIER).seed(12345L).list() + .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()) + .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) + .inputType(InputType.convolutional(28, 28, 1)).build(); - assertEquals(layerNameList.get(0), net.getLayer(0).conf().getLayer().getLayerName()); - assertEquals(layerNameList, net.getLayerNames()); - BaseLayer b = (BaseLayer) net.getLayer(layerNameList.get(2)).conf().getLayer(); - assertEquals("softmax", b.getActivationFn().toString()); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + DataSetIterator ds = new MnistDataSetIterator(10, 10); + net.fit(ds); + + DataSetIterator testDs = new MnistDataSetIterator(1, 1); + DataSet testData = testDs.next(); + testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")); + String actualLables = testData.getLabelName(0); + List prediction = net.predict(testData); + assertNotNull(actualLables); + assertNotNull(prediction.get(0)); + } + + @Test + //@Ignore + public void testCid() throws Exception { + System.out.println(EnvironmentUtils.buildCId()); + + Environment environment = EnvironmentUtils.buildEnvironment(); + environment.setSerialVersionID(EnvironmentUtils.buildCId()); + + Task task = TaskUtils.buildTask(Nd4j.create(new double[]{1, 2, 3, 4, 5, 6}, 1, 6)); + + Heartbeat.getInstance().reportEvent(Event.STANDALONE, environment, task); + + Thread.sleep(25000); + } + + @Test + public void testOutput() throws Exception { + Nd4j.getRandom().setSeed(12345); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .weightInit(WeightInit.XAVIER).seed(12345L).list() + .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()) + .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) + .inputType(InputType.convolutional(28, 28, 1)).build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + DataSetIterator fullData = new MnistDataSetIterator(1, 2); + net.fit(fullData); + + fullData.reset(); + DataSet expectedSet = fullData.next(2); + INDArray expectedOut = net.output(expectedSet.getFeatures(), false); + + fullData.reset(); + + INDArray actualOut = net.output(fullData); + + assertEquals(expectedOut, actualOut); + } + + @Test + public void testGradientUpdate() throws Exception { + DataSetIterator iter = new IrisDataSetIterator(1, 1); + + Gradient expectedGradient = new DefaultGradient(); + expectedGradient.setGradientFor("0_W", Nd4j.ones(4, 5)); + expectedGradient.setGradientFor("0_b", Nd4j.ones(1, 5)); + expectedGradient.setGradientFor("1_W", Nd4j.ones(5, 3)); + expectedGradient.setGradientFor("1_b", Nd4j.ones(1, 3)); + + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new Sgd(1.0)) + .activation(Activation.RELU).weightInit(WeightInit.XAVIER) + .list().layer(0, new DenseLayer.Builder().name("dnn1").nIn(4).nOut(5).build()) + .layer(1, new OutputLayer.Builder().name("output").nIn(5).nOut(3) + .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER) + .build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + net.fit(iter.next()); + // TODO validate actual layer gradientView - issue getting var out of BaseLayer w/o adding MLN getter that gets confused with local gradient vars + Gradient actualGradient = net.gradient; + assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W")); + + net.update(expectedGradient); + actualGradient = net.gradient; + assertEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W")); + + // Update params with set + net.setParam("0_W", Nd4j.ones(4, 5)); + net.setParam("0_b", Nd4j.ones(1, 5)); + net.setParam("1_W", Nd4j.ones(5, 3)); + net.setParam("1_b", Nd4j.ones(1, 3)); + INDArray actualParams = net.params(); + + // Confirm params + assertEquals(expectedGradient.gradient(), actualParams); + + net.update(expectedGradient); + actualParams = net.params(); + assertEquals(Nd4j.ones(1, 43).addi(1), actualParams); + } + + + @Test + public void testCnnInvalidData() { + assertThrows(DL4JException.class, () -> { + int miniBatch = 3; + int depth = 2; + int width = 5; + int height = 5; + + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() + .layer(0, + new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).nIn(2) + .nOut(2).build()) + .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nOut(2).build()) + .inputType(InputType.convolutional(height, width, depth)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray inputWrongDepth = Nd4j.rand(miniBatch, 5, height, + width); //Order: examples, channels, height, width + net.feedForward(inputWrongDepth); + }); + } + + @Test + public void testApplyingPreTrainConfigAndParams() { + int nIn = 10; + int nOut = 10; + + // Test pretrain true + MultiLayerNetwork aePre = getAeModel(true, nIn, nOut); + int actualNP = (int) aePre.numParams(); + assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); + INDArray params = aePre.params(); + assertEquals(params.length(), actualNP); // check num params + Map paramTable = aePre.getParamTable(); + assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer + aePre.setParam("0_vb", Nd4j.ones(10)); + params = aePre.getParam("0_vb"); + assertEquals(Nd4j.ones(1, 10), params); // check set params for vb + + // Test pretrain false, expect same for true because its not changed when applying update + MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut); + actualNP = (int) aeNoPre.numParams(); + assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); + params = aeNoPre.params(); + assertEquals(params.length(), actualNP); + paramTable = aePre.getParamTable(); + assertTrue(paramTable.containsKey("0_vb")); + } + + public MultiLayerNetwork getAeModel(boolean preTrain, int nIn, int nOut) { + NeuralNetConfiguration vae = NeuralNetConfiguration.builder() + .seed(42).updater(new NoOp()) + .weightInit(WeightInit.UNIFORM) + .layer(new AutoEncoder.Builder() + .activation(Activation.IDENTITY).nOut(nIn).build()) + .layer( + new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.COSINE_PROXIMITY) + .activation(Activation.IDENTITY).nOut(nOut) + .build()) + + .inputType(InputType.feedForward(nOut)).build(); + MultiLayerNetwork network = new MultiLayerNetwork(vae); + network.init(); + return network; + } + + + @Test + public void testIterationCountAndPersistence() throws IOException { + Nd4j.getRandom().setSeed(123); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) + .list() + .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + .build()) + .build(); + + MultiLayerNetwork network = new MultiLayerNetwork(conf); + network.init(); + + DataSetIterator iter = new IrisDataSetIterator(50, 150); + + assertEquals(0, network.getConfiguration().getIterationCount()); + network.fit(iter); + assertEquals(3, network.getConfiguration().getIterationCount()); + iter.reset(); + network.fit(iter); + assertEquals(6, network.getConfiguration().getIterationCount()); + iter.reset(); + network.fit(iter.next()); + assertEquals(7, network.getConfiguration().getIterationCount()); + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ModelSerializer.writeModel(network, baos, true); + byte[] asBytes = baos.toByteArray(); + + ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); + MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(bais, true); + assertEquals(7, net.getConfiguration().getIterationCount()); + } + + + @Test + public void testBiasL1L2() { + + Nd4j.getRandom().setSeed(123); + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .weightInit(WeightInit.XAVIER).activation(Activation.TANH).seed(123).list() + .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) + .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(10) + .build()) + .build(); + + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .l1Bias(0.1).l2Bias(0.2).weightInit(WeightInit.XAVIER).activation(Activation.TANH) + .seed(123).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) + .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( + LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(10) + .build()) + .build(); + + MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); + net1.init(); + + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net2.init(); + + BaseLayer bl0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration(); + assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6); + assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6); + + INDArray features = Nd4j.rand(10, 10); + INDArray labels = Nd4j.rand(10, 10); + + net2.setParams(net1.params().dup()); + + net1.setInput(features); + net1.setLabels(labels); + net2.setInput(features); + net2.setLabels(labels); + + net1.computeGradientAndScore(); + net2.computeGradientAndScore(); + + double r = net1.calcRegularizationScore(true); + assertEquals(0.0, r, 0.0); + + r = net2.calcRegularizationScore(true); + assertEquals(0.0, r, 0.0); + + double s1 = net1.score(); + double s2 = net2.score(); + assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score + + for (int i = 0; i < 10; i++) { + net1.fit(features, labels); } + net2.setParams(net1.params().dup()); + net1.computeGradientAndScore(); + net2.computeGradientAndScore(); - @Test - public void testScoreExamples() { - Nd4j.getRandom().setSeed(12345); - int nIn = 5; - int nOut = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) - .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) - .build(); + r = net1.calcRegularizationScore(true); + assertEquals(0.0, r, 0.0); - MultiLayerConfiguration confNoReg = new NeuralNetConfiguration.Builder().seed(12345) - .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() - .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) - .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() - .lossFunction(LossFunctions.LossFunction.MSE).nIn(30).nOut(nOut).build()) - .build(); + r = net2.calcRegularizationScore(true); + assertTrue(r > 0.0); + s1 = net1.score(); + s2 = net2.score(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); + assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2 - MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg); - netNoReg.init(); - netNoReg.setParameters(net.params().dup()); - - //Score single example, and compare to scoreExamples: - INDArray input = Nd4j.rand(3, nIn); - INDArray output = Nd4j.rand(3, nOut); - DataSet ds = new DataSet(input, output); - - INDArray scoresWithRegularization = net.scoreExamples(ds, true); - INDArray scoresNoRegularization = net.scoreExamples(ds, false); - - assertArrayEquals(new long[] {3, 1}, scoresWithRegularization.shape()); - assertArrayEquals(new long[] {3, 1}, scoresNoRegularization.shape()); - - for (int i = 0; i < 3; i++) { - DataSet singleEx = new DataSet(input.getRow(i,true), output.getRow(i,true)); - double score = net.score(singleEx); - double scoreNoReg = netNoReg.score(singleEx); - - double scoreUsingScoreExamples = scoresWithRegularization.getDouble(i); - double scoreUsingScoreExamplesNoReg = scoresNoRegularization.getDouble(i); - assertEquals(score, scoreUsingScoreExamples, 1e-4); - assertEquals(scoreNoReg, scoreUsingScoreExamplesNoReg, 1e-4); - assertTrue(scoreUsingScoreExamples > scoreUsingScoreExamplesNoReg); //Regularization term increases score - - // System.out.println(score + "\t" + scoreUsingScoreExamples + "\t|\t" + scoreNoReg + "\t" + scoreUsingScoreExamplesNoReg); - } + for (int i = 0; i < 2; i++) { + assertEquals(0.0, net1.getLayer(i).calcRegularizationScore(true), 0.0); + assertTrue(net2.getLayer(i).calcRegularizationScore(true) > 0.0); } - - @Test - public void testDataSetScore() { - - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .weightInit(WeightInit.XAVIER).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).activation(Activation.SIGMOID).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(3).nOut(3).build()) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - INDArray in = Nd4j.create(new double[] {1.0, 2.0, 3.0, 4.0}, 1, 4); - INDArray out = Nd4j.create(new double[] {1, 0, 0}, 1,3); - - double score = net.score(new DataSet(in, out)); - } - - @Test - public void testDataSetScoreCNN() { - - int miniBatch = 3; - int depth = 2; - int width = 3; - int height = 3; - int nOut = 2; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345L).list().layer(0, new ConvolutionLayer.Builder(2, 2).nOut(1).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(2).build()) - .setInputType(InputType.convolutionalFlat(height, width, depth)) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - Nd4j.getRandom().setSeed(12345); - Random r = new Random(12345); - INDArray input = Nd4j.rand(miniBatch, depth * width * height); - INDArray labels = Nd4j.create(miniBatch, nOut); - for (int i = 0; i < miniBatch; i++) { - labels.putScalar(new int[] {i, r.nextInt(nOut)}, 1.0); - } - - double score = net.score(new DataSet(input, labels)); - } - - @Test - public void testPredict() throws Exception { - - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .weightInit(WeightInit.XAVIER).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) - .setInputType(InputType.convolutional(28, 28, 1)).build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - DataSetIterator ds = new MnistDataSetIterator(10, 10); - net.fit(ds); - - DataSetIterator testDs = new MnistDataSetIterator(1, 1); - DataSet testData = testDs.next(); - testData.setLabelNames(Arrays.asList("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")); - String actualLables = testData.getLabelName(0); - List prediction = net.predict(testData); - assertNotNull(actualLables); - assertNotNull(prediction.get(0)); - } - - @Test - //@Ignore - public void testCid() throws Exception { - System.out.println(EnvironmentUtils.buildCId()); - - Environment environment = EnvironmentUtils.buildEnvironment(); - environment.setSerialVersionID(EnvironmentUtils.buildCId()); - - Task task = TaskUtils.buildTask(Nd4j.create(new double[] {1, 2, 3, 4, 5, 6}, 1,6)); - - Heartbeat.getInstance().reportEvent(Event.STANDALONE, environment, task); - - Thread.sleep(25000); - } - - @Test - public void testOutput() throws Exception { - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .weightInit(WeightInit.XAVIER).seed(12345L).list() - .layer(0, new DenseLayer.Builder().nIn(784).nOut(50).activation(Activation.RELU).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(50).nOut(10).build()) - .setInputType(InputType.convolutional(28, 28, 1)).build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - DataSetIterator fullData = new MnistDataSetIterator(1, 2); - net.fit(fullData); - - - fullData.reset(); - DataSet expectedSet = fullData.next(2); - INDArray expectedOut = net.output(expectedSet.getFeatures(), false); - - fullData.reset(); - - INDArray actualOut = net.output(fullData); - - assertEquals(expectedOut, actualOut); - } - - @Test - public void testGradientUpdate() throws Exception { - DataSetIterator iter = new IrisDataSetIterator(1, 1); - - Gradient expectedGradient = new DefaultGradient(); - expectedGradient.setGradientFor("0_W", Nd4j.ones(4, 5)); - expectedGradient.setGradientFor("0_b", Nd4j.ones(1, 5)); - expectedGradient.setGradientFor("1_W", Nd4j.ones(5, 3)); - expectedGradient.setGradientFor("1_b", Nd4j.ones(1, 3)); - - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Sgd(1.0)) - .activation(Activation.RELU).weightInit(WeightInit.XAVIER) - .list().layer(0, new DenseLayer.Builder().name("dnn1").nIn(4).nOut(5).build()) - .layer(1, new OutputLayer.Builder().name("output").nIn(5).nOut(3) - .activation(Activation.SOFTMAX).weightInit(WeightInit.XAVIER) - .build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - net.fit(iter.next()); - // TODO validate actual layer gradientView - issue getting var out of BaseLayer w/o adding MLN getter that gets confused with local gradient vars - Gradient actualGradient = net.gradient; - assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W")); - - net.update(expectedGradient); - actualGradient = net.gradient; - assertEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W")); - - // Update params with set - net.setParam("0_W", Nd4j.ones(4, 5)); - net.setParam("0_b", Nd4j.ones(1, 5)); - net.setParam("1_W", Nd4j.ones(5, 3)); - net.setParam("1_b", Nd4j.ones(1, 3)); - INDArray actualParams = net.params(); - - // Confirm params - assertEquals(expectedGradient.gradient(), actualParams); - - net.update(expectedGradient); - actualParams = net.params(); - assertEquals(Nd4j.ones(1, 43).addi(1), actualParams); - } - - - @Test - public void testCnnInvalidData() { - assertThrows(DL4JException.class, () -> { - int miniBatch = 3; - int depth = 2; - int width = 5; - int height = 5; - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() - .layer(0, new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).padding(0, 0).nIn(2) - .nOut(2).build()) - .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nOut(2).build()) - .setInputType(InputType.convolutional(height, width, depth)) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - INDArray inputWrongDepth = Nd4j.rand(miniBatch, 5, height, width); //Order: examples, channels, height, width - net.feedForward(inputWrongDepth); - }); - } - - @Test - public void testApplyingPreTrainConfigAndParams() { - int nIn = 10; - int nOut = 10; - - // Test pretrain true - MultiLayerNetwork aePre = getAeModel(true, nIn, nOut); - int actualNP = (int)aePre.numParams(); - assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); - INDArray params = aePre.params(); - assertEquals(params.length(), actualNP); // check num params - Map paramTable = aePre.paramTable(); - assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer - aePre.setParam("0_vb", Nd4j.ones(10)); - params = aePre.getParam("0_vb"); - assertEquals(Nd4j.ones(1,10), params); // check set params for vb - - - // Test pretrain false, expect same for true because its not changed when applying update - MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut); - actualNP = (int)aeNoPre.numParams(); - assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); - params = aeNoPre.params(); - assertEquals(params.length(), actualNP); - paramTable = aePre.paramTable(); - assertTrue(paramTable.containsKey("0_vb")); - } - - public MultiLayerNetwork getAeModel(boolean preTrain, int nIn, int nOut) { - MultiLayerConfiguration vae = new NeuralNetConfiguration.Builder() - .seed(42).updater(new NoOp()) - .weightInit(WeightInit.UNIFORM) - .list(new AutoEncoder.Builder() - .activation(Activation.IDENTITY).nOut(nIn).build(), - new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.COSINE_PROXIMITY) - .activation(Activation.IDENTITY).nOut(nOut) - .build()) - .setInputType(InputType.feedForward(nOut)).build(); - MultiLayerNetwork network = new MultiLayerNetwork(vae); - network.init(); - return network; - } - - - @Test - public void testIterationCountAndPersistence() throws IOException { - Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).seed(123) - .list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build(); - - - MultiLayerNetwork network = new MultiLayerNetwork(conf); - network.init(); - - DataSetIterator iter = new IrisDataSetIterator(50, 150); - - assertEquals(0, network.getLayerWiseConfigurations().getIterationCount()); - network.fit(iter); - assertEquals(3, network.getLayerWiseConfigurations().getIterationCount()); - iter.reset(); - network.fit(iter); - assertEquals(6, network.getLayerWiseConfigurations().getIterationCount()); - iter.reset(); - network.fit(iter.next()); - assertEquals(7, network.getLayerWiseConfigurations().getIterationCount()); - - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ModelSerializer.writeModel(network, baos, true); - byte[] asBytes = baos.toByteArray(); - - ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); - MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(7, net.getLayerWiseConfigurations().getIterationCount()); - } - - - @Test - public void testBiasL1L2() { - - - Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .weightInit(WeightInit.XAVIER).activation(Activation.TANH).seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(10) - .build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .l1Bias(0.1).l2Bias(0.2).weightInit(WeightInit.XAVIER).activation(Activation.TANH) - .seed(123).list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(10).nOut(10) - .build()) - .build(); - - MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); - net1.init(); - - MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); - net2.init(); - - BaseLayer bl0 = (BaseLayer) net2.getLayer(0).conf().getLayer(); - assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6); - assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6); - - INDArray features = Nd4j.rand(10, 10); - INDArray labels = Nd4j.rand(10, 10); - - net2.setParams(net1.params().dup()); - - net1.setInput(features); - net1.setLabels(labels); - net2.setInput(features); - net2.setLabels(labels); - - net1.computeGradientAndScore(); - net2.computeGradientAndScore(); - - double r = net1.calcRegularizationScore(true); - assertEquals(0.0, r, 0.0); - - r = net2.calcRegularizationScore(true); - assertEquals(0.0, r, 0.0); - - - double s1 = net1.score(); - double s2 = net2.score(); - assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score - - for (int i = 0; i < 10; i++) { - net1.fit(features, labels); - } - - net2.setParams(net1.params().dup()); - net1.computeGradientAndScore(); - net2.computeGradientAndScore(); - - r = net1.calcRegularizationScore(true); - assertEquals(0.0, r, 0.0); - - r = net2.calcRegularizationScore(true); - assertTrue(r > 0.0); - - s1 = net1.score(); - s2 = net2.score(); - - assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2 - - for (int i = 0; i < 2; i++) { - assertEquals(0.0, net1.getLayer(i).calcRegularizationScore(true), 0.0); - assertTrue(net2.getLayer(i).calcRegularizationScore(true) > 0.0); - } - } - - /* - Summary should pick up preprocessors set manually on inputs as well - */ - @Test - public void testSummary() { - int V_WIDTH = 130; - int V_HEIGHT = 130; - int V_NFRAMES = 150; - MultiLayerConfiguration confForArchitecture = - new NeuralNetConfiguration.Builder().seed(12345).l2(0.001) //l2 regularization on all layers - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .list() - .layer(0, new ConvolutionLayer.Builder(10, 10).nIn(3) //3 channels: RGB - .nOut(30).stride(4, 4).activation(Activation.RELU).weightInit( - WeightInit.RELU) - .updater(Updater.ADAGRAD).build()) //Output: (130-10+0)/4+1 = 31 -> 31*31*30 - .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) - .kernelSize(3, 3).stride(2, 2).build()) //(31-3+0)/2+1 = 15 - .layer(2, new ConvolutionLayer.Builder(3, 3).nIn(30).nOut(10).stride(2, 2) - .activation(Activation.RELU).weightInit(WeightInit.RELU) - .updater(Updater.ADAGRAD).build()) //Output: (15-3+0)/2+1 = 7 -> 7*7*10 = 490 - .layer(3, new DenseLayer.Builder().activation(Activation.RELU).nIn(490).nOut(50) - .weightInit(WeightInit.RELU).updater(Updater.ADAGRAD) - .gradientNormalization( - GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .layer(4, new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50) - .nOut(50).weightInit(WeightInit.XAVIER).updater(Updater.ADAGRAD) - .gradientNormalization( - GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10) - .build()) - .layer(5, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .activation(Activation.SOFTMAX).nIn(50).nOut(4) //4 possible shapes: circle, square, arc, line - .updater(Updater.ADAGRAD).weightInit(WeightInit.XAVIER) - .gradientNormalization( - GradientNormalization.ClipElementWiseAbsoluteValue) - .gradientNormalizationThreshold(10).build()) - .inputPreProcessor(0, new RnnToCnnPreProcessor(V_HEIGHT, V_WIDTH, 3)) - .inputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)) - .inputPreProcessor(4, new FeedForwardToRnnPreProcessor()) - .backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(V_NFRAMES / 5).tBPTTBackwardLength(V_NFRAMES / 5).build(); - MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(confForArchitecture); - modelExpectedArch.init(); - MultiLayerNetwork modelMow = new TransferLearning.Builder(modelExpectedArch).setFeatureExtractor(2).build(); + } + + /* + Summary should pick up preprocessors set manually on inputs as well + */ + @Test + public void testSummary() { + int V_WIDTH = 130; + int V_HEIGHT = 130; + int V_NFRAMES = 150; + NeuralNetConfiguration confForArchitecture = + NeuralNetConfiguration.builder().seed(12345).l2(0.001) //l2 regularization on all layers + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) + .list() + .layer(0, new ConvolutionLayer.Builder(10, 10).nIn(3) //3 channels: RGB + .nOut(30).stride(4, 4).activation(Activation.RELU).weightInit( + WeightInit.RELU) + .updater(Updater.ADAGRAD).build()) //Output: (130-10+0)/4+1 = 31 -> 31*31*30 + .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) + .kernelSize(3, 3).stride(2, 2).build()) //(31-3+0)/2+1 = 15 + .layer(2, new ConvolutionLayer.Builder(3, 3).nIn(30).nOut(10).stride(2, 2) + .activation(Activation.RELU).weightInit(WeightInit.RELU) + .updater(Updater.ADAGRAD).build()) //Output: (15-3+0)/2+1 = 7 -> 7*7*10 = 490 + .layer(3, new DenseLayer.Builder().activation(Activation.RELU).nIn(490).nOut(50) + .weightInit(WeightInit.RELU).updater(Updater.ADAGRAD) + .gradientNormalization( + GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(10).build()) + .layer(4, new GravesLSTM.Builder().activation(Activation.SOFTSIGN).nIn(50) + .nOut(50).weightInit(WeightInit.XAVIER).updater(Updater.ADAGRAD) + .gradientNormalization( + GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(10) + .build()) + .layer(5, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(50) + .nOut(4) //4 possible shapes: circle, square, arc, line + .updater(Updater.ADAGRAD).weightInit(WeightInit.XAVIER) + .gradientNormalization( + GradientNormalization.ClipElementWiseAbsoluteValue) + .gradientNormalizationThreshold(10).build()) + .inputPreProcessor(0, new RnnToCnnPreProcessor(V_HEIGHT, V_WIDTH, 3)) + .inputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)) + .inputPreProcessor(4, new FeedForwardToRnnPreProcessor()) + .backpropType(BackpropType.TruncatedBPTT) + .tbpttFwdLength(V_NFRAMES / 5).tbpttBackLength(V_NFRAMES / 5).build(); + MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(confForArchitecture); + modelExpectedArch.init(); + MultiLayerNetwork modelMow = new TransferLearning.Builder( + modelExpectedArch).setFeatureExtractor(2).build(); // System.out.println(modelExpectedArch.summary()); // System.out.println(modelMow.summary()); // System.out.println(modelMow.summary(InputType.recurrent(V_HEIGHT*V_WIDTH*3))); + } + + @Test + public void testErrorNoOutputLayer() { + assertThrows(DL4JException.class, () -> { + NeuralNetConfiguration c = NeuralNetConfiguration.builder().list() + .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).build(); + + MultiLayerNetwork net = new MultiLayerNetwork(c); + net.init(); + + INDArray f = Nd4j.create(1, 10); + INDArray l = Nd4j.create(1, 10); + + net.setInput(f); + net.setLabels(l); + + net.computeGradientAndScore(); + }); + } + + + @Test + public void testSetParamTable() { + + Nd4j.getRandom().setSeed(123); + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder().seed(123).list() + .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) + .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) + .build()) + .build(); + + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(987).list() + .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) + .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) + .build()) + .build(); + + MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); + net1.init(); + + MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); + net2.init(); + + assertNotEquals(net1.params(), net2.params()); + assertNotEquals(net1.getParamTable(), net2.getParamTable()); + + net1.setParamTable(net2.getParamTable()); + assertEquals(net1.params(), net2.params()); + assertEquals(net1.getParamTable(), net2.getParamTable()); + } + + + @Test + public void testCompareLayerMethods() { + //Simple test: compare .layer(int, ILayer) and .layer(ILayer) are identical + + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder().seed(123).list() + .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) + .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) + .build()) + .build(); + + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(123).list() + .layer(new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) + .activation(Activation.TANH).build()) + .layer(new LSTM.Builder().nIn(2).nOut(2).build()) + .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) + .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) + .build()) + .build(); + + assertEquals(conf1, conf2); + } + + + @Test + public void testEpochCounter() throws Exception { + + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .list() + .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + assertEquals(0, net.getConfiguration().getEpochCount()); + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + + for (int i = 0; i < 4; i++) { + assertEquals(i, net.getConfiguration().getEpochCount()); + net.fit(iter); + assertEquals(i + 1, net.getConfiguration().getEpochCount()); } - @Test - public void testErrorNoOutputLayer() { - assertThrows(DL4JException.class, () -> { - MultiLayerConfiguration c = new NeuralNetConfiguration.Builder().list() - .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).build(); + assertEquals(4, net.getConfiguration().getEpochCount()); - MultiLayerNetwork net = new MultiLayerNetwork(c); - net.init(); + MultiLayerNetwork restored = TestUtils.testModelSerialization(net); + assertEquals(4, restored.getConfiguration().getEpochCount()); + } - INDArray f = Nd4j.create(1, 10); - INDArray l = Nd4j.create(1, 10); + @Test + public void testInputClearance() throws Exception { + //Activations should be cleared - if not, it's possible for out of (workspace) scope arrays to be around + // which can cause a crash + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .convolutionMode(ConvolutionMode.Same) + .list() + .layer(new ConvolutionLayer.Builder().kernelSize(2, 2).stride(1, 1).nIn(1).nOut(1).build()) + .layer(new SubsamplingLayer.Builder().kernelSize(2, 2).stride(1, 1).build()) + .layer(new DenseLayer.Builder().nOut(10).build()) + .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) + .inputType(InputType.convolutional(28, 28, 1)) + .build(); - net.setInput(f); - net.setLabels(l); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); - net.computeGradientAndScore(); - }); + INDArray content = Nd4j.create(1, 1, 28, 28); + + //Check output: + net.output(content); + for (org.deeplearning4j.nn.api.Layer l : net.getLayers()) { + assertNull(l.input()); } - - @Test - public void testSetParamTable() { - - Nd4j.getRandom().setSeed(123); - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) - .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) - .build()) - .build(); - - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(987).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) - .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) - .build()) - .build(); - - MultiLayerNetwork net1 = new MultiLayerNetwork(conf1); - net1.init(); - - MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); - net2.init(); - - assertNotEquals(net1.params(), net2.params()); - assertNotEquals(net1.paramTable(), net2.paramTable()); - - net1.setParamTable(net2.paramTable()); - assertEquals(net1.params(), net2.params()); - assertEquals(net1.paramTable(), net2.paramTable()); + //Check feedForward: + net.feedForward(content, false); + for (org.deeplearning4j.nn.api.Layer l : net.getLayers()) { + assertNull(l.input()); } + } - @Test - public void testCompareLayerMethods(){ - //Simple test: compare .layer(int, ILayer) and .layer(ILayer) are identical + @Test + public void testExternalErrors() { + //Simple test: same network, but in one case: one less layer (the OutputLayer), where the epsilons are passed in externally + // instead. Should get identical results - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(123).list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(2, new LSTM.Builder().nIn(2).nOut(2).build()) - .layer(3, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) - .build()) - .build(); + for (WorkspaceMode ws : WorkspaceMode.values()) { + log.info("Workspace mode: " + ws); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(123).list() - .layer(new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(new DenseLayer.Builder().nIn(3).nOut(2).weightInit(WeightInit.XAVIER) - .activation(Activation.TANH).build()) - .layer(new LSTM.Builder().nIn(2).nOut(2).build()) - .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) - .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nIn(2).nOut(3) - .build()) - .build(); + Nd4j.getRandom().setSeed(12345); + INDArray inData = Nd4j.rand(3, 10); + INDArray outData = Nd4j.rand(3, 10); - assertEquals(conf1, conf2); + Nd4j.getRandom().setSeed(12345); + NeuralNetConfiguration standard = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) + .trainingWorkspaceMode(ws) + .inferenceWorkspaceMode(ws) + .seed(12345).list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) + .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10) + .nOut(10).build()) + .build(); + MultiLayerNetwork s = new MultiLayerNetwork(standard); + s.init(); + + Nd4j.getRandom().setSeed(12345); + NeuralNetConfiguration external = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) + .trainingWorkspaceMode(ws) + .inferenceWorkspaceMode(ws) + .seed(12345).list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) + .build(); + + MultiLayerNetwork e = new MultiLayerNetwork(external); + e.init(); + + s.setInput(inData); + s.setLabels(outData); + s.computeGradientAndScore(); + Gradient sGrad = s.gradient(); + + s.setInput(inData); + s.feedForward(true, false); //FF without clearing inputs as we need them later + + e.setInput(inData); + e.feedForward(true, false); //FF without clearing inputs as we need them later + + org.deeplearning4j.nn.layers.OutputLayer ol = (org.deeplearning4j.nn.layers.OutputLayer) s.getLayer( + 1); + Pair olPairStd = ol.backpropGradient(null, + LayerWorkspaceMgr.noWorkspaces()); + + INDArray olEpsilon = olPairStd.getSecond().detach(); + + e.setInput(inData); + e.feedForward(true, false); + Pair extErrorGrad = e.backpropGradient(olEpsilon, + LayerWorkspaceMgr.noWorkspaces()); + + int nParamsDense = 10 * 10 + 10; + assertEquals(sGrad.gradient() + .get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(0, nParamsDense)), + extErrorGrad.getFirst().gradient()); + + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); } + } + @Test + public void testExternalErrors2() { + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + int nIn = 4; + int nOut = 3; - @Test - public void testEpochCounter() throws Exception { - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - assertEquals(0, net.getLayerWiseConfigurations().getEpochCount()); - - - DataSetIterator iter = new IrisDataSetIterator(150, 150); - - for( int i=0; i<4; i++ ){ - assertEquals(i, net.getLayerWiseConfigurations().getEpochCount()); - net.fit(iter); - assertEquals(i+1, net.getLayerWiseConfigurations().getEpochCount()); - } - - assertEquals(4, net.getLayerWiseConfigurations().getEpochCount()); - - MultiLayerNetwork restored = TestUtils.testModelSerialization(net); - assertEquals(4, restored.getLayerWiseConfigurations().getEpochCount()); - } - - @Test - public void testInputClearance() throws Exception { - //Activations should be cleared - if not, it's possible for out of (workspace) scope arrays to be around - // which can cause a crash - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .convolutionMode(ConvolutionMode.Same) - .list() - .layer(new ConvolutionLayer.Builder().kernelSize(2,2).stride(1,1).nIn(1).nOut(1).build()) - .layer(new SubsamplingLayer.Builder().kernelSize(2,2).stride(1,1).build()) - .layer(new DenseLayer.Builder().nOut(10).build()) - .layer(new OutputLayer.Builder().nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28,28,1)) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - INDArray content = Nd4j.create(1,1,28,28); - - //Check output: - net.output(content); - for(org.deeplearning4j.nn.api.Layer l : net.getLayers()){ - assertNull(l.input()); - } - - //Check feedForward: - net.feedForward(content, false); - for(org.deeplearning4j.nn.api.Layer l : net.getLayers()){ - assertNull(l.input()); - } - } - - - @Test - public void testExternalErrors() { - //Simple test: same network, but in one case: one less layer (the OutputLayer), where the epsilons are passed in externally - // instead. Should get identical results - - for(WorkspaceMode ws : WorkspaceMode.values()) { - log.info("Workspace mode: " + ws); - - Nd4j.getRandom().setSeed(12345); - INDArray inData = Nd4j.rand(3, 10); - INDArray outData = Nd4j.rand(3, 10); - - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration standard = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .seed(12345).list() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nIn(10) - .nOut(10).build()) - .build(); - MultiLayerNetwork s = new MultiLayerNetwork(standard); - s.init(); - - - Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration external = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .seed(12345).list() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) - .build(); - - MultiLayerNetwork e = new MultiLayerNetwork(external); - e.init(); - - s.setInput(inData); - s.setLabels(outData); - s.computeGradientAndScore(); - Gradient sGrad = s.gradient(); - - s.setInput(inData); - s.feedForward(true, false); //FF without clearing inputs as we need them later - - e.setInput(inData); - e.feedForward(true, false); //FF without clearing inputs as we need them later - - org.deeplearning4j.nn.layers.OutputLayer ol = (org.deeplearning4j.nn.layers.OutputLayer) s.getLayer(1); - Pair olPairStd = ol.backpropGradient(null, LayerWorkspaceMgr.noWorkspaces()); - - INDArray olEpsilon = olPairStd.getSecond().detach(); - - e.setInput(inData); - e.feedForward(true, false); - Pair extErrorGrad = e.backpropGradient(olEpsilon, LayerWorkspaceMgr.noWorkspaces()); - - int nParamsDense = 10 * 10 + 10; - assertEquals(sGrad.gradient().get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nParamsDense)), - extErrorGrad.getFirst().gradient()); - - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - } - } - - @Test - public void testExternalErrors2(){ - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); - int nIn = 4; - int nOut = 3; - - for(WorkspaceMode ws : WorkspaceMode.values()) { + for (WorkspaceMode ws : WorkspaceMode.values()) { // System.out.println("***** WORKSPACE: " + ws); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .updater(new Adam(0.01)) - .trainingWorkspaceMode(ws) - .inferenceWorkspaceMode(ws) - .list() - .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.RELU).build()) - .layer(new ActivationLayer.Builder().activation(Activation.IDENTITY).build()) - .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) - .inputPreProcessor(1, new FeedForwardToRnnPreProcessor()) - .build(); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .updater(new Adam(0.01)) + .trainingWorkspaceMode(ws) + .inferenceWorkspaceMode(ws) + .list() + .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.RELU).build()) + .layer(new ActivationLayer.Builder().activation(Activation.IDENTITY).build()) + .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) + .inputPreProcessor(1, new FeedForwardToRnnPreProcessor()) + .build(); - MultiLayerNetwork graph = new MultiLayerNetwork(conf); - graph.init(); + MultiLayerNetwork graph = new MultiLayerNetwork(conf); + graph.init(); - final int minibatch = 5; - final int seqLen = 6; + final int minibatch = 5; + final int seqLen = 6; - INDArray param = Nd4j.create(new double[]{0.54, 0.31, 0.98, -0.30, -0.66, -0.19, -0.29, -0.62, 0.13, -0.32, 0.01, -0.03, 0.00, 0.00, 0.00}).reshape(1, -1); - graph.setParams(param); + INDArray param = Nd4j.create( + new double[]{0.54, 0.31, 0.98, -0.30, -0.66, -0.19, -0.29, -0.62, 0.13, -0.32, 0.01, + -0.03, 0.00, 0.00, 0.00}).reshape(1, -1); + graph.setParams(param); - INDArray input = Nd4j.rand(new int[]{minibatch, nIn, seqLen}, 12); - INDArray expected = Nd4j.ones(minibatch, nOut, seqLen); + INDArray input = Nd4j.rand(new int[]{minibatch, nIn, seqLen}, 12); + INDArray expected = Nd4j.ones(minibatch, nOut, seqLen); - graph.setInput(input); - INDArray output = graph.feedForward(false, false).get(2); - INDArray error = output.sub(expected); + graph.setInput(input); + INDArray output = graph.feedForward(false, false).get(2); + INDArray error = output.sub(expected); - for (org.deeplearning4j.nn.api.Layer l : graph.getLayers()) { - assertNotNull(l.input()); - assertFalse(l.input().isAttached()); - } + for (org.deeplearning4j.nn.api.Layer l : graph.getLayers()) { + assertNotNull(l.input()); + assertFalse(l.input().isAttached()); + } - // Compute Gradient - Pair gradient = graph.backpropGradient(error, LayerWorkspaceMgr.noWorkspaces()); - graph.getUpdater().update(graph, gradient.getFirst(), 0, 0, minibatch, LayerWorkspaceMgr.noWorkspaces()); + // Compute Gradient + Pair gradient = graph.backpropGradient(error, + LayerWorkspaceMgr.noWorkspaces()); + graph.getUpdater() + .update(graph, gradient.getFirst(), 0, 0, minibatch, LayerWorkspaceMgr.noWorkspaces()); - Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - } - - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); } - @Test - public void testLayerSize(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.DISABLED); + } - .list() - .layer(new ConvolutionLayer.Builder().kernelSize(2,2).nOut(6).build()) - .layer(new SubsamplingLayer.Builder().kernelSize(2,2).build()) - .layer(new DenseLayer.Builder().nOut(30).build()) - .layer(new OutputLayer.Builder().nOut(13).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28,28,3)) - .build(); + @Test + public void testLayerSize() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); + .list() + .layer(new ConvolutionLayer.Builder().kernelSize(2, 2).nOut(6).build()) + .layer(new SubsamplingLayer.Builder().kernelSize(2, 2).build()) + .layer(new DenseLayer.Builder().nOut(30).build()) + .layer(new OutputLayer.Builder().nOut(13).activation(Activation.SOFTMAX).build()) + .inputType(InputType.convolutional(28, 28, 3)) + .build(); - assertEquals(6, net.layerSize(0)); - assertEquals(0, net.layerSize(1)); - assertEquals(30, net.layerSize(2)); - assertEquals(13, net.layerSize(3)); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); - assertEquals(3, net.layerInputSize(0)); - assertEquals(0, net.layerInputSize(1)); - assertEquals(((FeedForwardLayer)net.getLayer(2).conf().getLayer()).getNIn(), net.layerInputSize(2)); - assertEquals(30, net.layerInputSize(3)); + assertEquals(6, net.layerSize(0)); + assertEquals(0, net.layerSize(1)); + assertEquals(30, net.layerSize(2)); + assertEquals(13, net.layerSize(3)); + + assertEquals(3, net.layerInputSize(0)); + assertEquals(0, net.layerInputSize(1)); + assertEquals(((FeedForwardLayer) net.getLayer(2).getLayerConfiguration()).getNIn(), + net.layerInputSize(2)); + assertEquals(30, net.layerInputSize(3)); + } + + + @Test + public void testZeroParamNet() throws Exception { + + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .list() + .layer(new SubsamplingLayer.Builder().kernelSize(2, 2).stride(2, 2).build()) + .layer(new LossLayer.Builder().activation(Activation.SIGMOID) + .lossFunction(LossFunctions.LossFunction.MSE).build()) + .inputType(InputType.convolutionalFlat(28, 28, 1)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + DataSet ds = new MnistDataSetIterator(16, true, 12345).next(); + + INDArray out = net.output(ds.getFeatures()); + + INDArray labelTemp = Nd4j.create(out.shape()); + ds.setLabels(labelTemp); + + net.fit(ds); + + MultiLayerNetwork net2 = TestUtils.testModelSerialization(net); + INDArray out2 = net2.output(ds.getFeatures()); + assertEquals(out, out2); + } + + + @Test + public void testInputActivationGradient() { + Nd4j.setDataType(DataType.DOUBLE); + + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .dataType(DataType.DOUBLE) + .seed(12345) + .activation(Activation.TANH) + .list() + .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) + .layer( + new OutputLayer.Builder().nIn(10).nOut(10).lossFunction(LossFunctions.LossFunction.MSE) + .build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray in = Nd4j.rand(1, 10); + INDArray label = Nd4j.rand(1, 10); + + Pair p = net.calculateGradients(in, label, null, null); + + //Quick gradient check: + double eps = 1e-6; + double maxRelError = 1e-5; + for (int i = 0; i < 10; i++) { + double orig = in.getDouble(i); + in.putScalar(i, orig + eps); + double scorePlus = net.score(new DataSet(in, label)); + in.putScalar(i, orig - eps); + double scoreMinus = net.score(new DataSet(in, label)); + in.putScalar(i, orig); + + double expGrad = (scorePlus - scoreMinus) / (2.0 * eps); + double actGrad = p.getSecond().getDouble(i); + + double relError = (Math.abs(expGrad - actGrad)) / (Math.abs(expGrad) + Math.abs(actGrad)); + + String str = i + " - " + relError + " - exp=" + expGrad + ", act=" + actGrad; + assertTrue(relError < maxRelError, str); } + } - @Test - public void testZeroParamNet() throws Exception { + @Test + public void testNeuralNetConfigurationActivationTypes() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new SubsamplingLayer.Builder().kernelSize(2,2).stride(2,2).build()) - .layer(new LossLayer.Builder().activation(Activation.SIGMOID).lossFunction(LossFunctions.LossFunction.MSE).build()) - .setInputType(InputType.convolutionalFlat(28,28,1)) - .build(); + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder() + .list() + .layer(new LSTM.Builder().nOut(6).build()) + .layer(new LSTM.Builder().nOut(7).build()) + .layer(new GlobalPoolingLayer()) + .layer(new OutputLayer.Builder().nOut(8).activation(Activation.SOFTMAX).build()) + .inputType(InputType.recurrent(10)); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); + NeuralNetConfiguration conf = builder.build(); - DataSet ds = new MnistDataSetIterator(16, true, 12345).next(); + List outBuilder = builder.getLayerActivationTypes(); + List outConf = conf.getLayerActivationTypes(InputType.recurrent(10)); - INDArray out = net.output(ds.getFeatures()); + List exp = Arrays.asList( + InputType.recurrent(6), + InputType.recurrent(7), + InputType.feedForward(7), + InputType.feedForward(8) + ); - INDArray labelTemp = Nd4j.create(out.shape()); - ds.setLabels(labelTemp); + assertEquals(exp, outBuilder); + assertEquals(exp, outConf); + } - net.fit(ds); + @Test + public void testMultipleEpochsSimple() { + //Mainly a simple sanity check on the preconditions in the method... + DataSetIterator iter = new IrisDataSetIterator(10, 150); - MultiLayerNetwork net2 = TestUtils.testModelSerialization(net); - INDArray out2 = net2.output(ds.getFeatures()); - assertEquals(out, out2); - } + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .list() + .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + net.fit(iter, 3); - @Test - public void testInputActivationGradient(){ - Nd4j.setDataType(DataType.DOUBLE); + ComputationGraph g = net.toComputationGraph(); + g.fit(iter, 3); + } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .dataType(DataType.DOUBLE) - .seed(12345) - .activation(Activation.TANH) - .list() - .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) - .layer(new OutputLayer.Builder().nIn(10).nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); + @Test + public void testPretrainFitMethods() { - INDArray in = Nd4j.rand(1, 10); - INDArray label = Nd4j.rand(1, 10); + //The fit methods should *not* do layerwise pretraining: - Pair p = net.calculateGradients(in, label, null, null); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() - //Quick gradient check: - double eps = 1e-6; - double maxRelError = 1e-5; - for( int i=0; i<10; i++ ){ - double orig = in.getDouble(i); - in.putScalar(i, orig + eps); - double scorePlus = net.score(new DataSet(in, label)); - in.putScalar(i, orig - eps); - double scoreMinus = net.score(new DataSet(in, label)); - in.putScalar(i, orig); + .list() + .layer(new VariationalAutoencoder.Builder() + .nIn(10).nOut(10).encoderLayerSizes(10).decoderLayerSizes(10).build()) + .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()) - double expGrad = (scorePlus - scoreMinus) / (2.0 * eps); - double actGrad = p.getSecond().getDouble(i); + .build(); - double relError = (Math.abs(expGrad - actGrad)) / (Math.abs(expGrad) + Math.abs(actGrad)); - - String str = i + " - " + relError + " - exp=" + expGrad + ", act=" + actGrad; - assertTrue(relError < maxRelError, str); - } - } - - - @Test - public void testMultiLayerConfigurationActivationTypes(){ - - NeuralNetConfiguration.ListBuilder builder = new NeuralNetConfiguration.Builder() - .list() - .layer(new LSTM.Builder().nOut(6).build()) - .layer(new LSTM.Builder().nOut(7).build()) - .layer(new GlobalPoolingLayer()) - .layer(new OutputLayer.Builder().nOut(8).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.recurrent(10)); - - MultiLayerConfiguration conf = builder.build(); - - List outBuilder = builder.getLayerActivationTypes(); - List outConf = conf.getLayerActivationTypes(InputType.recurrent(10)); - - List exp = Arrays.asList( - InputType.recurrent(6), - InputType.recurrent(7), - InputType.feedForward(7), - InputType.feedForward(8) - ); - - - assertEquals(exp, outBuilder); - assertEquals(exp, outConf); - } - - @Test - public void testMultipleEpochsSimple(){ - //Mainly a simple sanity check on the preconditions in the method... - DataSetIterator iter = new IrisDataSetIterator(10, 150); - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .list() - .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()) - .build(); - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - net.fit(iter, 3); - - ComputationGraph g = net.toComputationGraph(); - g.fit(iter, 3); - } - - - @Test - public void testPretrainFitMethods(){ - - //The fit methods should *not* do layerwise pretraining: - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - - .list() - .layer(new VariationalAutoencoder.Builder() - .nIn(10).nOut(10).encoderLayerSizes(10).decoderLayerSizes(10).build()) - .layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build()) - - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - Set> exp = new HashSet<>(); - exp.add(MultiLayerNetwork.class); - - CheckModelsListener listener = new CheckModelsListener(); - net.setListeners(listener); - - INDArray f = Nd4j.create(1,10); - INDArray l = Nd4j.create(1,10); - DataSet ds = new DataSet(f,l); - MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f,l); - - DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds)); - net.fit(iter); - assertEquals(exp, listener.getModelClasses()); - - net.fit(ds); - assertEquals(exp, listener.getModelClasses()); - - net.fit(f, l); - assertEquals(exp, listener.getModelClasses()); - - net.fit(f, l, null, null); - assertEquals(exp, listener.getModelClasses()); - - net.fit(mds); - assertEquals(exp, listener.getModelClasses()); - - net.fit(new SingletonMultiDataSetIterator(mds)); - assertEquals(exp, listener.getModelClasses()); - } - - @Test - public void testINDArrayConfigCloning(){ - //INDArrays in config should be cloned to avoid threading issues - - int mb = 3; - int b = 4; - int c = 3; - int depth = b * (5 + c); - int w = 6; - int h = 6; - - INDArray bbPrior = Nd4j.rand(b, 2).muliRowVector(Nd4j.create(new double[]{w, h}).castTo(Nd4j.defaultFloatingPointType())); - - - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .l2(0.01) - .list() - .layer(new ConvolutionLayer.Builder().nIn(depth).nOut(depth).kernelSize(1,1).build()) - .layer(new Yolo2OutputLayer.Builder() - .boundingBoxPriors(bbPrior) - .build()) - .build(); - - MultiLayerConfiguration conf2 = conf.clone(); - - INDArray bb1 = ((Yolo2OutputLayer)conf.getConf(1).getLayer()).getBoundingBoxes(); - INDArray bb2 = ((Yolo2OutputLayer)conf2.getConf(1).getLayer()).getBoundingBoxes(); - assertNotSame(bb1, bb2); - - assertEquals(bb1, bb2); - } - - @Data - @EqualsAndHashCode(callSuper = false) - public static class CheckModelsListener extends BaseTrainingListener { - - private Set> modelClasses = new HashSet<>(); - - @Override - public void iterationDone(Model model, int iteration, int epoch) { - modelClasses.add(model.getClass()); - } - } - - - @Test - public void testMLNUpdaterBlocks(){ - //Check that setting learning rate results in correct rearrangement of updater state within updater blocks - //https://github.com/deeplearning4j/deeplearning4j/issues/6809#issuecomment-463892644 - - double lr = 1e-3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .seed(12345) - .weightInit(WeightInit.XAVIER) - .updater(new Adam(lr)) - .list() - .layer(new DenseLayer.Builder().nIn(5).nOut(3).build()) - .layer(new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(2).nOut(1) - .activation(Activation.SIGMOID).build()) - .build(); - - MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.init(); - - INDArray in = Nd4j.rand(1, 5); - INDArray lbl = Nd4j.rand(1,1); - - net.fit(new DataSet(in, lbl)); - - INDArray viewArray = net.getUpdater().getStateViewArray(); - INDArray viewArrayCopy = viewArray.dup(); - //Initially updater view array is set out like: - //[m0w, m0b, m1w, m1b, m2w, m2b][v0w, v0b, v1w, v1b, v2w, v2b] - long soFar = 0; - INDArray m0w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+5*3)).assign(0); //m0w - soFar += 5*3; - INDArray m0b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3)).assign(1); //m0b - soFar += 3; - INDArray m1w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3*2)).assign(2); //m1w - soFar += 3*2; - INDArray m1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(3); //m1b - soFar += 2; - INDArray m2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+ 2)).assign(4); //m2w - soFar += 2; - INDArray m2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(5); //m2b - soFar += 1; - - INDArray v0w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+5*3)).assign(6); //v0w - soFar += 5*3; - INDArray v0b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3)).assign(7); //v0b - soFar += 3; - INDArray v1w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+3*2)).assign(8); //v1w - soFar += 3*2; - INDArray v1b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+2)).assign(9); //v1b - soFar += 2; - INDArray v2w = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+ 2)).assign(10); //v2w - soFar += 2; - INDArray v2b = viewArray.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(soFar, soFar+1)).assign(11); //v2b - soFar += 1; - - - net.setLearningRate(0, 0.0); - - //Expect new updater state to look like: - //[m0w, m0b][v0w,v0b], [m1w, m1b, m2w, m2b][v1w, v1b, v2w, v2b] - INDArray exp = Nd4j.concat(1, m0w, m0b, v0w, v0b, - m1w, m1b, m2w, m2b, v1w, v1b, v2w, v2b); - - INDArray act = net.getUpdater().getStateViewArray(); + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + Set> exp = new HashSet<>(); + exp.add(MultiLayerNetwork.class); + + CheckModelsListener listener = new CheckModelsListener(); + net.setListeners(listener); + + INDArray f = Nd4j.create(1, 10); + INDArray l = Nd4j.create(1, 10); + DataSet ds = new DataSet(f, l); + MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(f, l); + + DataSetIterator iter = new ExistingDataSetIterator(Collections.singletonList(ds)); + net.fit(iter); + assertEquals(exp, listener.getModelClasses()); + + net.fit(ds); + assertEquals(exp, listener.getModelClasses()); + + net.fit(f, l); + assertEquals(exp, listener.getModelClasses()); + + net.fit(f, l, null, null); + assertEquals(exp, listener.getModelClasses()); + + net.fit(mds); + assertEquals(exp, listener.getModelClasses()); + + net.fit(new SingletonMultiDataSetIterator(mds)); + assertEquals(exp, listener.getModelClasses()); + } + + @Test + public void testINDArrayConfigCloning() { + //INDArrays in config should be cloned to avoid threading issues + + int mb = 3; + int b = 4; + int c = 3; + int depth = b * (5 + c); + int w = 6; + int h = 6; + + INDArray bbPrior = Nd4j.rand(b, 2) + .muliRowVector(Nd4j.create(new double[]{w, h}).castTo(Nd4j.defaultFloatingPointType())); + + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .l2(0.01) + .list() + .layer(new ConvolutionLayer.Builder().nIn(depth).nOut(depth).kernelSize(1, 1).build()) + .layer(new Yolo2OutputLayer.Builder() + .boundingBoxPriors(bbPrior) + .build()) + .build(); + + NeuralNetConfiguration conf2 = conf.clone(); + + INDArray bb1 = ((Yolo2OutputLayer) conf.getConf(1).getLayer()).getBoundingBoxes(); + INDArray bb2 = ((Yolo2OutputLayer) conf2.getConf(1).getLayer()).getBoundingBoxes(); + assertNotSame(bb1, bb2); + + assertEquals(bb1, bb2); + } + + @Test + public void testMLNUpdaterBlocks() { + //Check that setting learning rate results in correct rearrangement of updater state within updater blocks + //https://github.com/deeplearning4j/deeplearning4j/issues/6809#issuecomment-463892644 + + double lr = 1e-3; + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .seed(12345) + .weightInit(WeightInit.XAVIER) + .updater(new Adam(lr)) + .list() + .layer(new DenseLayer.Builder().nIn(5).nOut(3).build()) + .layer(new DenseLayer.Builder().nIn(3).nOut(2).build()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(2).nOut(1) + .activation(Activation.SIGMOID).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + INDArray in = Nd4j.rand(1, 5); + INDArray lbl = Nd4j.rand(1, 1); + + net.fit(new DataSet(in, lbl)); + + INDArray viewArray = net.getUpdater().getStateViewArray(); + INDArray viewArrayCopy = viewArray.dup(); + //Initially updater view array is set out like: + //[m0w, m0b, m1w, m1b, m2w, m2b][v0w, v0b, v1w, v1b, v2w, v2b] + long soFar = 0; + INDArray m0w = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 5 * 3)).assign(0); //m0w + soFar += 5 * 3; + INDArray m0b = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 3)).assign(1); //m0b + soFar += 3; + INDArray m1w = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 3 * 2)).assign(2); //m1w + soFar += 3 * 2; + INDArray m1b = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 2)).assign(3); //m1b + soFar += 2; + INDArray m2w = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 2)).assign(4); //m2w + soFar += 2; + INDArray m2b = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 1)).assign(5); //m2b + soFar += 1; + + INDArray v0w = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 5 * 3)).assign(6); //v0w + soFar += 5 * 3; + INDArray v0b = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 3)).assign(7); //v0b + soFar += 3; + INDArray v1w = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 3 * 2)).assign(8); //v1w + soFar += 3 * 2; + INDArray v1b = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 2)).assign(9); //v1b + soFar += 2; + INDArray v2w = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 2)).assign(10); //v2w + soFar += 2; + INDArray v2b = viewArray.get(NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(soFar, soFar + 1)).assign(11); //v2b + soFar += 1; + + net.setLearningRate(0, 0.0); + + //Expect new updater state to look like: + //[m0w, m0b][v0w,v0b], [m1w, m1b, m2w, m2b][v1w, v1b, v2w, v2b] + INDArray exp = Nd4j.concat(1, m0w, m0b, v0w, v0b, + m1w, m1b, m2w, m2b, v1w, v1b, v2w, v2b); + + INDArray act = net.getUpdater().getStateViewArray(); // System.out.println(exp); // System.out.println(act); - assertEquals(exp, act); + assertEquals(exp, act); - //And set layer 1 LR: - net.setLearningRate(1, 0.2); - exp = Nd4j.concat(1, m0w, m0b, v0w, v0b, - m1w, m1b, v1w, v1b, - m2w, m2b, v2w, v2b); - assertEquals(exp, net.getUpdater().getStateViewArray()); + //And set layer 1 LR: + net.setLearningRate(1, 0.2); + exp = Nd4j.concat(1, m0w, m0b, v0w, v0b, + m1w, m1b, v1w, v1b, + m2w, m2b, v2w, v2b); + assertEquals(exp, net.getUpdater().getStateViewArray()); + //Set all back to original LR and check again: + net.setLearningRate(1, lr); + net.setLearningRate(0, lr); - //Set all back to original LR and check again: - net.setLearningRate(1, lr); - net.setLearningRate(0, lr); + exp = Nd4j.concat(1, m0w, m0b, m1w, m1b, m2w, m2b, v0w, v0b, v1w, v1b, v2w, v2b); + assertEquals(exp, net.getUpdater().getStateViewArray()); - exp = Nd4j.concat(1, m0w, m0b, m1w, m1b, m2w, m2b, v0w, v0b, v1w, v1b, v2w, v2b); - assertEquals(exp, net.getUpdater().getStateViewArray()); + //Finally, training sanity check (if things are wrong, we get -ve values in adam V, which causes NaNs) + net.getUpdater().getStateViewArray().assign(viewArrayCopy); + net.setLearningRate(0, 0.0); + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); + net.fit(new DataSet(in, lbl)); + Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + } - //Finally, training sanity check (if things are wrong, we get -ve values in adam V, which causes NaNs) - net.getUpdater().getStateViewArray().assign(viewArrayCopy); - net.setLearningRate(0, 0.0); + @Data + @EqualsAndHashCode(callSuper = false) + public static class CheckModelsListener extends BaseTrainingListener { - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.NAN_PANIC); - net.fit(new DataSet(in, lbl)); - Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); + private Set> modelClasses = new HashSet<>(); + + @Override + public void iterationDone(IModel model, int iteration, int epoch) { + modelClasses.add(model.getClass()); } + } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java index a12bd88f9..1a6175cde 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; @@ -67,8 +68,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int nIn = 8; int nOut = 25; int nHiddenUnits = 17; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder() .nIn(nIn).nOut(nHiddenUnits) @@ -112,7 +113,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int nIn = 8; int nOut = 25; int[] nHiddenUnits = {17, 19, 23}; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(17) .activation(Activation.TANH).build()) .layer(1, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(17).nOut(19) @@ -160,8 +161,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 6; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder() .nIn(5).nOut(7).activation(Activation.TANH) @@ -225,8 +226,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { public void testRnnTimeStepLayers() { for( int layerType=0; layerType<3; layerType++ ) { - org.deeplearning4j.nn.conf.layers.Layer l0; - org.deeplearning4j.nn.conf.layers.Layer l1; + LayerConfiguration l0; + LayerConfiguration l1; String lastActKey; if(layerType == 0){ @@ -262,7 +263,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int timeSeriesLength = 12; //4 layer network: 2 GravesLSTM + DenseLayerConfiguration + RnnOutputLayer. Hence also tests preprocessors. - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).list() .layer(0, l0) .layer(1, l1) .layer(2, new DenseLayer.Builder().nIn(8).nOut(9).activation(Activation.TANH) @@ -349,8 +350,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); int timeSeriesLength = 6; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder() .nIn(5).nOut(7).activation(Activation.TANH) @@ -408,7 +409,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int nIn = 5; int nOut = 4; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .trainingWorkspaceMode(WorkspaceMode.NONE).inferenceWorkspaceMode(WorkspaceMode.NONE) .list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) @@ -427,7 +428,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { .build(); assertEquals(BackpropType.Standard, conf.getBackpropType()); - MultiLayerConfiguration confTBPTT = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration confTBPTT = NeuralNetConfiguration.builder().seed(12345) .trainingWorkspaceMode(WorkspaceMode.NONE).inferenceWorkspaceMode(WorkspaceMode.NONE) .list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) @@ -443,8 +444,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest { .nIn(8).nOut(nOut).activation(Activation.SOFTMAX) .dist(new NormalDistribution(0, 0.5)) .build()) - .backpropType(BackpropType.TruncatedBPTT).tBPTTBackwardLength(timeSeriesLength) - .tBPTTForwardLength(timeSeriesLength).build(); + .backpropType(BackpropType.TruncatedBPTT) + .tbpttBackLength(timeSeriesLength) + .tbpttBackLength(timeSeriesLength).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -456,9 +458,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest { mlnTBPTT.clearTbpttState = false; - assertEquals(BackpropType.TruncatedBPTT, mlnTBPTT.getLayerWiseConfigurations().getBackpropType()); - assertEquals(timeSeriesLength, mlnTBPTT.getLayerWiseConfigurations().getTbpttFwdLength()); - assertEquals(timeSeriesLength, mlnTBPTT.getLayerWiseConfigurations().getTbpttBackLength()); + assertEquals(BackpropType.TruncatedBPTT, mlnTBPTT.getConfiguration().getBackpropType()); + assertEquals(timeSeriesLength, mlnTBPTT.getConfiguration().getTbpttFwdLength()); + assertEquals(timeSeriesLength, mlnTBPTT.getConfiguration().getTbpttBackLength()); INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); @@ -520,8 +522,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int nTimeSlices = 5; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345).list().layer(0, + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345).list().layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH) .dist(new NormalDistribution(0, 0.5)).build()) @@ -602,7 +604,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int nTimeSlices = 20; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH) @@ -618,7 +620,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { .dist(new NormalDistribution(0, 0.5)) .build()) .backpropType(BackpropType.TruncatedBPTT) - .tBPTTBackwardLength(timeSeriesLength).tBPTTForwardLength(timeSeriesLength).build(); + .tbpttBackLength(timeSeriesLength).tbpttFwdLength(timeSeriesLength).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -639,7 +641,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int nIn = 5; int nOut = 4; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) .activation(Activation.TANH) @@ -655,7 +657,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { .dist(new NormalDistribution(0, 0.5)) .build()) .backpropType(BackpropType.TruncatedBPTT) - .tBPTTBackwardLength(tbpttLength).tBPTTForwardLength(tbpttLength).build(); + .tbpttBackLength(tbpttLength).tbpttFwdLength(tbpttLength).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -675,8 +677,8 @@ public class MultiLayerTestRNN extends BaseDL4JTest { @Test public void testRnnTimeStepWithPreprocessor() { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10) @@ -698,7 +700,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { @Test public void testRnnTimeStepWithPreprocessorGraph() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(10).nOut(10) @@ -727,7 +729,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int nIn = 5; int nOut = 4; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).list() .layer(0, new org.deeplearning4j.nn.conf.layers.GravesLSTM.Builder().nIn(nIn).nOut(7) @@ -737,7 +739,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { .layer(2, new RnnOutputLayer.Builder(LossFunction.MSE).nIn(8).nOut(nOut) .activation(Activation.IDENTITY).build()) .backpropType(BackpropType.TruncatedBPTT) - .tBPTTBackwardLength(tbpttLength).tBPTTForwardLength(tbpttLength).build(); + .tbpttBackLength(tbpttLength).tbpttFwdLength(tbpttLength).build(); Nd4j.getRandom().setSeed(12345); MultiLayerNetwork mln = new MultiLayerNetwork(conf); @@ -764,7 +766,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { int nHiddenUnits = 17; try { - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .list() .layer(new org.deeplearning4j.nn.conf.layers.LSTM.Builder().nIn(nIn).nOut(nHiddenUnits).build()) .layer(new GlobalPoolingLayer()) @@ -783,7 +785,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { @Test public void testWrapperLayerGetPreviousState(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new FrozenLayer(new org.deeplearning4j.nn.conf.layers.LSTM.Builder() .nIn(5).nOut(5).build())) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java index c4c3067a9..1cca6ede8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java @@ -66,11 +66,12 @@ public class TestMasking extends BaseDL4JTest { public void checkMaskArrayClearance() { for (boolean tbptt : new boolean[] {true, false}) { //Simple "does it throw an exception" type test... - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).list() .layer(0, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE) .activation(Activation.IDENTITY).nIn(1).nOut(1).build()) .backpropType(tbptt ? BackpropType.TruncatedBPTT : BackpropType.Standard) - .tBPTTForwardLength(8).tBPTTBackwardLength(8).build(); + + .tbpttFwdLength(8).tbpttBackLength(8).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); @@ -149,7 +150,7 @@ public class TestMasking extends BaseDL4JTest { Activation a = act[i]; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new NoOp()) .dist(new NormalDistribution(0, 1)).seed(12345) .list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -195,7 +196,7 @@ public class TestMasking extends BaseDL4JTest { //Do the same for CompGraph - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().updater(new NoOp()) + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder().updater(new NoOp()) .dist(new NormalDistribution(0, 1)).seed(12345) .graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(nIn).nOut(layerSize) @@ -237,7 +238,7 @@ public class TestMasking extends BaseDL4JTest { int nIn = 5; int nOut = 4; - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder().updater(new NoOp()) + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder().updater(new NoOp()) .dist(new NormalDistribution(0, 1)).seed(12345) .graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(nIn).nOut(layerSize).activation(Activation.TANH) @@ -269,7 +270,7 @@ public class TestMasking extends BaseDL4JTest { int cnnStride1 = 1; int channels = 1; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .weightInit(WeightInit.XAVIER) .convolutionMode(ConvolutionMode.Same) @@ -304,7 +305,7 @@ public class TestMasking extends BaseDL4JTest { @Test public void testMaskingStackUnstack(){ - ComputationGraphConfiguration nnConfig = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration nnConfig = NeuralNetConfiguration.builder() .updater(new Adam(2e-2)) .graphBuilder() .setInputTypes( diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java index cb9536e3d..7b75bc97b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java @@ -21,7 +21,6 @@ package org.deeplearning4j.nn.multilayer; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.layers.*; @@ -40,7 +39,7 @@ public class TestSetGetParameters extends BaseDL4JTest { @Test public void testSetParameters() { //Set up a MLN, then do set(get) on parameters. Results should be identical compared to before doing this. - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(9).nOut(10) .dist(new NormalDistribution(0, 1)).build()) .layer(1, new DenseLayer.Builder().nIn(10).nOut(11) @@ -55,12 +54,12 @@ public class TestSetGetParameters extends BaseDL4JTest { net.init(); INDArray initParams = net.params().dup(); - Map initParams2 = net.paramTable(); + Map initParams2 = net.getParamTable(); net.setParams(net.params()); INDArray initParamsAfter = net.params(); - Map initParams2After = net.paramTable(); + Map initParams2After = net.getParamTable(); for (String s : initParams2.keySet()) { assertEquals(initParams2.get(s), initParams2After.get(s), "Params differ: " + s); @@ -79,7 +78,7 @@ public class TestSetGetParameters extends BaseDL4JTest { public void testSetParametersRNN() { //Set up a MLN, then do set(get) on parameters. Results should be identical compared to before doing this. - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new GravesLSTM.Builder().nIn(9).nOut(10) .dist(new NormalDistribution(0, 1)).build()) .layer(1, new GravesLSTM.Builder().nIn(10).nOut(11) @@ -92,12 +91,12 @@ public class TestSetGetParameters extends BaseDL4JTest { net.init(); INDArray initParams = net.params().dup(); - Map initParams2 = net.paramTable(); + Map initParams2 = net.getParamTable(); net.setParams(net.params()); INDArray initParamsAfter = net.params(); - Map initParams2After = net.paramTable(); + Map initParams2After = net.getParamTable(); for (String s : initParams2.keySet()) { assertEquals(initParams2.get(s), initParams2After.get(s), "Params differ: " + s); @@ -118,7 +117,7 @@ public class TestSetGetParameters extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); //Create configuration. Doesn't matter if this doesn't actually work for forward/backward pass here - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).list() .layer(0, new ConvolutionLayer.Builder().nIn(10).nOut(10).kernelSize(2, 2).stride(2, 2) .padding(2, 2).build()) .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()) @@ -145,9 +144,9 @@ public class TestSetGetParameters extends BaseDL4JTest { assertSame(params, net3.params()); //Same object due to clone - Map paramsMap = net.paramTable(); - Map paramsMap2 = net2.paramTable(); - Map paramsMap3 = net3.paramTable(); + Map paramsMap = net.getParamTable(); + Map paramsMap2 = net2.getParamTable(); + Map paramsMap3 = net3.getParamTable(); for (String s : paramsMap.keySet()) { assertEquals(paramsMap.get(s), paramsMap2.get(s)); assertEquals(paramsMap.get(s), paramsMap3.get(s)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java index 5d5daed14..7dc7480c6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.multilayer; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; @@ -48,7 +47,6 @@ import org.nd4j.linalg.learning.config.NoOp; import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; -import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Random; @@ -72,7 +70,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { for (int nExamples : miniBatchSizes) { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.1)).seed(12345).list() .layer(0, new GravesLSTM.Builder().activation(Activation.TANH).nIn(2).nOut(2).build()) @@ -160,7 +158,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { for (int nExamples : miniBatchSizes) { Nd4j.getRandom().setSeed(1234); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.1)).seed(12345).list() .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(2).nOut(2).build()) @@ -170,7 +168,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { .nOut(1).activation(Activation.TANH).build()) .inputPreProcessor(0, new RnnToFeedForwardPreProcessor()) .inputPreProcessor(2, new FeedForwardToRnnPreProcessor()) - .setInputType(InputType.recurrent(2,-1, RNNFormat.NCW)) + .inputType(InputType.recurrent(2,-1, RNNFormat.NCW)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -305,8 +303,8 @@ public class TestVariableLengthTS extends BaseDL4JTest { INDArray input = Nd4j.rand(miniBatch, nIn, tsLength); INDArray labels = Nd4j.ones(miniBatch, nOut, tsLength); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(5) .dist(new NormalDistribution(0, 1)) @@ -368,8 +366,8 @@ public class TestVariableLengthTS extends BaseDL4JTest { INDArray input = Nd4j.rand(miniBatch, nIn, tsLength); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(5) .dist(new NormalDistribution(0, 1)) @@ -384,8 +382,8 @@ public class TestVariableLengthTS extends BaseDL4JTest { MultiLayerNetwork mln = new MultiLayerNetwork(conf); mln.init(); - MultiLayerConfiguration conf2 = - new NeuralNetConfiguration.Builder().seed(12345L).list() + NeuralNetConfiguration conf2 = + NeuralNetConfiguration.builder().seed(12345L).list() .layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(5) .dist(new NormalDistribution(0, 1)) @@ -440,7 +438,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { int layerSize = 3; int nOut = 3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .activation(Activation.TANH).list() .layer(0, new GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize).build()) .layer(1, new GravesBidirectionalLSTM.Builder().nIn(layerSize).nOut(layerSize).build()) @@ -517,7 +515,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { // System.out.println("Starting test: bidirectional = " + bidirectional + ", poolingType = " + pt); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .activation(Activation.TANH).list().layer(0, bidirectional ? new GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize).build() : new GravesLSTM.Builder().nIn(nIn).nOut(layerSize).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java index 92b8375dd..fe80d1e24 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.rl; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -55,11 +54,11 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { for (boolean regularization : new boolean[] {false, true}) { for (IUpdater u : new IUpdater[] {new Sgd(0.1), new Nesterovs(0.1), new Adam(0.1)}) { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345).activation(Activation.TANH) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(12345).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).updater(u) .l1(regularization ? 0.2 : 0.0) - .l2(regularization ? 0.3 : 0.0).list() + .l2(regularization ? 0.3 : 0.0) .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(10).build()) .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(2, new OutputLayer.Builder( @@ -125,8 +124,8 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { net2GradUpd.getUpdater().getStateViewArray()); //Remove the next 2 lines: fails - as net 1 is 1 iteration ahead - net1GradCalc.getLayerWiseConfigurations().setIterationCount(0); - net2GradUpd.getLayerWiseConfigurations().setIterationCount(0); + net1GradCalc.getConfiguration().setIterationCount(0); + net2GradUpd.getConfiguration().setIterationCount(0); for (int i = 0; i < 100; i++) { net1GradCalc.fit(f, l); @@ -148,7 +147,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { for (IUpdater u : new IUpdater[] {new Sgd(0.1), new Adam(0.1)}) { ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345).activation(Activation.TANH) + NeuralNetConfiguration.builder().seed(12345).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).updater(u) .l1(regularization ? 0.2 : 0.0) .l2(regularization ? 0.3 : 0.0).graphBuilder().addInputs("in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java index ecda6b48a..d35b46911 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.transferlearning; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -76,7 +75,7 @@ public class TestFrozenLayers extends BaseDL4JTest { } Map paramsBefore = new LinkedHashMap<>(); - for(Map.Entry entry : transfer.paramTable().entrySet()){ + for(Map.Entry entry : transfer.getParamTable().entrySet()){ paramsBefore.put(entry.getKey(), entry.getValue().dup()); } @@ -86,7 +85,7 @@ public class TestFrozenLayers extends BaseDL4JTest { transfer.fit(f,l); } - for(Map.Entry entry : transfer.paramTable().entrySet()){ + for(Map.Entry entry : transfer.getParamTable().entrySet()){ String s = msg + " - " + entry.getKey(); if(entry.getKey().startsWith("5_")){ //Non-frozen layer @@ -152,7 +151,7 @@ public class TestFrozenLayers extends BaseDL4JTest { } public static MultiLayerNetwork getOriginalNet(int seed){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(seed) .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) @@ -165,7 +164,7 @@ public class TestFrozenLayers extends BaseDL4JTest { .layer(new DenseLayer.Builder().nOut(64).build()) .layer(new DenseLayer.Builder().nIn(64).nOut(64).build()) .layer(new OutputLayer.Builder().nIn(64).nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build()) - .setInputType(InputType.convolutionalFlat(28,28,1)) + .inputType(InputType.convolutionalFlat(28,28,1)) .build(); @@ -175,7 +174,7 @@ public class TestFrozenLayers extends BaseDL4JTest { } public static ComputationGraph getOriginalGraph(int seed){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .seed(seed) .weightInit(WeightInit.XAVIER) .activation(Activation.TANH) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java index 44c3bcb07..b328c8dff 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java @@ -23,12 +23,11 @@ package org.deeplearning4j.nn.transferlearning; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -53,7 +52,7 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest { int nIn = 6; int nOut = 3; - MultiLayerConfiguration origConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration origConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.TANH).dropOut(0.5).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5).build()) .layer(1, new DenseLayer.Builder().nIn(5).nOut(4).build()) @@ -71,9 +70,9 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest { assertTrue(withFrozen.getLayer(0) instanceof FrozenLayer); assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer); - assertTrue(withFrozen.getLayerWiseConfigurations().getConf(0) + assertTrue(withFrozen.getConfiguration().getConf(0) .getLayer() instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); - assertTrue(withFrozen.getLayerWiseConfigurations().getConf(1) + assertTrue(withFrozen.getConfiguration().getConf(1) .getLayer() instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); MultiLayerNetwork restored = TestUtils.testModelSerialization(withFrozen); @@ -102,7 +101,7 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest { int nIn = 6; int nOut = 3; - ComputationGraphConfiguration origConf = new NeuralNetConfiguration.Builder().activation(Activation.TANH).graphBuilder().addInputs("in") + ComputationGraphConfiguration origConf = NeuralNetConfiguration.builder().activation(Activation.TANH).graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(nIn).nOut(5).build(), "in") .addLayer("1", new DenseLayer.Builder().nIn(5).nOut(4).build(), "0") .addLayer("2", new DenseLayer.Builder().nIn(4).nOut(3).build(), "1") @@ -121,8 +120,8 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest { assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer); Map m = withFrozen.getComputationGraphConfiguration().getVertices(); - Layer l0 = ((LayerVertex) m.get("0")).getLayerConf().getLayer(); - Layer l1 = ((LayerVertex) m.get("1")).getLayerConf().getLayer(); + LayerConfiguration l0 = ((LayerVertex) m.get("0")).getNetConfiguration().getFirstLayer(); + LayerConfiguration l1 = ((LayerVertex) m.get("1")).getNetConfiguration().getFirstLayer(); assertTrue(l0 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); assertTrue(l1 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java index efc821b6e..0f75f1426 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java @@ -25,6 +25,7 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint; import org.deeplearning4j.nn.conf.distribution.ConstantDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; @@ -63,7 +64,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { long rng = 12345L; DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); //original conf - ComputationGraphConfiguration confToChange = new NeuralNetConfiguration.Builder().seed(rng) + ComputationGraphConfiguration confToChange = NeuralNetConfiguration.builder().seed(rng) .optimizationAlgo(OptimizationAlgorithm.LBFGS).updater(new Nesterovs(0.01, 0.99)) .graphBuilder().addInputs("layer0In").setInputTypes(InputType.feedForward(4)) .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In") @@ -76,7 +77,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { .setOutputs("layer1").build(); //conf with learning parameters changed - ComputationGraphConfiguration expectedConf = new NeuralNetConfiguration.Builder().seed(rng) + ComputationGraphConfiguration expectedConf = NeuralNetConfiguration.builder().seed(rng) .updater(new RmsProp(0.2)) .graphBuilder().addInputs("layer0In") .setInputTypes(InputType.feedForward(4)) @@ -115,7 +116,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { public void testNoutChanges() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 2)); - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY); FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY).build(); @@ -138,9 +139,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { //.setOutputs("layer3") .build(); - BaseLayer bl0 = ((BaseLayer) modelNow.getLayer("layer0").conf().getLayer()); - BaseLayer bl1 = ((BaseLayer) modelNow.getLayer("layer1").conf().getLayer()); - BaseLayer bl3 = ((BaseLayer) modelNow.getLayer("layer3").conf().getLayer()); + BaseLayer bl0 = ((BaseLayer) modelNow.getLayer("layer0").getLayerConfiguration()); + BaseLayer bl1 = ((BaseLayer) modelNow.getLayer("layer1").getLayerConfiguration()); + BaseLayer bl3 = ((BaseLayer) modelNow.getLayer("layer3").getLayerConfiguration()); assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1))); assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); @@ -182,7 +183,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { public void testRemoveAndAdd() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY); FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)) .activation(Activation.IDENTITY).build(); @@ -250,7 +251,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { DataSet randomData = new DataSet(Nd4j.rand(10, 28 * 28 * 3).reshape(10, 3, 28, 28), Nd4j.rand(10, 10)); ComputationGraph modelToFineTune = new ComputationGraph( - new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration.builder().seed(123) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)).graphBuilder() .addInputs("layer0In") @@ -303,7 +304,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { modelToFineTune.init(); //this will override the learning configuration set in the model - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().seed(456).updater(new Sgd(0.001)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().seed(456).updater(new Sgd(0.001)); FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().seed(456).updater(new Sgd(0.001)) .build(); @@ -399,7 +400,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { @Test public void testTransferGlobalPool() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new Adam(0.1)) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new Adam(0.1)) .weightInit(WeightInit.XAVIER) .graphBuilder().addInputs("in") .addLayer("blstm1",new GravesBidirectionalLSTM.Builder().nIn(10).nOut(10) @@ -425,7 +426,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { .nIn(10).nOut(5).build(), "dense") .build(); - ComputationGraphConfiguration confExpected = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration confExpected = NeuralNetConfiguration.builder().seed(12345) .updater(new Sgd(0.01)) .weightInit(WeightInit.XAVIER) .graphBuilder().addInputs("in") @@ -452,7 +453,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { @Test public void testObjectOverrides(){ //https://github.com/deeplearning4j/deeplearning4j/issues/4368 - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .dropOut(0.5) .weightNoise(new DropConnect(0.5)) .l2(0.5) @@ -477,7 +478,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { .fineTuneConfiguration(ftc) .build(); - DenseLayer l = (DenseLayer) transfer.getLayer(0).conf().getLayer(); + DenseLayer l = (DenseLayer) transfer.getLayer(0).getLayerConfiguration(); assertNull(l.getIDropout()); assertNull(l.getWeightNoise()); @@ -494,7 +495,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { final String firstConv = "firstConv"; final String secondConv = "secondConv"; final INDArray input = Nd4j.create(6,6,6,6); - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + final ComputationGraph graph = new ComputationGraph(NeuralNetConfiguration.builder() .weightInit(new ConstantDistribution(666)) .graphBuilder() .addInputs(inputName) @@ -541,7 +542,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { final String afterPoolName = "afterPool"; final String outputName = "output"; final INDArray input = Nd4j.create(new long[] {1, 2, 4, 4}); - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + final ComputationGraph graph = new ComputationGraph(NeuralNetConfiguration.builder() .graphBuilder() .addInputs(inputName) .setOutputs(outputName) @@ -578,7 +579,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { @Test public void testTransferLearningSameDiffLayersGraph(){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") @@ -624,7 +625,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { @Test public void testTransferLearningSameDiffLayersGraphVertex(){ - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java index d30227339..ba201c62a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseLayer; @@ -55,7 +56,7 @@ public class TransferLearningComplex extends BaseDL4JTest { // (b) Test global override (should be selective) - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(1e-4)) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(new Adam(1e-4)) .activation(Activation.LEAKYRELU).graphBuilder().addInputs("in1", "in2") .addLayer("A", new DenseLayer.Builder().nIn(10).nOut(9).build(), "in1") .addLayer("B", new DenseLayer.Builder().nIn(9).nOut(8).build(), "A") @@ -87,9 +88,9 @@ public class TransferLearningComplex extends BaseDL4JTest { Layer[] layers = graph2.getLayers(); for (Layer l : layers) { - String name = l.conf().getLayer().getLayerName(); + String name = l.getLayerConfiguration().getLayerName(); log.info(name + "\t frozen: " + (l instanceof FrozenLayer)); - if ("C".equals(l.conf().getLayer().getLayerName())) { + if ("C".equals(l.getLayerConfiguration().getLayerName())) { //Only C should be frozen in this config cFound = true; assertTrue(l instanceof FrozenLayer, name); @@ -98,7 +99,7 @@ public class TransferLearningComplex extends BaseDL4JTest { } //Also check config: - BaseLayer bl = ((BaseLayer) l.conf().getLayer()); + BaseLayer bl = ((BaseLayer) l.getLayerConfiguration()); assertEquals(new Adam(2e-2), bl.getIUpdater()); assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn()); } @@ -109,7 +110,7 @@ public class TransferLearningComplex extends BaseDL4JTest { @Test public void testSimplerMergeBackProp() { - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.9)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.9)) .activation(Activation.IDENTITY) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); @@ -191,7 +192,7 @@ public class TransferLearningComplex extends BaseDL4JTest { @Test public void testLessSimpleMergeBackProp() { - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.9)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.9)) .activation(Activation.IDENTITY); /* @@ -248,7 +249,7 @@ public class TransferLearningComplex extends BaseDL4JTest { @Test public void testAddOutput() { - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.9)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.9)) .activation(Activation.IDENTITY); ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java index d7e58be43..f606e6402 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.graph.SubsetVertex; import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -41,6 +42,7 @@ import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.List; +import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -50,7 +52,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest { @Test public void tesUnfrozenSubset() { - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().seed(124) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().seed(124) .activation(Activation.IDENTITY) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)); /* @@ -132,7 +134,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest { @Test public void testFitUnFrozen() { - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.9)).seed(124) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.9)).seed(124) .activation(Activation.IDENTITY) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); @@ -185,11 +187,11 @@ public class TransferLearningHelperTest extends BaseDL4JTest { assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params()); assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params()); assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params()); - assertEquals(modelIdentical.getLayer("denseRight").conf().toJson(), - modelToTune.getLayer("denseRight").conf().toJson()); + assertEquals(modelIdentical.getLayer("denseRight").getNetConfiguration().toJson(), + modelToTune.getLayer("denseRight").getNetConfiguration().toJson()); assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params()); - assertEquals(modelIdentical.getLayer("denseRight0").conf().toJson(), - modelToTune.getLayer("denseRight0").conf().toJson()); + assertEquals(modelIdentical.getLayer("denseRight0").getNetConfiguration().toJson(), + modelToTune.getLayer("denseRight0").getNetConfiguration().toJson()); //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params()); assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params()); assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params()); @@ -206,18 +208,19 @@ public class TransferLearningHelperTest extends BaseDL4JTest { public void testMLN() { DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3)); - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .activation(Activation.IDENTITY); - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) - .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork( + (NeuralNetConfiguration) overallConf.clone() + .layer(0, new Builder().nIn(4).nOut(3).build()) + .layer(1, new Builder().nIn(3).nOut(2).build()) + .layer(2, new Builder().nIn(2).nOut(3).build()) + .layer(3, new OutputLayer.Builder( + LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + .build()) + .build()); modelToFineTune.init(); MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).setFeatureExtractor(1).build(); @@ -228,12 +231,13 @@ public class TransferLearningHelperTest extends BaseDL4JTest { INDArray paramsLastTwoLayers = Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); - MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build(), paramsLastTwoLayers); + MultiLayerNetwork notFrozen = new MultiLayerNetwork( + (NeuralNetConfiguration) overallConf.clone().list() + .layer(0, new Builder().nIn(2).nOut(3).build()) + .layer(1, new OutputLayer.Builder( + LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + .build()) + .build(), paramsLastTwoLayers); assertEquals(asFrozenFeatures, helper.featurize(randomData).getFeatures()); assertEquals(randomData.getLabels(), helper.featurize(randomData).getLabels()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java index 005f2158c..cda7da0b4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java @@ -26,13 +26,13 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint; import org.deeplearning4j.nn.conf.distribution.ConstantDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToCnnPreProcessor; @@ -54,6 +54,7 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; import com.fasterxml.jackson.core.JsonProcessingException; import java.util.Map; +import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import static org.junit.jupiter.api.Assertions.*; @@ -67,16 +68,17 @@ public class TransferLearningMLNTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(rng); DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 4), TestUtils.randomOneHot(DataType.FLOAT, 10, 3)); //original conf - NeuralNetConfiguration.Builder confToChange = - new NeuralNetConfiguration.Builder().seed(rng).optimizationAlgo(OptimizationAlgorithm.LBFGS) + NeuralNetConfiguration.NeuralNetConfigurationBuilder confToChange = + (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder().seed(rng).optimizationAlgo(OptimizationAlgorithm.LBFGS) .updater(new Nesterovs(0.01, 0.99)); - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(confToChange.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) - .build()) - .build()); + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork( + (NeuralNetConfiguration) confToChange.list() + .layer(0, new Builder().nIn(4).nOut(3).build()) + .layer(1, new OutputLayer.Builder( + LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + .build()) + .build()); modelToFineTune.init(); //model after applying changes with transfer learning @@ -89,19 +91,19 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .build(); for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) { - BaseLayer bl = ((BaseLayer) l.conf().getLayer()); + BaseLayer bl = ((BaseLayer) l.getLayerConfiguration()); assertEquals(new RmsProp(0.5), bl.getIUpdater()); } - NeuralNetConfiguration.Builder confSet = new NeuralNetConfiguration.Builder().seed(rng) + NeuralNetConfiguration.NeuralNetConfigurationBuilder confSet = NeuralNetConfiguration.builder().seed(rng) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new RmsProp(0.5)).l2(0.4); - MultiLayerNetwork expectedModel = new MultiLayerNetwork(confSet.list() - .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + MultiLayerNetwork expectedModel = new MultiLayerNetwork((NeuralNetConfiguration) confSet.list() + .layer(0, new Builder().nIn(4).nOut(3).build()) + .layer(1, new OutputLayer.Builder( + LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) .build()) .build()); expectedModel.init(); @@ -110,8 +112,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest { assertEquals(expectedModel.params(), modelNow.params()); //Check json - MultiLayerConfiguration expectedConf = expectedModel.getLayerWiseConfigurations(); - assertEquals(expectedConf.toJson(), modelNow.getLayerWiseConfigurations().toJson()); + NeuralNetConfiguration expectedConf = expectedModel.getConfiguration(); + assertEquals(expectedConf.toJson(), modelNow.getConfiguration().toJson()); //Check params after fit modelNow.fit(randomData); @@ -128,11 +130,11 @@ public class TransferLearningMLNTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 4), TestUtils.randomOneHot(DataType.FLOAT,10, 2)); - NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder equivalentConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)); FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)) .build(); - MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(equivalentConf.list() + MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(equivalentConf .layer(0, new DenseLayer.Builder().nIn(4).nOut(5).build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) @@ -145,7 +147,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .nOutReplace(3, 2, WeightInit.XAVIER, WeightInit.XAVIER) .nOutReplace(0, 3, WeightInit.XAVIER, new NormalDistribution(1, 1e-1)).build(); - MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(equivalentConf.list() + MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(equivalentConf .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()) .layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()) @@ -156,11 +158,11 @@ public class TransferLearningMLNTest extends BaseDL4JTest { modelExpectedArch.init(); //Will fail - expected because of dist and weight init changes - //assertEquals(modelExpectedArch.getLayerWiseConfigurations().toJson(), modelNow.getLayerWiseConfigurations().toJson()); + //assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); - BaseLayer bl0 = ((BaseLayer) modelNow.getLayerWiseConfigurations().getConf(0).getLayer()); - BaseLayer bl1 = ((BaseLayer) modelNow.getLayerWiseConfigurations().getConf(1).getLayer()); - BaseLayer bl3 = ((BaseLayer) modelNow.getLayerWiseConfigurations().getConf(3).getLayer()); + BaseLayer bl0 = ((BaseLayer) modelNow.getConfiguration().getConf(0).getLayer()); + BaseLayer bl1 = ((BaseLayer) modelNow.getConfiguration().getConf(1).getLayer()); + BaseLayer bl3 = ((BaseLayer) modelNow.getConfiguration().getConf(3).getLayer()); assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class); try { assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), @@ -191,7 +193,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT,10, 4), TestUtils.randomOneHot(DataType.FLOAT, 10, 3)); - NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder equivalentConf = (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder().updater(new Sgd(0.1)); FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(//overallConf.list() @@ -248,8 +250,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest { int V_HEIGHT = 130; int V_NFRAMES = 150; - MultiLayerConfiguration confForArchitecture = - new NeuralNetConfiguration.Builder().seed(12345).l2(0.001) //l2 regularization on all layers + NeuralNetConfiguration confForArchitecture = + NeuralNetConfiguration.builder().seed(12345).l2(0.001) //l2 regularization on all layers .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new AdaGrad(0.4)).list() .layer(0, new ConvolutionLayer.Builder(10, 10).nIn(3) //3 channels: RGB @@ -277,13 +279,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .inputPreProcessor(3, new CnnToFeedForwardPreProcessor(7, 7, 10)) .inputPreProcessor(4, new FeedForwardToRnnPreProcessor()) .backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(V_NFRAMES / 5).tBPTTBackwardLength(V_NFRAMES / 5).build(); + .tbpttFwdLength(V_NFRAMES / 5).tbpttBackLength(V_NFRAMES / 5).build(); MultiLayerNetwork modelExpectedArch = new MultiLayerNetwork(confForArchitecture); modelExpectedArch.init(); MultiLayerNetwork modelToTweak = new MultiLayerNetwork( - new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration.builder().seed(12345) .updater(new RmsProp(0.1)) .list() .layer(0, new ConvolutionLayer.Builder(10, 10) //Only keep the first layer the same @@ -324,8 +326,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .inputPreProcessor(4, new FeedForwardToRnnPreProcessor()) .backpropType(BackpropType.TruncatedBPTT) - .tBPTTForwardLength(V_NFRAMES / 5) - .tBPTTBackwardLength(V_NFRAMES / 5).build()); + .tbpttFwdLength(V_NFRAMES / 5) + .tbpttBackLength(V_NFRAMES / 5).build()); modelToTweak.init(); MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToTweak) @@ -355,18 +357,18 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .setInputPreProcessor(4, new FeedForwardToRnnPreProcessor()).build(); //modelNow should have the same architecture as modelExpectedArch - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(0).toJson(), - modelNow.getLayerWiseConfigurations().getConf(0).toJson()); + assertEquals(modelExpectedArch.getConfiguration().getConf(0).toJson(), + modelNow.getConfiguration().getConf(0).toJson()); //some learning related info the subsampling layer will not be overwritten - //assertTrue(modelExpectedArch.getLayerWiseConfigurations().getConf(1).toJson().equals(modelNow.getLayerWiseConfigurations().getConf(1).toJson())); - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(2).toJson(), - modelNow.getLayerWiseConfigurations().getConf(2).toJson()); - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(3).toJson(), - modelNow.getLayerWiseConfigurations().getConf(3).toJson()); - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(4).toJson(), - modelNow.getLayerWiseConfigurations().getConf(4).toJson()); - assertEquals(modelExpectedArch.getLayerWiseConfigurations().getConf(5).toJson(), - modelNow.getLayerWiseConfigurations().getConf(5).toJson()); + //assertTrue(modelExpectedArch.getConfiguration().getConf(1).toJson().equals(modelNow.getConfiguration().getConf(1).toJson())); + assertEquals(modelExpectedArch.getConfiguration().getConf(2).toJson(), + modelNow.getConfiguration().getConf(2).toJson()); + assertEquals(modelExpectedArch.getConfiguration().getConf(3).toJson(), + modelNow.getConfiguration().getConf(3).toJson()); + assertEquals(modelExpectedArch.getConfiguration().getConf(4).toJson(), + modelNow.getConfiguration().getConf(4).toJson()); + assertEquals(modelExpectedArch.getConfiguration().getConf(5).toJson(), + modelNow.getConfiguration().getConf(5).toJson()); assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); @@ -386,7 +388,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT, 10, 28 * 28 * 3).reshape(10, 3, 28, 28), TestUtils.randomOneHot(DataType.FLOAT,10, 10)); MultiLayerNetwork modelToFineTune = new MultiLayerNetwork( - new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration.builder().seed(123) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)) .list() @@ -413,12 +415,12 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .nOut(100) .activation(Activation.SOFTMAX) .build()) - .setInputType(InputType.convolutionalFlat(28, 28, 3)) + .inputType(InputType.convolutionalFlat(28, 28, 3)) .build()); modelToFineTune.init(); INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); //10x20x12x12 - NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.2)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder equivalentConf = (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder().updater(new Sgd(0.2)) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT); FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.2)) @@ -444,7 +446,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .layer(5, new DenseLayer.Builder().activation(Activation.RELU).nOut(50).build()) .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(10) .activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(12, 12, 20)).build()); + .inputType(InputType.convolutionalFlat(12, 12, 20)).build()); notFrozen.init(); assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); @@ -481,8 +483,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest { public void testFineTuneOverride() { //Check that fine-tune overrides are selective - i.e., if I only specify a new LR, only the LR should be modified - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Adam(1e-4)) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new Adam(1e-4)) .activation(Activation.TANH).weightInit(WeightInit.RELU) .l1(0.1).l2(0.2).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(5).build()).layer(1, @@ -501,13 +503,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest { //Check original net isn't modified: - BaseLayer l0 = (BaseLayer) net.getLayer(0).conf().getLayer(); + BaseLayer l0 = (BaseLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(new Adam(1e-4), l0.getIUpdater()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(0.1, TestUtils.getL1(l0), 1e-6); - BaseLayer l1 = (BaseLayer) net.getLayer(1).conf().getLayer(); + BaseLayer l1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); assertEquals(new Adam(1e-4), l1.getIUpdater()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); @@ -516,19 +518,19 @@ public class TransferLearningMLNTest extends BaseDL4JTest { assertEquals(BackpropType.Standard, conf.getBackpropType()); //Check new net has only the appropriate things modified (i.e., LR) - l0 = (BaseLayer) net2.getLayer(0).conf().getLayer(); + l0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration(); assertEquals(new Adam(2e-2), l0.getIUpdater()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(0.1, TestUtils.getL1(l0), 1e-6); - l1 = (BaseLayer) net2.getLayer(1).conf().getLayer(); + l1 = (BaseLayer) net2.getLayer(1).getLayerConfiguration(); assertEquals(new Adam(2e-2), l1.getIUpdater()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(0.2, TestUtils.getL2(l1), 1e-6); - assertEquals(BackpropType.TruncatedBPTT, net2.getLayerWiseConfigurations().getBackpropType()); + assertEquals(BackpropType.TruncatedBPTT, net2.getConfiguration().getBackpropType()); } @Test @@ -538,7 +540,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { DataSet randomData = new DataSet(Nd4j.rand(DataType.FLOAT,10, 28 * 28 * 3).reshape(10, 3, 28, 28), TestUtils.randomOneHot(10, 10)); MultiLayerNetwork modelToFineTune = new MultiLayerNetwork( - new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration.builder().seed(123) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)) .list() @@ -554,12 +556,12 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .layer(5, new DenseLayer.Builder().activation(Activation.RELU).nOut(250).build()) .layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(100).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 3)) //See note below + .inputType(InputType.convolutionalFlat(28, 28, 3)) //See note below .build()); modelToFineTune.init(); INDArray asFrozenFeatures = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false).get(2); //10x20x12x12 - NeuralNetConfiguration.Builder equivalentConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.2)); + NeuralNetConfiguration.NeuralNetConfigurationBuilder equivalentConf = (NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder().updater(new Sgd(0.2)); FineTuneConfiguration overallConf = new FineTuneConfiguration.Builder().updater(new Sgd(0.2)).build(); MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(overallConf) @@ -610,7 +612,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { @Test public void testObjectOverrides(){ //https://github.com/deeplearning4j/deeplearning4j/issues/4368 - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dropOut(0.5) .weightNoise(new DropConnect(0.5)) .l2(0.5) @@ -633,7 +635,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .fineTuneConfiguration(ftc) .build(); - DenseLayer l = (DenseLayer) transfer.getLayer(0).conf().getLayer(); + DenseLayer l = (DenseLayer) transfer.getLayer(0).getLayerConfiguration(); assertNull(l.getIDropout()); assertNull(l.getWeightNoise()); @@ -645,10 +647,10 @@ public class TransferLearningMLNTest extends BaseDL4JTest { @Test public void testTransferLearningSubsequent() { final INDArray input = Nd4j.create(6,6,6,6); - final MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder() + final MultiLayerNetwork net = new MultiLayerNetwork(NeuralNetConfiguration.builder() .weightInit(new ConstantDistribution(666)) .list() - .setInputType(InputType.inferInputTypes(input)[0]) + .inputType(InputType.inferInputTypes(input)[0]) .layer(new Convolution2D.Builder(3, 3).nOut(10).build()) .layer(new Convolution2D.Builder(1, 1).nOut(3).build()) .layer(new OutputLayer.Builder().nOut(2).lossFunction(LossFunctions.LossFunction.MSE) @@ -677,9 +679,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest { @Test public void testChangeNOutNIn() { INDArray input = Nd4j.create(new long[] {1, 2, 4, 4}); - MultiLayerNetwork net = new MultiLayerNetwork(new NeuralNetConfiguration.Builder() + MultiLayerNetwork net = new MultiLayerNetwork( NeuralNetConfiguration.builder() .list() - .setInputType(InputType.inferInputTypes(input)[0]) + .inputType(InputType.inferInputTypes(input)[0]) .layer(new Convolution2D.Builder(1, 1).nOut(10).build()) .layer(new SubsamplingLayer.Builder(1,1).build()) .layer(new Convolution2D.Builder(1, 1).nOut(7).build()) @@ -703,7 +705,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { @Test public void testTransferLearningSameDiffLayers(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .activation(Activation.TANH) .updater(new Adam(0.01)) @@ -714,7 +716,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.MAX).build()) .layer(new OutputLayer.Builder().nOut(2).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) - .setInputType(InputType.recurrent(4)) + .inputType(InputType.recurrent(4)) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); @@ -733,8 +735,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest { net2.setParam("3_W", net.getParam("3_W")); net2.setParam("3_b", net.getParam("3_b")); - Map p1 = net.paramTable(); - Map p2 = net2.paramTable(); + Map p1 = net.getParamTable(); + Map p2 = net2.getParamTable(); for(String s : p1.keySet()){ INDArray i1 = p1.get(s); INDArray i2 = p2.get(s); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java index 02616d66d..63c936b17 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java @@ -46,15 +46,15 @@ public class TestGradientNormalization extends BaseDL4JTest { public void testRenormalizatonPerLayer() { Nd4j.getRandom().setSeed(12345); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new DenseLayer.Builder().nIn(10).nOut(20) .updater(new NoOp()) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)), @@ -92,15 +92,15 @@ public class TestGradientNormalization extends BaseDL4JTest { public void testRenormalizationPerParamType() { Nd4j.getRandom().setSeed(12345); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new DenseLayer.Builder().nIn(10).nOut(20) .updater(new NoOp()) .gradientNormalization(GradientNormalization.RenormalizeL2PerParamType).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(params.shape())); Updater updater = UpdaterCreator.getUpdater(layer); INDArray weightGrad = Nd4j.rand(10, 20); @@ -125,15 +125,15 @@ public class TestGradientNormalization extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); double threshold = 3; - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer( + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer( new DenseLayer.Builder().nIn(10).nOut(20).updater(new NoOp()) .gradientNormalization(GradientNormalization.ClipElementWiseAbsoluteValue) .gradientNormalizationThreshold(threshold).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)), @@ -181,15 +181,15 @@ public class TestGradientNormalization extends BaseDL4JTest { //t=0: small -> no clipping //t=1: large -> clipping - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer( + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer( new DenseLayer.Builder().nIn(10).nOut(20).updater(new NoOp()) .gradientNormalization(GradientNormalization.ClipL2PerLayer) .gradientNormalizationThreshold(threshold).build()) .build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(t == 0 ? 0.05 : 10).subi(t == 0 ? 0 : 5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = @@ -236,15 +236,15 @@ public class TestGradientNormalization extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); double threshold = 3; - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().layer( + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer( new DenseLayer.Builder().nIn(10).nOut(20).updater(new NoOp()) .gradientNormalization(GradientNormalization.ClipL2PerParamType) .gradientNormalizationThreshold(threshold).build()) .build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(params.shape())); Updater updater = UpdaterCreator.getUpdater(layer); INDArray weightGrad = Nd4j.rand(10, 20).muli(0.05); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java index 462143897..cf73bb012 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java @@ -26,7 +26,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; @@ -89,15 +88,15 @@ public class TestUpdaters extends BaseDL4JTest { double rho = 0.85; - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut) .updater(new AdaDelta(rho, Nd4j.EPS_THRESHOLD)) .build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -157,13 +156,13 @@ public class TestUpdaters extends BaseDL4JTest { double epsilon = AdaGrad.DEFAULT_ADAGRAD_EPSILON; NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new AdaGrad(lr)) + NeuralNetConfiguration.builder().updater(new AdaGrad(lr)) .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -201,13 +200,13 @@ public class TestUpdaters extends BaseDL4JTest { double epsilon = Adam.DEFAULT_ADAM_EPSILON; - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Adam(lr, beta1, beta2, Adam.DEFAULT_ADAM_EPSILON)) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Adam(lr, beta1, beta2, Adam.DEFAULT_ADAM_EPSILON)) .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -261,16 +260,16 @@ public class TestUpdaters extends BaseDL4JTest { double epsilon = Nadam.DEFAULT_NADAM_EPSILON; NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration.builder() .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut) .updater(Nadam.builder().learningRate(lr).beta1(beta1) .beta2(beta2).epsilon(epsilon).build()) .build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); @@ -353,14 +352,14 @@ public class TestUpdaters extends BaseDL4JTest { double beta2 = 0.888; double epsilon = AdaMax.DEFAULT_ADAMAX_EPSILON; - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new AdaMax(lr, beta1, beta2, AdaMax.DEFAULT_ADAMAX_EPSILON)) .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -410,13 +409,13 @@ public class TestUpdaters extends BaseDL4JTest { double mu = 0.6; NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Nesterovs(lr, mu)) + NeuralNetConfiguration.builder().updater(new Nesterovs(lr, mu)) .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -457,13 +456,13 @@ public class TestUpdaters extends BaseDL4JTest { NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new RmsProp(lr,rmsDecay, RmsProp.DEFAULT_RMSPROP_EPSILON)) + NeuralNetConfiguration.builder().updater(new RmsProp(lr,rmsDecay, RmsProp.DEFAULT_RMSPROP_EPSILON)) .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); @@ -504,13 +503,13 @@ public class TestUpdaters extends BaseDL4JTest { double lr = 0.05; NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Sgd(lr)) + NeuralNetConfiguration.builder().updater(new Sgd(lr)) .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); @@ -538,13 +537,13 @@ public class TestUpdaters extends BaseDL4JTest { double lr = 0.5; NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration.builder().updater(new NoOp()) .layer(new DenseLayer.Builder().nIn(nIn).nOut(nOut).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); @@ -574,7 +573,7 @@ public class TestUpdaters extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345L); double lr = 0.03; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(5).updater(new Sgd(lr)).build()) .layer(1, new DenseLayer.Builder().nIn(5).nOut(6) .updater(new NoOp()).build()) @@ -675,7 +674,7 @@ public class TestUpdaters extends BaseDL4JTest { int nIn = 4; int nOut = 8; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(lr,0.6)).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Nesterovs(lr,0.6)).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5) .updater(org.deeplearning4j.nn.conf.Updater.SGD).build()) .layer(1, new DenseLayer.Builder().nIn(5).nOut(6) @@ -706,7 +705,7 @@ public class TestUpdaters extends BaseDL4JTest { int nIn = 4; int nOut = 8; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Nesterovs(lr,0.6)).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Nesterovs(lr,0.6)).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(5) .updater(org.deeplearning4j.nn.conf.Updater.SGD).build()) .layer(1, new DenseLayer.Builder().nIn(5).nOut(6) @@ -743,14 +742,14 @@ public class TestUpdaters extends BaseDL4JTest { gradient.setGradientFor(PretrainParamInitializer.VISIBLE_BIAS_KEY, vbiasGradient); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().updater(new Sgd(lr)).seed(42) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new Sgd(lr)).seed(42) .layer(new AutoEncoder.Builder() .lossFunction(LossFunctions.LossFunction.COSINE_PROXIMITY) .activation(Activation.IDENTITY).nIn(nIn).nOut(nOut).build()) .build(); - long numParams = conf.getLayer().initializer().numParams(conf); + long numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - BaseLayer layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); @@ -795,7 +794,7 @@ public class TestUpdaters extends BaseDL4JTest { gradientCopyPreUpdate.setFlattenedGradient(g); params = Nd4j.create(1, numParams); - layer = (BaseLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); updater = UpdaterCreator.getUpdater(layer); assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); @@ -807,7 +806,7 @@ public class TestUpdaters extends BaseDL4JTest { List blocks; if (i == 0) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).name("l0") .updater(new Adam(0.5)).build()) .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).name("l1") @@ -827,7 +826,7 @@ public class TestUpdaters extends BaseDL4JTest { MultiLayerUpdater u = (MultiLayerUpdater) net.getUpdater(); blocks = u.getUpdaterBlocks(); } else { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder().addInputs("in") .addLayer("l0", new DenseLayer.Builder().nIn(10).nOut(10) .updater(new Adam(0.5)).build(), "in") @@ -940,8 +939,8 @@ public class TestUpdaters extends BaseDL4JTest { public void testUpdaterBlockVae() { List blocks; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Adam(0.5)).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new Adam(0.5)).list() .layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(12) .encoderLayerSizes(10, 11).decoderLayerSizes(13, 14).build()) .build(); @@ -981,7 +980,7 @@ public class TestUpdaters extends BaseDL4JTest { public void testDivisionByMinibatch1(){ //No batch norm - should be single INDArray equal to flattened gradient view - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(new DenseLayer.Builder().nIn(10).nOut(10).build()) @@ -1008,7 +1007,7 @@ public class TestUpdaters extends BaseDL4JTest { //With batch norm - should be multiple 'division by minibatch' array segments //i.e., exclude batch norm mean/variance - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new DenseLayer.Builder().nIn(10).nOut(9).build()) .layer(new BatchNormalization.Builder().nOut(9).build()) @@ -1059,7 +1058,7 @@ public class TestUpdaters extends BaseDL4JTest { //With batch norm - should be multiple 'division by minibatch' array segments //i.e., exclude batch norm mean/variance - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new BatchNormalization.Builder().nOut(6).build()) .layer(new ConvolutionLayer.Builder().nIn(6).nOut(5).kernelSize(2,2).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java index 170c6bdc1..e5caf981f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java @@ -21,7 +21,6 @@ package org.deeplearning4j.nn.updater.custom; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -47,7 +46,7 @@ public class TestCustomUpdater extends BaseDL4JTest { double lr = 0.03; Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf1 = NeuralNetConfiguration.builder().seed(12345) .activation(Activation.TANH).updater(new CustomIUpdater(lr)) //Specify custom IUpdater .list().layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder().nIn(10).nOut(10) @@ -55,7 +54,7 @@ public class TestCustomUpdater extends BaseDL4JTest { .build(); Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(12345) .activation(Activation.TANH).updater(new Sgd(lr)).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder() .nIn(10).nOut(10).lossFunction(LossFunctions.LossFunction.MSE).build()) @@ -80,7 +79,7 @@ public class TestCustomUpdater extends BaseDL4JTest { //Second: check JSON String asJson = conf1.toJson(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(asJson); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(asJson); assertEquals(conf1, fromJson); Nd4j.getRandom().setSeed(12345); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java index 8b9b35e4f..b5becc819 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/weights/WeightInitIdentityTest.java @@ -48,7 +48,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest { final String inputName = "input"; final String conv = "conv"; final String output = "output"; - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + final ComputationGraph graph = new ComputationGraph(NeuralNetConfiguration.builder() .graphBuilder() .addInputs(inputName) .setOutputs(output) @@ -76,7 +76,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest { final String inputName = "input"; final String conv = "conv"; final String output = "output"; - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + final ComputationGraph graph = new ComputationGraph(NeuralNetConfiguration.builder() .graphBuilder() .setInputTypes(InputType.inferInputType(input)) .addInputs(inputName) @@ -103,7 +103,7 @@ public class WeightInitIdentityTest extends BaseDL4JTest { final String inputName = "input"; final String conv = "conv"; final String output = "output"; - final ComputationGraph graph = new ComputationGraph(new NeuralNetConfiguration.Builder() + final ComputationGraph graph = new ComputationGraph(NeuralNetConfiguration.builder() .graphBuilder() .setInputTypes(InputType.inferInputType(input)) .addInputs(inputName) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java index 8b73c10ee..692f0f44f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java @@ -24,7 +24,6 @@ import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.layers.OutputLayer; @@ -127,7 +126,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); INDArray currParams = layer.params(); sf.step(currParams, origGradient, step); - layer.setParams(currParams); + layer.setParamsTable(currParams); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); score2 = layer.score(); @@ -157,7 +156,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { INDArray currParams = layer.params(); sf.step(currParams, origGradient, step); - layer.setParams(currParams); + layer.setParamsTable(currParams); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); score2 = layer.score(); @@ -167,16 +166,16 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { private static OutputLayer getIrisLogisticLayerConfig(Activation activationFunction, int maxIterations, LossFunctions.LossFunction lossFunction) { NeuralNetConfiguration conf = - new NeuralNetConfiguration.Builder().seed(12345L).miniBatch(true) + NeuralNetConfiguration.builder().seed(12345L).miniBatch(true) .maxNumLineSearchIterations(maxIterations) .layer(new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(lossFunction) .nIn(4).nOut(3).activation(activationFunction) .weightInit(WeightInit.XAVIER).build()) .build(); - val numParams = conf.getLayer().initializer().numParams(conf); + val numParams = conf.getFirstLayer().initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return (OutputLayer) conf.getLayer().instantiate(conf, null, 0, params, true, params.dataType()); + return (OutputLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); } /////////////////////////////////////////////////////////////////////////// @@ -239,8 +238,8 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { } - private static MultiLayerConfiguration getIrisMultiLayerConfig(Activation activationFunction, OptimizationAlgorithm optimizer) { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(optimizer) + private static NeuralNetConfiguration getIrisMultiLayerConfig(Activation activationFunction, OptimizationAlgorithm optimizer) { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().optimizationAlgo(optimizer) .updater(new Adam(0.01)).seed(12345L).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER) .activation(activationFunction).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index 73e1a7a56..7753fae33 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -21,13 +21,14 @@ package org.deeplearning4j.optimize.solver; import lombok.val; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.*; import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -134,8 +135,8 @@ public class TestOptimizers extends BaseDL4JTest { } } - private static MultiLayerConfiguration getMLPConfigIris(OptimizationAlgorithm oa) { - MultiLayerConfiguration c = new NeuralNetConfiguration.Builder().optimizationAlgo(oa) + private static NeuralNetConfiguration getMLPConfigIris(OptimizationAlgorithm oa) { + NeuralNetConfiguration c = NeuralNetConfiguration.builder().optimizationAlgo(oa) .updater(new AdaGrad(1e-1)).seed(12345L) .list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).weightInit(WeightInit.XAVIER) @@ -206,15 +207,15 @@ public class TestOptimizers extends BaseDL4JTest { System.out.println("---------\n Alg= " + oa + ", nIter= " + numLineSearchIter + ", nDimensions= " + nDimensions); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder().maxNumLineSearchIterations(numLineSearchIter) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter) .updater(new Sgd(1e-2)) .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); - conf.addVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here + conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here Random rng = new DefaultRandom(12345L); org.nd4j.linalg.api.rng.distribution.Distribution dist = new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10); - Model m = new SphereFunctionModel(nDimensions, dist, conf); + IModel m = new SphereFunctionModel(nDimensions, dist, conf); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); double scoreBefore = m.score(); assertTrue(!Double.isNaN(scoreBefore) && !Double.isInfinite(scoreBefore)); @@ -246,7 +247,7 @@ public class TestOptimizers extends BaseDL4JTest { assertTrue( scoreAfter < scoreBefore, "Score did not improve after optimization (b= " + scoreBefore + " ,a= " + scoreAfter + ")"); } - private static ConvexOptimizer getOptimizer(OptimizationAlgorithm oa, NeuralNetConfiguration conf, Model m) { + private static ConvexOptimizer getOptimizer(OptimizationAlgorithm oa, NeuralNetConfiguration conf, IModel m) { switch (oa) { case STOCHASTIC_GRADIENT_DESCENT: return new StochasticGradientDescent(conf, new NegativeDefaultStepFunction(), null, m); @@ -269,12 +270,12 @@ public class TestOptimizers extends BaseDL4JTest { Random rng = new DefaultRandom(12345L); org.nd4j.linalg.api.rng.distribution.Distribution dist = new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10); - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .maxNumLineSearchIterations(maxNumLineSearchIter).updater(new Sgd(0.1)) .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); - conf.addVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here + conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here - Model m = new SphereFunctionModel(100, dist, conf); + IModel m = new SphereFunctionModel(100, dist, conf); if (i == 0) { m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); scores[0] = m.score(); //Before optimization @@ -404,13 +405,13 @@ public class TestOptimizers extends BaseDL4JTest { double[] scores = new double[nOptIter + 1]; for (int i = 0; i <= nOptIter; i++) { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .maxNumLineSearchIterations(maxNumLineSearchIter).miniBatch(false) .updater(new AdaGrad(1e-2)) .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); - conf.addVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here + conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here - Model m = new RastriginFunctionModel(10, conf); + IModel m = new RastriginFunctionModel(10, conf); int nParams = (int)m.numParams(); if (i == 0) { m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); @@ -587,15 +588,15 @@ public class TestOptimizers extends BaseDL4JTest { double[] scores = new double[nOptIter + 1]; for (int i = 0; i <= nOptIter; i++) { - NeuralNetConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .maxNumLineSearchIterations(maxNumLineSearchIter) .updater(new Sgd(1e-1)) .stepFunction(new org.deeplearning4j.nn.conf.stepfunctions.NegativeDefaultStepFunction()) .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()) .build(); - conf.addVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here + conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here - Model m = new RosenbrockFunctionModel(100, conf); + IModel m = new RosenbrockFunctionModel(100, conf); if (i == 0) { m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); scores[0] = m.score(); //Before optimization @@ -768,7 +769,7 @@ public class TestOptimizers extends BaseDL4JTest { * methods here. Classes extending this model for optimizer tests need only implement the score() and * gradient() methods. */ - private static abstract class SimpleOptimizableModel implements Model, Layer { + private static abstract class SimpleOptimizableModel implements IModel, Layer { private static final long serialVersionUID = 4409380971404019303L; protected INDArray parameters; protected INDArray gradientView; @@ -784,6 +785,16 @@ public class TestOptimizers extends BaseDL4JTest { this.conf = conf; } + /** + * Return the configuration of this layer + * + * @return the configuration + */ + @Override + public LayerConfiguration getLayerConfiguration() { + return this.conf.getFirstLayer(); + } + @Override public void addListeners(TrainingListener... listener) { // no-op @@ -791,7 +802,7 @@ public class TestOptimizers extends BaseDL4JTest { @Override public TrainingConfig getConfig() { - return conf.getLayer(); + return conf.getFirstLayer(); } /** @@ -896,12 +907,12 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public NeuralNetConfiguration conf() { + public NeuralNetConfiguration getNetConfiguration() { return conf; } @Override - public void setConf(NeuralNetConfiguration conf) { + public void setLayerConfiguration(NeuralNetConfiguration layerConfiguration) { throw new UnsupportedOperationException(); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java index 4c3760d95..6f422fda1 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java @@ -22,7 +22,6 @@ package org.deeplearning4j.optimizer.listener; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -51,7 +50,7 @@ public class TestCheckpointListener extends BaseDL4JTest { public File tempDir; private static Pair getNetAndData(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX) .lossFunction(LossFunctions.LossFunction.MCXENT).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java index 81786baa7..a1933c247 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java @@ -22,7 +22,6 @@ package org.deeplearning4j.optimizer.listener; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -51,7 +50,7 @@ public class TestFailureListener extends BaseDL4JTest { @Test public void testFailureIter5() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(1e-4)) .list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -73,7 +72,7 @@ public class TestFailureListener extends BaseDL4JTest { @Test public void testFailureRandom_OR(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(1e-4)) .list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) @@ -101,7 +100,7 @@ public class TestFailureListener extends BaseDL4JTest { @Test public void testFailureRandom_AND() throws Exception { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(1e-4)) .list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java index 55b1d39c8..b335d43a6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java @@ -22,14 +22,13 @@ package org.deeplearning4j.optimizer.listener; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.AutoEncoder; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -71,7 +70,7 @@ public class TestListeners extends BaseDL4JTest { public void testSettingListenersUnsupervised() { //Pretrain layers should get copies of the listeners, in addition to the - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new AutoEncoder.Builder().nIn(10).nOut(10).build()) .layer(1, new VariationalAutoencoder.Builder().nIn(10).nOut(10).build()).build(); @@ -95,7 +94,7 @@ public class TestListeners extends BaseDL4JTest { assertTrue(lArr[1] instanceof TestRoutingListener); - ComputationGraphConfiguration gConf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration gConf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("0", new AutoEncoder.Builder().nIn(10).nOut(10).build(), "in") .addLayer("1", new VariationalAutoencoder.Builder().nIn(10).nOut(10).build(), "0") .setOutputs("1").build(); @@ -151,7 +150,7 @@ public class TestListeners extends BaseDL4JTest { } @Override - public void iterationDone(Model model, int iteration, int epoch) {} + public void iterationDone(IModel model, int iteration, int epoch) {} } @@ -172,7 +171,7 @@ public class TestListeners extends BaseDL4JTest { DataSetIterator iter = new IrisDataSetIterator(10, 150); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new OutputLayer.Builder().nIn(4).nOut(3) .activation(Activation.SOFTMAX) @@ -208,7 +207,7 @@ public class TestListeners extends BaseDL4JTest { @Test public void testListenerCalls(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); @@ -284,37 +283,37 @@ public class TestListeners extends BaseDL4JTest { @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { calls.add(new Triple<>(Call.ITER_DONE, iteration, epoch)); } @Override - public void onEpochStart(Model model) { + public void onEpochStart(IModel model) { calls.add(new Triple<>(Call.EPOCH_START, BaseOptimizer.getIterationCount(model), BaseOptimizer.getEpochCount(model))); } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { calls.add(new Triple<>(Call.EPOCH_END, BaseOptimizer.getIterationCount(model), BaseOptimizer.getEpochCount(model))); } @Override - public void onForwardPass(Model model, List activations) { + public void onForwardPass(IModel model, List activations) { calls.add(new Triple<>(Call.ON_FWD, BaseOptimizer.getIterationCount(model), BaseOptimizer.getEpochCount(model))); } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { calls.add(new Triple<>(Call.ON_FWD, BaseOptimizer.getIterationCount(model), BaseOptimizer.getEpochCount(model))); } @Override - public void onGradientCalculation(Model model) { + public void onGradientCalculation(IModel model) { calls.add(new Triple<>(Call.ON_GRAD, BaseOptimizer.getIterationCount(model), BaseOptimizer.getEpochCount(model))); } @Override - public void onBackwardPass(Model model) { + public void onBackwardPass(IModel model) { calls.add(new Triple<>(Call.ON_BWD, BaseOptimizer.getIterationCount(model), BaseOptimizer.getEpochCount(model))); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java index 97a1cb799..114d90887 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java @@ -20,9 +20,8 @@ package org.deeplearning4j.parallelism; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.api.Model; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -52,13 +51,13 @@ public class RandomTests extends BaseDL4JTest { */ @Test public void testModelInitialParamsEquality1() throws Exception { - final List models = new CopyOnWriteArrayList<>(); + final List models = new CopyOnWriteArrayList<>(); for (int i = 0; i < 4; i++) { Thread thread = new Thread(new Runnable() { @Override public void run() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(119) // Training iterations as above + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(119) // Training iterations as above .l2(0.0005) //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) .weightInit(WeightInit.XAVIER) @@ -78,7 +77,7 @@ public class RandomTests extends BaseDL4JTest { .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(10).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below + .inputType(InputType.convolutionalFlat(28, 28, 1)) //See note below .build(); MultiLayerNetwork network = new MultiLayerNetwork(conf); @@ -104,7 +103,7 @@ public class RandomTests extends BaseDL4JTest { public void testRngInitMLN() { Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).activation(Activation.TANH) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).activation(Activation.TANH) .weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(2, @@ -122,7 +121,7 @@ public class RandomTests extends BaseDL4JTest { assertEquals(net1.params(), net2.params()); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json); Nd4j.getRandom().setSeed(987654321); MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java index 02e089090..c52f4943f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java @@ -24,7 +24,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.listener.SystemInfoFilePrintListener; import org.deeplearning4j.core.listener.SystemInfoPrintListener; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -57,7 +56,7 @@ public class TestSystemInfoPrintListener extends BaseDL4JTest { .build(); tmpFile.deleteOnExit(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).build()) .build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java index 686501ff8..316ad2f46 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/MiscRegressionTests.java @@ -23,12 +23,11 @@ package org.deeplearning4j.regressiontest; import org.apache.commons.io.FileUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.DenseLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; import org.junit.jupiter.api.Test; import org.nd4j.common.io.ClassPathResource; @@ -55,7 +54,7 @@ public class MiscRegressionTests extends BaseDL4JTest { assertNotNull(gv); if(gv instanceof LayerVertex){ LayerVertex lv = (LayerVertex)gv; - Layer layer = lv.getLayerConf().getLayer(); + LayerConfiguration layer = lv.getNetConfiguration().getFirstLayer(); if(layer instanceof FrozenLayer) countFrozen++; } @@ -66,13 +65,13 @@ public class MiscRegressionTests extends BaseDL4JTest { @Test public void testFrozenNewFormat(){ - MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration configuration = NeuralNetConfiguration.builder() .list() .layer(0, new FrozenLayer(new DenseLayer.Builder().nIn(10).nOut(10).build())) .build(); String json = configuration.toJson(); - MultiLayerConfiguration fromJson = MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json); assertEquals(configuration, fromJson); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index 022545685..50c177332 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -23,7 +23,7 @@ package org.deeplearning4j.regressiontest; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.*; @@ -65,8 +65,8 @@ public class RegressionTest050 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(2, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertEquals("relu", l0.getActivationFn().toString()); @@ -99,8 +99,8 @@ public class RegressionTest050 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(2, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationLReLU); @@ -138,8 +138,8 @@ public class RegressionTest050 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(3, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(3, conf.getNetConfigurations().size()); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index 87a53e54a..9b0870b0f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -25,7 +25,7 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.graph.LayerVertex; @@ -67,8 +67,8 @@ public class RegressionTest060 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(2, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertEquals("relu", l0.getActivationFn().toString()); @@ -101,8 +101,8 @@ public class RegressionTest060 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(2, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationLReLU); @@ -144,8 +144,8 @@ public class RegressionTest060 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(3, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(3, conf.getNetConfigurations().size()); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); @@ -190,8 +190,8 @@ public class RegressionTest060 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(3, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(3, conf.getNetConfigurations().size()); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); @@ -224,7 +224,7 @@ public class RegressionTest060 extends BaseDL4JTest { ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); assertEquals(3, conf.getVertices().size()); - GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); + GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getNetConfiguration().getFirstLayer(); assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); @@ -232,14 +232,14 @@ public class RegressionTest060 extends BaseDL4JTest { assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = - (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer(); + (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getNetConfiguration().getFirstLayer(); assertEquals("softsign", l1.getActivationFn().toString()); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); - RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer(); + RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getNetConfiguration().getFirstLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertEquals("softmax", l2.getActivationFn().toString()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index 0dc3839bb..e21f75680 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -25,7 +25,7 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.graph.LayerVertex; @@ -68,8 +68,8 @@ public class RegressionTest071 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(2, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertEquals("relu", l0.getActivationFn().toString()); @@ -102,8 +102,8 @@ public class RegressionTest071 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(2, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationLReLU); @@ -145,8 +145,8 @@ public class RegressionTest071 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(3, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(3, conf.getNetConfigurations().size()); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); @@ -191,8 +191,8 @@ public class RegressionTest071 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(3, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(3, conf.getNetConfigurations().size()); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); assertEquals("tanh", l0.getActivationFn().toString()); @@ -224,7 +224,7 @@ public class RegressionTest071 extends BaseDL4JTest { ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); assertEquals(3, conf.getVertices().size()); - GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); + GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getNetConfiguration().getFirstLayer(); assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); @@ -232,14 +232,14 @@ public class RegressionTest071 extends BaseDL4JTest { assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = - (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer(); + (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getNetConfiguration().getFirstLayer(); assertEquals("softsign", l1.getActivationFn().toString()); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); - RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer(); + RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getNetConfiguration().getFirstLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertEquals("softmax", l2.getActivationFn().toString()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index 6460582ba..06af06ff4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -25,7 +25,7 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.graph.LayerVertex; @@ -67,8 +67,8 @@ public class RegressionTest080 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(2, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationReLU); @@ -106,8 +106,8 @@ public class RegressionTest080 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(2, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationLReLU); @@ -155,8 +155,8 @@ public class RegressionTest080 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(3, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(3, conf.getNetConfigurations().size()); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationTanH); @@ -206,8 +206,8 @@ public class RegressionTest080 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - assertEquals(3, conf.getConfs().size()); + NeuralNetConfiguration conf = net.getConfiguration(); + assertEquals(3, conf.getNetConfigurations().size()); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); assertTrue(l0.getActivationFn() instanceof ActivationTanH); @@ -240,7 +240,7 @@ public class RegressionTest080 extends BaseDL4JTest { ComputationGraphConfiguration conf = net.getComputationGraphConfiguration(); assertEquals(3, conf.getVertices().size()); - GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getLayerConf().getLayer(); + GravesLSTM l0 = (GravesLSTM) ((LayerVertex) conf.getVertices().get("0")).getNetConfiguration().getFirstLayer(); assertTrue(l0.getActivationFn() instanceof ActivationTanH); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); @@ -248,14 +248,14 @@ public class RegressionTest080 extends BaseDL4JTest { assertEquals(1.5, l0.getGradientNormalizationThreshold(), 1e-5); GravesBidirectionalLSTM l1 = - (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getLayerConf().getLayer(); + (GravesBidirectionalLSTM) ((LayerVertex) conf.getVertices().get("1")).getNetConfiguration().getFirstLayer(); assertTrue(l1.getActivationFn() instanceof ActivationSoftSign); assertEquals(4, l1.getNIn()); assertEquals(4, l1.getNOut()); assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, l1.getGradientNormalization()); assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); - RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getLayerConf().getLayer(); + RnnOutputLayer l2 = (RnnOutputLayer) ((LayerVertex) conf.getVertices().get("2")).getNetConfiguration().getFirstLayer(); assertEquals(4, l2.getNIn()); assertEquals(5, l2.getNOut()); assertTrue(l2.getActivationFn() instanceof ActivationSoftmax); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index f294e16a7..a847a85ef 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -86,30 +86,30 @@ public class RegressionTest100a extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100a/GravesLSTMCharModelingExample_100a.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - GravesLSTM l0 = (GravesLSTM) net.getLayer(0).conf().getLayer(); + GravesLSTM l0 = (GravesLSTM) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(200, l0.getNOut()); assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new RmsProp(0.1), l0.getIUpdater()); - GravesLSTM l1 = (GravesLSTM) net.getLayer(1).conf().getLayer(); + GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(200, l1.getNOut()); assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new RmsProp(0.1), l1.getIUpdater()); - RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer(); + RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(77, l2.getNOut()); assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new RmsProp(0.1), l0.getIUpdater()); - assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType()); - assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength()); - assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength()); + assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); + assertEquals(50, net.getConfiguration().getTbpttBackLength()); + assertEquals(50, net.getConfiguration().getTbpttFwdLength()); INDArray outExp; File f2 = Resources.asFile("regression_testing/100a/GravesLSTMCharModelingExample_Output_100a.bin"); @@ -134,7 +134,7 @@ public class RegressionTest100a extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100a/VaeMNISTAnomaly_100a.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer(); + VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationLReLU(), l0.getActivationFn()); assertEquals(32, l0.getNOut()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); @@ -171,7 +171,7 @@ public class RegressionTest100a extends BaseDL4JTest { int nBoxes = 5; int nClasses = 10; - ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getComputationGraphConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer(); + ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getComputationGraphConfiguration().getVertices().get("convolution2d_9")).getNetConfiguration().getFirstLayer(); assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); @@ -195,8 +195,8 @@ public class RegressionTest100a extends BaseDL4JTest { //Which means: the record output doesn't have this. To account for this, we'll manually set eps to 0.0 here //https://github.com/deeplearning4j/deeplearning4j/issues/5836#issuecomment-405526228 for(Layer l : net.getLayers()){ - if(l.conf().getLayer() instanceof BatchNormalization){ - BatchNormalization bn = (BatchNormalization) l.conf().getLayer(); + if(l.getLayerConfiguration() instanceof BatchNormalization){ + BatchNormalization bn = (BatchNormalization) l.getLayerConfiguration(); bn.setEps(0.0); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index 35fb7391b..23ae5d5bd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -72,12 +72,12 @@ public class RegressionTest100b3 extends BaseDL4JTest { MultiLayerNetwork net = MultiLayerNetwork.load(f, true); // net = net.clone(); - DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer(); + DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new WeightDecay(0.03, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new RmsProp(0.95), l0.getIUpdater()); - CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer(); + CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); assertEquals(new RmsProp(0.95), l1.getIUpdater()); @@ -108,7 +108,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { List activations = net.feedForward(in); - assertEquals(dt, net.getLayerWiseConfigurations().getDataType()); + assertEquals(dt, net.getConfiguration().getDataType()); assertEquals(dt, net.params().dataType()); assertEquals( outExp, outAct, dtype); } @@ -121,30 +121,30 @@ public class RegressionTest100b3 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b3/GravesLSTMCharModelingExample_100b3.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - LSTM l0 = (LSTM) net.getLayer(0).conf().getLayer(); + LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(200, l0.getNOut()); assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); - LSTM l1 = (LSTM) net.getLayer(1).conf().getLayer(); + LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(200, l1.getNOut()); assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new Adam(0.005), l1.getIUpdater()); - RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer(); + RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(77, l2.getNOut()); assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); - assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType()); - assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength()); - assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength()); + assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); + assertEquals(50, net.getConfiguration().getTbpttBackLength()); + assertEquals(50, net.getConfiguration().getTbpttFwdLength()); INDArray outExp; File f2 = Resources.asFile("regression_testing/100b3/GravesLSTMCharModelingExample_Output_100b3.bin"); @@ -169,7 +169,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b3/VaeMNISTAnomaly_100b3.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer(); + VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationLReLU(), l0.getActivationFn()); assertEquals(32, l0.getNOut()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); @@ -206,7 +206,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { int nBoxes = 5; int nClasses = 10; - ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getComputationGraphConfiguration().getVertices().get("convolution2d_9")).getLayerConf().getLayer(); + ConvolutionLayer cl = (ConvolutionLayer)((LayerVertex)net.getComputationGraphConfiguration().getVertices().get("convolution2d_9")).getNetConfiguration().getFirstLayer(); assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index 00e46bf0c..fbbe55592 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -91,12 +91,12 @@ public class RegressionTest100b4 extends BaseDL4JTest { MultiLayerNetwork net = MultiLayerNetwork.load(f, true); // net = net.clone(); - DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer(); + DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0)); assertEquals(new RmsProp(0.95), l0.getIUpdater()); - CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer(); + CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); assertEquals(new RmsProp(0.95), l1.getIUpdater()); @@ -125,7 +125,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { INDArray outAct = net.output(in); assertEquals(dtype, outAct.dataType()); - assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); + assertEquals(dtype, net.getConfiguration().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct); @@ -139,30 +139,30 @@ public class RegressionTest100b4 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_100b4.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - LSTM l0 = (LSTM) net.getLayer(0).conf().getLayer(); + LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(200, l0.getNOut()); assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); - LSTM l1 = (LSTM) net.getLayer(1).conf().getLayer(); + LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(200, l1.getNOut()); assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new Adam(0.005), l1.getIUpdater()); - RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer(); + RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(77, l2.getNOut()); assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new Adam(0.005), l2.getIUpdater()); - assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType()); - assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength()); - assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength()); + assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); + assertEquals(50, net.getConfiguration().getTbpttBackLength()); + assertEquals(50, net.getConfiguration().getTbpttFwdLength()); INDArray outExp; File f2 = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_Output_100b4.bin"); @@ -187,7 +187,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b4/VaeMNISTAnomaly_100b4.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer(); + VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationLReLU(), l0.getActivationFn()); assertEquals(32, l0.getNOut()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); @@ -225,7 +225,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { int nClasses = 10; ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getComputationGraphConfiguration().getVertices() - .get("convolution2d_9")).getLayerConf().getLayer(); + .get("convolution2d_9")).getNetConfiguration().getFirstLayer(); assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); @@ -257,7 +257,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b4/SyntheticCNN_100b4.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).conf().getLayer(); + ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationReLU(), l0.getActivationFn()); assertEquals(4, l0.getNOut()); assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); @@ -268,7 +268,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertArrayEquals(new int[]{1, 1}, l0.getDilation()); assertArrayEquals(new int[]{0, 0}, l0.getPadding()); - SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).conf().getLayer(); + SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationReLU(), l1.getActivationFn()); assertEquals(8, l1.getNOut()); assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); @@ -281,20 +281,20 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(ConvolutionMode.Same, l1.getConvolutionMode()); assertEquals(1, l1.getDepthMultiplier()); - SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).conf().getLayer(); + SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).getLayerConfiguration(); assertArrayEquals(new int[]{3, 3}, l2.getKernelSize()); assertArrayEquals(new int[]{2, 2}, l2.getStride()); assertArrayEquals(new int[]{1, 1}, l2.getDilation()); assertArrayEquals(new int[]{0, 0}, l2.getPadding()); assertEquals(PoolingType.MAX, l2.getPoolingType()); - ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).conf().getLayer(); + ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).getLayerConfiguration(); assertArrayEquals(new int[]{4, 4, 4, 4}, l3.getPadding()); - Upsampling2D l4 = (Upsampling2D) net.getLayer(4).conf().getLayer(); + Upsampling2D l4 = (Upsampling2D) net.getLayer(4).getLayerConfiguration(); assertArrayEquals(new int[]{3, 3}, l4.getSize()); - DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).conf().getLayer(); + DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration(); assertEquals(new ActivationReLU(), l5.getActivationFn()); assertEquals(16, l5.getNOut()); assertEquals(new WeightInitXavier(), l5.getWeightInitFn()); @@ -306,17 +306,17 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertArrayEquals(new int[]{0, 0}, l5.getPadding()); assertEquals(2, l5.getDepthMultiplier()); - SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).conf().getLayer(); + SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).getLayerConfiguration(); assertArrayEquals(new int[]{2, 2}, l6.getKernelSize()); assertArrayEquals(new int[]{2, 2}, l6.getStride()); assertArrayEquals(new int[]{1, 1}, l6.getDilation()); assertArrayEquals(new int[]{0, 0}, l6.getPadding()); assertEquals(PoolingType.MAX, l6.getPoolingType()); - Cropping2D l7 = (Cropping2D) net.getLayer(7).conf().getLayer(); + Cropping2D l7 = (Cropping2D) net.getLayer(7).getLayerConfiguration(); assertArrayEquals(new int[]{3, 3, 2, 2}, l7.getCropping()); - ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).conf().getLayer(); + ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration(); assertEquals(4, l8.getNOut()); assertEquals(new WeightInitXavier(), l8.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); @@ -326,7 +326,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertArrayEquals(new int[]{1, 1}, l8.getDilation()); assertArrayEquals(new int[]{0, 0}, l8.getPadding()); - CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).conf().getLayer(); + CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration(); assertEquals(new WeightInitXavier(), l9.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); assertEquals(new Adam(0.005), l9.getIUpdater()); @@ -361,7 +361,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b4/SyntheticBidirectionalRNNGraph_100b4.bin"); ComputationGraph net = ComputationGraph.load(f, true); - Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").conf().getLayer(); + Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").getLayerConfiguration(); LSTM l1 = (LSTM) l0.getFwd(); assertEquals(16, l1.getNOut()); @@ -373,7 +373,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(new ActivationReLU(), l2.getActivationFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); - Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").conf().getLayer(); + Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").getLayerConfiguration(); SimpleRnn l4 = (SimpleRnn) l3.getFwd(); assertEquals(16, l4.getNOut()); @@ -387,12 +387,12 @@ public class RegressionTest100b4 extends BaseDL4JTest { MergeVertex mv = (MergeVertex) net.getVertex("concat"); - GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").conf().getLayer(); + GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").getLayerConfiguration(); assertEquals(PoolingType.MAX, gpl.getPoolingType()); assertArrayEquals(new int[]{2}, gpl.getPoolingDimensions()); assertTrue(gpl.isCollapseDimensions()); - OutputLayer outl = (OutputLayer) net.getLayer("out").conf().getLayer(); + OutputLayer outl = (OutputLayer) net.getLayer("out").getLayerConfiguration(); assertEquals(3, outl.getNOut()); assertEquals(new LossMCXENT(), outl.getLossFn()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index 15a9c2bc3..979518196 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -73,12 +73,12 @@ public class RegressionTest100b6 extends BaseDL4JTest { MultiLayerNetwork net = MultiLayerNetwork.load(f, true); // net = net.clone(); - DenseLayer l0 = (DenseLayer) net.getLayer(0).conf().getLayer(); + DenseLayer l0 = (DenseLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(new L2Regularization(0.03), TestUtils.getL2Reg(l0)); assertEquals(new RmsProp(0.95), l0.getIUpdater()); - CustomLayer l1 = (CustomLayer) net.getLayer(1).conf().getLayer(); + CustomLayer l1 = (CustomLayer) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(new ActivationSigmoid(), l1.getSecondActivationFunction()); assertEquals(new RmsProp(0.95), l1.getIUpdater()); @@ -107,7 +107,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { INDArray outAct = net.output(in); assertEquals(dtype, outAct.dataType()); - assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); + assertEquals(dtype, net.getConfiguration().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct); @@ -121,30 +121,30 @@ public class RegressionTest100b6 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_100b6.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - LSTM l0 = (LSTM) net.getLayer(0).conf().getLayer(); + LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(200, l0.getNOut()); assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); - LSTM l1 = (LSTM) net.getLayer(1).conf().getLayer(); + LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(200, l1.getNOut()); assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new Adam(0.005), l1.getIUpdater()); - RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).conf().getLayer(); + RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(77, l2.getNOut()); assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new Adam(0.005), l2.getIUpdater()); - assertEquals(BackpropType.TruncatedBPTT, net.getLayerWiseConfigurations().getBackpropType()); - assertEquals(50, net.getLayerWiseConfigurations().getTbpttBackLength()); - assertEquals(50, net.getLayerWiseConfigurations().getTbpttFwdLength()); + assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); + assertEquals(50, net.getConfiguration().getTbpttBackLength()); + assertEquals(50, net.getConfiguration().getTbpttFwdLength()); INDArray outExp; File f2 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Output_100b6.bin"); @@ -169,7 +169,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b6/VaeMNISTAnomaly_100b6.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).conf().getLayer(); + VariationalAutoencoder l0 = (VariationalAutoencoder) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationLReLU(), l0.getActivationFn()); assertEquals(32, l0.getNOut()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); @@ -206,7 +206,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { int nClasses = 10; ConvolutionLayer cl = (ConvolutionLayer) ((LayerVertex) net.getComputationGraphConfiguration().getVertices() - .get("convolution2d_9")).getLayerConf().getLayer(); + .get("convolution2d_9")).getNetConfiguration().getFirstLayer(); assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); @@ -237,7 +237,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b6/SyntheticCNN_100b6.bin"); MultiLayerNetwork net = MultiLayerNetwork.load(f, true); - ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).conf().getLayer(); + ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationReLU(), l0.getActivationFn()); assertEquals(4, l0.getNOut()); assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); @@ -248,7 +248,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertArrayEquals(new int[]{1, 1}, l0.getDilation()); assertArrayEquals(new int[]{0, 0}, l0.getPadding()); - SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).conf().getLayer(); + SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationReLU(), l1.getActivationFn()); assertEquals(8, l1.getNOut()); assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); @@ -261,20 +261,20 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(ConvolutionMode.Same, l1.getConvolutionMode()); assertEquals(1, l1.getDepthMultiplier()); - SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).conf().getLayer(); + SubsamplingLayer l2 = (SubsamplingLayer) net.getLayer(2).getLayerConfiguration(); assertArrayEquals(new int[]{3, 3}, l2.getKernelSize()); assertArrayEquals(new int[]{2, 2}, l2.getStride()); assertArrayEquals(new int[]{1, 1}, l2.getDilation()); assertArrayEquals(new int[]{0, 0}, l2.getPadding()); assertEquals(PoolingType.MAX, l2.getPoolingType()); - ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).conf().getLayer(); + ZeroPaddingLayer l3 = (ZeroPaddingLayer) net.getLayer(3).getLayerConfiguration(); assertArrayEquals(new int[]{4, 4, 4, 4}, l3.getPadding()); - Upsampling2D l4 = (Upsampling2D) net.getLayer(4).conf().getLayer(); + Upsampling2D l4 = (Upsampling2D) net.getLayer(4).getLayerConfiguration(); assertArrayEquals(new int[]{3, 3}, l4.getSize()); - DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).conf().getLayer(); + DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration(); assertEquals(new ActivationReLU(), l5.getActivationFn()); assertEquals(16, l5.getNOut()); assertEquals(new WeightInitXavier(), l5.getWeightInitFn()); @@ -286,17 +286,17 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertArrayEquals(new int[]{0, 0}, l5.getPadding()); assertEquals(2, l5.getDepthMultiplier()); - SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).conf().getLayer(); + SubsamplingLayer l6 = (SubsamplingLayer) net.getLayer(6).getLayerConfiguration(); assertArrayEquals(new int[]{2, 2}, l6.getKernelSize()); assertArrayEquals(new int[]{2, 2}, l6.getStride()); assertArrayEquals(new int[]{1, 1}, l6.getDilation()); assertArrayEquals(new int[]{0, 0}, l6.getPadding()); assertEquals(PoolingType.MAX, l6.getPoolingType()); - Cropping2D l7 = (Cropping2D) net.getLayer(7).conf().getLayer(); + Cropping2D l7 = (Cropping2D) net.getLayer(7).getLayerConfiguration(); assertArrayEquals(new int[]{3, 3, 2, 2}, l7.getCropping()); - ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).conf().getLayer(); + ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration(); assertEquals(4, l8.getNOut()); assertEquals(new WeightInitXavier(), l8.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); @@ -306,7 +306,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertArrayEquals(new int[]{1, 1}, l8.getDilation()); assertArrayEquals(new int[]{0, 0}, l8.getPadding()); - CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).conf().getLayer(); + CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration(); assertEquals(new WeightInitXavier(), l9.getWeightInitFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); assertEquals(new Adam(0.005), l9.getIUpdater()); @@ -341,7 +341,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { File f = Resources.asFile("regression_testing/100b6/SyntheticBidirectionalRNNGraph_100b6.bin"); ComputationGraph net = ComputationGraph.load(f, true); - Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").conf().getLayer(); + Bidirectional l0 = (Bidirectional) net.getLayer("rnn1").getLayerConfiguration(); LSTM l1 = (LSTM) l0.getFwd(); assertEquals(16, l1.getNOut()); @@ -353,7 +353,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(new ActivationReLU(), l2.getActivationFn()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); - Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").conf().getLayer(); + Bidirectional l3 = (Bidirectional) net.getLayer("rnn2").getLayerConfiguration(); SimpleRnn l4 = (SimpleRnn) l3.getFwd(); assertEquals(16, l4.getNOut()); @@ -367,12 +367,12 @@ public class RegressionTest100b6 extends BaseDL4JTest { MergeVertex mv = (MergeVertex) net.getVertex("concat"); - GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").conf().getLayer(); + GlobalPoolingLayer gpl = (GlobalPoolingLayer) net.getLayer("pooling").getLayerConfiguration(); assertEquals(PoolingType.MAX, gpl.getPoolingType()); assertArrayEquals(new int[]{2}, gpl.getPoolingDimensions()); assertTrue(gpl.isCollapseDimensions()); - OutputLayer outl = (OutputLayer) net.getLayer("out").conf().getLayer(); + OutputLayer outl = (OutputLayer) net.getLayer("out").getLayerConfiguration(); assertEquals(3, outl.getNOut()); assertEquals(new LossMCXENT(), outl.getLossFn()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java index acb3963b1..b8b3cdad6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java @@ -85,9 +85,9 @@ public class CustomLayer extends FeedForwardLayer { //Initialize the layer parameters. For example, // Note that the entries in paramTable (2 entries here: a weight array of shape [nIn,nOut] and biases of shape [1,nOut] // are in turn a view of the 'layerParamsView' array. - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); myCustomLayer.setParamTable(paramTable); - myCustomLayer.setConf(conf); + myCustomLayer.setLayerConfiguration(conf); return myCustomLayer; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java index 18c0ab8e0..42b91d908 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java @@ -56,7 +56,7 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); IActivation activation1 = layerConf().getActivationFn(); - IActivation activation2 = ((CustomLayer) conf.getLayer()).getSecondActivationFunction(); + IActivation activation2 = ((CustomLayer) layerConfiguration.getFirstLayer()).getSecondActivationFunction(); //IActivation function instances modify the activation functions in-place activation1.getActivation(firstHalf, training); @@ -105,7 +105,7 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); IActivation activation1 = layerConf().getActivationFn(); - IActivation activation2 = ((CustomLayer) conf.getLayer()).getSecondActivationFunction(); + IActivation activation2 = ((CustomLayer) layerConfiguration.getFirstLayer()).getSecondActivationFunction(); //IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz activation1.backprop(firstHalf, epsilonFirstHalf); @@ -127,7 +127,7 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad); ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad); - INDArray epsilonNext = params.get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose(); + INDArray epsilonNext = paramsTable.get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose(); return new Pair<>(ret, epsilonNext); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index 73610f45e..03b8192f4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -23,7 +23,6 @@ package org.deeplearning4j.samediff; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -151,7 +150,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Create equivalent DL4J net - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).seed(12345) .l1(l1Val).l2(l2Val) @@ -165,7 +164,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(mlc); net.init(); - Map oldParams = net.paramTable(); + Map oldParams = net.getParamTable(); //Assign parameters so we have identical models at the start: w0.getArr().assign(net.getParam("0_W")); @@ -215,7 +214,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Check training with updater - mlc = new NeuralNetConfiguration.Builder() + mlc = NeuralNetConfiguration.builder() .dataType(DataType.DOUBLE) .weightInit(WeightInit.XAVIER).seed(12345) .l1(l1Val).l2(l2Val) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java index 8bfaa9eb2..49a9c7fa1 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java @@ -24,7 +24,6 @@ import org.apache.commons.io.FileUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -80,8 +79,8 @@ public class CrashReportingUtilTest extends BaseDL4JTest { int width = 28; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new NoOp()) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new NoOp()) .dist(new NormalDistribution(0, 1)) .list().layer(0, @@ -99,7 +98,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { .layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) .nOut(10).build()) - .setInputType(InputType.convolutionalFlat(height, width, + .inputType(InputType.convolutionalFlat(height, width, inputDepth)) .build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java index 2ff1c481d..02a1fdaf5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java @@ -20,12 +20,11 @@ package org.deeplearning4j.util; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.compress.utils.IOUtils; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.util.ModelGuesser; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -60,11 +59,11 @@ public class ModelGuesserTest extends BaseDL4JTest { public void testModelGuessFile() throws Exception { File f = Resources.asFile("modelimport/keras/examples/mnist_mlp/mnist_mlp_tf_keras_1_model.h5"); assertTrue(f.exists()); - Model guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); + IModel guess1 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); Assertions.assertNotNull(guess1); f = Resources.asFile("modelimport/keras/examples/mnist_cnn/mnist_cnn_tf_keras_1_model.h5"); assertTrue(f.exists()); - Model guess2 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); + IModel guess2 = ModelGuesser.loadModelGuess(f.getAbsolutePath()); Assertions.assertNotNull(guess2); } @@ -75,7 +74,7 @@ public class ModelGuesserTest extends BaseDL4JTest { assertTrue(f.exists()); try (InputStream inputStream = new FileInputStream(f)) { - Model guess1 = ModelGuesser.loadModelGuess(inputStream); + IModel guess1 = ModelGuesser.loadModelGuess(inputStream); Assertions.assertNotNull(guess1); } @@ -83,7 +82,7 @@ public class ModelGuesserTest extends BaseDL4JTest { assertTrue(f.exists()); try (InputStream inputStream = new FileInputStream(f)) { - Model guess1 = ModelGuesser.loadModelGuess(inputStream); + IModel guess1 = ModelGuesser.loadModelGuess(inputStream); Assertions.assertNotNull(guess1); } } @@ -101,7 +100,7 @@ public class ModelGuesserTest extends BaseDL4JTest { NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); normalizer.fit(new DataSet(Nd4j.rand(2, 2), Nd4j.rand(2, 2))); ModelSerializer.addNormalizerToModel(tempFile, normalizer); - Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); + IModel model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); Normalizer normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath()); assertEquals(model, net); assertEquals(normalizer, normalizer1); @@ -119,7 +118,7 @@ public class ModelGuesserTest extends BaseDL4JTest { normalizer.fit(new DataSet(Nd4j.rand(2, 2), Nd4j.rand(2, 2))); ModelSerializer.writeModel(net, tempFile, true,normalizer); - Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); + IModel model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); Normalizer normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath()); assertEquals(model, net); assertEquals(normalizer, normalizer1); @@ -137,7 +136,7 @@ public class ModelGuesserTest extends BaseDL4JTest { NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1); normalizer.fit(new DataSet(Nd4j.rand(2, 2), Nd4j.rand(2, 2))); ModelSerializer.addNormalizerToModel(tempFile, normalizer); - Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); + IModel model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); try (InputStream inputStream = new FileInputStream(tempFile)) { Normalizer normalizer1 = ModelGuesser.loadNormalizer(inputStream); assertEquals(model, net); @@ -156,7 +155,7 @@ public class ModelGuesserTest extends BaseDL4JTest { ModelSerializer.writeModel(net, tempFile, true); MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); - assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); + assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); @@ -173,7 +172,7 @@ public class ModelGuesserTest extends BaseDL4JTest { try (InputStream inputStream = new FileInputStream(tempFile)) { MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream); Assertions.assertNotNull(network); - assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); + assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -187,7 +186,7 @@ public class ModelGuesserTest extends BaseDL4JTest { File f = getTempFile(resource); String configFilename = f.getAbsolutePath(); Object conf = ModelGuesser.loadConfigGuess(configFilename); - assertTrue(conf instanceof MultiLayerConfiguration); + assertTrue(conf instanceof NeuralNetConfiguration); ClassPathResource sequenceResource = new ClassPathResource("/keras/simple/mlp_fapi_multiloss_config.json"); File f2 = getTempFile(sequenceResource); @@ -212,7 +211,7 @@ public class ModelGuesserTest extends BaseDL4JTest { try (InputStream inputStream = new FileInputStream(f)) { Object conf = ModelGuesser.loadConfigGuess(inputStream); - assertTrue(conf instanceof MultiLayerConfiguration); + assertTrue(conf instanceof NeuralNetConfiguration); } ClassPathResource sequenceResource = new ClassPathResource("/keras/simple/mlp_fapi_multiloss_config.json"); @@ -249,7 +248,7 @@ public class ModelGuesserTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01).l2(0.01) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01).l2(0.01) .updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java index e01d42f01..9f52ae300 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java @@ -26,7 +26,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -65,7 +64,7 @@ public class ModelSerializerTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() @@ -81,7 +80,7 @@ public class ModelSerializerTest extends BaseDL4JTest { MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); - assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); + assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -91,7 +90,7 @@ public class ModelSerializerTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() @@ -125,7 +124,7 @@ public class ModelSerializerTest extends BaseDL4JTest { MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis); - assertEquals(network.getLayerWiseConfigurations().toJson(), net.getLayerWiseConfigurations().toJson()); + assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -133,7 +132,7 @@ public class ModelSerializerTest extends BaseDL4JTest { @Test public void testWriteCGModel() throws Exception { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration config = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", @@ -158,7 +157,7 @@ public class ModelSerializerTest extends BaseDL4JTest { @Test public void testWriteCGModelInputStream() throws Exception { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration config = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", @@ -189,7 +188,7 @@ public class ModelSerializerTest extends BaseDL4JTest { } private ComputationGraph simpleComputationGraph() { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration config = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(new Sgd(0.1)) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", @@ -253,7 +252,7 @@ public class ModelSerializerTest extends BaseDL4JTest { @Test public void testInvalidLoading1() throws Exception { - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration config = NeuralNetConfiguration.builder() .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in") .addLayer("out",new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) @@ -282,7 +281,7 @@ public class ModelSerializerTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) .l2(0.01).updater(new Sgd(0.1)).activation(Activation.TANH).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(20).build()) .layer(1, new DenseLayer.Builder().nIn(20).nOut(30).build()).layer(2, new OutputLayer.Builder() @@ -310,7 +309,7 @@ public class ModelSerializerTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) .list() .layer(new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()) .build(); @@ -357,7 +356,7 @@ public class ModelSerializerTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) .graphBuilder() .addInputs("in") .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in") @@ -406,7 +405,7 @@ public class ModelSerializerTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) .graphBuilder() .addInputs("in") .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).build(), "in") @@ -433,7 +432,7 @@ public class ModelSerializerTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) .list() .layer(0, new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build()) .build(); @@ -458,7 +457,7 @@ public class ModelSerializerTest extends BaseDL4JTest { int nIn = 5; int nOut = 6; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).l1(0.01) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).l1(0.01) .graphBuilder() .addInputs("in") .layer("0", new OutputLayer.Builder().nIn(nIn).nOut(nOut).activation(Activation.SOFTMAX).build(), "in") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java index 9d6a27183..eef3472d2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelValidatorTests.java @@ -23,7 +23,6 @@ package org.deeplearning4j.util; import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -167,7 +166,7 @@ public class ModelValidatorTests extends BaseDL4JTest { assertFalse(vr6.isValid()); s = vr6.getIssues().get(0); assertEquals(1, vr6.getIssues().size()); - assertTrue(s.contains("JSON") && s.contains("valid") && s.contains("MultiLayerConfiguration"), s); + assertTrue(s.contains("JSON") && s.contains("valid") && s.contains("NeuralNetConfiguration"), s); assertEquals("MultiLayerNetwork", vr6.getFormatType()); assertEquals(MultiLayerNetwork.class, vr6.getFormatClass()); assertNotNull(vr6.getException()); @@ -296,7 +295,7 @@ public class ModelValidatorTests extends BaseDL4JTest { public static MultiLayerNetwork getSimpleNet(){ - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345) .updater(new Adam(0.01)) .list() diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java index 601237b53..810dbce85 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasLayer.java @@ -27,13 +27,12 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfigurationFactory; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; -import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasConvolutionUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils; import org.nd4j.common.util.ArrayUtil; @@ -57,7 +56,7 @@ public class KerasLayer { protected DimOrder dimOrder; // Keras layer backend dimension order protected List inboundLayerNames; // List of inbound layers protected List outboundLayerNames; //List of outbound layers - protected Layer layer; // Resulting DL4J layer + protected LayerConfiguration layer; // Resulting DL4J layer protected GraphVertex vertex; // Resulting DL4J vertex protected Map weights; // Weights protected double weightL1Regularization = 0.0; // L1 regularization @@ -302,7 +301,7 @@ public class KerasLayer { */ public void copyWeightsToLayer(org.deeplearning4j.nn.api.Layer layer) throws InvalidKerasConfigurationException { if (this.getNumParams() > 0) { - String dl4jLayerName = layer.conf().getLayer().getLayerName(); + String dl4jLayerName = layer.getLayerConfiguration().getLayerName(); String kerasLayerName = this.getLayerName(); String msg = "Error when attempting to copy weights from Keras layer " + kerasLayerName + " to DL4J layer " + dl4jLayerName; @@ -310,7 +309,7 @@ public class KerasLayer { if (getWeights() == null) throw new InvalidKerasConfigurationException(msg + "(weights is null)"); - Set paramsInLayer = new HashSet<>(layer.paramTable().keySet()); + Set paramsInLayer = new HashSet<>(layer.getParamTable().keySet()); Set paramsInKerasLayer = new HashSet<>(this.weights.keySet()); /* Check for parameters in layer for which we don't have weights. */ @@ -322,7 +321,7 @@ public class KerasLayer { } /* Check for parameters NOT in layer for which we DO have weights. */ - paramsInKerasLayer.removeAll(layer.paramTable().keySet()); + paramsInKerasLayer.removeAll(layer.getParamTable().keySet()); if (!paramsInKerasLayer.isEmpty()) { String joinedParamsInKerasLayer = StringUtils.join(paramsInKerasLayer, ", "); throw new InvalidKerasConfigurationException( @@ -330,9 +329,9 @@ public class KerasLayer { } /* Copy weights. */ - for (String paramName : layer.paramTable().keySet()) { + for (String paramName : layer.getParamTable().keySet()) { try { - long[] dl4jWeights = layer.paramTable().get(paramName).shape(); + long[] dl4jWeights = layer.getParamTable().get(paramName).shape(); long[] kerasWeights = weights.get(paramName).shape(); INDArray variable = this.weights.get(paramName); if(!Arrays.equals(dl4jWeights,kerasWeights) && @@ -348,7 +347,7 @@ public class KerasLayer { log.error(e.getMessage()); throw new InvalidKerasConfigurationException(e.getMessage() + "\nTried to set weights for layer with name " + this.getLayerName() - + ", of " + layer.conf().getLayer().getClass() + ".\n" + + ", of " + layer.getLayerConfiguration().getClass() + ".\n" + "Failed to set weights for parameter " + paramName + "\n" + "Expected shape for this parameter: " + layer.getParam(paramName).shapeInfoToString() + ", \ngot: " + this.weights.get(paramName).shapeInfoToString()); @@ -372,11 +371,11 @@ public class KerasLayer { * @return DL4J ILayer * @see org.deeplearning4j.nn.api.Layer */ - public Layer getLayer() { + public LayerConfiguration getLayer() { return this.layer; } - public void setLayer(Layer layer){ + public void setLayer(LayerConfiguration layer){ this.layer = layer; } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java index ea0b99f0c..4ce518eac 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModel.java @@ -22,11 +22,10 @@ package org.deeplearning4j.nn.modelimport.keras; import lombok.Data; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.collections4.set.ListOrderedSet; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.graph.PreprocessorVertex; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration; @@ -44,13 +43,10 @@ import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils; import org.deeplearning4j.nn.modelimport.keras.utils.KerasOptimizerUtils; import org.deeplearning4j.util.ConvolutionUtils; -import org.nd4j.autodiff.samediff.internal.DependencyList; -import org.nd4j.autodiff.samediff.internal.DependencyTracker; import org.nd4j.common.primitives.Counter; import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.learning.config.IUpdater; import com.google.common.collect.Lists; -import org.tensorflow.framework.NodeDef; import java.io.IOException; import java.util.*; @@ -444,7 +440,7 @@ public class KerasModel { } KerasInput kerasInput = (KerasInput) layer; - Layer layer1 = layersOrdered.get(kerasLayerIdx + 1).layer; + LayerConfiguration layer1 = layersOrdered.get(kerasLayerIdx + 1).layer; //no dim order, try to pull it from the next layer if there is one if(ConvolutionUtils.layerHasConvolutionLayout(layer1)) { CNN2DFormat formatForLayer = ConvolutionUtils.getFormatForLayer(layer1); @@ -491,7 +487,7 @@ public class KerasModel { && !this.className.equals(config.getFieldNameClassFunctional())) throw new InvalidKerasConfigurationException( "Keras model class name " + this.className + " incompatible with ComputationGraph"); - NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder(); + NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder(); if (optimizer != null) { modelBuilder.updater(optimizer); @@ -597,8 +593,8 @@ public class KerasModel { /* Whether to use standard backprop (or BPTT) or truncated BPTT. */ if (this.useTruncatedBPTT && this.truncatedBPTT > 0) - graphBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(truncatedBPTT) - .tBPTTBackwardLength(truncatedBPTT); + graphBuilder.backpropType(BackpropType.TruncatedBPTT).tbpttFwdLength(truncatedBPTT) + .tbpttBackLength(truncatedBPTT); else graphBuilder.backpropType(BackpropType.Standard); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModelImport.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModelImport.java index c9f3d15a0..850cdd7ad 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModelImport.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasModelImport.java @@ -23,7 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; @@ -341,12 +341,12 @@ public class KerasModelImport { * @throws IOException IO exception * @see MultiLayerNetwork */ - public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename, + public static NeuralNetConfiguration importKerasSequentialConfiguration(String modelJsonFilename, boolean enforceTrainingConfig) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { KerasSequentialModel kerasModel = new KerasSequentialModel().modelBuilder().modelJsonFilename(modelJsonFilename) .enforceTrainingConfig(enforceTrainingConfig).buildSequential(); - return kerasModel.getMultiLayerConfiguration(); + return kerasModel.getNeuralNetConfiguration(); } /** @@ -358,11 +358,11 @@ public class KerasModelImport { * @throws IOException IO exception * @see MultiLayerNetwork */ - public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename) + public static NeuralNetConfiguration importKerasSequentialConfiguration(String modelJsonFilename) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { KerasSequentialModel kerasModel = new KerasSequentialModel().modelBuilder().modelJsonFilename(modelJsonFilename) .enforceTrainingConfig(false).buildSequential(); - return kerasModel.getMultiLayerConfiguration(); + return kerasModel.getNeuralNetConfiguration(); } private static File toTempFile(InputStream is) throws IOException { diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java index 696dc3df9..2a99d0c34 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/KerasSequentialModel.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.modelimport.keras; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; @@ -159,11 +158,11 @@ public class KerasSequentialModel extends KerasModel { } /** - * Configure a MultiLayerConfiguration from this Keras Sequential model configuration. + * Configure a NeuralNetConfiguration from this Keras Sequential model configuration. * - * @return MultiLayerConfiguration + * @return NeuralNetConfiguration */ - public MultiLayerConfiguration getMultiLayerConfiguration() + public NeuralNetConfiguration getNeuralNetConfiguration() throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { if (!this.className.equals(config.getFieldClassNameSequential())) throw new InvalidKerasConfigurationException( @@ -175,15 +174,15 @@ public class KerasSequentialModel extends KerasModel { throw new InvalidKerasConfigurationException( "MultiLayerNetwork expects only 1 output (found " + this.outputLayerNames.size() + ")"); - NeuralNetConfiguration.Builder modelBuilder = new NeuralNetConfiguration.Builder(); + NeuralNetConfiguration.NeuralNetConfigurationBuilder modelBuilder = NeuralNetConfiguration.builder(); if (optimizer != null) { modelBuilder.updater(optimizer); } - NeuralNetConfiguration.ListBuilder listBuilder = modelBuilder.list(); - //don't forcibly over ride for keras import - listBuilder.overrideNinUponBuild(false); + + //don't forcibly override for keras import + modelBuilder.overrideNinUponBuild(false); /* Add layers one at a time. */ KerasLayer prevLayer = null; int layerIndex = 0; @@ -192,7 +191,7 @@ public class KerasSequentialModel extends KerasModel { int nbInbound = layer.getInboundLayerNames().size(); if (nbInbound != 1) throw new InvalidKerasConfigurationException( - "Layers in MultiLayerConfiguration must have exactly one inbound layer (found " + "Layers in NeuralNetConfiguration must have exactly one inbound layer (found " + nbInbound + " for layer " + layer.getLayerName() + ")"); if (prevLayer != null) { InputType[] inputTypes = new InputType[1]; @@ -201,39 +200,40 @@ public class KerasSequentialModel extends KerasModel { inputTypes[0] = this.outputTypes.get(prevLayer.getInboundLayerNames().get(0)); preprocessor = prevLayer.getInputPreprocessor(inputTypes); InputType outputType = preprocessor.getOutputType(inputTypes[0]); - layer.getLayer().setNIn(outputType,listBuilder.isOverrideNinUponBuild()); + layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild()); } else { inputTypes[0] = this.outputTypes.get(prevLayer.getLayerName()); preprocessor = layer.getInputPreprocessor(inputTypes); if(preprocessor != null) { InputType outputType = preprocessor.getOutputType(inputTypes[0]); - layer.getLayer().setNIn(outputType,listBuilder.isOverrideNinUponBuild()); + layer.getLayer().setNIn(outputType,modelBuilder.isOverrideNinUponBuild()); } else - layer.getLayer().setNIn(inputTypes[0],listBuilder.isOverrideNinUponBuild()); + layer.getLayer().setNIn(inputTypes[0],modelBuilder.isOverrideNinUponBuild()); } if (preprocessor != null) - listBuilder.inputPreProcessor(layerIndex, preprocessor); + modelBuilder.inputPreProcessor(layerIndex, preprocessor); } - listBuilder.layer(layerIndex++, layer.getLayer()); + modelBuilder.layer(layerIndex++, layer.getLayer()); } else if (layer.getVertex() != null) - throw new InvalidKerasConfigurationException("Cannot add vertex to MultiLayerConfiguration (class name " + throw new InvalidKerasConfigurationException("Cannot add vertex to NeuralNetConfiguration (class name " + layer.getClassName() + ", layer name " + layer.getLayerName() + ")"); prevLayer = layer; } /* Whether to use standard backprop (or BPTT) or truncated BPTT. */ if (this.useTruncatedBPTT && this.truncatedBPTT > 0) - listBuilder.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(truncatedBPTT) - .tBPTTBackwardLength(truncatedBPTT); + modelBuilder.backpropType(BackpropType.TruncatedBPTT) + .tbpttFwdLength(truncatedBPTT) + .tbpttBackLength(truncatedBPTT); else - listBuilder.backpropType(BackpropType.Standard); + modelBuilder.backpropType(BackpropType.Standard); - MultiLayerConfiguration build = listBuilder.build(); + NeuralNetConfiguration build = modelBuilder.build(); return build; @@ -256,7 +256,7 @@ public class KerasSequentialModel extends KerasModel { */ public MultiLayerNetwork getMultiLayerNetwork(boolean importWeights) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException { - MultiLayerNetwork model = new MultiLayerNetwork(getMultiLayerConfiguration()); + MultiLayerNetwork model = new MultiLayerNetwork(getNeuralNetConfiguration()); model.init(); if (importWeights) model = (MultiLayerNetwork) KerasModelUtils.copyWeightsToModel(model, this.layers); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java index 8e30f72f2..11bb40d58 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java @@ -26,9 +26,8 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.modelimport.keras.layers.TFOpLayerImpl; import org.deeplearning4j.nn.params.EmptyParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -41,7 +40,7 @@ import java.util.List; import java.util.Map; -public class TFOpLayer extends Layer { +public class TFOpLayer extends LayerConfiguration { private final Map nodeDef; private final Map constants; @@ -90,7 +89,8 @@ public class TFOpLayer extends Layer { Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, lconf, networkDataType); tfOpLayerImpl.setListeners(trainingListeners); tfOpLayerImpl.setIndex(layerIndex); return tfOpLayerImpl; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java index ba2b98db4..43ce8e985 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayerImpl.java @@ -26,6 +26,7 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -56,7 +57,7 @@ public class TFOpLayerImpl extends AbstractLayer { private List inputNames; TFGraphRunnerService graphRunnerService; - public TFOpLayerImpl(Map nodeDef, Map constants, NeuralNetConfiguration conf, DataType dtype){ + public TFOpLayerImpl(Map nodeDef, Map constants, LayerConfiguration conf, DataType dtype){ super(conf, dtype); this.nodeDef = nodeDef; this.constants = constants; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index e1c6be765..97ceac993 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -216,7 +216,7 @@ public class KerasLSTM extends KerasLayer { * * @return LSTM ILayer */ - public Layer getLSTMLayer() { + public LayerConfiguration getLSTMLayer() { return layer; } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index ea71fc8d7..35a1aed01 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; @@ -186,7 +186,7 @@ public class KerasSimpleRnn extends KerasLayer { * * @return SimpleRnn ILayer */ - public Layer getSimpleRnnLayer() { + public LayerConfiguration getSimpleRnnLayer() { return this.layer; } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java index ccbbbd9d6..3da1a4642 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/wrappers/KerasBidirectional.java @@ -24,10 +24,9 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; import org.deeplearning4j.nn.conf.layers.LSTM; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; -import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; @@ -146,7 +145,7 @@ public class KerasBidirectional extends KerasLayer { break; case "SimpleRNN": kerasRnnlayer = new KerasSimpleRnn(innerRnnConfig, enforceTrainingConfig, previousLayers); - Layer rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer(); + LayerConfiguration rnnLayer = ((KerasSimpleRnn) kerasRnnlayer).getSimpleRnnLayer(); this.layer = new Bidirectional(mode, rnnLayer); layer.setLayerName(layerName); break; @@ -162,7 +161,7 @@ public class KerasBidirectional extends KerasLayer { * * @return ILayer, recurrent layer */ - public Layer getUnderlyingRecurrentLayer() { + public LayerConfiguration getUnderlyingRecurrentLayer() { return kerasRnnlayer.getLayer(); } @@ -240,7 +239,7 @@ public class KerasBidirectional extends KerasLayer { } - private Map getUnderlyingWeights(Layer l, Map weights, String direction) + private Map getUnderlyingWeights(LayerConfiguration l, Map weights, String direction) throws InvalidKerasConfigurationException { int keras1SubstringLength; if (kerasRnnlayer instanceof KerasLSTM) @@ -269,7 +268,7 @@ public class KerasBidirectional extends KerasLayer { weights = newWeights; } - Layer layerBefore = kerasRnnlayer.getLayer(); + LayerConfiguration layerBefore = kerasRnnlayer.getLayer(); kerasRnnlayer.setLayer(l); kerasRnnlayer.setWeights(weights); Map ret = kerasRnnlayer.getWeights(); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java index 536afb915..883ff4dd7 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasLayerUtils.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.utils; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.config.Keras2LayerConfiguration; @@ -34,7 +34,6 @@ import org.deeplearning4j.nn.modelimport.keras.layers.KerasTFOpLayer; import org.deeplearning4j.nn.modelimport.keras.layers.advanced.activations.*; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.*; import org.deeplearning4j.nn.modelimport.keras.layers.core.*; -import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.Keras2DEmbedding; import org.deeplearning4j.nn.modelimport.keras.layers.embeddings.KerasEmbedding; import org.deeplearning4j.nn.modelimport.keras.layers.local.KerasLocallyConnected1D; import org.deeplearning4j.nn.modelimport.keras.layers.noise.KerasAlphaDropout; @@ -48,7 +47,6 @@ import org.deeplearning4j.nn.modelimport.keras.layers.pooling.KerasPooling3D; import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasLSTM; import org.deeplearning4j.nn.modelimport.keras.layers.recurrent.KerasSimpleRnn; import org.deeplearning4j.nn.modelimport.keras.layers.wrappers.KerasBidirectional; -import org.nd4j.common.primitives.Counter; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -154,7 +152,7 @@ public class KerasLayerUtils { * * @param layerConfig map containing Keras layer properties * @return KerasLayer - * @see Layer + * @see LayerConfiguration */ public static KerasLayer getKerasLayerFromConfig(Map layerConfig, KerasLayerConfiguration conf, @@ -174,7 +172,7 @@ public class KerasLayerUtils { * @param layerConfig map containing Keras layer properties * @param enforceTrainingConfig whether to enforce training-only configurations * @return KerasLayer - * @see Layer + * @see LayerConfiguration */ public static KerasLayer getKerasLayerFromConfig(Map layerConfig, boolean enforceTrainingConfig, diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java index 43f3b244f..969626676 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/utils/KerasModelUtils.java @@ -24,7 +24,7 @@ package org.deeplearning4j.nn.modelimport.keras.utils; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -55,7 +55,7 @@ public class KerasModelUtils { * @return DL4J Model interface * @throws InvalidKerasConfigurationException Invalid Keras config */ - public static Model copyWeightsToModel(Model model, Map kerasLayers) + public static IModel copyWeightsToModel(IModel model, Map kerasLayers) throws InvalidKerasConfigurationException { /* Get list if layers from model. */ Layer[] layersFromModel; @@ -67,7 +67,7 @@ public class KerasModelUtils { /* Iterate over layers in model, setting weights when relevant. */ Set layerNames = new HashSet<>(kerasLayers.keySet()); for (org.deeplearning4j.nn.api.Layer layer : layersFromModel) { - String layerName = layer.conf().getLayer().getLayerName(); + String layerName = layer.getLayerConfiguration().getLayerName(); if (!kerasLayers.containsKey(layerName)) throw new InvalidKerasConfigurationException( "No weights found for layer in model (named " + layerName + ")"); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java index f50df5084..db0fc466b 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/FullModelComparisons.java @@ -48,7 +48,6 @@ import org.nd4j.common.resources.Resources; import java.io.File; import java.io.IOException; import java.io.InputStream; -import java.util.Arrays; import java.util.LinkedList; import java.util.List; @@ -88,7 +87,7 @@ public class FullModelComparisons extends BaseDL4JTest { // 1. ILayer LSTM firstLstm = (LSTM) model.getLayer(0); org.deeplearning4j.nn.conf.layers.LSTM firstConf = - (org.deeplearning4j.nn.conf.layers.LSTM) firstLstm.conf().getLayer(); + (org.deeplearning4j.nn.conf.layers.LSTM) firstLstm.getLayerConfiguration(); // "unit_forget_bias": true assertEquals(1.0, firstConf.getForgetGateBiasInit()); @@ -126,7 +125,7 @@ public class FullModelComparisons extends BaseDL4JTest { // 2. ILayer LSTM secondLstm = (LSTM) ((LastTimeStepLayer) model.getLayer(1)).getUnderlying(); org.deeplearning4j.nn.conf.layers.LSTM secondConf = - (org.deeplearning4j.nn.conf.layers.LSTM) secondLstm.conf().getLayer(); + (org.deeplearning4j.nn.conf.layers.LSTM) secondLstm.getLayerConfiguration(); // "unit_forget_bias": true assertEquals(1.0, secondConf.getForgetGateBiasInit()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java index fc48183e2..0a1bcb4a9 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras1ModelConfigurationTest.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasModel; @@ -140,9 +140,9 @@ public class Keras1ModelConfigurationTest extends BaseDL4JTest { private void runSequentialConfigTest(String path, boolean training) throws Exception { try(InputStream is = Resources.asStream(path)) { - MultiLayerConfiguration config = + NeuralNetConfiguration config = new KerasModel().modelBuilder().modelJsonInputStream(is) - .enforceTrainingConfig(training).buildSequential().getMultiLayerConfiguration(); + .enforceTrainingConfig(training).buildSequential().getNeuralNetConfiguration(); MultiLayerNetwork model = new MultiLayerNetwork(config); model.init(); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index 05f6162f3..9bb3e5b4a 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -23,7 +23,7 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; @@ -42,7 +42,6 @@ import org.nd4j.common.resources.Resources; import java.io.File; import java.io.IOException; import java.io.InputStream; -import java.util.Arrays; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertArrayEquals; @@ -260,9 +259,9 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { @Test public void oneLstmLayerTest() throws Exception { try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/one_lstm_no_sequences_tf_keras_2.json")) { - MultiLayerConfiguration config = + NeuralNetConfiguration config = new KerasModel().modelBuilder().modelJsonInputStream(is) - .enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration(); + .enforceTrainingConfig(false).buildSequential().getNeuralNetConfiguration(); MultiLayerNetwork model = new MultiLayerNetwork(config); model.init(); INDArray input = Nd4j.create(DataType.FLOAT, 50, 1500, 500); //NWC format - [Minibatch, seqLength, channels] @@ -287,9 +286,9 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { private void runSequentialConfigTest(String path) throws Exception { try(InputStream is = Resources.asStream(path)) { - MultiLayerConfiguration config = + NeuralNetConfiguration config = new KerasModel().modelBuilder().modelJsonInputStream(is) - .enforceTrainingConfig(false).buildSequential().getMultiLayerConfiguration(); + .enforceTrainingConfig(false).buildSequential().getNeuralNetConfiguration(); MultiLayerNetwork model = new MultiLayerNetwork(config); model.init(); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java index c45b3c52b..20721371b 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java @@ -20,12 +20,13 @@ package org.deeplearning4j.nn.modelimport.keras.configurations; +import java.util.List; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; @@ -34,7 +35,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Timeout; import org.nd4j.common.resources.Resources; -import org.nd4j.linalg.convolution.Convolution; import org.nd4j.linalg.factory.Nd4j; import java.io.IOException; @@ -57,12 +57,12 @@ public class KerasModelImportTest extends BaseDL4JTest { @Test public void testNCHWNWHCChangeImport() { MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5"); - MultiLayerConfiguration multiLayerConfiguration = model.getLayerWiseConfigurations(); - ConvolutionLayer convolutionLayer = (ConvolutionLayer) multiLayerConfiguration.getConf(0).getLayer(); + List layerConfigs = model.getConfiguration().getFlattenedLayerConfigurations(); + ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0); assertEquals(CNN2DFormat.NCHW,convolutionLayer.getCnn2dDataFormat()); - SubsamplingLayer subsamplingLayer = (SubsamplingLayer) multiLayerConfiguration.getConf(1).getLayer(); + SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1); assertEquals(CNN2DFormat.NHWC,subsamplingLayer.getCnn2dDataFormat()); - ConvolutionLayer convolutionLayer1 = (ConvolutionLayer) multiLayerConfiguration.getConf(2).getLayer(); + ConvolutionLayer convolutionLayer1 = (ConvolutionLayer) layerConfigs.get(2); assertEquals(CNN2DFormat.NHWC,convolutionLayer1.getCnn2dDataFormat()); model.output(Nd4j.zeros(1,1,28,28)); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java index 67caf1e3b..f5b7584d3 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasCustomLayerTest.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.modelimport.keras.e2e; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.FileUtils; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -61,7 +62,7 @@ public class KerasCustomLayerTest extends BaseDL4JTest { cachedKerasFile.deleteOnExit(); } - org.deeplearning4j.nn.api.Model importedModel = + IModel importedModel = KerasModelImport.importKerasModelAndWeights(cachedKerasFile.getAbsolutePath()); ModelSerializer.writeModel(importedModel, outputPath, false); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 9b6797c06..2fea0bb82 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -38,7 +39,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.modelimport.keras.Hdf5Archive; import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel; -import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelBuilder; import org.deeplearning4j.nn.modelimport.keras.utils.KerasModelUtils; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -971,7 +971,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { if (net.getOutputLayer() instanceof IOutputLayer) { netToTest = net; } else { - org.deeplearning4j.nn.conf.layers.Layer l; + LayerConfiguration l; if (labels.rank() == 2) { l = new LossLayer.Builder() .lossFunction(LossFunctions.LossFunction.MSE) @@ -1000,11 +1000,11 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { for (Layer l : netToTest.getLayers()) { // Remove any dropout manually - until this is fixed: // https://github.com/eclipse/deeplearning4j/issues/4368 - l.conf().getLayer().setIDropout(null); + l.getLayerConfiguration().setIDropout(null); //Also swap out activation functions... this is a bit of a hack, but should make the net gradient checkable... - if (l.conf().getLayer() instanceof FeedForwardLayer) { - FeedForwardLayer ffl = (FeedForwardLayer) l.conf().getLayer(); + if (l.getLayerConfiguration() instanceof FeedForwardLayer) { + FeedForwardLayer ffl = (FeedForwardLayer) l.getLayerConfiguration(); IActivation activation = ffl.getActivationFn(); if (activation instanceof ActivationReLU || activation instanceof ActivationLReLU) { ffl.setActivationFn(new ActivationSoftPlus()); diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java index 19681185c..1d7144b20 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java @@ -24,7 +24,6 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.models.paragraphvectors.ParagraphVectorsTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; @@ -53,7 +52,6 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.File; import java.util.Collection; -import java.util.concurrent.Callable; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -189,7 +187,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest { INDArray w = vec.lookupTable().getWeights(); System.out.println(w); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .seed(12345).list() .layer(new EmbeddingLayer.Builder().weightInit(vec).build()) .layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(w.size(1)).nOut(3).build()) @@ -210,7 +208,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); + assertEquals(net.getConfiguration(), restored.getConfiguration()); assertTrue(net.params().equalsWithEps(restored.params(), 2e-3)); } } diff --git a/cavis-dnn/cavis-dnn-nn/build.gradle b/cavis-dnn/cavis-dnn-nn/build.gradle index 0e097093d..59ff712ab 100644 --- a/cavis-dnn/cavis-dnn-nn/build.gradle +++ b/cavis-dnn/cavis-dnn-nn/build.gradle @@ -58,5 +58,4 @@ dependencies { implementation "com.squareup.okhttp3:okhttp" implementation "com.squareup.okhttp3:logging-interceptor" } -sourceCompatibility = JavaVersion.VERSION_11 -targetCompatibility = JavaVersion.VERSION_11 + diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/Animal.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/Animal.java new file mode 100644 index 000000000..2b5ac714c --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/Animal.java @@ -0,0 +1,68 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +public class Animal { + + private String animalString; + + protected Animal(AnimalBuilder b) { + this.animalString = b.animalString; + } + + public static AnimalBuilder builder() { + return new AnimalBuilderImpl(); + } + + public static abstract class AnimalBuilder> { + + private String animalString; + + public B animalString(String animalString) { + this.animalString = animalString; + return self(); + } + + protected abstract B self(); + + public abstract C build(); + + public String toString() { + return "Animal.AnimalBuilder(animalString=" + this.animalString + ")"; + } + } + + private static final class AnimalBuilderImpl extends + AnimalBuilder { + + private AnimalBuilderImpl() { + } + + protected AnimalBuilderImpl self() { + return this; + } + + public Animal build() { + return new Animal(this); + } + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IActivationFunction.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IActivationFunction.java new file mode 100644 index 000000000..18794b8fe --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IActivationFunction.java @@ -0,0 +1,57 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +/** + * Activation Function An activation function takes in weighted data (matrix multiplication between + * input data and weights) and outputs a non-linear transformation of the data. For example, output + * = max(0,weighted_data) is the rectified linear activation function (essentially set all negative + * values to zero). The difference between units and activation functions is that units can be more + * complex, that is, a unit can have multiple activation functions (for example LSTM units) or a + * slightly more complex structure (for example maxout units). + *

+ * The difference between linear and non-linear activation functions can be shown with the + * relationship of some weighted values: Imagine the four points A1, A2, B1 and B2. The pairs A1 / + * A2, and B1 / B2 lie close to each other, but A1 is distant from B1 and B2, and vice versa; the + * same for A2. + *

+ * With a linear transformation the relationship between pairs might change. For example A1 and A2 + * might be far apart, but this implies that B1 and B2 are also far apart. The distance between the + * pairs might shrink, but if it does, then both B1 and B2 will be close to A1 and A2 at the same + * time. We can apply many linear transformations, but the relationship between A1 / A2 and B1 / B2 + * will always be similar. + *

+ * In contrast, with a non-linear activation function we can increase the distance between A1 and A2 + * while we decrease the distance between B1 and B2. We can make B1 close to A1, but B2 distant from + * A1. By applying non-linear functions, we create new relationships between the points. With every + * new non-linear transformation we can increase the complexity of the relationships. In deep + * learning, using non-linear activation functions creates increasingly complex features with every + * layer. + *

+ * In contrast, the features of 1000 layers of pure linear transformations can be reproduced by a + * single layer (because a chain of matrix multiplication can always be represented by a single + * matrix multiplication). This is why non-linear activation functions are so important in deep + * learning. + */ +public interface IActivationFunction { + +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java index f0c6a722a..2c31319fc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java @@ -21,66 +21,262 @@ package net.brutex.ai.dnn.api; +import java.util.Collection; +import java.util.Map; +import lombok.NonNull; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; +import org.deeplearning4j.optimize.api.ConvexOptimizer; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.common.primitives.Pair; +import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; /** * A Neural Network is an instance of a {@link INeuralNetworkConfiguration}, that can be trained, * evaluated, saved, exported, etc. Its configuration state is defined with the - * {@link #setConfiguration(INeuralNetworkConfiguration)} and {@link #getConfiguration()} methods. - * - */ + * {@link #setNetConfiguration(NeuralNetConfiguration)} (INeuralNetworkConfiguration)} and + * {@link #getNetConfiguration()} methods. + **/ + public interface IModel { /** - * The configuration that defines this Neural Network + * This method returns updater state (if applicable), null otherwise * - * @param conf the configuration to use for this network + * @return */ - void setConfiguration(INeuralNetworkConfiguration conf); - INeuralNetworkConfiguration getConfiguration(); + INDArray updaterState(); /** - * Fit the model for one iteration on the provided data + * This method returns Optimizer used for training * - * @param features the examples to classify (one example in each row) - * @param labels the example labels(a binary outcome matrix) - * @param featuresMask The mask array for the features (used for variable length time series, etc). May be null. - * @param labelsMask The mask array for the labels (used for variable length time series, etc). May be null. + * @return */ - void fit(INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask); + ConvexOptimizer getOptimizer(); /** * This method fits model with a given DataSet * - * @param dataSet the dataset to use for training + * @param dataSet */ void fit(DataSet dataSet); /** * This method fits model with a given MultiDataSet * - * @param dataSet the multi dataset to use for training + * @param dataSet */ void fit(MultiDataSet dataSet); /** - * The name of the Neural Network - * @return the name + * This method fits model with a given DataSetIterator + * + * @param iterator */ - String getName(); + void fit(DataSetIterator iterator); /** - * Set the name for this Neural Network - * @param name the name + * This method fits model with a given MultiDataSetIterator + * + * @param iterator */ - void setName(String name); + void fit(MultiDataSetIterator iterator); /** - * An implementation should provide a method to validate the network - * @return true if no errors found; false otherwise + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator */ - boolean isValid(); + T[] doEvaluation(DataSetIterator iterator, T... evaluations); + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + */ + T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations); + + NeuralNetConfiguration getNetConfiguration(); + + void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration); + + /** + * Init the model + */ + void init(); + + /** + * Get the number of parameters in this model + * @return number of parameters + */ + long numParams(); + + /** + * All models have a fit method + */ + @Deprecated + void fit(); + + /** + * Update layer weights and biases with gradient change + */ + void update(Gradient gradient); + + /** + * Perform one update applying the gradient + * + * @param gradient the gradient to apply + */ + void update(INDArray gradient, String paramType); + + + /** + * The score for the model + * + * @return the score for the model + */ + double score(); + + + /** + * Update the score + */ + void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr); + + /** + * Parameters of the model (if any) + * + * @return the parameters of the model + */ + INDArray params(); + + + /** + * the number of parameters for the model + * + * @return the number of parameters for the model + */ + long numParams(boolean backwards); + + /** + * Set the parameters for this model. This expects a linear ndarray which then be unpacked + * internally relative to the expected ordering of the model + * + * @param params the parameters for the model + */ + void setParams(INDArray params); + + /** + * Set the initial parameters array as a view of the full (backprop) network parameters NOTE: this + * is intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users. + * + * @param params a 1 x nParams row vector that is a view of the larger (MLN/CG) parameters array + */ + void setParamsViewArray(INDArray params); + + + INDArray getGradientsViewArray(); + + /** + * Set the gradients array as a view of the full (backprop) network parameters NOTE: this is + * intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users. + * + * @param gradients a 1 x nParams row vector that is a view of the larger (MLN/CG) gradients + * array + */ + void setBackpropGradientsViewArray(INDArray gradients); + + /** + * Fit the model to the given data + * + * @param data the data to fit the model to + */ + void fit(INDArray data, LayerWorkspaceMgr workspaceMgr); + + + /** + * Get the gradient. Note that this method will not calculate the gradient, it will rather return + * the gradient that has been computed before. For calculating the gradient, see + * {@link IModel#computeGradientAndScore(LayerWorkspaceMgr)} } . + * + * @return the gradient for this model, as calculated before + */ + Gradient gradient(); + + /** + * Get the gradient and score + * + * @return the gradient and score + */ + Pair gradientAndScore(); + + /** + * The current inputs batch size + * + * @return the current inputs batch size + */ + int batchSize(); + + /** + * The input/feature matrix for the model + * + * @return the input/feature matrix for the model + */ + INDArray input(); + + /** + * Get a parameter array for a given parameter type key + * @param param the key of the parameter + * @return ndarray of parameters + */ + INDArray getParam(String param); + + + + /** + * Set the parameters for a given parameter type. + * @param key the param type key to set + * @param val the new parameters ndarray + */ + void setParam(String key, INDArray val); + + /** + * Clear input + */ + void clear(); + + + /** + * Apply any constraints to the model + */ + void applyConstraints(int iteration, int epoch); + + + void close(); + + /** + * Get the TrainingListeners + * @return training listener + */ + Collection getListeners(); + + /** + * Replace the TrainingListeners for this model + * @param listeners new listeners + */ + void setListeners(TrainingListener... listeners); + + /** + * Add TrainingListeners to the model + * @param listener listener to add + */ + void addListeners(TrainingListener... listener); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetwork.java deleted file mode 100644 index 48d6c561b..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetwork.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * - * ****************************************************************************** - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - * - */ - -package net.brutex.ai.dnn.api; - -import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; -import org.deeplearning4j.optimize.api.ConvexOptimizer; -import org.nd4j.evaluation.IEvaluation; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.DataSet; -import org.nd4j.linalg.dataset.api.MultiDataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; - -/** - * @author raver119 - */ -public interface INeuralNetwork { - - /** - * This method does initialization of model - *

- * PLEASE NOTE: All implementations should track own state, to avoid double spending - */ - void init(); - - /** - * This method returns model parameters as single INDArray - * - * @return - */ - INDArray params(); - - /** - * This method returns updater state (if applicable), null otherwise - * - * @return - */ - INDArray updaterState(); - - /** - * This method returns Optimizer used for training - * - * @return - */ - ConvexOptimizer getOptimizer(); - - /** - * This method fits model with a given DataSet - * - * @param dataSet - */ - void fit(DataSet dataSet); - - /** - * This method fits model with a given MultiDataSet - * - * @param dataSet - */ - void fit(MultiDataSet dataSet); - - /** - * This method fits model with a given DataSetIterator - * - * @param iterator - */ - void fit(DataSetIterator iterator); - - /** - * This method fits model with a given MultiDataSetIterator - * - * @param iterator - */ - void fit(MultiDataSetIterator iterator); - - /** - * This method executes evaluation of the model against given iterator and evaluation - * implementations - * - * @param iterator - */ - T[] doEvaluation(DataSetIterator iterator, T... evaluations); - - /** - * This method executes evaluation of the model against given iterator and evaluation - * implementations - * - * @param iterator - */ - T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations); - - /** - * A neural network is created from a configuration. - * @param conf the configuration to create the network from - */ - void setConfiguration(NeuralNetworkConfiguration conf); - - /** - * Return the configuration for this configuration - * @return - */ - NeuralNetworkConfiguration getConfiguration(); - -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java index 81d447fa3..b317e4ab0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java @@ -21,10 +21,14 @@ package net.brutex.ai.dnn.api; +import java.io.Serializable; import java.util.List; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -public interface INeuralNetworkConfiguration { +public interface INeuralNetworkConfiguration extends Serializable, Cloneable { + INeuralNetworkConfiguration clone(); + void init(); } /** /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IUnit.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IUnit.java new file mode 100644 index 000000000..dd9643c6b --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IUnit.java @@ -0,0 +1,47 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +/** + * Unit A unit often refers to the activation function in a layer by which the inputs are + * transformed via a nonlinear activation function (for example by the logistic sigmoid function). + * Usually, a unit has several incoming connections and several outgoing connections. However, units + * can also be more complex, like long short-term memory (LSTM) units, which have multiple + * activation functions with a distinct layout of connections to the nonlinear activation functions, + * or maxout units, which compute the final output over an array of nonlinearly transformed input + * values. Pooling, convolution, and other input transforming functions are usually not referred to + * as units. + *

+ * Artificial Neuron The term artificial neuron—or most often just neuron—is an equivalent term to + * unit, but implies a close connection to neurobiology and the human brain while deep learning has + * very little to do with the brain (for example, it is now thought that biological neurons are more + * similar to entire multilayer perceptrons rather than a single unit in a neural network). The term + * neuron was encouraged after the last AI winter to differentiate the more successful neural + * network from the failing and abandoned perceptron. However, since the wild successes of deep + * learning after 2012, the media often picked up on the term “neuron” and sought to explain deep + * learning as mimicry of the human brain, which is very misleading and potentially dangerous for + * the perception of the field of deep learning. Now the term neuron is discouraged and the more + * descriptive term unit should be used instead. + */ +public interface IUnit { + +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/LayerType.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/LayerType.java new file mode 100644 index 000000000..ba432d132 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/LayerType.java @@ -0,0 +1,52 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +import lombok.Getter; +import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.conf.layers.BatchNormalization; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.DropoutLayer; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; +import org.deeplearning4j.nn.conf.layers.NoParamLayer; + +public enum LayerType { + CONV("CONV", "Convolutional", ConvolutionLayer.class), + ACT("ACT", "Activation", ActivationLayer.class), + POOL( "POOL", "Pooling/ Subsampling", NoParamLayer.class), + FC( "FC", "Fully Connected", FeedForwardLayer.class), + BN("BN", "Batch Normalization", BatchNormalization.class), + DO("DO", "Dropout", DropoutLayer.class), + UNKNOWN("UNKNOWN", "Type not specified", LayerConfiguration.class); + +@Getter + String description; + @Getter String name; + @Getter Class clazz; + + LayerType(String name, String description, Class clazz) { + this.name = name; + this.description = description; + this.clazz = clazz; + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java new file mode 100644 index 000000000..3e13e811a --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java @@ -0,0 +1,42 @@ +/* + * + * ****************************************************************************** + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + * + */ + +package net.brutex.ai.dnn.api; + +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; + +/** + * A fluent API to configure and create artificial neural networks + */ +public class NN { + + + public static NeuralNetConfigurationBuilder net() { + return NeuralNetConfiguration.builder(); + } + + void test() { + Dog.DogBuilder builder = Dog.builder() + .animalString("") + .dogString(""); + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java deleted file mode 100644 index 51de9f873..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/NeuralNetworkConfiguration.java +++ /dev/null @@ -1,705 +0,0 @@ -/* - * - * ****************************************************************************** - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - * - */ - -package net.brutex.ai.dnn.conf; - -import com.fasterxml.jackson.annotation.JsonIgnoreProperties; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; -import com.fasterxml.jackson.databind.node.ArrayNode; -import java.io.IOException; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Random; -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; -import lombok.Singular; -import lombok.extern.jackson.Jacksonized; -import lombok.extern.slf4j.Slf4j; -import net.brutex.ai.dnn.api.ILayerConfiguration; -import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; -import org.deeplearning4j.nn.conf.BackpropType; -import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.deeplearning4j.nn.conf.distribution.Distribution; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; -import org.deeplearning4j.nn.conf.layers.Layer; -import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.deeplearning4j.nn.weights.IWeightInit; -import org.deeplearning4j.nn.weights.WeightInit; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.lossfunctions.LossFunctions; -import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; -import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; -import org.nd4j.linalg.lossfunctions.impl.LossMSE; -import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; - -/** - * The INeuralNetworkConfiguration is a sequential container for the different layers in your - * network (or other NeuralNetworkConfigurations). That said, NeuralNetworkConfigurations can be - * stacked.

- * It then “chains” outputs to inputs sequentially for each INeuralNetworkConfiguration, - * finally returning the output of the "top" configuration. Any settings made, are inherited and can - * be overridden on a "deeper" level. For this use case, you need to wrap the INeuralNetworkConfiguration - * into a BuildingBlockLayer - * - */ -@Jacksonized -@JsonIgnoreProperties(ignoreUnknown = true) -@lombok.Builder -@Slf4j -public class NeuralNetworkConfiguration extends NeuralNetConfiguration implements - INeuralNetworkConfiguration, Serializable, Cloneable { - - private static final int DEFAULT_TBPTT_LENGTH = 20; - @Getter protected final List confs = new ArrayList<>(); - /** - * hidden list of layers, that "flattens" all the layers of this network and applies - * inheritance. - */ - @lombok.Builder.ObtainVia(method = "calculateInnerLayers") - private final List innerLayerConfigurations; - @Getter @Setter @NonNull @Singular - protected List layers = new ArrayList<>(); - @Getter @Setter @NonNull @lombok.Builder.Default @Deprecated - protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; - @Getter @Setter @NonNull @lombok.Builder.Default @Deprecated - protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; - /** - * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but - * optionally truncated BPTT can be used for training recurrent neural networks. If using - * TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() - */ - @Getter @Setter @NonNull @lombok.Builder.Default - protected BackpropType backpropType = BackpropType.Standard; - @Getter - protected Map inputPreProcessors = new HashMap<>(); - /** - * When doing truncated BPTT: how many steps of forward pass should we do before doing - * (truncated) backprop?
Only applicable when doing - * backpropType(BackpropType.TruncatedBPTT)
Typically tBPTTForwardLength parameter is same - * as the tBPTTBackwardLength parameter, but may be larger than it in some circumstances (but - * never smaller)
Ideally your training data time series length should be divisible by this - * This is the k1 parameter on pg23 of - * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param forwardLength Forward length > 0, >= backwardLength - */ - @Getter @Setter protected int tbpttFwdLength = 20; - /** - * When doing truncated BPTT: how many steps of backward should we do?
Only applicable when - * doing backpropType(BackpropType.TruncatedBPTT)
This is the k2 parameter on pg23 of - * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param backwardLength <= forwardLength - */ - @Getter @Setter protected int tbpttBackLength = 20; - /** - * Creates and returns a copy of this object. - * - * @return a clone of this instance. - * @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable} - * interface. Subclasses that override the {@code clone} method - * can also throw this exception to indicate that an instance - * cannot be cloned. - * @see Cloneable - */ - - //Nd4j.getRandom().setSeed(getConf(0).getSeed()); //TODO - //Counter for the number of parameter updates so far - // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted - // for Spark and model serialization - @Getter @Setter - protected int iterationCount = 0; - //Counter for the number of epochs completed so far. Used for per-epoch schedules - @Getter @Setter - protected int epochCount = 0; - protected double dampingFactor = 100; - @Getter @Setter //todo why? - private Layer layer; - /** - * A seed for this network, will be random if not specified. - */ - @Getter @Setter @NonNull @lombok.Builder.Default - private long seed = new Random().nextLong(); - /** - * The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified otherwise. - * This method defines how/if preOutput cache is handled: NONE: cache disabled (default value) - * HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect will - * be the same as for HOST) - * - * Valid values are
- * CacheMode.NONE,
- * CacheMode.HOST or
- * CacheMode.DEVICE
- * @param cacheMode - */ - @NonNull @Getter @Setter - @lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE; - /** - * The list of layer configurations in this configuration. They will be indexed automatically - * as the layers get added starting with index 0. - */ - @Singular @Getter - private List layerConfigurations; - /** - * The name for this configuration. Defaults to "Anonymous INeuralNetworkConfiguration" if - * it is not specified. - */ - @lombok.Builder.Default @Getter - private String name = "Anonymous INeuralNetworkConfiguration"; - /** - * The {@link InputType} of the data for this network configuration - */ - private InputType inputType; - /** - * Set the DataType for the network parameters and activations for all layers in the network. - * Default: Float - * - * @param dataType Datatype to use for parameters and activations - */ - @Getter @Setter @lombok.Builder.Default @NonNull - private DataType dataType = DataType.FLOAT; - /** - * Whether to override the nIn configuration forcibly upon construction. Default value is true. - * @return builder pattern - */ - @Getter @Setter - @lombok.Builder.Default - private boolean overrideNinUponBuild = true; - /** - * Enabled by default. If enabled, the output layer configuration will be validated, to throw an - * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
If - * disabled (false) no output layer validation will be performed.
Disabling this validation - * is not recommended, as the configurations that fail validation usually will not be able to - * learn correctly. However, the option to disable this validation is provided for advanced - * users when creating non-standard architectures. - * - * @param validate If true: validate output layer configuration. False: don't validate - */ - @Getter @Setter @lombok.Builder.Default - private boolean validateOutputLayerConfig=true; - /** - * Enabled by default. If enabled, an exception will be throw when using the (invalid) - * combination of truncated backpropagation through time (TBPTT) with either a - * GlobalPoolingLayer or LastTimeStepLayer.
It is possible to disable this validation to - * allow what is almost certainly an invalid configuration to be used, however this is not - * recommended. - * - * @param validate Whether TBPTT validation should be performed - */ - @Getter @Setter @lombok.Builder.Default - private boolean validateTbpttConfig=true; - - - - /** - * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} - * or {@link org.nd4j.linalg.learning.config.Nesterovs}
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param updater Updater to use - */ - @Getter @Setter @NonNull - private IUpdater updater; - - /** - * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. - * See {@link GradientNormalization} for details
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param gradientNormalization Type of normalization to use. Defaults to None. - * @see GradientNormalization - */ - @Getter @Setter @NonNull @lombok.Builder.Default - private GradientNormalization gradientNormalization = GradientNormalization.None; - - /** - * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, - * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
- * Not used otherwise.
- * L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - */ - @Getter @Setter - private double gradientNormalizationThreshold; - - - /** - * Weight initialization scheme to use, for initial weight values - * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - */ - @Getter @Setter - private IWeightInit weightInit; - - /** - * Activation function / neuron non-linearity
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - */ - @Getter @Setter - private IActivation activation; - - - - /** - * Create a neural net configuration from json - * - * @param json the neural net configuration from json - * @return {@link NeuralNetworkConfiguration} - */ - public static NeuralNetworkConfiguration fromJson(String json) { - NeuralNetworkConfiguration conf; - ObjectMapper mapper = NeuralNetworkConfiguration.mapper(); - try { - conf = mapper.readValue(json, NeuralNetworkConfiguration.class); - } catch (InvalidTypeIdException e) { - if (e.getMessage().contains("@class")) { - try { - //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format - return JsonMappers.getLegacyMapper().readValue(json, NeuralNetworkConfiguration.class); - } catch (InvalidTypeIdException e2) { - //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.ILayer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." - //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work - String msg = e2.getMessage(); - if (msg != null && msg.contains("Could not resolve type id")) { - throw new RuntimeException( - "Error deserializing MultiLayerConfiguration - configuration may have a custom " + - "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" - + - " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", - e); - } - throw new RuntimeException(e2); - } catch (IOException e2) { - throw new RuntimeException(e2); - } - } - throw new RuntimeException(e); - } catch (IOException e) { - //Check if this exception came from legacy deserializer... - String msg = e.getMessage(); - if (msg != null && msg.contains("legacy")) { - throw new RuntimeException( - "Error deserializing MultiLayerConfiguration - configuration may have a custom " + - "layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " - + - "deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", - e); - } - throw new RuntimeException(e); - } - - //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier) - // Previously: enumeration used for loss functions. Now: use classes - // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums - int layerCount = 0; - JsonNode confs = null; - for (NeuralNetworkConfiguration nnc : conf.getConfs()) { - Layer l = nnc.getLayer(); - if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) { - //lossFn field null -> may be an old config format, with lossFunction field being for the enum - //if so, try walking the JSON graph to extract out the appropriate enum value - - BaseOutputLayer ol = (BaseOutputLayer) l; - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) { - throw new RuntimeException("should never happen"); //return conf; //Should never happen... - } - JsonNode outputLayerNode = outputLayerNNCNode.get("layer"); - - JsonNode lossFunctionNode = null; - if (outputLayerNode.has("output")) { - lossFunctionNode = outputLayerNode.get("output").get("lossFunction"); - } else if (outputLayerNode.has("rnnoutput")) { - lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction"); - } - - if (lossFunctionNode != null) { - String lossFunctionEnumStr = lossFunctionNode.asText(); - LossFunctions.LossFunction lossFunction = null; - try { - lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr); - } catch (Exception e) { - log.warn( - "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", - e); - } - - if (lossFunction != null) { - switch (lossFunction) { - case MSE: - ol.setLossFn(new LossMSE()); - break; - case XENT: - ol.setLossFn(new LossBinaryXENT()); - break; - case NEGATIVELOGLIKELIHOOD: - ol.setLossFn(new LossNegativeLogLikelihood()); - break; - case MCXENT: - ol.setLossFn(new LossMCXENT()); - break; - - //Remaining: TODO - case SQUARED_LOSS: - case RECONSTRUCTION_CROSSENTROPY: - default: - log.warn( - "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", - lossFunction); - break; - } - } - } - - } else { - log.warn( - "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", - (confs != null ? confs.getClass() : null)); - } - } catch (IOException e) { - log.warn( - "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", - e); - break; - } - } - - //Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn") - //Try to load the old format if necessary, and create the appropriate IActivation instance - if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) { - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) { - throw new RuntimeException("Should never happen"); //return conf; //Should never happen... - } - JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); - - if (layerWrapperNode == null || layerWrapperNode.size() != 1) { - continue; - } - - JsonNode layerNode = layerWrapperNode.elements().next(); - JsonNode activationFunction = layerNode.get( - "activationFunction"); //Should only have 1 element: "dense", "output", etc - - if (activationFunction != null) { - IActivation ia = Activation.fromString(activationFunction.asText()) - .getActivationFunction(); - ((BaseLayer) l).setActivationFn(ia); - } - } - - } catch (IOException e) { - log.warn( - "ILayer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", - e); - } - } - - if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) { - return conf; - } - - layerCount++; - } - return conf; - } - - /** - * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied - * from handling of {@link Activation} above. - * - * @return True if all is well and layer iteration shall continue. False else-wise. - */ - private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper, - JsonNode confs, int layerCount) { - if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) { - return false; //Should never happen... - } - JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); - - if (layerWrapperNode == null || layerWrapperNode.size() != 1) { - return true; - } - - JsonNode layerNode = layerWrapperNode.elements().next(); - JsonNode weightInit = layerNode.get( - "weightInit"); //Should only have 1 element: "dense", "output", etc - JsonNode distribution = layerNode.get("dist"); - - Distribution dist = null; - if (distribution != null) { - dist = mapper.treeToValue(distribution, Distribution.class); - } - - if (weightInit != null) { - final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) - .getWeightInitFunction(dist); - ((BaseLayer) l).setWeightInitFn(wi); - } - } - - } catch (IOException e) { - log.warn( - "ILayer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON", - e); - } - } - return true; - - } - - /** - * Object mapper for serialization of configurations - * - * @return - */ - public static ObjectMapper mapperYaml() { - return JsonMappers.getMapperYaml(); - } - - /** - * Object mapper for serialization of configurations - * - * @return - */ - public static ObjectMapper mapper() { - return JsonMappers.getMapper(); - } - - - - /** - * @return JSON representation of NN configuration - */ - public String toYaml() { - ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); - synchronized (mapper) { - try { - return mapper.writeValueAsString(this); - } catch (com.fasterxml.jackson.core.JsonProcessingException e) { - throw new RuntimeException(e); - } - } - } - - /** - * @return JSON representation of NN configuration - */ - public String toJson() { - ObjectMapper mapper = NeuralNetConfiguration.mapper(); - synchronized (mapper) { - //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally - //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 - try { - return mapper.writeValueAsString(this); - } catch (com.fasterxml.jackson.core.JsonProcessingException e) { - throw new RuntimeException(e); - } - } - } - - @Override - public String toString() { - return toJson(); - } - - public NeuralNetworkConfiguration getConf(int i) { - return confs.get(i); - } - - @Override - public NeuralNetworkConfiguration clone() { - - NeuralNetworkConfiguration clone = (NeuralNetworkConfiguration) super.clone(); - List confList = clone.getConfs(); - if (confList != null) { - List list = new ArrayList<>(); - for (NeuralNetworkConfiguration conf : confList) { - list.add(conf.clone()); - } - } - - if (clone.getInputPreProcessors() != null) { - Map map = new HashMap<>(); - for (Map.Entry entry : clone.getInputPreProcessors().entrySet()) { - map.put(entry.getKey(), entry.getValue().clone()); - } - clone.getInputPreProcessors().clear(); - clone.getInputPreProcessors().putAll(map); - } - - clone.setInferenceWorkspaceMode(this.inferenceWorkspaceMode); - clone.setTrainingWorkspaceMode(this.trainingWorkspaceMode); - clone.setCacheMode(this.cacheMode); - clone.setValidateOutputLayerConfig(this.validateOutputLayerConfig); - clone.setDataType(this.dataType); - - return clone; - - } - - public InputPreProcessor getInputPreProcess(int curr) { - return inputPreProcessors.get(curr); - } - - /** - * Get a {@link MemoryReport} for the given MultiLayerConfiguration. This is used to estimate the - * memory requirements for the given network configuration and input - * - * @param inputType Input types for the network - * @return Memory report for the network - */ - public NetworkMemoryReport getMemoryReport(InputType inputType) { - - Map memoryReportMap = new LinkedHashMap<>(); - int nLayers = confs.size(); - for (int i = 0; i < nLayers; i++) { - String layerName = confs.get(i).getLayer().getLayerName(); - if (layerName == null) { - layerName = String.valueOf(i); - } - - //Pass input type through preprocessor, if necessary - InputPreProcessor preproc = getInputPreProcess(i); - //TODO memory requirements for preprocessor - if (preproc != null) { - inputType = preproc.getOutputType(inputType); - } - - LayerMemoryReport report = confs.get(i).getLayer().getMemoryReport(inputType); - memoryReportMap.put(layerName, report); - - inputType = confs.get(i).getLayer().getOutputType(i, inputType); - } - - return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, - "MultiLayerNetwork", inputType); - } - - /** - * For the given input shape/type for the network, return a list of activation sizes for each - * layer in the network.
i.e., list.get(i) is the output activation sizes for layer i - * - * @param inputType Input type for the network - * @return A lits of activation types for the network, indexed by layer number - */ - public List getLayerActivationTypes(@NonNull InputType inputType) { - List out = new ArrayList<>(); - int nLayers = confs.size(); - for (int i = 0; i < nLayers; i++) { - InputPreProcessor preproc = getInputPreProcess(i); - if (preproc != null) { - inputType = preproc.getOutputType(inputType); - } - - inputType = confs.get(i).getLayer().getOutputType(i, inputType); - out.add(inputType); - } - return out; - } - - /** - * Defines some additional handy methods. Other than that, - * the builder is generated by lombok. - */ - public static class NeuralNetworkConfigurationBuilder { - - /** - * Specify the processors. These are used at each layer for doing things like normalization and - * shaping of input. - * - * @param processor what to use to preProcess the data. - * @return builder pattern - */ - public NeuralNetworkConfigurationBuilder inputPreProcessor(Integer layer, - InputPreProcessor processor) { - inputPreProcessors.put(layer, processor); - return this; - } - - /** - * Specify additional layer configurations - */ - @Deprecated - public NeuralNetworkConfigurationBuilder layersFromArray(Layer[] arrLayers) { - for(Layer l : arrLayers) { - layers.add( l ); - } - return this; - } - } - - -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/DenseLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/DenseLayerConfiguration.java deleted file mode 100644 index d472d99b2..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/DenseLayerConfiguration.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * - * ****************************************************************************** - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - * - */ - -package net.brutex.ai.dnn.conf.layer; - -import lombok.Builder; -import lombok.experimental.SuperBuilder; -import org.deeplearning4j.nn.conf.layers.LayerValidation; - -/** - * The dense layer is a neural network layer that is connected deeply, which means each neuron in - * the dense layer receives input from all neurons of its previous layer. The dense layer is found - * to be the most commonly used layer in the models. - *

- * In the background, the dense layer performs a matrix-vector multiplication. The values used in - * the matrix are actually parameters that can be trained and updated with the help of - * backpropagation. - *

- * The output generated by the dense layer is an ‘m’ dimensional vector. Thus, dense layer is - * basically used for changing the dimensions of the vector. Dense layers also applies operations - * like rotation, scaling, translation on the vector. - */ -@SuperBuilder -public class DenseLayerConfiguration extends FeedForwardLayerConfiguration { - - /** - * Decides whether we should include a bias vector for calculation purposes or not. - */ - @Builder.Default - boolean bias = true; - - - - /** - * An implementation to validate the network - * - * @return true if no errors found; false otherwise - */ - @Override - public boolean isValid() { - LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getName(), -99, getIn(), getOut()); - return super.isValid(); - } -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FeedForwardLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FeedForwardLayerConfiguration.java deleted file mode 100644 index c86869d54..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/FeedForwardLayerConfiguration.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * - * ****************************************************************************** - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - * - */ - -package net.brutex.ai.dnn.conf.layer; - -import lombok.Getter; -import lombok.experimental.SuperBuilder; -import lombok.extern.slf4j.Slf4j; -import net.brutex.ai.dnn.api.ILayer; -import net.brutex.ai.dnn.api.ILayerConfiguration; -import net.brutex.ai.dnn.api.IModel; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.inputs.InputType.Type; - -/** - * A Feed Forward Layer Configuration - */ -@Slf4j -@SuperBuilder -public class FeedForwardLayerConfiguration extends AbstractLayerConfiguration implements ILayerConfiguration { - - @Getter private int in; - @Getter private int out; - - /** - * This Fast Forward ILayer will always output data as - * FF type. - * @return InputType for FF - **/ - @Getter - final InputType.Type outputType = InputType.Type.FF; - - @Getter - final InputType.Type inputType = InputType.Type.FF; - - /** - * Create and return an instance of a ILayerConfiguration. - * - * @param network the "holding" network for the instance - * @return the new layer instance - */ - //@Override - public ILayer instantiate(IModel network) { - //Let's do some verifications first - if(getInputType() != Type.FF) { - log.error("The {} layer configuration must use an InputType of {}, but found {}", - this.getClass().getSimpleName(), - Type.FF.name(), - getInputType().name()); - } - return null; - } - - /** - * Number of trainable parameter in this layer - * - * @return number of parameter - */ - @Override - public long numParameters() { - return in * out + out; //add one extra out for the bias - } - - /** - * An implementation should provide a method to validate the network - * - * @return true if no errors found; false otherwise - */ - @Override - public boolean isValid() { - boolean result = true; - if(getInputType() != Type.FF) { - log.error("The {} layer configuration must use an InputType of {}, but found {}", - this.getClass().getSimpleName(), - Type.FF.name(), - getInputType().name()); - result = false; - } - return result; - } -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java index 0a605b94f..2b900a5ff 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java @@ -21,10 +21,15 @@ package net.brutex.ai.dnn.networks; +import java.util.Map; import lombok.Getter; +import lombok.NonNull; import lombok.Setter; -import net.brutex.ai.dnn.conf.NeuralNetworkConfiguration; -import net.brutex.ai.dnn.api.INeuralNetwork; +import net.brutex.ai.dnn.api.IModel; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.gradient.Gradient; +import org.nd4j.linalg.api.ndarray.INDArray; + /** * Artificial Neural Network An artificial neural network (1) takes some input data, and (2) @@ -41,13 +46,50 @@ import net.brutex.ai.dnn.api.INeuralNetwork; * predictions of the network and the desired values and then using this error signal to change the * weights (or parameters) so that predictions get more accurate. */ -public abstract class ArtificialNeuralNetwork implements INeuralNetwork { +public abstract class ArtificialNeuralNetwork implements IModel { /** * A neural network is created from a configuration. + * * @param conf The (new net.brutex.ai) configuration for the network */ @Getter - @Setter //TODO make this also final and @NonNull - private NeuralNetworkConfiguration configuration; + @Setter + @NonNull + private NeuralNetConfiguration netConfiguration; + + + /** + * Create a new network from configuration + * + * @param conf the configuration + */ + public ArtificialNeuralNetwork(NeuralNetConfiguration conf) { + this.netConfiguration = conf; + } + + /** + * Update all parameters (for all parameter types) with the given gradient. + * + * @param gradient the gradients to add + */ + public void update(Gradient gradient) { + for (String paramType : gradient.gradientForVariable().keySet()) { + update(gradient.getGradientFor(paramType), paramType); + } + } + + /** + * Update the parameters of a given type with a given gradient. + * + * @param gradient the gradient to apply + * @param paramType + */ + public void update(INDArray gradient, String paramType) { + setParam(paramType, getParam(paramType).addi(gradient)); + } + + + + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java index 8f55745ed..d95c5aab6 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingConfiguration.java @@ -22,12 +22,12 @@ package org.deeplearning4j.earlystopping; import lombok.Data; import lombok.NoArgsConstructor; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition; import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition; import org.deeplearning4j.exception.DL4JInvalidConfigException; -import org.deeplearning4j.nn.api.Model; import org.nd4j.common.function.Supplier; import java.io.Serializable; @@ -37,7 +37,7 @@ import java.util.List; @Data @NoArgsConstructor -public class EarlyStoppingConfiguration implements Serializable { +public class EarlyStoppingConfiguration implements Serializable { private EarlyStoppingModelSaver modelSaver; private List epochTerminationConditions; @@ -89,7 +89,7 @@ public class EarlyStoppingConfiguration implements Serializable } - public static class Builder { + public static class Builder { private EarlyStoppingModelSaver modelSaver = new InMemoryModelSaver<>(); private List epochTerminationConditions = new ArrayList<>(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingModelSaver.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingModelSaver.java index a9793175a..9037e0792 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingModelSaver.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingModelSaver.java @@ -20,10 +20,10 @@ package org.deeplearning4j.earlystopping; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.saver.InMemoryModelSaver; import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver; import org.deeplearning4j.earlystopping.saver.LocalFileModelSaver; -import org.deeplearning4j.nn.api.Model; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; @@ -38,7 +38,7 @@ import java.io.Serializable; }) @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface EarlyStoppingModelSaver extends Serializable { +public interface EarlyStoppingModelSaver extends Serializable { /** Save the best model (so far) learned during early stopping training */ void saveBestModel(T net, double score) throws IOException; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingResult.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingResult.java index 6f44c7fdb..817f4c7db 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingResult.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/EarlyStoppingResult.java @@ -21,13 +21,13 @@ package org.deeplearning4j.earlystopping; import lombok.Data; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import java.io.Serializable; import java.util.Map; @Data -public class EarlyStoppingResult implements Serializable { +public class EarlyStoppingResult implements Serializable { public enum TerminationReason { Error, IterationTerminationCondition, EpochTerminationCondition } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/listener/EarlyStoppingListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/listener/EarlyStoppingListener.java index 191870de3..016b31881 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/listener/EarlyStoppingListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/listener/EarlyStoppingListener.java @@ -20,11 +20,11 @@ package org.deeplearning4j.earlystopping.listener; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingResult; -import org.deeplearning4j.nn.api.Model; -public interface EarlyStoppingListener { +public interface EarlyStoppingListener { /**Method to be called when early stopping training is first started */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/saver/InMemoryModelSaver.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/saver/InMemoryModelSaver.java index 4e63ef0c5..b24b47651 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/saver/InMemoryModelSaver.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/saver/InMemoryModelSaver.java @@ -21,11 +21,11 @@ package org.deeplearning4j.earlystopping.saver; import org.deeplearning4j.earlystopping.EarlyStoppingModelSaver; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import java.io.IOException; -public class InMemoryModelSaver implements EarlyStoppingModelSaver { +public class InMemoryModelSaver implements EarlyStoppingModelSaver { private transient T bestModel; private transient T latestModel; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java index 0c70667dd..69f1785e4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/AutoencoderScoreCalculator.java @@ -20,9 +20,9 @@ package org.deeplearning4j.earlystopping.scorecalc; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -32,7 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -public class AutoencoderScoreCalculator extends BaseScoreCalculator { +public class AutoencoderScoreCalculator extends BaseScoreCalculator { protected final Metric metric; protected RegressionEvaluation evaluation; @@ -48,7 +48,7 @@ public class AutoencoderScoreCalculator extends BaseScoreCalculator { } @Override - protected INDArray output(Model net, INDArray input, INDArray fMask, INDArray lMask) { + protected INDArray output(IModel net, INDArray input, INDArray fMask, INDArray lMask) { Layer l; if(net instanceof MultiLayerNetwork) { @@ -71,19 +71,19 @@ public class AutoencoderScoreCalculator extends BaseScoreCalculator { } @Override - protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) { + protected INDArray[] output(IModel network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) { return new INDArray[]{output(network, get0(input), get0(fMask), get0(lMask))}; } @Override - protected double scoreMinibatch(Model network, INDArray features, INDArray labels, INDArray fMask, + protected double scoreMinibatch(IModel network, INDArray features, INDArray labels, INDArray fMask, INDArray lMask, INDArray output) { evaluation.eval(features, output); return 0.0; //Not used } @Override - protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) { + protected double scoreMinibatch(IModel network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) { return scoreMinibatch(network, get0(features), get0(labels), get0(fMask), get0(lMask), get0(output)); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ClassificationScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ClassificationScoreCalculator.java index ae13edc79..b9884f68f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ClassificationScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ClassificationScoreCalculator.java @@ -20,13 +20,13 @@ package org.deeplearning4j.earlystopping.scorecalc; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator; -import org.deeplearning4j.nn.api.Model; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -public class ClassificationScoreCalculator extends BaseIEvaluationScoreCalculator { +public class ClassificationScoreCalculator extends BaseIEvaluationScoreCalculator { protected final Evaluation.Metric metric; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculator.java index e8d403a7f..2f6199449 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/DataSetLossCalculator.java @@ -21,7 +21,7 @@ package org.deeplearning4j.earlystopping.scorecalc; import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,7 +31,7 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import com.fasterxml.jackson.annotation.JsonProperty; -public class DataSetLossCalculator extends BaseScoreCalculator { +public class DataSetLossCalculator extends BaseScoreCalculator { @JsonProperty private boolean average; @@ -70,12 +70,12 @@ public class DataSetLossCalculator extends BaseScoreCalculator { } @Override - protected INDArray output(Model network, INDArray input, INDArray fMask, INDArray lMask) { + protected INDArray output(IModel network, INDArray input, INDArray fMask, INDArray lMask) { return output(network, arr(input), arr(fMask), arr(lMask))[0]; } @Override - protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) { + protected INDArray[] output(IModel network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) { if(network instanceof MultiLayerNetwork){ INDArray out = ((MultiLayerNetwork) network).output(input[0], false, get0(fMask), get0(lMask)); return new INDArray[]{out}; @@ -87,7 +87,7 @@ public class DataSetLossCalculator extends BaseScoreCalculator { } @Override - protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) { + protected double scoreMinibatch(IModel network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) { if(network instanceof MultiLayerNetwork){ return ((MultiLayerNetwork) network).score(new DataSet(get0(features), get0(labels), get0(fMask), get0(lMask)), false) * features[0].size(0); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator.java index 27fdbd8aa..ca3e5ab1c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ROCScoreCalculator.java @@ -20,8 +20,8 @@ package org.deeplearning4j.earlystopping.scorecalc; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator; -import org.deeplearning4j.nn.api.Model; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROCBinary; @@ -29,7 +29,7 @@ import org.nd4j.evaluation.classification.ROCMultiClass; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -public class ROCScoreCalculator extends BaseIEvaluationScoreCalculator { +public class ROCScoreCalculator extends BaseIEvaluationScoreCalculator { public enum ROCType {ROC, BINARY, MULTICLASS} public enum Metric {AUC, AUPRC} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java index 5dab31e29..3ffd58a6a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/RegressionScoreCalculator.java @@ -20,13 +20,13 @@ package org.deeplearning4j.earlystopping.scorecalc; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.scorecalc.base.BaseIEvaluationScoreCalculator; -import org.deeplearning4j.nn.api.Model; import org.nd4j.evaluation.regression.RegressionEvaluation; import org.nd4j.evaluation.regression.RegressionEvaluation.Metric; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -public class RegressionScoreCalculator extends BaseIEvaluationScoreCalculator { +public class RegressionScoreCalculator extends BaseIEvaluationScoreCalculator { protected final Metric metric; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ScoreCalculator.java index 8e994a678..a9568d2d9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/ScoreCalculator.java @@ -20,7 +20,7 @@ package org.deeplearning4j.earlystopping.scorecalc; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonSubTypes; import com.fasterxml.jackson.annotation.JsonTypeInfo; @@ -34,7 +34,7 @@ import java.io.Serializable; @JsonSubTypes.Type(value = DataSetLossCalculatorCG.class, name = "MaxEpochsTerminationCondition"), }) -public interface ScoreCalculator extends Serializable { +public interface ScoreCalculator extends Serializable { /** Calculate the score for the given MultiLayerNetwork */ double calculateScore(T network); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java index 687eb9969..4b2f1eb9f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconErrorScoreCalculator.java @@ -20,9 +20,9 @@ package org.deeplearning4j.earlystopping.scorecalc; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -32,7 +32,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -public class VAEReconErrorScoreCalculator extends BaseScoreCalculator { +public class VAEReconErrorScoreCalculator extends BaseScoreCalculator { protected final Metric metric; protected RegressionEvaluation evaluation; @@ -54,7 +54,7 @@ public class VAEReconErrorScoreCalculator extends BaseScoreCalculator { } @Override - protected INDArray output(Model net, INDArray input, INDArray fMask, INDArray lMask) { + protected INDArray output(IModel net, INDArray input, INDArray fMask, INDArray lMask) { Layer l; if(net instanceof MultiLayerNetwork) { MultiLayerNetwork network = (MultiLayerNetwork)net; @@ -74,19 +74,19 @@ public class VAEReconErrorScoreCalculator extends BaseScoreCalculator { } @Override - protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) { + protected INDArray[] output(IModel network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) { return new INDArray[]{output(network, get0(input), get0(fMask), get0(lMask))}; } @Override - protected double scoreMinibatch(Model network, INDArray features, INDArray labels, INDArray fMask, + protected double scoreMinibatch(IModel network, INDArray features, INDArray labels, INDArray fMask, INDArray lMask, INDArray output) { evaluation.eval(features, output); return 0.0; //Not used } @Override - protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) { + protected double scoreMinibatch(IModel network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) { return scoreMinibatch(network, get0(features), get0(labels), get0(fMask), get0(lMask), get0(output)); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconProbScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconProbScoreCalculator.java index 0ed2aef4b..0328d7e66 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconProbScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/VAEReconProbScoreCalculator.java @@ -20,16 +20,16 @@ package org.deeplearning4j.earlystopping.scorecalc; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -public class VAEReconProbScoreCalculator extends BaseScoreCalculator { +public class VAEReconProbScoreCalculator extends BaseScoreCalculator { protected final int reconstructionProbNumSamples; protected final boolean logProb; @@ -73,17 +73,17 @@ public class VAEReconProbScoreCalculator extends BaseScoreCalculator { } @Override - protected INDArray output(Model network, INDArray input, INDArray fMask, INDArray lMask) { + protected INDArray output(IModel network, INDArray input, INDArray fMask, INDArray lMask) { return null; //Not used } @Override - protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) { + protected INDArray[] output(IModel network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) { return null; //Not used } @Override - protected double scoreMinibatch(Model net, INDArray features, INDArray labels, INDArray fMask, + protected double scoreMinibatch(IModel net, INDArray features, INDArray labels, INDArray fMask, INDArray lMask, INDArray output) { Layer l; if(net instanceof MultiLayerNetwork) { @@ -108,7 +108,7 @@ public class VAEReconProbScoreCalculator extends BaseScoreCalculator { } @Override - protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) { + protected double scoreMinibatch(IModel network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) { return 0; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/base/BaseIEvaluationScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/base/BaseIEvaluationScoreCalculator.java index 89dd780dc..7a064c151 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/base/BaseIEvaluationScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/base/BaseIEvaluationScoreCalculator.java @@ -22,7 +22,7 @@ package org.deeplearning4j.earlystopping.scorecalc.base; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.evaluation.IEvaluation; @@ -30,7 +30,7 @@ import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -public abstract class BaseIEvaluationScoreCalculator implements ScoreCalculator { +public abstract class BaseIEvaluationScoreCalculator implements ScoreCalculator { protected MultiDataSetIterator iterator; protected DataSetIterator iter; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/base/BaseScoreCalculator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/base/BaseScoreCalculator.java index d0407b2e9..ce01ebfcd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/base/BaseScoreCalculator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/scorecalc/base/BaseScoreCalculator.java @@ -22,14 +22,14 @@ package org.deeplearning4j.earlystopping.scorecalc.base; import lombok.NonNull; import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -public abstract class BaseScoreCalculator implements ScoreCalculator { +public abstract class BaseScoreCalculator implements ScoreCalculator { protected MultiDataSetIterator mdsIterator; protected DataSetIterator iterator; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java index 4d6ff7675..db65ca7bb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java @@ -20,13 +20,13 @@ package org.deeplearning4j.earlystopping.trainer; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingResult; import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener; import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition; import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.TrainingListener; @@ -47,7 +47,7 @@ import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Map; -public abstract class BaseEarlyStoppingTrainer implements IEarlyStoppingTrainer { +public abstract class BaseEarlyStoppingTrainer implements IEarlyStoppingTrainer { private static final Logger log = LoggerFactory.getLogger(BaseEarlyStoppingTrainer.class); @@ -337,7 +337,7 @@ public abstract class BaseEarlyStoppingTrainer implements IEarl } //Trigger epoch listener methods manually - these won't be triggered due to not calling fit(DataSetIterator) etc - protected void triggerEpochListeners(boolean epochStart, Model model, int epochNum){ + protected void triggerEpochListeners(boolean epochStart, IModel model, int epochNum){ Collection listeners; if(model instanceof MultiLayerNetwork){ MultiLayerNetwork n = ((MultiLayerNetwork) model); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingTrainer.java index f4df7a3d4..8c36c07d2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/EarlyStoppingTrainer.java @@ -25,8 +25,7 @@ import org.deeplearning4j.datasets.iterator.impl.SingletonDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.SingletonMultiDataSetIterator; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -39,7 +38,7 @@ public class EarlyStoppingTrainer extends BaseEarlyStoppingTrainer earlyStoppingConfiguration, - MultiLayerConfiguration configuration, DataSetIterator train) { + NeuralNetConfiguration configuration, DataSetIterator train) { this(earlyStoppingConfiguration, new MultiLayerNetwork(configuration), train); net.init(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/IEarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/IEarlyStoppingTrainer.java index fd86168c6..718e10d0d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/IEarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/IEarlyStoppingTrainer.java @@ -20,11 +20,11 @@ package org.deeplearning4j.earlystopping.trainer; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.EarlyStoppingResult; import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener; -import org.deeplearning4j.nn.api.Model; -public interface IEarlyStoppingTrainer { +public interface IEarlyStoppingTrainer { /** Conduct early stopping training */ EarlyStoppingResult fit(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java index 696e92bc2..d106f827f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java @@ -23,6 +23,7 @@ package org.deeplearning4j.gradientcheck; import lombok.*; import lombok.experimental.Accessors; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.common.function.Consumer; @@ -83,7 +84,7 @@ public class GradientCheckUtil { if(outputLayer instanceof BaseOutputLayer){ BaseOutputLayer o = (BaseOutputLayer)outputLayer; lfn = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)o.layerConf()).getLossFn(); - afn = o.layerConf().getActivationFn(); + afn = o.getLayerConfiguration().getActivationFn(); } else if(outputLayer instanceof LossLayer){ LossLayer o = (LossLayer) outputLayer; lfn = o.layerConf().getLossFn(); @@ -204,7 +205,7 @@ public class GradientCheckUtil { + "DataTypeUtil.setDTypeForContext(DataType.DOUBLE); before using GradientCheckUtil"); } - DataType netDataType = c.net.getLayerWiseConfigurations().getDataType(); + DataType netDataType = c.net.getNetConfiguration().getDataType(); if (netDataType != DataType.DOUBLE) { throw new IllegalStateException("Cannot perform gradient check: Network datatype is not set to double precision (" + "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil"); @@ -218,9 +219,9 @@ public class GradientCheckUtil { //Check network configuration: int layerCount = 0; - for (NeuralNetConfiguration n : c.net.getLayerWiseConfigurations().getConfs()) { - if (n.getLayer() instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) n.getLayer(); + for (LayerConfiguration n : c.net.getNetConfiguration().getFlattenedLayerConfigurations()) { + if (n instanceof BaseLayer) { + BaseLayer bl = (BaseLayer) n; IUpdater u = bl.getIUpdater(); if (u instanceof Sgd) { //Must have LR of 1.0 @@ -228,7 +229,7 @@ public class GradientCheckUtil { if (lr != 1.0) { throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" - + n.getLayer().getLayerName() + "\""); + + n.getLayerName() + "\""); } } else if (!(u instanceof NoOp)) { throw new IllegalStateException( @@ -238,7 +239,7 @@ public class GradientCheckUtil { IActivation activation = bl.getActivationFn(); if (activation != null) { if (!VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) { - log.warn("Layer " + layerCount + " is possibly using an unsuitable activation function: " + log.warn("LayerConfiguration " + layerCount + " is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not " + "contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)"); @@ -246,10 +247,10 @@ public class GradientCheckUtil { } } - if (n.getLayer().getIDropout() != null && c.callEachIter == null) { + if (n.getIDropout() != null && c.callEachIter == null) { throw new IllegalStateException("When gradient checking dropout, need to reset RNG seed each iter, or no" + " dropout should be present during gradient checks - got dropout = " - + n.getLayer().getIDropout() + " for layer " + layerCount); + + n.getIDropout() + " for layer " + layerCount); } } @@ -277,7 +278,7 @@ public class GradientCheckUtil { val nParams = originalParams.length(); - Map paramTable = c.net.paramTable(); + Map paramTable = c.net.getParamTable(); List paramNames = new ArrayList<>(paramTable.keySet()); val paramEnds = new long[paramNames.size()]; paramEnds[0] = paramTable.get(paramNames.get(0)).length(); @@ -306,8 +307,8 @@ public class GradientCheckUtil { if(c.print == PrintMode.ALL) { int i=0; for (Layer l : c.net.getLayers()) { - Set s = l.paramTable().keySet(); - log.info("Layer " + i + ": " + l.getClass().getSimpleName() + " - params " + s); + Set s = l.getParamTable().keySet(); + log.info("LayerConfiguration " + i + ": " + l.getClass().getSimpleName() + " - params " + s); i++; } } @@ -450,8 +451,8 @@ public class GradientCheckUtil { continue; LayerVertex lv = (LayerVertex) gv; - if (lv.getLayerConf().getLayer() instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) lv.getLayerConf().getLayer(); + if (lv.getLayerConfiguration() instanceof BaseLayer) { + BaseLayer bl = (BaseLayer) lv.getLayerConfiguration(); IUpdater u = bl.getIUpdater(); if (u instanceof Sgd) { //Must have LR of 1.0 @@ -459,7 +460,7 @@ public class GradientCheckUtil { if (lr != 1.0) { throw new IllegalStateException("When using SGD updater, must also use lr=1.0 for layer " + layerCount + "; got " + u + " with lr=" + lr + " for layer \"" - + lv.getLayerConf().getLayer().getLayerName() + "\""); + + lv.getLayerConfiguration().getLayerName() + "\""); } } else if (!(u instanceof NoOp)) { throw new IllegalStateException( @@ -469,7 +470,7 @@ public class GradientCheckUtil { IActivation activation = bl.getActivationFn(); if (activation != null) { if (!VALID_ACTIVATION_FUNCTIONS.contains(activation.getClass())) { - log.warn("Layer \"" + vertexName + "\" is possibly using an unsuitable activation function: " + log.warn("LayerConfiguration \"" + vertexName + "\" is possibly using an unsuitable activation function: " + activation.getClass() + ". Activation functions for gradient checks must be smooth (like sigmoid, tanh, softmax) and not " + "contain discontinuities like ReLU or LeakyReLU (these may cause spurious failures)"); @@ -477,10 +478,10 @@ public class GradientCheckUtil { } } - if (lv.getLayerConf().getLayer().getIDropout() != null && c.callEachIter == null) { + if (lv.getLayerConfiguration().getIDropout() != null && c.callEachIter == null) { throw new IllegalStateException("When gradient checking dropout, rng seed must be reset each iteration, or no" + " dropout should be present during gradient checks - got dropout = " - + lv.getLayerConf().getLayer().getIDropout() + " for layer " + layerCount); + + lv.getLayerConfiguration().getIDropout() + " for layer " + layerCount); } } @@ -513,7 +514,7 @@ public class GradientCheckUtil { val nParams = originalParams.length(); - Map paramTable = c.net.paramTable(); + Map paramTable = c.net.getParamTable(); List paramNames = new ArrayList<>(paramTable.keySet()); val paramEnds = new long[paramNames.size()]; paramEnds[0] = paramTable.get(paramNames.get(0)).length(); @@ -646,7 +647,7 @@ public class GradientCheckUtil { val nParams = originalParams.length(); - Map paramTable = layer.paramTable(); + Map paramTable = layer.getParamTable(); List paramNames = new ArrayList<>(paramTable.keySet()); val paramEnds = new long[paramNames.size()]; paramEnds[0] = paramTable.get(paramNames.get(0)).length(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/adapters/YoloModelAdapter.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/adapters/YoloModelAdapter.java index 57ec18aa1..ea435af20 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/adapters/YoloModelAdapter.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/adapters/YoloModelAdapter.java @@ -24,7 +24,7 @@ import lombok.AllArgsConstructor; import lombok.Builder; import lombok.NoArgsConstructor; import lombok.val; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.ModelAdapter; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.objdetect.DetectedObject; @@ -43,7 +43,7 @@ public class YoloModelAdapter implements ModelAdapter> { @Builder.Default private double detectionThreshold = 0.5; @Override - public List apply(Model model, INDArray[] inputs, INDArray[] masks, INDArray[] labelsMasks) { + public List apply(IModel model, INDArray[] inputs, INDArray[] masks, INDArray[] labelsMasks) { if (model instanceof ComputationGraph) { val blindLayer = ((ComputationGraph) model).getOutputLayer(outputLayerIndex); if (blindLayer instanceof Yolo2OutputLayer) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/AbstractParamInitializer.java similarity index 67% rename from cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/AbstractParamInitializer.java index 1ed923bda..d93c96448 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/AbstractLayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/AbstractParamInitializer.java @@ -19,17 +19,21 @@ * */ -package net.brutex.ai.dnn.conf.layer; +package org.deeplearning4j.nn.api; import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; -import lombok.experimental.SuperBuilder; -import net.brutex.ai.dnn.api.ILayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -@SuperBuilder -public abstract class AbstractLayerConfiguration implements ILayerConfiguration { +public abstract class AbstractParamInitializer implements ParamInitializer { + + @Deprecated + public long numParams(NeuralNetConfiguration conf) { + long res = 0; + for(LayerConfiguration lc : conf.getFlattenedLayerConfigurations()) { + res += lc.initializer().numParams(lc); + } + return res; + } - @Getter @Setter @NonNull - private String name; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Classifier.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Classifier.java index 3643297d3..631f1bed4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Classifier.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Classifier.java @@ -20,6 +20,7 @@ package org.deeplearning4j.nn.api; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; @@ -27,7 +28,7 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import java.util.List; -public interface Classifier extends Model { +public interface Classifier extends IModel { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java index e7500055f..41051df53 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java @@ -21,7 +21,11 @@ package org.deeplearning4j.nn.api; +import lombok.NonNull; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.CacheMode; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -42,7 +46,25 @@ import java.util.Collection; * * @see NVIDIA Deep Learning In A Nutshell */ -public interface Layer extends Serializable, Cloneable, Model, Trainable { +public interface Layer extends Serializable, Cloneable, Trainable, IModel { + + /** + * Return the configuration of this layer + * @return the configuration + */ + LayerConfiguration getLayerConfiguration(); + + /** + * Set a new layer configuration, new init() needs to be called afterwards. + * @param lconf layer configuration + */ + void setLayerConfiguration(LayerConfiguration lconf); + /** + * Convenient method to get the network configuration + * @return the configuration of the network this layer is part of + * + */ + NeuralNetConfiguration getNetConfiguration(); /** * This method sets given CacheMode for current layer @@ -107,23 +129,6 @@ public interface Layer extends Serializable, Cloneable, Model, Trainable { */ INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr mgr); - /** - * Get the iteration listeners for this layer. - */ - Collection getListeners(); - - /** - * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, - * they will be replaced by this method - */ - void setListeners(TrainingListener... listeners); - - /** - * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, - * they will be replaced by this method - */ - void setListeners(Collection listeners); - /** * Get the layer index. */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Model.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Model.java deleted file mode 100644 index 53107fdc5..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Model.java +++ /dev/null @@ -1,237 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nn.api; - -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.gradient.Gradient; -import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.deeplearning4j.optimize.api.ConvexOptimizer; -import org.deeplearning4j.optimize.api.TrainingListener; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.primitives.Pair; - -import java.util.Collection; -import java.util.Map; - -public interface Model { - - /** - * Init the model - */ - void init(); - - - /** - * Set the trainingListeners for the ComputationGraph (and all layers in the network) - */ - void setListeners(Collection listeners); - - - /** - * Set the trainingListeners for the ComputationGraph (and all layers in the network) - */ - void setListeners(TrainingListener... listeners); - - /** - * This method ADDS additional TrainingListener to existing listeners - * - * @param listener - */ - void addListeners(TrainingListener... listener); - - - /** - * All models have a fit method - */ - @Deprecated - void fit(); - - /** - * Update layer weights and biases with gradient change - */ - void update(Gradient gradient); - - /** - * Perform one update applying the gradient - * @param gradient the gradient to apply - */ - void update(INDArray gradient, String paramType); - - - /** - * The score for the model - * @return the score for the model - */ - double score(); - - - /** - * Update the score - */ - void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr); - - /** - * Parameters of the model (if any) - * @return the parameters of the model - */ - INDArray params(); - - /** - * the number of parameters for the model - * @return the number of parameters for the model - * - */ - long numParams(); - - - /** - * the number of parameters for the model - * @return the number of parameters for the model - * - */ - long numParams(boolean backwards); - - /** - * Set the parameters for this model. - * This expects a linear ndarray which then be unpacked internally - * relative to the expected ordering of the model - * @param params the parameters for the model - */ - void setParams(INDArray params); - - /** - * Set the initial parameters array as a view of the full (backprop) network parameters - * NOTE: this is intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users. - * @param params a 1 x nParams row vector that is a view of the larger (MLN/CG) parameters array - */ - void setParamsViewArray(INDArray params); - - - INDArray getGradientsViewArray(); - - /** - * Set the gradients array as a view of the full (backprop) network parameters - * NOTE: this is intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users. - * @param gradients a 1 x nParams row vector that is a view of the larger (MLN/CG) gradients array - */ - void setBackpropGradientsViewArray(INDArray gradients); - - /** - * Fit the model to the given data - * @param data the data to fit the model to - */ - void fit(INDArray data, LayerWorkspaceMgr workspaceMgr); - - - /** - * Get the gradient. Note that this method will not calculate the gradient, it will rather return the gradient - * that has been computed before. - * For calculating the gradient, see {@link Model#computeGradientAndScore(LayerWorkspaceMgr)} } . - * @return the gradient for this model, as calculated before - */ - Gradient gradient(); - - /** - * Get the gradient and score - * @return the gradient and score - */ - Pair gradientAndScore(); - - /** - * The current inputs batch size - * @return the current inputs batch size - */ - int batchSize(); - - - /** - * The configuration for the neural network - * @return the configuration for the neural network - */ - NeuralNetConfiguration conf(); - - /** - * Setter for the configuration - * @param conf - */ - void setConf(NeuralNetConfiguration conf); - - /** - * The input/feature matrix for the model - * @return the input/feature matrix for the model - */ - INDArray input(); - - /** - * Returns this models optimizer - * @return this models optimizer - */ - ConvexOptimizer getOptimizer(); - - /** - * Get the parameter - * @param param the key of the parameter - * @return the parameter vector/matrix with that particular key - */ - INDArray getParam(String param); - - /** - * The param table - * @return - */ - Map paramTable(); - - /** - * Table of parameters by key, for backprop - * For many models (dense layers, etc) - all parameters are backprop parameters - * @param backpropParamsOnly If true, return backprop params only. If false: return all params (equivalent to - * paramsTable()) - */ - Map paramTable(boolean backpropParamsOnly); - - /** - * Setter for the param table - * @param paramTable - */ - void setParamTable(Map paramTable); - - - /** - * Set the parameter with a new ndarray - * @param key the key to se t - * @param val the new ndarray - */ - void setParam(String key, INDArray val); - - /** - * Clear input - */ - void clear(); - - - /** - * Apply any constraints to the model - */ - void applyConstraints(int iteration, int epoch); - - - void close(); -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java index 01a60b73e..1f87ea69b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ModelAdapter.java @@ -20,6 +20,7 @@ package org.deeplearning4j.nn.api; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.adapters.OutputAdapter; import org.nd4j.linalg.api.ndarray.INDArray; @@ -28,5 +29,5 @@ public interface ModelAdapter extends OutputAdapter { * This method invokes model internally, and does conversion to T * @return */ - T apply(Model model, INDArray[] inputs, INDArray[] inputMasks, INDArray[] labelsMasks); + T apply(IModel model, INDArray[] inputs, INDArray[] inputMasks, INDArray[] labelsMasks); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ParamInitializer.java index 7b6483483..2505e05f8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ParamInitializer.java @@ -21,7 +21,7 @@ package org.deeplearning4j.nn.api; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.List; @@ -34,9 +34,9 @@ import java.util.Map; */ public interface ParamInitializer { - long numParams(NeuralNetConfiguration conf); - - long numParams(org.deeplearning4j.nn.conf.layers.Layer layer); + long numParams(LayerConfiguration layer); + @Deprecated + long numParams(NeuralNetConfiguration netConfiguration); /** * Get a list of all parameter keys given the layer configuration @@ -44,7 +44,7 @@ public interface ParamInitializer { * @param layer ILayer * @return All parameter keys */ - List paramKeys(org.deeplearning4j.nn.conf.layers.Layer layer); + List paramKeys(LayerConfiguration layer); /** * Weight parameter keys given the layer configuration @@ -52,7 +52,7 @@ public interface ParamInitializer { * @param layer ILayer * @return Weight parameter keys */ - List weightKeys(org.deeplearning4j.nn.conf.layers.Layer layer); + List weightKeys(LayerConfiguration layer); /** * Bias parameter keys given the layer configuration @@ -60,7 +60,7 @@ public interface ParamInitializer { * @param layer ILayer * @return Bias parameter keys */ - List biasKeys(org.deeplearning4j.nn.conf.layers.Layer layer); + List biasKeys(LayerConfiguration layer); /** * Is the specified parameter a weight? @@ -69,7 +69,7 @@ public interface ParamInitializer { * @param key Key to check * @return True if parameter is a weight */ - boolean isWeightParam(Layer layer, String key); + boolean isWeightParam(LayerConfiguration layer, String key); /** * Is the specified parameter a bias? @@ -78,18 +78,18 @@ public interface ParamInitializer { * @param key Key to check * @return True if parameter is a bias */ - boolean isBiasParam(Layer layer, String key); + boolean isBiasParam(LayerConfiguration layer, String key); /** * Initialize the parameters * - * @param conf the configuration + * @param conf the configuration of the layer * @param paramsView a view of the full network (backprop) parameters * @param initializeParams if true: initialize the parameters according to the configuration. If false: don't modify the * values in the paramsView array (but do select out the appropriate subset, reshape etc as required) * @return Map of parameters keyed by type (view of the 'paramsView' array) */ - Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams); + Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams); /** * Return a map of gradients (in their standard non-flattened representation), taken from the flattened (row vector) gradientView array. @@ -100,6 +100,6 @@ public interface ParamInitializer { * @param gradientView The flattened gradients array, as a view of the larger array * @return A map containing an array by parameter type, that is a view of the full network gradients array */ - Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView); + Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Trainable.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Trainable.java index f93e1c5ee..33f87a736 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Trainable.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Trainable.java @@ -42,10 +42,29 @@ public interface Trainable { INDArray params(); /** - * @param backpropOnly If true: return only parameters that are not exclusively used for layerwise pretraining - * @return Parameter table + * The param table + * + * @return */ - Map paramTable(boolean backpropOnly); + Map getParamTable(); + + /** + * Table of parameters by key, for backprop. For many models (dense layers, etc) - all parameters + * are backprop parameters + * + * @param backpropParamsOnly If true, return backprop params only. If false: return all params + * (equivalent to paramsTable()) + */ + Map getParamTable(boolean backpropParamsOnly); + + /** + * Setter for the param table + * + * @param paramTable + */ + void setParamTable(Map paramTable); + + /** * DL4J layers typically produce the sum of the gradients during the backward pass for each layer, and if required diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java index a4f73d3b0..61c50b161 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/layers/RecurrentLayer.java @@ -98,6 +98,4 @@ public interface RecurrentLayer extends Layer { * for standard BPTT. */ Pair tbpttBackpropGradient(INDArray epsilon, int tbpttBackLength, LayerWorkspaceMgr workspaceMgr); - - } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index 8fe4b99a3..afba61743 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -29,8 +29,7 @@ import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; -import org.deeplearning4j.nn.conf.layers.Layer; -import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -140,7 +139,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @return JSON representation of computation graph configuration */ public String toJson() { - //As per MultiLayerConfiguration.toJson() + //As per NeuralNetConfiguration.toJson() ObjectMapper mapper = NeuralNetConfiguration.mapper(); synchronized (mapper) { //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally @@ -160,7 +159,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @return {@link ComputationGraphConfiguration} */ public static ComputationGraphConfiguration fromJson(String json) { - //As per MultiLayerConfiguration.fromJson() + //As per NeuralNetConfiguration.fromJson() ObjectMapper mapper = NeuralNetConfiguration.mapper(); ComputationGraphConfiguration conf; try { @@ -171,7 +170,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format return JsonMappers.getLegacyMapper().readValue(json, ComputationGraphConfiguration.class); } catch (InvalidTypeIdException e2){ - //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." + //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.LayerConfiguration]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work String msg = e2.getMessage(); if(msg != null && msg.contains("Could not resolve type id")){ @@ -207,8 +206,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { } LayerVertex lv = (LayerVertex) entry.getValue(); - if (lv.getLayerConf() != null && lv.getLayerConf().getLayer() != null) { - Layer layer = lv.getLayerConf().getLayer(); + if (lv.getNetConfiguration() != null && lv.getLayerConfiguration() != null) { + LayerConfiguration layer = lv.getLayerConfiguration(); if (layer instanceof BaseLayer && ((BaseLayer) layer).getActivationFn() == null) { String layerName = layer.getLayerName(); @@ -240,7 +239,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { } } catch (IOException e) { - log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", + log.warn("LayerConfiguration with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e); } } @@ -257,7 +256,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * above. * @return True if all is well and layer iteration shall continue. False else-wise. */ - private static void handleLegacyWeightInitFromJson(String json, Layer layer, ObjectMapper mapper, JsonNode vertices) { + private static void handleLegacyWeightInitFromJson(String json, LayerConfiguration layer, ObjectMapper mapper, JsonNode vertices) { if (layer instanceof BaseLayer && ((BaseLayer) layer).getWeightInitFn() == null) { String layerName = layer.getLayerName(); @@ -294,7 +293,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { } } catch (IOException e) { - log.warn("Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", + log.warn("LayerConfiguration with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", e); } } @@ -331,7 +330,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { conf.trainingWorkspaceMode = trainingWorkspaceMode; conf.inferenceWorkspaceMode = inferenceWorkspaceMode; conf.cacheMode = this.cacheMode; - conf.defaultConfiguration.cacheMode = this.cacheMode; + conf.defaultConfiguration.setCacheMode(this.cacheMode); conf.validateOutputLayerConfig = this.validateOutputLayerConfig; conf.dataType = this.dataType; @@ -517,7 +516,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { inputTypeList.add(layerInput); LayerVertex lv = (LayerVertex) gv; - Layer l = lv.getLayerConf().getLayer(); + LayerConfiguration l = lv.getLayerConfiguration(); //Preprocessors - add if necessary if (lv.getPreProcessor() == null && addPreprocIfNecessary) { @@ -710,7 +709,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { protected Map inputPreProcessors = new LinkedHashMap<>(); - protected NeuralNetConfiguration.Builder globalConfiguration; + protected NeuralNetConfiguration globalConfiguration; protected boolean allowDisconnected = false; protected boolean allowNoOutput = false; @@ -719,11 +718,11 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { protected String lastAdded = null; - public GraphBuilder(NeuralNetConfiguration.Builder globalConfiguration) { - this.globalConfiguration = globalConfiguration; + public GraphBuilder(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfiguration) { + this.globalConfiguration = globalConfiguration.build(); } - public GraphBuilder(ComputationGraphConfiguration newConf, NeuralNetConfiguration.Builder globalConfiguration) { + public GraphBuilder(ComputationGraphConfiguration newConf, NeuralNetConfiguration globalConfiguration) { ComputationGraphConfiguration clonedConf = newConf.clone(); @@ -742,7 +741,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { /** * Specify the processors for a given layer * These are used at each layer for doing things like normalization and shaping of input.
- * Note: preprocessors can also be defined using the {@link #addLayer(String, Layer, InputPreProcessor, String...)} method. + * Note: preprocessors can also be defined using the {@link #addLayer(String, LayerConfiguration, InputPreProcessor, String...)} method. * * @param layer the name of the layer that this preprocessor will be used with * @param processor the preprocessor to use for the specified layer @@ -776,7 +775,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * * @param forwardLength Forward length > 0, >= backwardLength */ - public GraphBuilder tBPTTForwardLength(int forwardLength) { + public GraphBuilder tbpttFwdLength(int forwardLength) { this.tbpttFwdLength = forwardLength; return this; } @@ -789,7 +788,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * * @param backwardLength <= forwardLength */ - public GraphBuilder tBPTTBackwardLength(int backwardLength) { + public GraphBuilder tbpttBackLength(int backwardLength) { this.tbpttBackLength = backwardLength; return this; } @@ -802,8 +801,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @param tbpttLength length > 0 */ public GraphBuilder tBPTTLength(int tbpttLength){ - tBPTTForwardLength(tbpttLength); - return tBPTTBackwardLength(tbpttLength); + tbpttFwdLength(tbpttLength); + return tbpttBackLength(tbpttLength); } /** @@ -813,9 +812,9 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @param layer The layer configuration * @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects, * on a combination of the two. - * @see #addLayer(String, Layer, InputPreProcessor, String...) + * @see #addLayer(String, LayerConfiguration, InputPreProcessor, String...) */ - public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) { + public GraphBuilder addLayer(String layerName, LayerConfiguration layer, String... layerInputs) { return addLayer(layerName, layer, null, layerInputs); } @@ -825,9 +824,9 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * * @param layerName Name/label of the layer to add * @param layer The layer configuration - * @see #addLayer(String, Layer, InputPreProcessor, String...) + * @see #addLayer(String, LayerConfiguration, InputPreProcessor, String...) */ - public GraphBuilder appendLayer(String layerName, Layer layer) { + public GraphBuilder appendLayer(String layerName, LayerConfiguration layer) { return appendLayer(layerName, layer, null); } @@ -838,9 +837,9 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @param layer The layer configuration * @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects, * on a combination of the two. - * @see #addLayer(String, Layer, InputPreProcessor, String...) + * @see #addLayer(String, LayerConfiguration, InputPreProcessor, String...) */ - public GraphBuilder layer(int layerName, Layer layer, String... layerInputs) { + public GraphBuilder layer(int layerName, LayerConfiguration layer, String... layerInputs) { return addLayer(String.valueOf(layerName), layer, null, layerInputs); } @@ -851,9 +850,9 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @param layer The layer configuration * @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects, * on a combination of the two. - * @see #addLayer(String, Layer, InputPreProcessor, String...) + * @see #addLayer(String, LayerConfiguration, InputPreProcessor, String...) */ - public GraphBuilder layer(String layerName, Layer layer, String... layerInputs) { + public GraphBuilder layer(String layerName, LayerConfiguration layer, String... layerInputs) { return addLayer(layerName, layer, null, layerInputs); } @@ -866,11 +865,11 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects, * on a combination of the two. */ - public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor, + public GraphBuilder addLayer(String layerName, LayerConfiguration layer, InputPreProcessor preProcessor, String... layerInputs) { - NeuralNetConfiguration.Builder builder = globalConfiguration.clone(); - builder.layer(layer); - addVertex(layerName, new LayerVertex(builder.build(), preProcessor), layerInputs); + NeuralNetConfiguration conf = globalConfiguration.clone(); + conf.getLayerConfigurations().add(layer); + addVertex(layerName, new LayerVertex(conf, preProcessor), layerInputs); layer.setLayerName(layerName); return this; } @@ -883,7 +882,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @param layer The layer configuration * @param preProcessor The InputPreProcessor to use with this layer. */ - public GraphBuilder appendLayer(String layerName, Layer layer, InputPreProcessor preProcessor) { + public GraphBuilder appendLayer(String layerName, LayerConfiguration layer, InputPreProcessor preProcessor) { if(lastAdded == null){ throw new IllegalStateException("Can not use appendLayer with no previous layers"); @@ -902,7 +901,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @param layerInputs Inputs to this layer. Inputs may be other layers, GraphVertex objects, * on a combination of the two. */ - public GraphBuilder layer(String layerName, Layer layer, InputPreProcessor preProcessor, + public GraphBuilder layer(String layerName, LayerConfiguration layer, InputPreProcessor preProcessor, String... layerInputs) { return addLayer(layerName, layer, preProcessor, layerInputs); } @@ -1173,13 +1172,13 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { conf.vertices = this.vertices; conf.vertexInputs = this.vertexInputs; - conf.trainingWorkspaceMode = globalConfiguration.trainingWorkspaceMode; - conf.inferenceWorkspaceMode = globalConfiguration.inferenceWorkspaceMode; - conf.cacheMode = globalConfiguration.cacheMode; + conf.trainingWorkspaceMode = getGlobalConfiguration().getTrainingWorkspaceMode(); + conf.inferenceWorkspaceMode = getGlobalConfiguration().getInferenceWorkspaceMode(); + conf.cacheMode = globalConfiguration.getCacheMode(); conf.validateOutputLayerConfig = validateOutputConfig; - conf.dataType = globalConfiguration.dataType; + conf.dataType = globalConfiguration.getDataType(); - conf.defaultConfiguration = globalConfiguration.build(); + conf.defaultConfiguration = globalConfiguration; //Add preprocessors that were defined separately to the Layers to which they belong for (Map.Entry entry : inputPreProcessors.entrySet()) { @@ -1198,7 +1197,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { for (Map.Entry gv : vertices.entrySet()) { if (gv.getValue() instanceof LayerVertex) { LayerVertex lv = (LayerVertex) gv.getValue(); - Layer l = lv.getLayerConf().getLayer(); + LayerConfiguration l = lv.getLayerConfiguration(); } if (gv.getValue() instanceof SameDiffVertex) ((SameDiffVertex) gv.getValue()).applyGlobalConfig(globalConfiguration); @@ -1226,7 +1225,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { //Validate output layer configurations... for (Map.Entry e : conf.getVertices().entrySet()) { if (e.getValue() instanceof LayerVertex) { - Layer l = ((LayerVertex) e.getValue()).getLayerConf().getLayer(); + LayerConfiguration l = ((LayerVertex) e.getValue()).getLayerConfiguration(); OutputLayerUtil.validateOutputLayer(e.getKey(), l); //No-op for non output/loss layers } } @@ -1236,7 +1235,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { //Check for invalid combination - tbptt plus LastTimeStepLayer or for(Map.Entry e : vertices.entrySet()){ GraphVertex gv = e.getValue(); - Layer l = (gv instanceof LayerVertex ? ((LayerVertex)gv).getLayerConf().getLayer() : null); + LayerConfiguration l = (gv instanceof LayerVertex ? ((LayerVertex)gv).getLayerConfiguration() : null); if(gv instanceof LastTimeStepVertex || (l != null && (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer))){ String s = (l == null ? gv.getClass().getName() : l.getClass().getName()); String n = e.getKey(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java deleted file mode 100644 index 47baaebfd..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/MultiLayerConfiguration.java +++ /dev/null @@ -1,841 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - - -package org.deeplearning4j.nn.conf; - -import lombok.*; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.conf.distribution.Distribution; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.*; -import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; -import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; -import org.deeplearning4j.nn.conf.serde.JsonMappers; -import org.deeplearning4j.nn.weights.IWeightInit; -import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.util.OutputLayerUtil; -import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.lossfunctions.LossFunctions; -import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; -import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; -import org.nd4j.linalg.lossfunctions.impl.LossMSE; -import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; -import com.fasterxml.jackson.databind.node.ArrayNode; - -import java.io.IOException; -import java.io.Serializable; -import java.util.*; - -/** - * Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of - * multiple layers. Everything starts with a MultiLayerConfiguration, which organizes those layers - * and their hyperparameters. Hyperparameters are variables that determine how a neural network - * learns. They include how many times to update the weights of the model, how to initialize those - * weights, which activation function to attach to the nodes, which optimization algorithm to use, - * and how fast the model should learn. This is what one configuration would look like: - *

- * - * MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
- * .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)
- * .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- * .updater(new Sgd(0.05)) //... other hyperparameters
- * .list() .backprop(true)
- * .build();

- * - * With Deeplearning4j, you add a layer - * by calling layer on the NeuralNetConfiguration.Builder(), specifying its place in the order of - * layers (the zero-indexed layer below is the input layer), the number of input and output nodes, - * nIn and nOut, as well as the type: DenseLayer.

- * - * .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
- * .build())

- * - * Once you've configured your net, you train the - * model with model.fit. - */ -@Data -@AllArgsConstructor(access = AccessLevel.PRIVATE) -@NoArgsConstructor -@Slf4j -public class MultiLayerConfiguration implements Serializable, Cloneable { - - protected List confs; - protected Map inputPreProcessors = new HashMap<>(); - protected BackpropType backpropType = BackpropType.Standard; - protected int tbpttFwdLength = 20; - protected int tbpttBackLength = 20; - protected boolean validateOutputLayerConfig = true; //Default to legacy for pre 1.0.0-beta3 networks on deserialization - - @Getter - @Setter - protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; - - @Getter - @Setter - protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; - - @Getter - @Setter - protected CacheMode cacheMode; - - @Getter - @Setter - protected DataType dataType = DataType.FLOAT; //Default to float for deserialization of beta3 and earlier nets - - //Counter for the number of parameter updates so far - // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted - // for Spark and model serialization - protected int iterationCount = 0; - - //Counter for the number of epochs completed so far. Used for per-epoch schedules - protected int epochCount = 0; - - /** - * Create a neural net configuration from json - * - * @param json the neural net configuration from json - * @return {@link MultiLayerConfiguration} - */ - public static MultiLayerConfiguration fromYaml(String json) { - ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); - try { - return mapper.readValue(json, MultiLayerConfiguration.class); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - /** - * Create a neural net configuration from json - * - * @param json the neural net configuration from json - * @return {@link MultiLayerConfiguration} - */ - public static MultiLayerConfiguration fromJson(String json) { - MultiLayerConfiguration conf; - ObjectMapper mapper = NeuralNetConfiguration.mapper(); - try { - conf = mapper.readValue(json, MultiLayerConfiguration.class); - } catch (InvalidTypeIdException e) { - if (e.getMessage().contains("@class")) { - try { - //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format - return JsonMappers.getLegacyMapper().readValue(json, MultiLayerConfiguration.class); - } catch (InvalidTypeIdException e2) { - //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.Layer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." - //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work - String msg = e2.getMessage(); - if (msg != null && msg.contains("Could not resolve type id")) { - throw new RuntimeException( - "Error deserializing MultiLayerConfiguration - configuration may have a custom " + - "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" - + - " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", - e); - } - throw new RuntimeException(e2); - } catch (IOException e2) { - throw new RuntimeException(e2); - } - } - throw new RuntimeException(e); - } catch (IOException e) { - //Check if this exception came from legacy deserializer... - String msg = e.getMessage(); - if (msg != null && msg.contains("legacy")) { - throw new RuntimeException( - "Error deserializing MultiLayerConfiguration - configuration may have a custom " + - "layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " - + - "deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", - e); - } - throw new RuntimeException(e); - } - - //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier) - // Previously: enumeration used for loss functions. Now: use classes - // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums - int layerCount = 0; - JsonNode confs = null; - for (NeuralNetConfiguration nnc : conf.getConfs()) { - Layer l = nnc.getLayer(); - if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) { - //lossFn field null -> may be an old config format, with lossFunction field being for the enum - //if so, try walking the JSON graph to extract out the appropriate enum value - - BaseOutputLayer ol = (BaseOutputLayer) l; - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) { - return conf; //Should never happen... - } - JsonNode outputLayerNode = outputLayerNNCNode.get("layer"); - - JsonNode lossFunctionNode = null; - if (outputLayerNode.has("output")) { - lossFunctionNode = outputLayerNode.get("output").get("lossFunction"); - } else if (outputLayerNode.has("rnnoutput")) { - lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction"); - } - - if (lossFunctionNode != null) { - String lossFunctionEnumStr = lossFunctionNode.asText(); - LossFunctions.LossFunction lossFunction = null; - try { - lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr); - } catch (Exception e) { - log.warn( - "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", - e); - } - - if (lossFunction != null) { - switch (lossFunction) { - case MSE: - ol.setLossFn(new LossMSE()); - break; - case XENT: - ol.setLossFn(new LossBinaryXENT()); - break; - case NEGATIVELOGLIKELIHOOD: - ol.setLossFn(new LossNegativeLogLikelihood()); - break; - case MCXENT: - ol.setLossFn(new LossMCXENT()); - break; - - //Remaining: TODO - case SQUARED_LOSS: - case RECONSTRUCTION_CROSSENTROPY: - default: - log.warn( - "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", - lossFunction); - break; - } - } - } - - } else { - log.warn( - "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", - (confs != null ? confs.getClass() : null)); - } - } catch (IOException e) { - log.warn( - "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", - e); - break; - } - } - - //Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn") - //Try to load the old format if necessary, and create the appropriate IActivation instance - if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) { - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) { - return conf; //Should never happen... - } - JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); - - if (layerWrapperNode == null || layerWrapperNode.size() != 1) { - continue; - } - - JsonNode layerNode = layerWrapperNode.elements().next(); - JsonNode activationFunction = layerNode.get( - "activationFunction"); //Should only have 1 element: "dense", "output", etc - - if (activationFunction != null) { - IActivation ia = Activation.fromString(activationFunction.asText()) - .getActivationFunction(); - ((BaseLayer) l).setActivationFn(ia); - } - } - - } catch (IOException e) { - log.warn( - "Layer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", - e); - } - } - - if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) { - return conf; - } - - layerCount++; - } - return conf; - } - - /** - * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied - * from handling of {@link Activation} above. - * - * @return True if all is well and layer iteration shall continue. False else-wise. - */ - private static boolean handleLegacyWeightInitFromJson(String json, Layer l, ObjectMapper mapper, - JsonNode confs, int layerCount) { - if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { - try { - JsonNode jsonNode = mapper.readTree(json); - if (confs == null) { - confs = jsonNode.get("confs"); - } - if (confs instanceof ArrayNode) { - ArrayNode layerConfs = (ArrayNode) confs; - JsonNode outputLayerNNCNode = layerConfs.get(layerCount); - if (outputLayerNNCNode == null) { - return false; //Should never happen... - } - JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); - - if (layerWrapperNode == null || layerWrapperNode.size() != 1) { - return true; - } - - JsonNode layerNode = layerWrapperNode.elements().next(); - JsonNode weightInit = layerNode.get( - "weightInit"); //Should only have 1 element: "dense", "output", etc - JsonNode distribution = layerNode.get("dist"); - - Distribution dist = null; - if (distribution != null) { - dist = mapper.treeToValue(distribution, Distribution.class); - } - - if (weightInit != null) { - final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) - .getWeightInitFunction(dist); - ((BaseLayer) l).setWeightInitFn(wi); - } - } - - } catch (IOException e) { - log.warn( - "Layer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON", - e); - } - } - return true; - - } - - public int getEpochCount() { - return epochCount; - } - - public void setEpochCount(int epochCount) { - this.epochCount = epochCount; - for (int i = 0; i < confs.size(); i++) { - getConf(i).setEpochCount(epochCount); - } - } - - /** - * @return JSON representation of NN configuration - */ - public String toYaml() { - ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); - synchronized (mapper) { - try { - return mapper.writeValueAsString(this); - } catch (com.fasterxml.jackson.core.JsonProcessingException e) { - throw new RuntimeException(e); - } - } - } - - /** - * @return JSON representation of NN configuration - */ - public String toJson() { - ObjectMapper mapper = NeuralNetConfiguration.mapper(); - synchronized (mapper) { - //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally - //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 - try { - return mapper.writeValueAsString(this); - } catch (com.fasterxml.jackson.core.JsonProcessingException e) { - throw new RuntimeException(e); - } - } - } - - @Override - public String toString() { - return toJson(); - } - - public NeuralNetConfiguration getConf(int i) { - return confs.get(i); - } - - @Override - public MultiLayerConfiguration clone() { - try { - MultiLayerConfiguration clone = (MultiLayerConfiguration) super.clone(); - - if (clone.confs != null) { - List list = new ArrayList<>(); - for (NeuralNetConfiguration conf : clone.confs) { - list.add(conf.clone()); - } - clone.confs = list; - } - - if (clone.inputPreProcessors != null) { - Map map = new HashMap<>(); - for (Map.Entry entry : clone.inputPreProcessors.entrySet()) { - map.put(entry.getKey(), entry.getValue().clone()); - } - clone.inputPreProcessors = map; - } - - clone.inferenceWorkspaceMode = this.inferenceWorkspaceMode; - clone.trainingWorkspaceMode = this.trainingWorkspaceMode; - clone.cacheMode = this.cacheMode; - clone.validateOutputLayerConfig = this.validateOutputLayerConfig; - clone.dataType = this.dataType; - - return clone; - - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } - } - - public InputPreProcessor getInputPreProcess(int curr) { - return inputPreProcessors.get(curr); - } - - /** - * Get a {@link MemoryReport} for the given MultiLayerConfiguration. This is used to estimate the - * memory requirements for the given network configuration and input - * - * @param inputType Input types for the network - * @return Memory report for the network - */ - public NetworkMemoryReport getMemoryReport(InputType inputType) { - - Map memoryReportMap = new LinkedHashMap<>(); - int nLayers = confs.size(); - for (int i = 0; i < nLayers; i++) { - String layerName = confs.get(i).getLayer().getLayerName(); - if (layerName == null) { - layerName = String.valueOf(i); - } - - //Pass input type through preprocessor, if necessary - InputPreProcessor preproc = getInputPreProcess(i); - //TODO memory requirements for preprocessor - if (preproc != null) { - inputType = preproc.getOutputType(inputType); - } - - LayerMemoryReport report = confs.get(i).getLayer().getMemoryReport(inputType); - memoryReportMap.put(layerName, report); - - inputType = confs.get(i).getLayer().getOutputType(i, inputType); - } - - return new NetworkMemoryReport(memoryReportMap, MultiLayerConfiguration.class, - "MultiLayerNetwork", inputType); - } - - /** - * For the given input shape/type for the network, return a list of activation sizes for each - * layer in the network.
i.e., list.get(i) is the output activation sizes for layer i - * - * @param inputType Input type for the network - * @return A lits of activation types for the network, indexed by layer number - */ - public List getLayerActivationTypes(@NonNull InputType inputType) { - List out = new ArrayList<>(); - int nLayers = confs.size(); - for (int i = 0; i < nLayers; i++) { - InputPreProcessor preproc = getInputPreProcess(i); - if (preproc != null) { - inputType = preproc.getOutputType(inputType); - } - - inputType = confs.get(i).getLayer().getOutputType(i, inputType); - out.add(inputType); - } - return out; - } - - @Data - public static class Builder { - - private static final int DEFAULT_TBPTT_LENGTH = 20; - - protected List confs = new ArrayList<>(); - protected double dampingFactor = 100; - protected Map inputPreProcessors = new HashMap<>(); - protected BackpropType backpropType = BackpropType.Standard; - protected int tbpttFwdLength = DEFAULT_TBPTT_LENGTH; - protected int tbpttBackLength = DEFAULT_TBPTT_LENGTH; - protected InputType inputType; - - protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; - protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; - protected CacheMode cacheMode = CacheMode.NONE; - protected boolean validateOutputConfig = true; - protected boolean validateTbpttConfig = true; - protected DataType dataType; - protected boolean overrideNinUponBuild = true; - - - /** - * Whether to over ride the nIn configuration forcibly upon construction. Default value is true - * - * @param overrideNinUponBuild Whether to over ride the nIn configuration forcibly upon - * construction. - * @return builder pattern - */ - public Builder overrideNinUponBuild(boolean overrideNinUponBuild) { - this.overrideNinUponBuild = overrideNinUponBuild; - return this; - } - - /** - * Specify the processors. These are used at each layer for doing things like normalization and - * shaping of input. - * - * @param processor what to use to preProcess the data. - * @return builder pattern - */ - public Builder inputPreProcessor(Integer layer, InputPreProcessor processor) { - inputPreProcessors.put(layer, processor); - return this; - } - - public Builder inputPreProcessor(String layer, InputPreProcessor processor) { - int i = 0; - for (NeuralNetConfiguration conf : this.confs) { - if (conf.getLayer().getLayerName().equals(layer)) { - inputPreProcessors.put(i, processor); - log.trace("Assigned preProcessor to layer with name {} at index {}", layer, i); - break; - } - i++; - } - if (i >= this.confs.size()) { - log.warn("Could not assign preprocessor to layer with name {} as layer was not found.", - layer); - } - return this; - } - - public Builder inputPreProcessors(Map processors) { - this.inputPreProcessors = processors; - return this; - } - - /** - * @deprecated Use {@link NeuralNetConfiguration.Builder#trainingWorkspaceMode(WorkspaceMode)} - */ - @Deprecated - public Builder trainingWorkspaceMode(@NonNull WorkspaceMode workspaceMode) { - this.trainingWorkspaceMode = workspaceMode; - return this; - } - - /** - * @deprecated Use {@link NeuralNetConfiguration.Builder#inferenceWorkspaceMode(WorkspaceMode)} - */ - @Deprecated - public Builder inferenceWorkspaceMode(@NonNull WorkspaceMode workspaceMode) { - this.inferenceWorkspaceMode = workspaceMode; - return this; - } - - /** - * This method defines how/if preOutput cache is handled: NONE: cache disabled (default value) - * HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect will - * be the same as for HOST) - * - * @param cacheMode - * @return - */ - public Builder cacheMode(@NonNull CacheMode cacheMode) { - this.cacheMode = cacheMode; - return this; - } - - /** - * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but - * optionally truncated BPTT can be used for training recurrent neural networks. If using - * TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() - */ - public Builder backpropType(@NonNull BackpropType type) { - this.backpropType = type; - return this; - } - - /** - * When doing truncated BPTT: how many steps should we do?
Only applicable when doing - * backpropType(BackpropType.TruncatedBPTT)
See: http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param bpttLength length > 0 - */ - public Builder tBPTTLength(int bpttLength) { - tBPTTForwardLength(bpttLength); - return tBPTTBackwardLength(bpttLength); - } - - /** - * When doing truncated BPTT: how many steps of forward pass should we do before doing - * (truncated) backprop?
Only applicable when doing - * backpropType(BackpropType.TruncatedBPTT)
Typically tBPTTForwardLength parameter is same - * as the tBPTTBackwardLength parameter, but may be larger than it in some circumstances (but - * never smaller)
Ideally your training data time series length should be divisible by this - * This is the k1 parameter on pg23 of - * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param forwardLength Forward length > 0, >= backwardLength - */ - public Builder tBPTTForwardLength(int forwardLength) { - this.tbpttFwdLength = forwardLength; - return this; - } - - /** - * When doing truncated BPTT: how many steps of backward should we do?
Only applicable when - * doing backpropType(BackpropType.TruncatedBPTT)
This is the k2 parameter on pg23 of - * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param backwardLength <= forwardLength - */ - public Builder tBPTTBackwardLength(int backwardLength) { - this.tbpttBackLength = backwardLength; - return this; - } - - public Builder confs(List confs) { - this.confs = confs; - return this; - } - - public Builder setInputType(InputType inputType) { - this.inputType = inputType; - return this; - } - - /** - * Enabled by default. If enabled, the output layer configuration will be validated, to throw an - * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
If - * disabled (false) no output layer validation will be performed.
Disabling this validation - * is not recommended, as the configurations that fail validation usually will not be able to - * learn correctly. However, the option to disable this validation is provided for advanced - * users when creating non-standard architectures. - * - * @param validate If true: validate output layer configuration. False: don't validate - */ - public Builder validateOutputLayerConfig(boolean validate) { - this.validateOutputConfig = validate; - return this; - } - - /** - * Enabled by default. If enabled, an exception will be throw when using the (invalid) - * combination of truncated backpropagation through time (TBPTT) with either a - * GlobalPoolingLayer or LastTimeStepLayer.
It is possible to disable this validation to - * allow what is almost certainly an invalid configuration to be used, however this is not - * recommended. - * - * @param validate Whether TBPTT validation should be performed - */ - public Builder validateTbpttConfig(boolean validate) { - this.validateTbpttConfig = validate; - return this; - } - - /** - * Set the DataType for the network parameters and activations for all layers in the network. - * Default: Float - * - * @param dataType Datatype to use for parameters and activations - */ - public Builder dataType(@NonNull DataType dataType) { - this.dataType = dataType; - return this; - } - - - public MultiLayerConfiguration build() { - //Validate BackpropType setting - if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH) - && backpropType != BackpropType.TruncatedBPTT) { - log.warn("Truncated backpropagation through time lengths have been configured with values " - + tbpttFwdLength - + " and " + tbpttBackLength + " but backprop type is set to " + backpropType - + ". TBPTT configuration" + - " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT"); - } - - if (backpropType == BackpropType.TruncatedBPTT && validateTbpttConfig) { - //Check for invalid combination - tbptt plus LastTimeStepLayer or - for (int i = 0; i < confs.size(); i++) { - Layer l = confs.get(i).getLayer(); - if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) { - throw new IllegalStateException( - "Invalid network configuration detected: Truncated backpropagation through time (TBPTT)" - + - " cannot be used with layer " + i + " of type " + l.getClass().getName() - + ": TBPTT is incompatible with this layer type (which is designed " + - "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n" - + - "This check can be disabled using validateTbpttConfig(false) but this is not recommended."); - } - } - } - - if (inputType == null && inputPreProcessors.get(0) == null) { - //User hasn't set the InputType. Sometimes we can infer it... - // For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in - // standard feedforward or RNN data - //This isn't the most elegant implementation, but should avoid breaking backward compatibility here - //Can't infer InputType for CNN layers, however (don't know image dimensions/depth) - Layer firstLayer = confs.get(0).getLayer(); - if (firstLayer instanceof BaseRecurrentLayer) { - BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer; - val nIn = brl.getNIn(); - if (nIn > 0) { - inputType = InputType.recurrent(nIn, brl.getRnnDataFormat()); - } - } else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer - || firstLayer instanceof OutputLayer) { - //Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a FeedForwardLayer - FeedForwardLayer ffl = (FeedForwardLayer) firstLayer; - val nIn = ffl.getNIn(); - if (nIn > 0) { - inputType = InputType.feedForward(nIn); - } - } - } - - //Add preprocessors and set nIns, if InputType has been set - // Builder.inputType field can be set in 1 of 4 ways: - // 1. User calls setInputType directly - // 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...)) - // 3. Via the above code: i.e., assume input is as expected by the RNN or dense layer -> sets the inputType field - if (inputType != null) { - InputType currentInputType = inputType; - for (int i = 0; i < confs.size(); i++) { - Layer l = confs.get(i).getLayer(); - if (inputPreProcessors.get(i) == null) { - //Don't override preprocessor setting, but set preprocessor if required... - InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType); - if (inputPreProcessor != null) { - inputPreProcessors.put(i, inputPreProcessor); - } - } - - InputPreProcessor inputPreProcessor = inputPreProcessors.get(i); - if (inputPreProcessor != null) { - currentInputType = inputPreProcessor.getOutputType(currentInputType); - } - if (i > 0) { - Layer layer = confs.get(i - 1).getLayer(); - //convolution 1d is an edge case where it has rnn input type but the filters - //should be the output - if (layer instanceof Convolution1DLayer) { - if (l instanceof DenseLayer && inputType instanceof InputType.InputTypeRecurrent) { - FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l; - if (inputType instanceof InputType.InputTypeRecurrent) { - InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType; - feedForwardLayer.setNIn(recurrent.getTimeSeriesLength()); - } - } else { - l.setNIn(currentInputType, - overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user - } - } else { - l.setNIn(currentInputType, - overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user - } - - } else { - l.setNIn(currentInputType, - overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user - } - - currentInputType = l.getOutputType(i, currentInputType); - } - - } - - MultiLayerConfiguration conf = new MultiLayerConfiguration(); - conf.confs = this.confs; - conf.inputPreProcessors = inputPreProcessors; - conf.backpropType = backpropType; - conf.tbpttFwdLength = tbpttFwdLength; - conf.tbpttBackLength = tbpttBackLength; - conf.trainingWorkspaceMode = trainingWorkspaceMode; - conf.inferenceWorkspaceMode = inferenceWorkspaceMode; - conf.cacheMode = cacheMode; - conf.dataType = dataType; - - Nd4j.getRandom().setSeed(conf.getConf(0).getSeed()); - - //Validate output layer configuration - if (validateOutputConfig) { - //Validate output layer configurations... - for (NeuralNetConfiguration n : conf.getConfs()) { - Layer l = n.getLayer(); - OutputLayerUtil.validateOutputLayer(l.getLayerName(), - l); //No-op for non output/loss layers - } - } - - return conf; - - } - } -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java new file mode 100644 index 000000000..8ff512612 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java @@ -0,0 +1,1021 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.nn.conf; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ArrayNode; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import lombok.Builder; +import lombok.Data; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.NonNull; +import lombok.Setter; +import lombok.Singular; +import lombok.experimental.SuperBuilder; +import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; +import org.deeplearning4j.nn.api.OptimizationAlgorithm; +import org.deeplearning4j.nn.api.layers.LayerConstraint; +import org.deeplearning4j.nn.conf.distribution.Distribution; +import org.deeplearning4j.nn.conf.dropout.Dropout; +import org.deeplearning4j.nn.conf.dropout.IDropout; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; +import org.deeplearning4j.nn.conf.serde.JsonMappers; +import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; +import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; +import org.deeplearning4j.nn.weights.IWeightInit; +import org.deeplearning4j.nn.weights.WeightInit; +import org.deeplearning4j.nn.weights.WeightInitDistribution; +import org.deeplearning4j.nn.weights.WeightInitXavier; +import org.deeplearning4j.util.NetworkUtils; +import org.nd4j.common.base.Preconditions; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.activations.impl.ActivationSigmoid; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.config.Sgd; +import org.nd4j.linalg.learning.regularization.L1Regularization; +import org.nd4j.linalg.learning.regularization.L2Regularization; +import org.nd4j.linalg.learning.regularization.Regularization; +import org.nd4j.linalg.learning.regularization.WeightDecay; + +/** + * Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of + * multiple layers. Everything starts with a NeuralNetConfiguration, which organizes those layers + * and their hyperparameters. Hyperparameters are variables that determine how a neural network + * learns. They include how many times to update the weights of the model, how to initialize those + * weights, which activation function to attach to the nodes, which optimization algorithm to use, + * and how fast the model should learn. This is what one configuration would look like: + *

+ * + * NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
+ * .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)
+ * .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ * .updater(new Sgd(0.05)) //... other hyperparameters
+ * .backprop(true)
+ * .build();

+ * + * With Deeplearning4j, you add a layer + * by calling layer on the NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of + * layers (the zero-indexed layer below is the input layer), the number of input and output nodes, + * nIn and nOut, as well as the type: DenseLayer.

+ * + * .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
+ * .build())

+ * + * Once you've configured your net, you train the + * model with model.fit. + */ + +@Data +@Slf4j +@EqualsAndHashCode(exclude = {"iterationCount", "epochCount"}) +@JsonIgnoreProperties(ignoreUnknown = true) +//The inner builder, that we can then extend ... +@SuperBuilder //TODO fix access +public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetworkConfiguration { + + private static final int DEFAULT_TBPTT_LENGTH = 20; + + + /** + * Set constraints to be applied to all layers. Default: no constraints.
+ * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param constraints Constraints to apply to all weight parameters of all layers + */ + @lombok.Builder.Default + protected final List contrainWeights = new ArrayList<>(); + + + + + /** + * Set constraints to be applied to all layers. Default: no constraints.
+ * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param constraints Constraints to apply to all bias parameters of all layers + */ + @lombok.Builder.Default + protected final List biasConstraints = new ArrayList<>(); + /** + * Set constraints to be applied to all layers. Default: no constraints.
+ * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param constraints Constraints to apply to all parameters of all layers + */ + @lombok.Builder.Default + protected final List allParamContraints = new ArrayList<>(); + /** + * This is a basic concept, a neural network is made of layers, but also can use + * another neural network as a building block. When the configuration is initialized, those + * building blocks will be flattened into a single list of layers. + * Internal ordered list of layers and inner neural networks. If the object is a NeuralNetConfiguration, + * each configuration must contain at least one layer. + */ + @Getter @lombok.Builder.Default + protected final List innerConfigurations = new ArrayList<>(); + @Getter + @Setter + @NonNull + @lombok.Builder.Default + @Deprecated + protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; + @Getter + @Setter + @NonNull + @lombok.Builder.Default + @Deprecated + protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; + /** + * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but optionally + * truncated BPTT can be used for training recurrent neural networks. If using TruncatedBPTT make + * sure you set both tBPTTForwardLength() and tBPTTBackwardLength() + */ + @Getter + @Setter + @NonNull + @lombok.Builder.Default + protected BackpropType backpropType = BackpropType.Standard; + @Getter + @lombok.Builder.Default + protected Map inputPreProcessors = new HashMap<>(); + /** + * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated) + * backprop?
Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
Typically + * tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, but may be larger + * than it in some circumstances (but never smaller)
Ideally your training data time series + * length should be divisible by this This is the k1 parameter on pg23 of http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param forwardLength Forward length > 0, >= backwardLength + */ + @Getter + @Setter + @lombok.Builder.Default + protected int tbpttFwdLength = 20; + /** + * When doing truncated BPTT: how many steps of backward should we do?
Only applicable when + * doing backpropType(BackpropType.TruncatedBPTT)
This is the k2 parameter on pg23 of http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param backwardLength <= forwardLength + */ + @Getter + @Setter + @lombok.Builder.Default + protected int tbpttBackLength = 20; + //Counter for the number of parameter updates so far + // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted + // for Spark and model serialization + @Getter + @Setter + @lombok.Builder.Default + protected int iterationCount = 0; + //Counter for the number of epochs completed so far. Used for per-epoch schedules + @Getter + @Setter + @lombok.Builder.Default + protected int epochCount = 0; + @lombok.Builder.Default + protected double dampingFactor = 100; + //gradient keys used for ensuring order when getting and setting the gradient + //@lombok.Builder.Default + //protected List variables = new ArrayList<>(); + @Getter + @Setter + @lombok.Builder.Default + private boolean miniBatch = false; + /** + * A seed for this network, will be random if not specified. + */ + @Getter + @Setter + @lombok.Builder.Default + private long seed = new Random().nextLong(); + /** + * The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified + * otherwise. This method defines how/if preOutput cache is handled: NONE: cache disabled (default + * value) HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect + * will be the same as for HOST) + *

+ * Valid values are
CacheMode.NONE,
CacheMode.HOST or
CacheMode.DEVICE
+ * + * @param cacheMode + */ + @NonNull + @Getter + @Setter + @lombok.Builder.Default + private CacheMode cacheMode = CacheMode.NONE; + + /** + * The name for this configuration. Defaults to "Anonymous INeuralNetworkConfiguration" if it is + * not specified. + */ + @lombok.Builder.Default + @Getter + private String name = "Anonymous INeuralNetworkConfiguration"; + /** + * The {@link InputType} of the data for this network configuration + */ + @Getter + @Setter + private InputType inputType; + /** + * Set the DataType for the network parameters and activations for all layers in the network. + * Default: Float + * + * @param dataType Datatype to use for parameters and activations + */ + @Getter + @Setter + @lombok.Builder.Default + @NonNull + private DataType dataType = DataType.FLOAT; + /** + * Whether to override the nIn configuration forcibly upon construction. Default value is true. + * + * @return builder pattern + */ + @Getter + @Setter + @lombok.Builder.Default + private boolean overrideNinUponBuild = true; + /** + * Enabled by default. If enabled, the output layer configuration will be validated, to throw an + * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
If + * disabled (false) no output layer validation will be performed.
Disabling this validation is + * not recommended, as the configurations that fail validation usually will not be able to learn + * correctly. However, the option to disable this validation is provided for advanced users when + * creating non-standard architectures. + * + * @param validate If true: validate output layer configuration. False: don't validate + */ + @Getter + @Setter + @lombok.Builder.Default + private boolean validateOutputLayerConfig = true; + /** + * Enabled by default. If enabled, an exception will be throw when using the (invalid) combination + * of truncated backpropagation through time (TBPTT) with either a GlobalPoolingLayer or + * LastTimeStepLayer.
It is possible to disable this validation to allow what is almost + * certainly an invalid configuration to be used, however this is not recommended. + * + * @param validate Whether TBPTT validation should be performed + */ + @Getter + @Setter + @lombok.Builder.Default + private boolean validateTbpttConfig = true; + /** + * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or + * {@link org.nd4j.linalg.learning.config.Nesterovs}
Note: values set by this method will be + * applied to all applicable layers in the network, unless a different value is explicitly set on + * a given layer. In other words: values set via this method are used as the default value, and + * can be overridden on a per-layer basis. + * + * @param updater Updater to use + */ + @Getter + @Setter + private IUpdater updater; + /** + * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping + * etc. See {@link GradientNormalization} for details
Note: values set by this method will be + * applied to all applicable layers in the network, unless a different value is explicitly set on + * a given layer. In other words: values set via this method are used as the default value, and + * can be overridden on a per-layer basis. + * + * @param gradientNormalization Type of normalization to use. Defaults to None. + * @see GradientNormalization + */ + @Getter + @Setter + @NonNull + @lombok.Builder.Default + private GradientNormalization gradientNormalization = GradientNormalization.None; + /** + * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, + * GradientNormalization.ClipL2PerParamType, and + * GradientNormalization.ClipElementWiseAbsoluteValue
Not used otherwise.
L2 threshold for + * first two types of clipping, or absolute value threshold for last type of clipping.
Note: + * values set by this method will be applied to all applicable layers in the network, unless a + * different value is explicitly set on a given layer. In other words: values set via this method + * are used as the default value, and can be overridden on a per-layer basis. + */ + @Getter + @Setter + private double gradientNormalizationThreshold; + /** + * Activation function / neuron non-linearity
Note: values set by this method will be applied + * to all applicable layers in the network, unless a different value is explicitly set on a given + * layer. In other words: values set via this method are used as the default value, and can be + * overridden on a per-layer basis. + */ + @Getter + @Setter + private IActivation activation; + //whether to constrain the gradient to unit norm or not + @Getter + @Setter + private StepFunction stepFunction; + @Getter + @Setter + @lombok.Builder.Default + private OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; + @Getter + @Setter + @lombok.Builder.Default + private int maxNumLineSearchIterations = 5; + /** + * Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis.
+ * + * @param regularization Regularization to apply for the network parameters/weights (excluding biases) + */ + @Getter + @lombok.Builder.Default + private List regularization = new ArrayList<>(); + /** + * Set the regularization for the biases only - for example {@link WeightDecay}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis.
+ * + * @param regularizationBias Regularization to apply for the network biases only + */ + @Getter + @lombok.Builder.Default + private List regularizationBias = new ArrayList<>(); + @Getter + @Setter + @lombok.Builder.Default + private IUpdater iUpdater = new Sgd(); + /** + * Gradient updater configuration, for the biases only. If not set, biases will use the updater as + * set by {@link #setIUpdater(IUpdater)}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param updater Updater to use for bias parameters + */ + @Getter + @Setter + @lombok.Builder.Default + private IUpdater biasUpdater = null; + @Getter + @Setter + @lombok.Builder.Default + private IActivation activationFn = new ActivationSigmoid(); + /** + * Weight initialization scheme to use, for initial weight values Note: values set by this method + * will be applied to all applicable layers in the network, unless a different value is explicitly + * set on a given layer. In other words: values set via this method are used as the default value, + * and can be overridden on a per-layer basis. + */ + @Getter + @Setter + @lombok.Builder.Default + private IWeightInit weightInitFn = new WeightInitXavier(); + /** + * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. + * See {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * @param convolutionMode Convolution mode to use + */ + @Getter + @Setter + @lombok.Builder.Default + private ConvolutionMode convolutionMode = ConvolutionMode.Truncate; + /** + * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN. + * See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. + *
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * @param cudnnAlgoMode cuDNN algo mode to use + */ + @Getter + @Setter + @lombok.Builder.Default + private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST; + @Getter + @Setter + @lombok.Builder.Default + private boolean minimize = true; + /** + * Set the dropout for all layers in this network
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * * Dropout probability. This is the probability of retaining each input activation value for a layer. + * * dropOut(x) will keep an input activation with probability x, and set to 0 with probability 1-x.
+ * * dropOut(0.0) is a special value / special case - when set to 0.0., dropout is disabled (not applied). Note + * * that a dropout value of 1.0 is functionally equivalent to no dropout: i.e., 100% probability of retaining + * * each input activation.
+ * *

+ * * Note 1: Dropout is applied at training time only - and is automatically not applied at test time + * * (for evaluation, etc)
+ * * Note 2: This sets the probability per-layer. Care should be taken when setting lower values for + * * complex networks (too much information may be lost with aggressive (very low) dropout values).
+ * * Note 3: Frequently, dropout is not applied to (or, has higher retain probability for) input (first layer) + * * layers. Dropout is also often not applied to output layers. This needs to be handled MANUALLY by the user + * * - set .dropout(0) on those layers when using global dropout setting.
+ * * Note 4: Implementation detail (most users can ignore): DL4J uses inverted dropout, as described here: + * * http://cs231n.github.io/neural-networks-2/ + * *

+ * *
+ * * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * * value, and can be overridden on a per-layer basis. + * * + * * @param inputRetainProbability Dropout probability (probability of retaining each input activation value for a layer) + * * @see #dropOut(IDropout) + * + * + * @param dropout Dropout, such as {@link Dropout}, {@link org.deeplearning4j.nn.conf.dropout.GaussianDropout}, + * {@link org.deeplearning4j.nn.conf.dropout.GaussianNoise} etc + * @return + */ + @Getter + @Setter + private IDropout idropOut; + /** + * Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and + * {@link org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) for the layers in this network.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param weightNoise Weight noise instance to use + */ + @Getter + @Setter + private IWeightNoise weightNoise; + @Getter + @Setter + @lombok.Builder.Default + private double biasInit = 0.0; + @Getter + @Setter + @lombok.Builder.Default + private double gainInit = 1.0; + + /** + * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied + * from handling of {@link Activation} above. + * + * @return True if all is well and layer iteration shall continue. False else-wise. + */ + private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l, + ObjectMapper mapper, + JsonNode confs, int layerCount) { + if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { + try { + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + return false; //Should never happen... + } + JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); + + if (layerWrapperNode == null || layerWrapperNode.size() != 1) { + return true; + } + + JsonNode layerNode = layerWrapperNode.elements().next(); + JsonNode weightInit = layerNode.get( + "weightInit"); //Should only have 1 element: "dense", "output", etc + JsonNode distribution = layerNode.get("dist"); + + Distribution dist = null; + if (distribution != null) { + dist = mapper.treeToValue(distribution, Distribution.class); + } + + if (weightInit != null) { + final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) + .getWeightInitFunction(dist); + ((BaseLayer) l).setWeightInitFn(wi); + } + } + + } catch (IOException e) { + log.warn( + "ILayer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON", + e); + } + } + return true; + + } + + /** + * Object mapper for serialization of configurations + * + * @return + */ + public static ObjectMapper mapperYaml() { + return JsonMappers.getMapperYaml(); + } + + /** + * Object mapper for serialization of configurations + * + * @return + */ + public static ObjectMapper mapper() { + return JsonMappers.getMapper(); + } + + public static NeuralNetBaseBuilderConfiguration fromYaml(String input) { + throw new RuntimeException("Needs fixing - not supported."); //TODO + } + + + /** + * @return JSON representation of NN configuration + */ + public String toYaml() { + ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapperYaml(); + synchronized (mapper) { + try { + return mapper.writeValueAsString(this); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + /** + * @return JSON representation of NN configuration + */ + public String toJson() { + ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapper(); + synchronized (mapper) { + //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally + //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 + try { + return mapper.writeValueAsString(this); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + public abstract static class NeuralNetBaseBuilderConfigurationBuilder + > { + + List innerConfigurations$value = new ArrayList<>(); //initialize with an empty list + + /** + * Set constraints to be applied to all layers. Default: no constraints.
+ * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param constraints Constraints to apply to all weight parameters of all layers + */ + public B constrainWeights(LayerConstraint... constraints) { + contrainWeights$value = Arrays.asList(constraints); + contrainWeights$set = true; + return (B) this; + } + + /** + * For the (perhaps partially constructed) network configuration, return a list of activation sizes for each + * layer in the network.
+ * Note: To use this method, the network input type must have been set using {@link #setInputType(InputType)} first + * @return A list of activation types for the network, indexed by layer number + */ + public List getLayerActivationTypes(){ + Preconditions.checkState(inputType != null, "Can only calculate activation types if input type has" + + "been set. Use setInputType(InputType)"); + + + throw new RuntimeException("Error calculating layer activation types: error instantiating MultiLayerConfiguration"); + + } + + + /** + * Set constraints to be applied to all layers. Default: no constraints.
+ * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param constraints Constraints to apply to all parameters of all layers + */ + public B constrainAllParameters(LayerConstraint... constraints){ + allParamContraints$value = Arrays.asList(constraints); + allParamContraints$set = true; + return (B) this; + } + + /** + * Set constraints to be applied to all layers. Default: no constraints.
+ * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param constraints Constraints to apply to all bias parameters of all layers + */ + public B constrainBias(LayerConstraint... constraints) { + biasConstraints$value = Arrays.asList(constraints); + biasConstraints$set = true; + return (B) this; + } + + /** + * Specify the processors. These are used at each layer for doing things like normalization and + * shaping of input. + * + * @param processor what to use to preProcess the data. + * @return builder pattern + */ + public B inputPreProcessor(Integer layer, + InputPreProcessor processor) { + inputPreProcessors$value.put(layer, processor); + inputPreProcessors$set = true; + return (B) this; + } + + + /** + * Set layer at index + * + * @param index where to insert + * @param layer the layer + * @return builder + */ + public B layer(Integer index, @NonNull LayerConfiguration layer) { + innerConfigurations$value.add(index, layer); + innerConfigurations$set = true; + return (B) this; + } + + /** + * Add a layer + * + * @param layer the layer + * @return builder + */ + public B layer(@NonNull LayerConfiguration layer) { + innerConfigurations$value.add(layer); + innerConfigurations$set = true; + return (B) this; + } + + //TODO this is a dirty workaround + public boolean isOverrideNinUponBuild() { + return isOverrideNinUponBuild(); + } + + /** + * Specify additional layer configurations + */ + @Deprecated + public B layersFromArray(@NonNull LayerConfiguration[] arrLayers) { + innerConfigurations$value.addAll(List.of(arrLayers)); + innerConfigurations$set = true; + return (B) this; + } + + /** + * Specify additional layer configurations + */ + @Deprecated + public B layersFromList(@NonNull List listLayers) { + innerConfigurations$value.addAll(listLayers); + innerConfigurations$set = true; + return (B) this; + } + + + /** + * L1 regularization coefficient for the weights (excluding biases).
Note: values set by + * this method will be applied to all applicable layers in the network, unless a different value + * is explicitly set on a given layer. In other words: values set via this method are used as + * the default value, and can be overridden on a per-layer basis. + */ + public B l1(double l1) { + //Check if existing L1 exists; if so, replace it + NetworkUtils.removeInstances(regularization$value, L1Regularization.class); + if (l1 > 0.0) { + regularization$value.add(new L1Regularization(l1)); + } + regularization$set = true; + return (B) this; + } + + /** + * L2 regularization coefficient for the weights (excluding biases).
+ * Note: Generally, {@link WeightDecay} (set via {@link #weightDecay(double)} should be + * preferred to + * L2 regularization. See {@link WeightDecay} javadoc for further details.
Note: values set + * by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used + * as the default value, and can be overridden on a per-layer basis.
Note: L2 regularization + * and weight decay usually should not be used together; if any weight decay (or L2) has been + * added for the biases, these will be removed first. + * + * @see #weightDecay(double, boolean) + */ + public B l2(double l2) { + //Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make sense to use both + NetworkUtils.removeInstances(regularization$value, L2Regularization.class); + if (l2 > 0.0) { + NetworkUtils.removeInstancesWithWarning(regularization$value, WeightDecay.class, + "WeightDecay regularization removed: incompatible with added L2 regularization"); + regularization$value.add(new L2Regularization(l2)); + } + regularization$set = true; + return (B) this; + } + + /** + * L1 regularization coefficient for the bias.
Note: values set by this method will be + * applied to all applicable layers in the network, unless a different value is explicitly set + * on a given layer. In other words: values set via this method are used as the default value, + * and can be overridden on a per-layer basis. + */ + public B l1Bias(double l1Bias) { + NetworkUtils.removeInstances(regularizationBias$value, L1Regularization.class); + if (l1Bias > 0.0) { + regularizationBias$value.add(new L1Regularization(l1Bias)); + } + regularizationBias$set = true; + return (B) this; + } + + /** + * L2 regularization coefficient for the bias.
+ * Note: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double, boolean)} + * should be preferred to + * L2 regularization. See {@link WeightDecay} javadoc for further details.
Note: values set + * by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used + * as the default value, and can be overridden on a per-layer basis.
Note: L2 regularization + * and weight decay usually should not be used together; if any weight decay (or L2) has been + * added for the biases, these will be removed first. + * + * @see #weightDecayBias(double, boolean) + */ + public B l2Bias(double l2Bias) { + NetworkUtils.removeInstances(regularizationBias$value, L2Regularization.class); + if (l2Bias > 0.0) { + NetworkUtils.removeInstancesWithWarning(regularizationBias$value, WeightDecay.class, + "L2 bias regularization removed: incompatible with added WeightDecay regularization"); + regularizationBias$value.add(new L2Regularization(l2Bias)); + } + return (B) this; + } + + /** + * Add weight decay regularization for the network parameters (excluding biases).
This + * applies weight decay with multiplying the learning rate - see {@link WeightDecay} for + * more details.
Note: values set by this method will be applied to all applicable layers in + * the network, unless a different value is explicitly set on a given layer. In other words: + * values set via this method are used as the default value, and can be overridden on a + * per-layer basis.
+ * + * @param coefficient Weight decay regularization coefficient + * @see #weightDecay(double, boolean) + */ + public B weightDecay(double coefficient) { + return weightDecay(coefficient, true); + } + + /** + * Add weight decay regularization for the network parameters (excluding biases). See + * {@link WeightDecay} for more details.
Note: values set by this method will be applied to + * all applicable layers in the network, unless a different value is explicitly set on a given + * layer. In other words: values set via this method are used as the default value, and can be + * overridden on a per-layer basis.
+ * + * @param coefficient Weight decay regularization coefficient + * @param applyLR Whether the learning rate should be multiplied in when performing weight + * decay updates. See {@link WeightDecay} for more details. + * @see #weightDecay(double, boolean) + */ + public B weightDecay(double coefficient, boolean applyLR) { + //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both + NetworkUtils.removeInstances(regularization$value, WeightDecay.class); + if (coefficient > 0.0) { + NetworkUtils.removeInstancesWithWarning(regularization$value, L2Regularization.class, + "L2 regularization removed: incompatible with added WeightDecay regularization"); + regularization$value.add(new WeightDecay(coefficient, applyLR)); + } + regularization$set = true; + return (B) this; + } + + /** + * Weight decay for the biases only - see {@link #weightDecay(double)} for more details. This + * applies weight decay with multiplying the learning rate.
Note: values set by this + * method will be applied to all applicable layers in the network, unless a different value is + * explicitly set on a given layer. In other words: values set via this method are used as the + * default value, and can be overridden on a per-layer basis.
+ * + * @param coefficient Weight decay regularization coefficient + * @see #weightDecayBias(double, boolean) + */ + public B weightDecayBias(double coefficient) { + return weightDecayBias(coefficient, true); + } + + /** + * Weight decay for the biases only - see {@link #weightDecay(double)} for more details
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis.
+ * + * @param coefficient Weight decay regularization coefficient + */ + public B weightDecayBias(double coefficient, boolean applyLR) { + //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both + NetworkUtils.removeInstances(regularizationBias$value, WeightDecay.class); + if (coefficient > 0) { + NetworkUtils.removeInstancesWithWarning(regularizationBias$value, L2Regularization.class, + "L2 bias regularization removed: incompatible with added WeightDecay regularization"); + regularizationBias$value.add(new WeightDecay(coefficient, applyLR)); + } + regularization$set = true; + return (B) this; + } + + /** + * Activation function / neuron non-linearity
Note: values set by this method will be + * applied to all applicable layers in the network, unless a different value is explicitly set + * on a given layer. In other words: values set via this method are used as the default value, + * and can be overridden on a per-layer basis. + */ + @Deprecated + public B activation(@NonNull Activation activation) { + return (B) activationFn(activation.getActivationFunction()); + } + + + + @Deprecated + public B weightInit(@NonNull WeightInit wi) { + return (B) weightInitFn(wi.getWeightInitFunction()); + } + + /** + * legacy code, does nothing + * @return + */ + @Deprecated + public B list() { + return (B) this; + } + + + /** + * Set weight initialization scheme to random sampling via the specified distribution. + * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))} Note: values set + * by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used + * as the default value, and can be overridden on a per-layer basis. + * + * @param distribution Distribution to use for weight initialization + */ + public B weightInit(@NonNull Distribution distribution) { + return (B) weightInitFn(new WeightInitDistribution(distribution)); + } + + /** + * Same as {@link #weightInit(Distribution)}. + * @param distribution + * @return + */ + public B dist(@NonNull Distribution distribution) { + return (B) weightInit(distribution); + } + + public B dropOut(@NonNull IDropout dropout) { + return (B) idropOut(dropout); + } + + /** + * Creates a new {@link Dropout} and sets the dropout in the builder for this configuration + * @param dropout activationRetainProbability + * @return builder + */ + public B dropOut( double dropout) { + return (B) idropOut( new Dropout(dropout)); + } + + /** + * Add multiple inner neural net configurations at once + * @param confs list of configurations + * @return builder + */ + @Deprecated + public B confs(@NonNull List confs) { + innerConfigurations$value.addAll(confs); + innerConfigurations$set=true; + return (B) this; + } + } + + @Override + public NeuralNetBaseBuilderConfiguration clone() { + NeuralNetBaseBuilderConfiguration clone; + try { + clone = (NeuralNetBaseBuilderConfiguration) super.clone(); + } catch(CloneNotSupportedException ex) { + throw new RuntimeException(ex); + } + if (clone.stepFunction != null) { + clone.stepFunction = clone.stepFunction.clone(); + } + /** + if (clone.variables != null) { + clone.variables = new ArrayList<>(clone.variables); + } + **/ + + clone.getInnerConfigurations().addAll(innerConfigurations); + + if (clone.getInputPreProcessors() != null) { + Map map = new HashMap<>(); + for (Map.Entry entry : clone.getInputPreProcessors().entrySet()) { + map.put(entry.getKey(), entry.getValue().clone()); + } + clone.getInputPreProcessors().clear(); + clone.getInputPreProcessors().putAll(map); + } + + clone.setInferenceWorkspaceMode(this.inferenceWorkspaceMode); + clone.setTrainingWorkspaceMode(this.trainingWorkspaceMode); + clone.setCacheMode(this.cacheMode); + clone.setValidateOutputLayerConfig(this.validateOutputLayerConfig); + clone.setDataType(this.dataType); + + return clone; + + } +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index f44a8f3ab..5c221222c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -20,33 +20,54 @@ package org.deeplearning4j.nn.conf; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; +import com.fasterxml.jackson.databind.node.ArrayNode; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import lombok.Data; import lombok.EqualsAndHashCode; -import lombok.NoArgsConstructor; +import lombok.Getter; import lombok.NonNull; +import lombok.Setter; +import lombok.experimental.SuperBuilder; +import lombok.extern.jackson.Jacksonized; import lombok.extern.slf4j.Slf4j; -import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; +import lombok.val; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.*; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; -import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; -import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; +import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; +import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.EmbeddingLayer; +import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.memory.MemoryReport; +import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.nn.weights.WeightInitDistribution; -import org.deeplearning4j.nn.weights.WeightInitXavier; -import org.deeplearning4j.util.NetworkUtils; -import org.nd4j.common.base.Preconditions; +import org.deeplearning4j.util.OutputLayerUtil; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationSigmoid; @@ -54,1168 +75,1074 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.config.Sgd; -import org.nd4j.linalg.learning.regularization.L1Regularization; -import org.nd4j.linalg.learning.regularization.L2Regularization; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.WeightDecay; -import com.fasterxml.jackson.databind.ObjectMapper; - -import java.io.IOException; -import java.io.Serializable; -import java.util.*; +import org.nd4j.linalg.lossfunctions.LossFunctions; +import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; +import org.nd4j.linalg.lossfunctions.impl.LossMCXENT; +import org.nd4j.linalg.lossfunctions.impl.LossMSE; +import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; +/** + * Deeplearning4j is a domain-specific language to configure deep neural networks, which are made of + * multiple layers. Everything starts with a NeuralNetConfiguration, which organizes those layers + * and their hyperparameters. Hyperparameters are variables that determine how a neural network + * learns. They include how many times to update the weights of the model, how to initialize those + * weights, which activation function to attach to the nodes, which optimization algorithm to use, + * and how fast the model should learn. This is what one configuration would look like: + *

+ * + * NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
+ * .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)
+ * .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ * .updater(new Sgd(0.05)) //... other hyperparameters
+ * .backprop(true)
+ * .build();

+ * + * With Deeplearning4j, you add a layer + * by calling layer on the NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of + * layers (the zero-indexed layer below is the input layer), the number of input and output nodes, + * nIn and nOut, as well as the type: DenseLayer.

+ * + * .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
+ * .build())

+ * + * Once you've configured your net, you train the + * model with model.fit. + */ @Data -@NoArgsConstructor @Slf4j @EqualsAndHashCode(exclude = {"iterationCount", "epochCount"}) -public class NeuralNetConfiguration implements Serializable, Cloneable, - INeuralNetworkConfiguration { +@Jacksonized +@JsonIgnoreProperties(ignoreUnknown = true) +//The inner builder, that we can then extend ... +@SuperBuilder //TODO fix access +public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { + private static final int DEFAULT_TBPTT_LENGTH = 20; - protected Layer layer; - //batch size: primarily used for conv nets. Will be reinforced if set. - protected boolean miniBatch = true; - //number of line search iterations - protected int maxNumLineSearchIterations; - protected long seed; - protected OptimizationAlgorithm optimizationAlgo; - //gradient keys used for ensuring order when getting and setting the gradient - protected List variables = new ArrayList<>(); - //whether to constrain the gradient to unit norm or not - protected StepFunction stepFunction; - //minimize or maximize objective - protected boolean minimize = true; + /** + * Set constraints to be applied to all layers. Default: no constraints.
+ * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param constraints Constraints to apply to all bias parameters of all layers + */ + @lombok.Builder.Default + protected final List biasConstraints = new ArrayList<>(); + /** + * Set constraints to be applied to all layers. Default: no constraints.
+ * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param constraints Constraints to apply to all parameters of all layers + */ + @lombok.Builder.Default + protected final List allParamContraints = new ArrayList<>(); - // this field defines preOutput cache - protected CacheMode cacheMode; + @Getter + @Setter + @NonNull + @lombok.Builder.Default + @Deprecated + protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; + @Getter + @Setter + @NonNull + @lombok.Builder.Default + @Deprecated + protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; + /** + * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but optionally + * truncated BPTT can be used for training recurrent neural networks. If using TruncatedBPTT make + * sure you set both tBPTTForwardLength() and tBPTTBackwardLength() + */ + @Getter + @Setter + @NonNull + @lombok.Builder.Default + protected BackpropType backpropType = BackpropType.Standard; + /** + * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated) + * backprop?
Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
Typically + * tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, but may be larger + * than it in some circumstances (but never smaller)
Ideally your training data time series + * length should be divisible by this This is the k1 parameter on pg23 of http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param forwardLength Forward length > 0, >= backwardLength + */ + @Getter + @Setter + @lombok.Builder.Default + protected int tbpttFwdLength = 20; + /** + * When doing truncated BPTT: how many steps of backward should we do?
Only applicable when + * doing backpropType(BackpropType.TruncatedBPTT)
This is the k2 parameter on pg23 of http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param backwardLength <= forwardLength + */ + @Getter + @Setter + @lombok.Builder.Default + protected int tbpttBackLength = 20; + /** + * Creates and returns a copy of this object. + * + * @return a clone of this instance. + * @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable} + * interface. Subclasses that override the {@code clone} method can also throw this exception to + * indicate that an instance cannot be cloned. + * @see Cloneable + */ - protected DataType dataType = DataType.FLOAT; //Default to float for deserialization of legacy format nets + //Nd4j.getRandom().setSeed(getConf(0).getSeed()); //TODO + //Counter for the number of parameter updates so far + // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted + // for Spark and model serialization + @Getter + @Setter + @lombok.Builder.Default + protected int iterationCount = 0; + //Counter for the number of epochs completed so far. Used for per-epoch schedules + @Getter + @Setter + @lombok.Builder.Default + protected int epochCount = 0; + @lombok.Builder.Default + protected double dampingFactor = 100; + //gradient keys used for ensuring order when getting and setting the gradient + @lombok.Builder.Default + protected List netWideVariables = new ArrayList<>(); + @Getter + @Setter + @lombok.Builder.Default + private boolean miniBatch = false; + /** + * A seed for this network, will be random if not specified. - //Counter for the number of parameter updates so far for this layer. - //Note that this is only used for pretrain layers (AE, VAE) - MultiLayerConfiguration and ComputationGraphConfiguration - //contain counters for standard backprop training. - // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted - // for Spark and model serialization - protected int iterationCount = 0; + @Getter + @Setter + @lombok.Builder.Default + private long seed = new Random().nextLong(); */ + /** + * The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified + * otherwise. This method defines how/if preOutput cache is handled: NONE: cache disabled (default + * value) HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect + * will be the same as for HOST) + *

+ * Valid values are
CacheMode.NONE,
CacheMode.HOST or
CacheMode.DEVICE
+ * + * @param cacheMode + */ + @NonNull + @Getter + @Setter + @lombok.Builder.Default + private CacheMode cacheMode = CacheMode.NONE; + /** + * The list of layer configurations in this configuration. They will be indexed automatically as + * the layers get added starting with index 0. + */ - //Counter for the number of epochs completed so far. Used for per-epoch schedules - protected int epochCount = 0; + @lombok.Builder.Default + @Getter + private String name = "Anonymous INeuralNetworkConfiguration"; + /** + * The {@link InputType} of the data for this network configuration + */ + @Getter + @Setter + private InputType inputType; + /** + * Set the DataType for the network parameters and activations for all layers in the network. + * Default: Float + * + * @param dataType Datatype to use for parameters and activations + */ + @Getter + @Setter + @lombok.Builder.Default + @NonNull + private DataType dataType = DataType.FLOAT; + /** + * Whether to override the nIn configuration forcibly upon construction. Default value is true. + * + * @return builder pattern + */ + @Getter + @Setter + @lombok.Builder.Default + private boolean overrideNinUponBuild = true; + /** + * Enabled by default. If enabled, the output layer configuration will be validated, to throw an + * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
If + * disabled (false) no output layer validation will be performed.
Disabling this validation is + * not recommended, as the configurations that fail validation usually will not be able to learn + * correctly. However, the option to disable this validation is provided for advanced users when + * creating non-standard architectures. + * + * @param validate If true: validate output layer configuration. False: don't validate + */ + @Getter + @Setter + @lombok.Builder.Default + private boolean validateOutputLayerConfig = true; + /** + * Enabled by default. If enabled, an exception will be throw when using the (invalid) combination + * of truncated backpropagation through time (TBPTT) with either a GlobalPoolingLayer or + * LastTimeStepLayer.
It is possible to disable this validation to allow what is almost + * certainly an invalid configuration to be used, however this is not recommended. + * + * @param validate Whether TBPTT validation should be performed + */ + @Getter + @Setter + @lombok.Builder.Default + private boolean validateTbpttConfig = true; + /** + * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or + * {@link org.nd4j.linalg.learning.config.Nesterovs}
Note: values set by this method will be + * applied to all applicable layers in the network, unless a different value is explicitly set on + * a given layer. In other words: values set via this method are used as the default value, and + * can be overridden on a per-layer basis. + * + * @param updater Updater to use + */ + @Getter + @Setter + private IUpdater updater; + /** + * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping + * etc. See {@link GradientNormalization} for details
Note: values set by this method will be + * applied to all applicable layers in the network, unless a different value is explicitly set on + * a given layer. In other words: values set via this method are used as the default value, and + * can be overridden on a per-layer basis. + * + * @param gradientNormalization Type of normalization to use. Defaults to None. + * @see GradientNormalization + */ + @Getter + @Setter + @NonNull + @lombok.Builder.Default + private GradientNormalization gradientNormalization = GradientNormalization.None; + /** + * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, + * GradientNormalization.ClipL2PerParamType, and + * GradientNormalization.ClipElementWiseAbsoluteValue
Not used otherwise.
L2 threshold for + * first two types of clipping, or absolute value threshold for last type of clipping.
Note: + * values set by this method will be applied to all applicable layers in the network, unless a + * different value is explicitly set on a given layer. In other words: values set via this method + * are used as the default value, and can be overridden on a per-layer basis. + */ + @Getter + @Setter + private double gradientNormalizationThreshold; + /** + * Activation function / neuron non-linearity
Note: values set by this method will be applied + * to all applicable layers in the network, unless a different value is explicitly set on a given + * layer. In other words: values set via this method are used as the default value, and can be + * overridden on a per-layer basis. + */ + @Getter + @Setter + private IActivation activation; + //whether to constrain the gradient to unit norm or not + @Getter + @Setter + private StepFunction stepFunction; + @Getter + @Setter + @lombok.Builder.Default + private OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; + @Getter + @Setter + @lombok.Builder.Default + private int maxNumLineSearchIterations = 5; + /** + * Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis.
+ * + * @param regularization Regularization to apply for the network parameters/weights (excluding biases) + */ + @Getter + @lombok.Builder.Default + private List regularization = new ArrayList<>(); + /** + * Set the regularization for the biases only - for example {@link WeightDecay}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis.
+ * + * @param regularizationBias Regularization to apply for the network biases only + */ + @Getter + @lombok.Builder.Default + private List regularizationBias = new ArrayList<>(); + @Getter + @Setter + @lombok.Builder.Default + private IUpdater iUpdater = new Sgd(); + /** + * Gradient updater configuration, for the biases only. If not set, biases will use the updater as + * set by {@link #setIUpdater(IUpdater)}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param updater Updater to use for bias parameters + */ + @Getter + @Setter + @lombok.Builder.Default + private IUpdater biasUpdater = null; + @Getter + @Setter + @lombok.Builder.Default + private IActivation activationFn = new ActivationSigmoid(); - /** - * Creates and returns a deep copy of the configuration. - */ - @Override - public NeuralNetConfiguration clone() { + /** + * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. + * See {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * @param convolutionMode Convolution mode to use + */ + @Getter + @Setter + @lombok.Builder.Default + private ConvolutionMode convolutionMode = ConvolutionMode.Truncate; + /** + * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN. + * See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. + *
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * @param cudnnAlgoMode cuDNN algo mode to use + */ + @Getter + @Setter + @lombok.Builder.Default + private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST; + @Getter + @Setter + @lombok.Builder.Default + private boolean minimize = true; + /** + * Set the dropout for all layers in this network
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * * Dropout probability. This is the probability of retaining each input activation value for a layer. + * * dropOut(x) will keep an input activation with probability x, and set to 0 with probability 1-x.
+ * * dropOut(0.0) is a special value / special case - when set to 0.0., dropout is disabled (not applied). Note + * * that a dropout value of 1.0 is functionally equivalent to no dropout: i.e., 100% probability of retaining + * * each input activation.
+ * *

+ * * Note 1: Dropout is applied at training time only - and is automatically not applied at test time + * * (for evaluation, etc)
+ * * Note 2: This sets the probability per-layer. Care should be taken when setting lower values for + * * complex networks (too much information may be lost with aggressive (very low) dropout values).
+ * * Note 3: Frequently, dropout is not applied to (or, has higher retain probability for) input (first layer) + * * layers. Dropout is also often not applied to output layers. This needs to be handled MANUALLY by the user + * * - set .dropout(0) on those layers when using global dropout setting.
+ * * Note 4: Implementation detail (most users can ignore): DL4J uses inverted dropout, as described here: + * * http://cs231n.github.io/neural-networks-2/ + * *

+ * *
+ * * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * * value, and can be overridden on a per-layer basis. + * * + * * @param inputRetainProbability Dropout probability (probability of retaining each input activation value for a layer) + * * @see #dropOut(IDropout) + * + * + * @param dropout Dropout, such as {@link Dropout}, {@link org.deeplearning4j.nn.conf.dropout.GaussianDropout}, + * {@link org.deeplearning4j.nn.conf.dropout.GaussianNoise} etc + * @return + */ + @Getter + @Setter + private IDropout idropOut; + /** + * Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and + * {@link org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) for the layers in this network.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless a different + * value is explicitly set on a given layer. In other words: values set via this method are used as the default + * value, and can be overridden on a per-layer basis. + * + * @param weightNoise Weight noise instance to use + */ + @Getter + @Setter + private IWeightNoise weightNoise; + @Getter + @Setter + @lombok.Builder.Default + private double biasInit = 0.0; + @Getter + @Setter + @lombok.Builder.Default + private double gainInit = 1.0; + + /** + * Create a neural net configuration from json + * + * @param json the neural net configuration from json + * @return {@link NeuralNetConfiguration} + */ + public static NeuralNetConfiguration fromJson(String json) { + NeuralNetConfiguration conf; + ObjectMapper mapper = NeuralNetConfiguration.mapper(); + try { + conf = mapper.readValue(json, NeuralNetConfiguration.class); + } catch (InvalidTypeIdException e) { + if (e.getMessage().contains("@class")) { try { - NeuralNetConfiguration clone = (NeuralNetConfiguration) super.clone(); - if (clone.layer != null) - clone.layer = clone.layer.clone(); - if (clone.stepFunction != null) - clone.stepFunction = clone.stepFunction.clone(); - if (clone.variables != null) - clone.variables = new ArrayList<>(clone.variables); - return clone; - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); + //JSON may be legacy (1.0.0-alpha or earlier), attempt to load it using old format + return JsonMappers.getLegacyMapper().readValue(json, NeuralNetConfiguration.class); + } catch (InvalidTypeIdException e2) { + //Check for legacy custom layers: "Could not resolve type id 'CustomLayer' as a subtype of [simple type, class org.deeplearning4j.nn.conf.layers.ILayer]: known type ids = [Bidirectional, CenterLossOutputLayer, CnnLossLayer, ..." + //1.0.0-beta5: dropping support for custom layers defined in pre-1.0.0-beta format. Built-in layers from these formats still work + String msg = e2.getMessage(); + if (msg != null && msg.contains("Could not resolve type id")) { + throw new RuntimeException( + "Error deserializing NeuralNetConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-beta JSON format.\nModels in legacy format with custom" + + + " layers should be loaded in 1.0.0-beta to 1.0.0-beta4 and saved again, before loading in the current version of DL4J", + e); + } + throw new RuntimeException(e2); + } catch (IOException e2) { + throw new RuntimeException(e2); } + } + throw new RuntimeException(e); + } catch (IOException e) { + //Check if this exception came from legacy deserializer... + String msg = e.getMessage(); + if (msg != null && msg.contains("legacy")) { + throw new RuntimeException( + "Error deserializing NeuralNetConfiguration - configuration may have a custom " + + "layer, vertex or preprocessor, in pre version 1.0.0-alpha JSON format. These layers can be " + + + "deserialized by first registering them with NeuralNetConfiguration.registerLegacyCustomClassesForJSON(Class...)", + e); + } + throw new RuntimeException(e); } - public List variables() { - return new ArrayList<>(variables); - } + //To maintain backward compatibility after loss function refactoring (configs generated with v0.5.0 or earlier) + // Previously: enumeration used for loss functions. Now: use classes + // IN the past, could have only been an OutputLayer or RnnOutputLayer using these enums + int layerCount = 0; + JsonNode confs = null; + for (LayerConfiguration nnc : conf.getFlattenedLayerConfigurations()) { + LayerConfiguration l = nnc; + if (l instanceof BaseOutputLayer && ((BaseOutputLayer) l).getLossFn() == null) { + //lossFn field null -> may be an old config format, with lossFunction field being for the enum + //if so, try walking the JSON graph to extract out the appropriate enum value - public List variables(boolean copy) { - if (copy) - return variables(); - return variables; - } + BaseOutputLayer ol = (BaseOutputLayer) l; + try { + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + throw new RuntimeException( + "should never happen"); //return conf; //Should never happen... + } + JsonNode outputLayerNode = outputLayerNNCNode.get("layer"); - public void addVariable(String variable) { - if (!variables.contains(variable)) { - variables.add(variable); + JsonNode lossFunctionNode = null; + if (outputLayerNode.has("output")) { + lossFunctionNode = outputLayerNode.get("output").get("lossFunction"); + } else if (outputLayerNode.has("rnnoutput")) { + lossFunctionNode = outputLayerNode.get("rnnoutput").get("lossFunction"); + } + + if (lossFunctionNode != null) { + String lossFunctionEnumStr = lossFunctionNode.asText(); + LossFunctions.LossFunction lossFunction = null; + try { + lossFunction = LossFunctions.LossFunction.valueOf(lossFunctionEnumStr); + } catch (Exception e) { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", + e); + } + + if (lossFunction != null) { + switch (lossFunction) { + case MSE: + ol.setLossFn(new LossMSE()); + break; + case XENT: + ol.setLossFn(new LossBinaryXENT()); + break; + case NEGATIVELOGLIKELIHOOD: + ol.setLossFn(new LossNegativeLogLikelihood()); + break; + case MCXENT: + ol.setLossFn(new LossMCXENT()); + break; + + //Remaining: TODO + case SQUARED_LOSS: + case RECONSTRUCTION_CROSSENTROPY: + default: + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not set loss function for {}", + lossFunction); + break; + } + } + } + + } else { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON: layer 'confs' field is not an ArrayNode (is: {})", + (confs != null ? confs.getClass() : null)); + } + } catch (IOException e) { + log.warn( + "OutputLayer with null LossFunction or pre-0.6.0 loss function configuration detected: could not parse JSON", + e); + break; } + } + + //Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn") + //Try to load the old format if necessary, and create the appropriate IActivation instance + if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) { + try { + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + throw new RuntimeException( + "Should never happen"); //return conf; //Should never happen... + } + JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); + + if (layerWrapperNode == null || layerWrapperNode.size() != 1) { + continue; + } + + JsonNode layerNode = layerWrapperNode.elements().next(); + JsonNode activationFunction = layerNode.get( + "activationFunction"); //Should only have 1 element: "dense", "output", etc + + if (activationFunction != null) { + IActivation ia = Activation.fromString(activationFunction.asText()) + .getActivationFunction(); + ((BaseLayer) l).setActivationFn(ia); + } + } + + } catch (IOException e) { + log.warn( + "ILayer with null ActivationFn field or pre-0.7.2 activation function detected: could not parse JSON", + e); + } + } + + if (!handleLegacyWeightInitFromJson(json, l, mapper, confs, layerCount)) { + return conf; + } + + layerCount++; } + return conf; + } - public void clearVariables() { - variables.clear(); - } + /** + * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied + * from handling of {@link Activation} above. + * + * @return True if all is well and layer iteration shall continue. False else-wise. + */ + private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l, + ObjectMapper mapper, + JsonNode confs, int layerCount) { + if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { + try { + JsonNode jsonNode = mapper.readTree(json); + if (confs == null) { + confs = jsonNode.get("confs"); + } + if (confs instanceof ArrayNode) { + ArrayNode layerConfs = (ArrayNode) confs; + JsonNode outputLayerNNCNode = layerConfs.get(layerCount); + if (outputLayerNNCNode == null) { + return false; //Should never happen... + } + JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); - /** - * Fluent interface for building a list of configurations - */ - public static class ListBuilder extends MultiLayerConfiguration.Builder { - private int layerCounter = -1; //Used only for .layer(Layer) method - private final Map layerwise; - private final Builder globalConfig; + if (layerWrapperNode == null || layerWrapperNode.size() != 1) { + return true; + } - // Constructor - public ListBuilder(Builder globalConfig, Map layerMap) { - this.globalConfig = globalConfig; - this.layerwise = layerMap; + JsonNode layerNode = layerWrapperNode.elements().next(); + JsonNode weightInit = layerNode.get( + "weightInit"); //Should only have 1 element: "dense", "output", etc + JsonNode distribution = layerNode.get("dist"); + + Distribution dist = null; + if (distribution != null) { + dist = mapper.treeToValue(distribution, Distribution.class); + } + + if (weightInit != null) { + final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) + .getWeightInitFunction(dist); + ((BaseLayer) l).setWeightInitFn(wi); + } } - public ListBuilder(Builder globalConfig) { - this(globalConfig, new HashMap()); + } catch (IOException e) { + log.warn( + "ILayer with null WeightInit detected: " + l.getLayerName() + ", could not parse JSON", + e); + } + } + return true; + + } + + /** + * Object mapper for serialization of configurations + * + * @return + */ + public static ObjectMapper mapperYaml() { + return JsonMappers.getMapperYaml(); + } + + /** + * Object mapper for serialization of configurations + * + * @return + */ + public static ObjectMapper mapper() { + return JsonMappers.getMapper(); + } + + public static NeuralNetConfiguration fromYaml(String input) { + throw new RuntimeException("Needs fixing - not supported."); //TODO + } + + + /** + * @return JSON representation of NN configuration + */ + public String toYaml() { + ObjectMapper mapper = NeuralNetConfiguration.mapperYaml(); + synchronized (mapper) { + try { + return mapper.writeValueAsString(this); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + + /** + * @return JSON representation of NN configuration + */ + public String toJson() { + ObjectMapper mapper = NeuralNetConfiguration.mapper(); + synchronized (mapper) { + //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally + //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 + try { + return mapper.writeValueAsString(this); + } catch (com.fasterxml.jackson.core.JsonProcessingException e) { + log.error(e.getMessage()); + throw new RuntimeException(e); + } + } + } + + @Override + public String toString() { + return toJson(); + } + + @Override + public NeuralNetConfiguration clone() { + NeuralNetConfiguration clone; + clone = (NeuralNetConfiguration) super.clone(); + clone.stepFunction = clone.stepFunction.clone(); + clone.netWideVariables = new ArrayList<>(netWideVariables); + clone.getInnerConfigurations().addAll(innerConfigurations); + + if (clone.getInputPreProcessors() != null) { + Map map = new HashMap<>(); + for (Map.Entry entry : clone.getInputPreProcessors().entrySet()) { + map.put(entry.getKey(), entry.getValue().clone()); + } + clone.getInputPreProcessors().clear(); + clone.getInputPreProcessors().putAll(map); + } + + clone.setInferenceWorkspaceMode(this.inferenceWorkspaceMode); + clone.setTrainingWorkspaceMode(this.trainingWorkspaceMode); + clone.setCacheMode(this.cacheMode); + clone.setValidateOutputLayerConfig(this.validateOutputLayerConfig); + clone.setDataType(this.dataType); + + return clone; + + } + + /** + * + */ + @Override + public void init() { + getNetConfigurations().stream().forEach( conf -> conf.init()); //call init on all embedded configurations + innerConfigurations.add(0, this); //put this configuration at first place + getLayerConfigurations().stream().forEach( lconf -> lconf.setNetConfiguration(this)); //set this as net config for all layers (defined in here, not stacked + + + //Validate BackpropType setting + if ((tbpttBackLength != DEFAULT_TBPTT_LENGTH || tbpttFwdLength != DEFAULT_TBPTT_LENGTH) + && backpropType != BackpropType.TruncatedBPTT) { + log.warn("Truncated backpropagation through time lengths have been configured with values " + + tbpttFwdLength + + " and " + tbpttBackLength + " but backprop type is set to " + backpropType + + ". TBPTT configuration" + + " settings will only take effect if backprop type is set to BackpropType.TruncatedBPTT"); + } + + if (backpropType == BackpropType.TruncatedBPTT && validateTbpttConfig) { + //Check for invalid combination - tbptt plus LastTimeStepLayer or + for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) { + LayerConfiguration l = getFlattenedLayerConfigurations().get(i); + if (l instanceof LastTimeStep || l instanceof GlobalPoolingLayer) { + throw new IllegalStateException( + "Invalid network configuration detected: Truncated backpropagation through time (TBPTT)" + + + " cannot be used with layer " + i + " of type " + l.getClass().getName() + + ": TBPTT is incompatible with this layer type (which is designed " + + "to process entire sequences at once, and does support the type of sequence segments that TPBTT uses).\n" + + + "This check can be disabled using validateTbpttConfig(false) but this is not recommended."); + } + } + } + + if (inputType == null && inputPreProcessors.get(0) == null) { + //User hasn't set the InputType. Sometimes we can infer it... + // For example, Dense/RNN layers, where preprocessor isn't set -> user is *probably* going to feed in + // standard feedforward or RNN data + //This isn't the most elegant implementation, but should avoid breaking backward compatibility here + //Can't infer InputType for CNN layers, however (don't know image dimensions/depth) + LayerConfiguration firstLayer = getFlattenedLayerConfigurations().get(0); + if (firstLayer instanceof BaseRecurrentLayer) { + BaseRecurrentLayer brl = (BaseRecurrentLayer) firstLayer; + val nIn = brl.getNIn(); + if (nIn > 0) { + inputType = InputType.recurrent(nIn, brl.getRnnDataFormat()); + } + } else if (firstLayer instanceof DenseLayer || firstLayer instanceof EmbeddingLayer + || firstLayer instanceof OutputLayer) { + //Can't just use "instanceof FeedForwardLayer" here. ConvolutionLayer is also a FeedForwardLayer + FeedForwardLayer ffl = (FeedForwardLayer) firstLayer; + val nIn = ffl.getNIn(); + if (nIn > 0) { + inputType = InputType.feedForward(nIn); + } + } + } + + //Add preprocessors and set nIns, if InputType has been set + // Builder.inputType field can be set in 1 of 4 ways: + // 1. User calls setInputType directly + // 2. Via ConvolutionLayerSetup -> internally calls setInputType(InputType.convolutional(...)) + // 3. Via the above code: i.e., assume input is as expected by the RNN or dense layer -> sets the inputType field + if(inputPreProcessors == null) { + inputPreProcessors = new HashMap<>(); + } + if (inputType != null) { + InputType currentInputType = inputType; + for (int i = 0; i < getFlattenedLayerConfigurations().size(); i++) { + LayerConfiguration l = getFlattenedLayerConfigurations().get(i); + if (inputPreProcessors.get(i) == null) { + //Don't override preprocessor setting, but set preprocessor if required... + @NonNull + InputPreProcessor inputPreProcessor = l.getPreProcessorForInputType(currentInputType); + if (inputPreProcessor != null) { + inputPreProcessors.put(i, inputPreProcessor); + } } - public ListBuilder layer(int ind, @NonNull Layer layer) { - if (layerwise.containsKey(ind)) { - log.info("Layer index {} already exists, layer of type {} will be replace by layer type {}", - ind, layerwise.get(ind).getClass().getSimpleName(), layer.getClass().getSimpleName()); - layerwise.get(ind).layer(layer); + InputPreProcessor inputPreProcessor = inputPreProcessors.get(i); + if (inputPreProcessor != null) { + currentInputType = inputPreProcessor.getOutputType(currentInputType); + } + if (i > 0) { + LayerConfiguration layer = getFlattenedLayerConfigurations().get(i - 1); + //convolution 1d is an edge case where it has rnn input type but the filters + //should be the output + if (layer instanceof Convolution1DLayer) { + if (l instanceof DenseLayer && inputType instanceof InputType.InputTypeRecurrent) { + FeedForwardLayer feedForwardLayer = (FeedForwardLayer) l; + if (inputType instanceof InputType.InputTypeRecurrent) { + InputType.InputTypeRecurrent recurrent = (InputType.InputTypeRecurrent) inputType; + feedForwardLayer.setNIn(recurrent.getTimeSeriesLength()); + } } else { - layerwise.put(ind, globalConfig.clone().layer(layer)); + l.setNIn(currentInputType, + overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user } - if(layerCounter < ind){ - //Edge case: user is mixing .layer(Layer) and .layer(int, Layer) calls - //This should allow a .layer(A, X) and .layer(Y) to work such that layer Y is index (A+1) - layerCounter = ind; - } - return this; + } else { + l.setNIn(currentInputType, + overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user + } + + } else { + l.setNIn(currentInputType, + overrideNinUponBuild); //Don't override the nIn setting, if it's manually set by the user } - public ListBuilder layer(Layer layer){ - return layer(++layerCounter, layer); - } + currentInputType = l.getOutputType(i, currentInputType); + } - public Map getLayerwise() { - return layerwise; - } + } - @Override - public ListBuilder setInputType(InputType inputType){ - return (ListBuilder)super.setInputType(inputType); - } + Nd4j.getRandom().setSeed(getNetConfigurations().get(0).getSeed()); - /** - * A convenience method for setting input types: note that for example .inputType().convolutional(h,w,d) - * is equivalent to .setInputType(InputType.convolutional(h,w,d)) - */ - public ListBuilder.InputTypeBuilder inputType(){ - return new InputTypeBuilder(); - } + //Validate output layer configuration + if (isValidateOutputLayerConfig()) { + //Validate output layer configurations... + for (LayerConfiguration n : getFlattenedLayerConfigurations()) { + OutputLayerUtil.validateOutputLayer(n.getLayerName(), n); //No-op for non output/loss layers + } + } + } - /** - * For the (perhaps partially constructed) network configuration, return a list of activation sizes for each - * layer in the network.
- * Note: To use this method, the network input type must have been set using {@link #setInputType(InputType)} first - * @return A list of activation types for the network, indexed by layer number - */ - public List getLayerActivationTypes(){ - Preconditions.checkState(inputType != null, "Can only calculate activation types if input type has" + - "been set. Use setInputType(InputType)"); + public InputPreProcessor getInputPreProcess(int curr) { + return inputPreProcessors.get(curr); + } - MultiLayerConfiguration conf; - try{ - conf = build(); - } catch (Exception e){ - throw new RuntimeException("Error calculating layer activation types: error instantiating MultiLayerConfiguration", e); - } + /** + * Get a {@link MemoryReport} for the given NeuralNetConfiguration. This is used to estimate the + * memory requirements for the given network configuration and input + * + * @param inputType Input types for the network + * @return Memory report for the network + */ + public NetworkMemoryReport getMemoryReport(InputType inputType) { - return conf.getLayerActivationTypes(inputType); - } + Map memoryReportMap = new LinkedHashMap<>(); + int nLayers = getFlattenedLayerConfigurations().size(); + for (int i = 0; i < nLayers; i++) { + String layerName = getFlattenedLayerConfigurations().get(i).getLayerName(); + if (layerName == null) { + layerName = String.valueOf(i); + } - /** - * Build the multi layer network - * based on this neural network and - * overr ridden parameters - * - * @return the configuration to build - */ - public MultiLayerConfiguration build() { - List list = new ArrayList<>(); - if (layerwise.isEmpty()) - throw new IllegalStateException("Invalid configuration: no layers defined"); - for (int i = 0; i < layerwise.size(); i++) { - if (layerwise.get(i) == null) { - throw new IllegalStateException("Invalid configuration: layer number " + i - + " not specified. Expect layer " + "numbers to be 0 to " + (layerwise.size() - 1) - + " inclusive (number of layers defined: " + layerwise.size() + ")"); - } - if (layerwise.get(i).getLayer() == null) - throw new IllegalStateException("Cannot construct network: Layer config for" + "layer with index " - + i + " is not defined)"); + //Pass input type through preprocessor, if necessary + InputPreProcessor preproc = getInputPreProcess(i); + //TODO memory requirements for preprocessor + if (preproc != null) { + inputType = preproc.getOutputType(inputType); + } - //Layer names: set to default, if not set - if (layerwise.get(i).getLayer().getLayerName() == null) { - layerwise.get(i).getLayer().setLayerName("layer" + i); - } + LayerMemoryReport report = getFlattenedLayerConfigurations().get(i).getMemoryReport(inputType); + memoryReportMap.put(layerName, report); - list.add(layerwise.get(i).build()); - } + inputType = getFlattenedLayerConfigurations().get(i).getOutputType(i, inputType); + } - WorkspaceMode wsmTrain = (globalConfig.setTWM ? globalConfig.trainingWorkspaceMode : trainingWorkspaceMode); - WorkspaceMode wsmTest = (globalConfig.setIWM ? globalConfig.inferenceWorkspaceMode : inferenceWorkspaceMode); + return new NetworkMemoryReport(memoryReportMap, NeuralNetConfiguration.class, + "MultiLayerNetwork", inputType); + } + + /** + * For the given input shape/type for the network, return a list of activation sizes for each + * layer in the network.
i.e., list.get(i) is the output activation sizes for layer i + * + * @param inputType Input type for the network + * @return A lits of activation types for the network, indexed by layer number + */ + public List getLayerActivationTypes(@NonNull InputType inputType) { + List out = new ArrayList<>(); + int nLayers = getFlattenedLayerConfigurations().size(); + for (int i = 0; i < nLayers; i++) { + InputPreProcessor preproc = getInputPreProcess(i); + if (preproc != null) { + inputType = preproc.getOutputType(inputType); + } + + inputType = getFlattenedLayerConfigurations().get(i).getOutputType(i, inputType); + out.add(inputType); + } + return out; + } - return new MultiLayerConfiguration.Builder().inputPreProcessors(inputPreProcessors) - .backpropType(backpropType).tBPTTForwardLength(tbpttFwdLength) - .tBPTTBackwardLength(tbpttBackLength).setInputType(this.inputType) - .trainingWorkspaceMode(wsmTrain).cacheMode(globalConfig.cacheMode) - .inferenceWorkspaceMode(wsmTest).confs(list).validateOutputLayerConfig(validateOutputConfig) - .dataType(globalConfig.dataType) - .build(); - } + public List netWideVariables() { + return netWideVariables; + } - /** Helper class for setting input types */ - public class InputTypeBuilder { + public List netWideVariables(boolean copy) { + if (copy) { + return netWideVariables(); + } + return netWideVariables; + } + + public void addNetWideVariable(String variable) { + if (!netWideVariables.contains(variable)) { + netWideVariables.add(variable); + log.trace("Adding neural network wide variable '{}' to the list of variables. New length is {}.", variable, netWideVariables.size()); + } + log.trace("Skipped adding neural network wide variable '{}' to the list of variables. It was already present. Length remains {}.", variable, netWideVariables.size()); + } + + public void clearNetWideVariable() { + + netWideVariables.clear(); + log.trace("Adding neural network wide variables have been cleared. New length is {}.", netWideVariables.size()); + } + + + + /** + * From the list of layers and neural net configurations, only return the Layer Configurations that + * are defined in this neural network (it does not include embedded neural network configuration + * layers) + * @return list with layer configurations + */ + public List getLayerConfigurations() { + return innerConfigurations.stream() + .filter(obj -> (obj instanceof LayerConfiguration)) + .map( obj -> (LayerConfiguration)obj ) + .collect( Collectors.toList()); + } + + /** + * From the list of layers and neural net configurations, only return the neural net configurations + * @return list with neural net configurations + */ + public List getNetConfigurations() { + return innerConfigurations.stream() + .filter(obj -> (obj instanceof NeuralNetConfiguration)) + .map( obj -> (NeuralNetConfiguration)obj ) + .collect( Collectors.toList()); + } + + /** + * From the list of layer configurations and inner neural net configurations, create a single, + * flattened list of layer configurations with inheritance parameters resolved + * + * @return list of layer configurations + */ + public List getFlattenedLayerConfigurations(NeuralNetConfiguration conf) { + List ret = new ArrayList<>(); //create the final return list + for( Object obj : conf.getInnerConfigurations().stream().skip(1) //don't include self + .collect(Collectors.toList())) { + //if Layer Config, include in list and inherit parameters from this conf + //else if neural net configuration, call self recursively to resolve layer configurations + if (obj instanceof LayerConfiguration) + ret.add((LayerConfiguration) obj); + else if (obj instanceof NeuralNetConfiguration) + ret.addAll(getFlattenedLayerConfigurations( + (NeuralNetConfiguration) obj)); + else { + log.error( + "The list of layers and neural network configurations does contain an object of {}. Element will be ignored.", + obj.getClass().getSimpleName()); + } + } /** - * See {@link InputType#convolutional(long, long, long)} - */ - public ListBuilder convolutional(int height, int width, int depth){ - return ListBuilder.this.setInputType(InputType.convolutional(height, width, depth)); + LayerConfiguration lc = ((LayerConfiguration) lc).getType().getClazz().cast(obj); + switch(lc.getType()) { + case FC: { //fully connected layer + ((FeedForwardLayer) lc).setWeightInitFn(this.getWeightInitFn()); } + if(lc instanceof FeedForwardLayer && ((FeedForwardLayer) lc).getWeightInitFn() == null) { + **/ + return ret; + } - /** - * * See {@link InputType#convolutionalFlat(long, long, long)} - */ - public ListBuilder convolutionalFlat(int height, int width, int depth){ - return ListBuilder.this.setInputType(InputType.convolutionalFlat(height, width, depth)); - } + /** + * Sames as {@link #getFlattenedLayerConfigurations(NeuralNetConfiguration)}, but uses this configurations + * list of configurations + * @return list of layer configurations + */ + public List getFlattenedLayerConfigurations() { + return getFlattenedLayerConfigurations(this); + } - /** - * See {@link InputType#feedForward(long)} - */ - public ListBuilder feedForward(int size){ - return ListBuilder.this.setInputType(InputType.feedForward(size)); - } - /** - * See {@link InputType#recurrent(long)}} - */ - public ListBuilder recurrent(int size){ - return ListBuilder.this.setInputType(InputType.recurrent(size)); - } - } + /** + * Get the configuration of the first layer + * @return layer configuration + */ + /** + public LayerConfiguration getFirstLayer() { + return getFlattenedLayerConfigurations().get(0); + } +**/ + + /** + * Add a new layer to the first position + * @param layer configuration + */ + public void setLayer(@NonNull LayerConfiguration layer) { + innerConfigurations.add(0, layer); + } + + @Deprecated + public LayerConfiguration getConf(int index) { + return getFlattenedLayerConfigurations().get(index); + } + + public static abstract class NeuralNetConfigurationBuilder> extends + NeuralNetBaseBuilderConfigurationBuilder { + + public ComputationGraphConfiguration.GraphBuilder graphBuilder() { + return new ComputationGraphConfiguration.GraphBuilder(this); } - /** - * Return this configuration as json - * - * @return this configuration represented as json - */ - public String toYaml() { - ObjectMapper mapper = mapperYaml(); - - try { - String ret = mapper.writeValueAsString(this); - return ret; - - } catch (com.fasterxml.jackson.core.JsonProcessingException e) { - throw new RuntimeException(e); - } + public NeuralNetConfigurationBuilder clone() { + try { + return (NeuralNetConfigurationBuilder) super.clone(); + } catch(CloneNotSupportedException ex) { + throw new RuntimeException(ex); + } } - /** - * Create a neural net configuration from json - * - * @param json the neural net configuration from json - * @return - */ - public static NeuralNetConfiguration fromYaml(String json) { - ObjectMapper mapper = mapperYaml(); - try { - NeuralNetConfiguration ret = mapper.readValue(json, NeuralNetConfiguration.class); - return ret; - } catch (IOException e) { - throw new RuntimeException(e); - } - } + } - /** - * Return this configuration as json - * - * @return this configuration represented as json - */ - public String toJson() { - ObjectMapper mapper = mapper(); - - try { - return mapper.writeValueAsString(this); - } catch (com.fasterxml.jackson.core.JsonProcessingException e) { - throw new RuntimeException(e); - } - } - - /** - * Create a neural net configuration from json - * - * @param json the neural net configuration from json - * @return - */ - public static NeuralNetConfiguration fromJson(String json) { - ObjectMapper mapper = mapper(); - try { - NeuralNetConfiguration ret = mapper.readValue(json, NeuralNetConfiguration.class); - return ret; - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - /** - * Object mapper for serialization of configurations - * - * @return - */ - public static ObjectMapper mapperYaml() { - return JsonMappers.getMapperYaml(); - } - - /** - * Object mapper for serialization of configurations - * - * @return - */ - public static ObjectMapper mapper() { - return JsonMappers.getMapper(); - } - - /** - * NeuralNetConfiguration builder, used as a starting point for creating a MultiLayerConfiguration or - * ComputationGraphConfiguration.
- * Note that values set here on the layer will be applied to all relevant layers - unless the value is overridden - * on a layer's configuration - */ - @Data - public static class Builder implements Cloneable { - protected IActivation activationFn = new ActivationSigmoid(); - protected IWeightInit weightInitFn = new WeightInitXavier(); - protected double biasInit = 0.0; - protected double gainInit = 1.0; - protected List regularization = new ArrayList<>(); - protected List regularizationBias = new ArrayList<>(); - protected IDropout idropOut; - protected IWeightNoise weightNoise; - protected IUpdater iUpdater = new Sgd(); - protected IUpdater biasUpdater = null; - protected Layer layer; - protected boolean miniBatch = true; - protected int maxNumLineSearchIterations = 5; - protected long seed = System.currentTimeMillis(); - protected OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; - protected StepFunction stepFunction = null; - protected boolean minimize = true; - protected GradientNormalization gradientNormalization = GradientNormalization.None; - protected double gradientNormalizationThreshold = 1.0; - protected List allParamConstraints; - protected List weightConstraints; - protected List biasConstraints; - - protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; - protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; - protected boolean setTWM = false; - protected boolean setIWM = false; - protected CacheMode cacheMode = CacheMode.NONE; - protected DataType dataType = DataType.FLOAT; - - protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate; - protected ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST; - - public Builder() { - // - } - - public Builder(NeuralNetConfiguration newConf) { - if (newConf != null) { - minimize = newConf.minimize; - maxNumLineSearchIterations = newConf.maxNumLineSearchIterations; - layer = newConf.layer; - optimizationAlgo = newConf.optimizationAlgo; - seed = newConf.seed; - stepFunction = newConf.stepFunction; - miniBatch = newConf.miniBatch; - } - } - - /** - * Process input as minibatch vs full dataset. - * Default set to true. - */ - public Builder miniBatch(boolean miniBatch) { - this.miniBatch = miniBatch; - return this; - } - - /** - * This method defines Workspace mode being used during training:
- * NONE: workspace won't be used
- * ENABLED: workspaces will be used for training (reduced memory and better performance) - * - * @param workspaceMode Workspace mode for training - * @return Builder - */ - public Builder trainingWorkspaceMode(@NonNull WorkspaceMode workspaceMode) { - this.trainingWorkspaceMode = workspaceMode; - this.setTWM = true; - return this; - } - - /** - * This method defines Workspace mode being used during inference:
- * NONE: workspace won't be used
- * ENABLED: workspaces will be used for inference (reduced memory and better performance) - * - * @param workspaceMode Workspace mode for inference - * @return Builder - */ - public Builder inferenceWorkspaceMode(@NonNull WorkspaceMode workspaceMode) { - this.inferenceWorkspaceMode = workspaceMode; - this.setIWM = true; - return this; - } - - /** - * This method defines how/if preOutput cache is handled: - * NONE: cache disabled (default value) - * HOST: Host memory will be used - * DEVICE: GPU memory will be used (on CPU backends effect will be the same as for HOST) - * - * @param cacheMode Cache mode to use - * @return Builder - */ - public Builder cacheMode(@NonNull CacheMode cacheMode) { - this.cacheMode = cacheMode; - return this; - } - - /** - * Objective function to minimize or maximize cost function - * Default set to minimize true. - */ - public Builder minimize(boolean minimize) { - this.minimize = minimize; - return this; - } - - /** - * Maximum number of line search iterations. - * Only applies for line search optimizers: Line Search SGD, Conjugate Gradient, LBFGS - * is NOT applicable for standard SGD - * - * @param maxNumLineSearchIterations > 0 - * @return - */ - public Builder maxNumLineSearchIterations(int maxNumLineSearchIterations) { - this.maxNumLineSearchIterations = maxNumLineSearchIterations; - return this; - } - - - /** - * Layer class. - */ - public Builder layer(Layer layer) { - this.layer = layer; - return this; - } - - /** - * Step function to apply for back track line search. - * Only applies for line search optimizers: Line Search SGD, Conjugate Gradient, LBFGS - * Options: DefaultStepFunction (default), NegativeDefaultStepFunction - * GradientStepFunction (for SGD), NegativeGradientStepFunction - */ - @Deprecated - public Builder stepFunction(StepFunction stepFunction) { - this.stepFunction = stepFunction; - return this; - } - - /** - * Create a ListBuilder (for creating a MultiLayerConfiguration)
- * Usage:
- *
-         * {@code .list()
-         * .layer(new DenseLayer.Builder()...build())
-         * ...
-         * .layer(new OutputLayer.Builder()...build())
-         * }
-         * 
- */ - public ListBuilder list() { - return new ListBuilder(this); - } - - /** - * Create a ListBuilder (for creating a MultiLayerConfiguration) with the specified layers
- * Usage:
- *
-         * {@code .list(
-         *      new DenseLayer.Builder()...build(),
-         *      ...,
-         *      new OutputLayer.Builder()...build())
-         * }
-         * 
- * - * @param layers The layer configurations for the network - */ - public ListBuilder list(Layer... layers) { - if (layers == null || layers.length == 0) - throw new IllegalArgumentException("Cannot create network with no layers"); - Map layerMap = new HashMap<>(); - for (int i = 0; i < layers.length; i++) { - Builder b = this.clone(); - b.layer(layers[i]); - layerMap.put(i, b); - } - return new ListBuilder(this, layerMap); - - } - - /** - * Create a GraphBuilder (for creating a ComputationGraphConfiguration). - */ - public ComputationGraphConfiguration.GraphBuilder graphBuilder() { - return new ComputationGraphConfiguration.GraphBuilder(this); - } - - /** - * Random number generator seed. Used for reproducability between runs - */ - public Builder seed(long seed) { - this.seed = seed; - Nd4j.getRandom().setSeed(seed); - return this; - } - - /** - * Optimization algorithm to use. Most common: OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT - * - * @param optimizationAlgo Optimization algorithm to use when training - */ - public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) { - this.optimizationAlgo = optimizationAlgo; - return this; - } - - @Override - public Builder clone() { - try { - Builder clone = (Builder) super.clone(); - if (clone.layer != null) - clone.layer = clone.layer.clone(); - if (clone.stepFunction != null) - clone.stepFunction = clone.stepFunction.clone(); - - return clone; - - } catch (CloneNotSupportedException e) { - throw new RuntimeException(e); - } - } - - /** - * Activation function / neuron non-linearity
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @see #activation(Activation) - */ - public Builder activation(IActivation activationFunction) { - this.activationFn = activationFunction; - return this; - } - - /** - * Activation function / neuron non-linearity
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - */ - public Builder activation(Activation activation) { - return activation(activation.getActivationFunction()); - } - - - /** - * Weight initialization scheme to use, for initial weight values - * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @see IWeightInit - */ - public Builder weightInit(IWeightInit weightInit) { - this.weightInitFn = weightInit; - return this; - } - - /** - * Weight initialization scheme to use, for initial weight values - * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @see WeightInit - */ - public Builder weightInit(WeightInit weightInit) { - if(weightInit == WeightInit.DISTRIBUTION) { - // throw new UnsupportedOperationException("Not supported!, Use weightInit(Distribution distribution) instead!"); - } - - this.weightInitFn = weightInit.getWeightInitFunction(); - return this; - } - - /** - * Set weight initialization scheme to random sampling via the specified distribution. - * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))} - * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param distribution Distribution to use for weight initialization - */ - public Builder weightInit(Distribution distribution){ - return weightInit(new WeightInitDistribution(distribution)); - } - - /** - * Constant for bias initialization. Default: 0.0
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param biasInit Constant for bias initialization - */ - public Builder biasInit(double biasInit) { - this.biasInit = biasInit; - return this; - } - - /** - * Distribution to sample initial weights from. - * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))}.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @see #weightInit(Distribution) - * @deprecated Use {@link #weightInit(Distribution)} - */ - @Deprecated - public Builder dist(Distribution dist) { - return weightInit(dist); - } - - /** - * L1 regularization coefficient for the weights (excluding biases).
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - */ - public Builder l1(double l1) { - //Check if existing L1 exists; if so, replace it - NetworkUtils.removeInstances(this.regularization, L1Regularization.class); - if(l1 > 0.0) { - this.regularization.add(new L1Regularization(l1)); - } - return this; - } - - /** - * L2 regularization coefficient for the weights (excluding biases).
- * Note: Generally, {@link WeightDecay} (set via {@link #weightDecay(double)} should be preferred to - * L2 regularization. See {@link WeightDecay} javadoc for further details.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
- * Note: L2 regularization and weight decay usually should not be used together; if any weight decay (or L2) has - * been added for the biases, these will be removed first. - * - * @see #weightDecay(double, boolean) - */ - public Builder l2(double l2) { - //Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make sense to use both - NetworkUtils.removeInstances(this.regularization, L2Regularization.class); - if(l2 > 0.0) { - NetworkUtils.removeInstancesWithWarning(this.regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization"); - this.regularization.add(new L2Regularization(l2)); - } - return this; - } - - /** - * L1 regularization coefficient for the bias.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - */ - public Builder l1Bias(double l1Bias) { - NetworkUtils.removeInstances(this.regularizationBias, L1Regularization.class); - if(l1Bias > 0.0) { - this.regularizationBias.add(new L1Regularization(l1Bias)); - } - return this; - } - - /** - * L2 regularization coefficient for the bias.
- * Note: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)} should be preferred to - * L2 regularization. See {@link WeightDecay} javadoc for further details.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
- * Note: L2 regularization and weight decay usually should not be used together; if any weight decay (or L2) has - * been added for the biases, these will be removed first. - * - * @see #weightDecayBias(double, boolean) - */ - public Builder l2Bias(double l2Bias) { - NetworkUtils.removeInstances(this.regularizationBias, L2Regularization.class); - if(l2Bias > 0.0) { - NetworkUtils.removeInstancesWithWarning(this.regularizationBias, WeightDecay.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization"); - this.regularizationBias.add(new L2Regularization(l2Bias)); - } - return this; - } - - /** - * Add weight decay regularization for the network parameters (excluding biases).
- * This applies weight decay with multiplying the learning rate - see {@link WeightDecay} for more details.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
- * - * @param coefficient Weight decay regularization coefficient - * @see #weightDecay(double, boolean) - */ - public Builder weightDecay(double coefficient) { - return weightDecay(coefficient, true); - } - - /** - * Add weight decay regularization for the network parameters (excluding biases). See {@link WeightDecay} for more details.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
- * - * @param coefficient Weight decay regularization coefficient - * @param applyLR Whether the learning rate should be multiplied in when performing weight decay updates. See {@link WeightDecay} for more details. - * @see #weightDecay(double, boolean) - */ - public Builder weightDecay(double coefficient, boolean applyLR) { - //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both - NetworkUtils.removeInstances(this.regularization, WeightDecay.class); - if(coefficient > 0.0) { - NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization"); - this.regularization.add(new WeightDecay(coefficient, applyLR)); - } - return this; - } - - /** - * Weight decay for the biases only - see {@link #weightDecay(double)} for more details. - * This applies weight decay with multiplying the learning rate.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
- * - * @param coefficient Weight decay regularization coefficient - * @see #weightDecayBias(double, boolean) - */ - public Builder weightDecayBias(double coefficient) { - return weightDecayBias(coefficient, true); - } - - /** - * Weight decay for the biases only - see {@link #weightDecay(double)} for more details
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
- * - * @param coefficient Weight decay regularization coefficient - */ - public Builder weightDecayBias(double coefficient, boolean applyLR) { - //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both - NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class); - if(coefficient > 0) { - NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization"); - this.regularizationBias.add(new WeightDecay(coefficient, applyLR)); - } - return this; - } - - /** - * Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay}
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
- * - * @param regularization Regularization to apply for the network parameters/weights (excluding biases) - */ - public Builder regularization(List regularization) { - this.regularization = regularization; - return this; - } - - /** - * Set the regularization for the biases only - for example {@link WeightDecay}
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
- * - * @param regularizationBias Regularization to apply for the network biases only - */ - public Builder regularizationBias(List regularizationBias) { - this.regularizationBias = regularizationBias; - return this; - } - - /** - * Dropout probability. This is the probability of retaining each input activation value for a layer. - * dropOut(x) will keep an input activation with probability x, and set to 0 with probability 1-x.
- * dropOut(0.0) is a special value / special case - when set to 0.0., dropout is disabled (not applied). Note - * that a dropout value of 1.0 is functionally equivalent to no dropout: i.e., 100% probability of retaining - * each input activation.
- *

- * Note 1: Dropout is applied at training time only - and is automatically not applied at test time - * (for evaluation, etc)
- * Note 2: This sets the probability per-layer. Care should be taken when setting lower values for - * complex networks (too much information may be lost with aggressive (very low) dropout values).
- * Note 3: Frequently, dropout is not applied to (or, has higher retain probability for) input (first layer) - * layers. Dropout is also often not applied to output layers. This needs to be handled MANUALLY by the user - * - set .dropout(0) on those layers when using global dropout setting.
- * Note 4: Implementation detail (most users can ignore): DL4J uses inverted dropout, as described here: - * http://cs231n.github.io/neural-networks-2/ - *

- *
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param inputRetainProbability Dropout probability (probability of retaining each input activation value for a layer) - * @see #dropOut(IDropout) - */ - public Builder dropOut(double inputRetainProbability) { - if(inputRetainProbability == 0.0){ - return dropOut(null); - } - return dropOut(new Dropout(inputRetainProbability)); - } - - /** - * Set the dropout for all layers in this network
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param dropout Dropout, such as {@link Dropout}, {@link org.deeplearning4j.nn.conf.dropout.GaussianDropout}, - * {@link org.deeplearning4j.nn.conf.dropout.GaussianNoise} etc - * @return - */ - public Builder dropOut(IDropout dropout){ - //Clone: Dropout is stateful usually - don't want to have the same instance shared in multiple places - this.idropOut = (dropout == null ? null : dropout.clone()); - return this; - } - - /** - * Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and - * {@link org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) for the layers in this network.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param weightNoise Weight noise instance to use - */ - public Builder weightNoise(IWeightNoise weightNoise){ - this.weightNoise = weightNoise; - return this; - } - - - /** - * @deprecated Use {@link #updater(IUpdater)} - */ - @Deprecated - public Builder updater(Updater updater) { - return updater(updater.getIUpdaterWithDefaultConfig()); - } - - /** - * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} - * or {@link org.nd4j.linalg.learning.config.Nesterovs}
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param updater Updater to use - */ - public Builder updater(IUpdater updater) { - this.iUpdater = updater; - return this; - } - - /** - * Gradient updater configuration, for the biases only. If not set, biases will use the updater as - * set by {@link #updater(IUpdater)}
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param updater Updater to use for bias parameters - */ - public Builder biasUpdater(IUpdater updater){ - this.biasUpdater = updater; - return this; - } - - /** - * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. - * See {@link GradientNormalization} for details
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param gradientNormalization Type of normalization to use. Defaults to None. - * @see GradientNormalization - */ - public Builder gradientNormalization(GradientNormalization gradientNormalization) { - this.gradientNormalization = gradientNormalization; - return this; - } - - /** - * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, - * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
- * Not used otherwise.
- * L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - */ - public Builder gradientNormalizationThreshold(double threshold) { - this.gradientNormalizationThreshold = threshold; - return this; - } - - /** - * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. - * See {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * @param convolutionMode Convolution mode to use - */ - public Builder convolutionMode(ConvolutionMode convolutionMode) { - this.convolutionMode = convolutionMode; - return this; - } - - /** - * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN. - * See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. - *
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * @param cudnnAlgoMode cuDNN algo mode to use - */ - public Builder cudnnAlgoMode(ConvolutionLayer.AlgoMode cudnnAlgoMode) { - this.cudnnAlgoMode = cudnnAlgoMode; - return this; - } - - /** - * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param constraints Constraints to apply to all parameters of all layers - */ - public Builder constrainAllParameters(LayerConstraint... constraints){ - this.allParamConstraints = Arrays.asList(constraints); - return this; - } - - /** - * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param constraints Constraints to apply to all bias parameters of all layers - */ - public Builder constrainBias(LayerConstraint... constraints) { - this.biasConstraints = Arrays.asList(constraints); - return this; - } - - /** - * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * - * @param constraints Constraints to apply to all weight parameters of all layers - */ - public Builder constrainWeights(LayerConstraint... constraints) { - this.weightConstraints = Arrays.asList(constraints); - return this; - } - - - /** - * Set the DataType for the network parameters and activations. Must be a floating point type: {@link DataType#DOUBLE}, - * {@link DataType#FLOAT} or {@link DataType#HALF}.
- */ - public Builder dataType(@NonNull DataType dataType){ - Preconditions.checkState(dataType == DataType.DOUBLE || dataType == DataType.FLOAT || dataType == DataType.HALF, - "Data type must be a floating point type: one of DOUBLE, FLOAT, or HALF. Got datatype: %s", dataType); - this.dataType = dataType; - return this; - } - - /** - * Return a configuration based on this builder - * - * @return - */ - public NeuralNetConfiguration build() { - - NeuralNetConfiguration conf = new NeuralNetConfiguration(); - conf.minimize = minimize; - conf.maxNumLineSearchIterations = maxNumLineSearchIterations; - conf.layer = layer; - conf.optimizationAlgo = optimizationAlgo; - conf.seed = seed; - conf.stepFunction = stepFunction; - conf.miniBatch = miniBatch; - conf.cacheMode = this.cacheMode; - conf.dataType = this.dataType; - - configureLayer(layer); - if (layer instanceof FrozenLayer) { - configureLayer(((FrozenLayer) layer).getLayer()); - } - - if (layer instanceof FrozenLayerWithBackprop) { - configureLayer(((FrozenLayerWithBackprop) layer).getUnderlying()); - } - - return conf; - } - - private void configureLayer(Layer layer) { - String layerName; - if (layer == null || layer.getLayerName() == null) - layerName = "Layer not named"; - else - layerName = layer.getLayerName(); - - if(layer instanceof AbstractSameDiffLayer){ - AbstractSameDiffLayer sdl = (AbstractSameDiffLayer)layer; - sdl.applyGlobalConfig(this); - } - - if (layer != null) { - copyConfigToLayer(layerName, layer); - } - - if (layer instanceof FrozenLayer) { - copyConfigToLayer(layerName, ((FrozenLayer) layer).getLayer()); - } - - if (layer instanceof FrozenLayerWithBackprop) { - copyConfigToLayer(layerName, ((FrozenLayerWithBackprop) layer).getUnderlying()); - } - - if (layer instanceof Bidirectional) { - Bidirectional b = (Bidirectional)layer; - copyConfigToLayer(b.getFwd().getLayerName(), b.getFwd()); - copyConfigToLayer(b.getBwd().getLayerName(), b.getBwd()); - } - - if(layer instanceof BaseWrapperLayer){ - BaseWrapperLayer bwr = (BaseWrapperLayer)layer; - configureLayer(bwr.getUnderlying()); - } - - if (layer instanceof ConvolutionLayer) { - ConvolutionLayer cl = (ConvolutionLayer) layer; - if (cl.getConvolutionMode() == null) { - cl.setConvolutionMode(convolutionMode); - } - if (cl.getCudnnAlgoMode() == null) { - cl.setCudnnAlgoMode(cudnnAlgoMode); - } - } - if (layer instanceof SubsamplingLayer) { - SubsamplingLayer sl = (SubsamplingLayer) layer; - if (sl.getConvolutionMode() == null) { - sl.setConvolutionMode(convolutionMode); - } - } - LayerValidation.generalValidation(layerName, layer, idropOut, regularization, regularizationBias, - allParamConstraints, weightConstraints, biasConstraints); - } - - private void copyConfigToLayer(String layerName, Layer layer) { - - if (layer.getIDropout() == null) { - //Dropout is stateful usually - don't want to have the same instance shared by multiple layers - layer.setIDropout(idropOut == null ? null : idropOut.clone()); - } - - if (layer instanceof BaseLayer) { - BaseLayer bLayer = (BaseLayer) layer; - if (bLayer.getRegularization() == null || bLayer.getRegularization().isEmpty()) - bLayer.setRegularization(regularization); - if (bLayer.getRegularizationBias() == null || bLayer.getRegularizationBias().isEmpty()) - bLayer.setRegularizationBias(regularizationBias); - if (bLayer.getActivationFn() == null) - bLayer.setActivationFn(activationFn); - if (bLayer.getWeightInitFn() == null) - bLayer.setWeightInitFn(weightInitFn); - if (Double.isNaN(bLayer.getBiasInit())) - bLayer.setBiasInit(biasInit); - if (Double.isNaN(bLayer.getGainInit())) - bLayer.setGainInit(gainInit); - - //Configure weight noise: - if(weightNoise != null && ((BaseLayer) layer).getWeightNoise() == null){ - ((BaseLayer) layer).setWeightNoise(weightNoise.clone()); - } - - //Configure updaters: - if(iUpdater != null && bLayer.getIUpdater() == null){ - bLayer.setIUpdater(iUpdater.clone()); //Clone the updater to avoid shared instances - in case of setLearningRate calls later - } - if(biasUpdater != null && bLayer.getBiasUpdater() == null){ - bLayer.setBiasUpdater(biasUpdater.clone()); //Clone the updater to avoid shared instances - in case of setLearningRate calls later - } - - if(bLayer.getIUpdater() == null && iUpdater == null && bLayer.initializer().numParams(bLayer) > 0){ - //No updater set anywhere - IUpdater u = new Sgd(); - bLayer.setIUpdater(u); - log.warn("*** No updater configuration is set for layer {} - defaulting to {} ***", layerName, u); - } - - if (bLayer.getGradientNormalization() == null) - bLayer.setGradientNormalization(gradientNormalization); - if (Double.isNaN(bLayer.getGradientNormalizationThreshold())) - bLayer.setGradientNormalizationThreshold(gradientNormalizationThreshold); - } - - if (layer instanceof ActivationLayer){ - ActivationLayer al = (ActivationLayer)layer; - if(al.getActivationFn() == null) - al.setActivationFn(activationFn); - } - } - } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/BaseConstraint.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/BaseConstraint.java index fafb7a78e..f9a3e81f0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/BaseConstraint.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/constraint/BaseConstraint.java @@ -53,12 +53,12 @@ public abstract class BaseConstraint implements LayerConstraint { @Override public void applyConstraint(Layer layer, int iteration, int epoch) { - Map paramTable = layer.paramTable(); + Map paramTable = layer.getParamTable(); if(paramTable == null || paramTable.isEmpty() ){ return; } - ParamInitializer i = layer.conf().getLayer().initializer(); + ParamInitializer i = layer.getLayerConfiguration().initializer(); for(Map.Entry e : paramTable.entrySet()){ if(params.contains(e.getKey())){ apply(e.getValue()); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java index 0c7565db1..f93c1619b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java @@ -21,11 +21,13 @@ package org.deeplearning4j.nn.conf.graph; import lombok.Data; +import lombok.Getter; import lombok.NoArgsConstructor; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; import org.nd4j.linalg.api.buffer.DataType; @@ -37,16 +39,18 @@ import java.util.Arrays; @Data public class LayerVertex extends GraphVertex { - private NeuralNetConfiguration layerConf; + private NeuralNetConfiguration netConfiguration; + @Getter + private LayerConfiguration layerConfiguration; private InputPreProcessor preProcessor; //Set outputVertex to true when ILayer is an OutputLayer, OR For use in specialized situations like reinforcement learning - // For RL situations, this ILayer insn't an OutputLayer, but is the last layer in a graph, that gets its error/epsilon + // For RL situations, this ILayer isn't an OutputLayer, but is the last layer in a graph, that gets its error/epsilon // passed in externally private boolean outputVertex; - public LayerVertex(NeuralNetConfiguration layerConf, InputPreProcessor preProcessor) { - this.layerConf = layerConf; + public LayerVertex(NeuralNetConfiguration netConfiguration, InputPreProcessor preProcessor) { + this.netConfiguration = netConfiguration; this.preProcessor = preProcessor; } @@ -56,7 +60,8 @@ public class LayerVertex extends GraphVertex { @Override public GraphVertex clone() { - return new LayerVertex(layerConf.clone(), (preProcessor != null ? preProcessor.clone() : null)); + return new LayerVertex( + netConfiguration.clone(), (preProcessor != null ? preProcessor.clone() : null)); } @Override @@ -64,10 +69,11 @@ public class LayerVertex extends GraphVertex { if (!(o instanceof LayerVertex)) return false; LayerVertex lv = (LayerVertex) o; - if ((layerConf == null && lv.layerConf != null) || (layerConf != null && lv.layerConf == null)) { + if ((netConfiguration == null && lv.netConfiguration != null) || (netConfiguration != null && lv.netConfiguration + == null)) { return false; } - if (layerConf != null && !layerConf.equals(lv.layerConf)) + if (netConfiguration != null && !netConfiguration.equals(lv.netConfiguration)) return false; if (preProcessor == null && lv.preProcessor != null || preProcessor != null && lv.preProcessor == null) return false; @@ -76,12 +82,12 @@ public class LayerVertex extends GraphVertex { @Override public int hashCode() { - return layerConf.hashCode() ^ (preProcessor != null ? preProcessor.hashCode() : 0); + return netConfiguration.hashCode() ^ (preProcessor != null ? preProcessor.hashCode() : 0); } @Override public long numParams(boolean backprop) { - return layerConf.getLayer().initializer().numParams(layerConf); + return layerConfiguration.initializer().numParams(layerConfiguration); } @Override @@ -99,13 +105,13 @@ public class LayerVertex extends GraphVertex { INDArray paramsView, boolean initializeParams, DataType networkDatatype) { //Now, we need to work out if this vertex is an output vertex or not... boolean isOutput = graph.getComputationGraphConfiguration().getNetworkOutputs().contains(name); - + this.layerConfiguration = graph.getLayer(idx).getLayerConfiguration(); org.deeplearning4j.nn.api.Layer layer = - layerConf.getLayer().instantiate(layerConf, null, idx, paramsView, initializeParams, networkDatatype); + layerConfiguration.instantiate(netConfiguration, null, idx, paramsView, initializeParams, networkDatatype); if(layer == null) { throw new IllegalStateException("Encountered null layer during initialization for layer:" + - layerConf.getLayer().getClass().getSimpleName() + " initialization returned null layer?"); + layerConfiguration.getClass().getSimpleName() + " initialization returned null layer?"); } return new org.deeplearning4j.nn.graph.vertex.impl.LayerVertex(graph, name, idx, layer, preProcessor, isOutput, networkDatatype); @@ -125,7 +131,7 @@ public class LayerVertex extends GraphVertex { else afterPreprocessor = preProcessor.getOutputType(vertexInputs[0]); - InputType ret = layerConf.getLayer().getOutputType(layerIndex, afterPreprocessor); + InputType ret = layerConfiguration.getOutputType(layerIndex, afterPreprocessor); return ret; } @@ -142,11 +148,13 @@ public class LayerVertex extends GraphVertex { it = inputTypes[0]; } //TODO preprocessor memory - return layerConf.getLayer().getMemoryReport(it); + return layerConfiguration.getMemoryReport(it); } @Override public void setDataType(DataType dataType){ - layerConf.getLayer().setDataType(dataType); + layerConfiguration.setDataType(dataType); } + + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java index 0b10cedd4..378ae01a2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; @@ -48,6 +49,7 @@ public class ActivationLayer extends NoParamLayer { protected ActivationLayer(Builder builder) { super(builder); + setType(LayerType.ACT); this.activationFn = builder.activationFn; initializeConstraints(builder); } @@ -75,13 +77,16 @@ public class ActivationLayer extends NoParamLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - org.deeplearning4j.nn.layers.ActivationLayer ret = new org.deeplearning4j.nn.layers.ActivationLayer(conf, networkDataType); + this.setNetConfiguration(conf); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + + org.deeplearning4j.nn.layers.ActivationLayer ret = new org.deeplearning4j.nn.layers.ActivationLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -126,7 +131,7 @@ public class ActivationLayer extends NoParamLayer { @NoArgsConstructor @Getter @Setter - public static class Builder extends org.deeplearning4j.nn.conf.layers.Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * Activation function for the layer diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java index 09f14e034..311359f7f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java @@ -55,14 +55,17 @@ public class AutoEncoder extends BasePretrainNetwork { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + this.setNetConfiguration(conf); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder ret = - new org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder(conf, networkDataType); + new org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(conf.getFlattenedLayerConfigurations().get(layerIndex)); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java index 6aad5b0ef..bf30e0f7a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java @@ -47,9 +47,10 @@ import java.util.List; @Data @EqualsAndHashCode(callSuper = true) @NoArgsConstructor -public abstract class BaseLayer extends Layer implements Serializable, Cloneable { +public abstract class BaseLayer extends LayerConfiguration implements Serializable, Cloneable { protected IActivation activationFn; + @NonNull protected IWeightInit weightInitFn; protected double biasInit; protected double gainInit; @@ -153,7 +154,7 @@ public abstract class BaseLayer extends Layer implements Serializable, Cloneable @SuppressWarnings("unchecked") @Getter @Setter - public abstract static class Builder> extends Layer.Builder { + public abstract static class Builder> extends LayerConfiguration.Builder { /** * Set the activation function for the layer. This overload can be used for custom {@link IActivation} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseUpsamplingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseUpsamplingLayer.java index b92ad390f..07220f89e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseUpsamplingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseUpsamplingLayer.java @@ -21,10 +21,8 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; -import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.params.EmptyParamInitializer; /** * Upsampling base layer @@ -64,7 +62,7 @@ public abstract class BaseUpsamplingLayer extends NoParamLayer { @NoArgsConstructor @Getter @Setter - protected static abstract class UpsamplingBuilder> extends Layer.Builder { + protected static abstract class UpsamplingBuilder> extends LayerConfiguration.Builder { /** * An int array to specify upsampling dimensions, the length of which has to equal to the number of spatial diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java index 2dd228b0e..68e3a0851 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; @@ -64,6 +65,7 @@ public class BatchNormalization extends FeedForwardLayer { private BatchNormalization(Builder builder) { super(builder); + this.setType(LayerType.BN); this.decay = builder.decay; this.eps = builder.eps; this.isMinibatch = builder.isMinibatch; @@ -89,16 +91,18 @@ public class BatchNormalization extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - LayerValidation.assertNOutSet("BatchNormalization", getLayerName(), layerIndex, getNOut()); + this.setNetConfiguration(conf); + LayerValidation.assertNOutSet("BatchNormalization", getLayerName(), layerIndex, getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.normalization.BatchNormalization ret = - new org.deeplearning4j.nn.layers.normalization.BatchNormalization(conf, networkDataType); + new org.deeplearning4j.nn.layers.normalization.BatchNormalization(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java index 4081930c9..05d32dc56 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleLayer.java @@ -211,7 +211,7 @@ public class CapsuleLayer extends SameDiffLayer { } @Override - public E build() { + public E build() { return (E) new CapsuleLayer(this); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleStrengthLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleStrengthLayer.java index bd75b863e..e702b2de1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleStrengthLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CapsuleStrengthLayer.java @@ -59,7 +59,7 @@ public class CapsuleStrengthLayer extends SameDiffLambdaLayer { public static class Builder extends SameDiffLambdaLayer.Builder{ @Override - public E build() { + public E build() { return (E) new CapsuleStrengthLayer(this); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java index 820d73d5d..a25a10947 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java @@ -61,15 +61,17 @@ public class CenterLossOutputLayer extends BaseOutputLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + setNetConfiguration(conf); LayerValidation.assertNInNOutSet("CenterLossOutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); - Layer ret = new org.deeplearning4j.nn.layers.training.CenterLossOutputLayer(conf, networkDataType); - ret.setListeners(trainingListeners); + Layer ret = new org.deeplearning4j.nn.layers.training.CenterLossOutputLayer(lconf, networkDataType); + ret.setListeners(trainingListeners.toArray(new TrainingListener[]{})); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java index 774397ede..5c3cede7e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java @@ -56,14 +56,16 @@ public class Cnn3DLossLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + setNetConfiguration(conf); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer ret = - new org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java index 0b31dd703..bcad7fb65 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java @@ -61,14 +61,16 @@ public class CnnLossLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + setNetConfiguration(conf); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.convolution.CnnLossLayer ret = - new org.deeplearning4j.nn.layers.convolution.CnnLossLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.CnnLossLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index 1bd0e5172..eeb023374 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -64,16 +64,17 @@ public class Convolution1DLayer extends ConvolutionLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + setNetConfiguration(conf); LayerValidation.assertNInNOutSet("Convolution1DLayer", getLayerName(), layerIndex, getNIn(), getNOut()); - + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.convolution.Convolution1DLayer ret = - new org.deeplearning4j.nn.layers.convolution.Convolution1DLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.Convolution1DLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java index f012b0008..28a03ed4e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java @@ -97,13 +97,15 @@ public class Convolution3D extends ConvolutionLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("Convolution3D", getLayerName(), layerIndex, getNIn(), getNOut()); - Convolution3DLayer ret = new Convolution3DLayer(conf, networkDataType); + + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + Convolution3DLayer ret = new Convolution3DLayer(lconf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index ae26e62f0..a09d33506 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.*; @@ -113,6 +114,7 @@ public class ConvolutionLayer extends FeedForwardLayer { */ protected ConvolutionLayer(BaseConvBuilder builder) { super(builder); + this.setType(LayerType.CONV); int dim = builder.convolutionDim; this.hasBias = builder.hasBias; @@ -168,16 +170,19 @@ public class ConvolutionLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + setNetConfiguration(conf); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), layerIndex, getNIn(), getNOut()); org.deeplearning4j.nn.layers.convolution.ConvolutionLayer ret = - new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index e4f789ab7..d5b113b7f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -81,16 +81,19 @@ public class Deconvolution2D extends ConvolutionLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + setNetConfiguration(conf); LayerValidation.assertNInNOutSet("Deconvolution2D", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer ret = - new org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java index 9f96b25da..99ed3137b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java @@ -30,10 +30,8 @@ import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer; import org.deeplearning4j.nn.layers.convolution.Deconvolution3DLayer; import org.deeplearning4j.nn.params.Deconvolution3DParamInitializer; -import org.deeplearning4j.nn.params.DeconvolutionParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.util.ValidationUtils; import org.nd4j.linalg.api.buffer.DataType; @@ -84,15 +82,15 @@ public class Deconvolution3D extends ConvolutionLayer { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("Deconvolution2D", getLayerName(), layerIndex, getNIn(), getNOut()); - + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); Deconvolution3DLayer ret = - new Deconvolution3DLayer(conf, networkDataType); + new Deconvolution3DLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java index 1a6ce905c..fce42e8e5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.DefaultParamInitializer; +import org.deeplearning4j.nn.weights.WeightInitXavier; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -35,6 +36,10 @@ import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; import java.util.Map; +/** + * Dense Layer + * Uses WeightInitXavier as default + */ @Data @NoArgsConstructor @ToString(callSuper = true) @@ -55,16 +60,20 @@ public class DenseLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret = - new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType); + if(getWeightInitFn() == null) setWeightInitFn(new WeightInitXavier()); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); + return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java index d412c7158..52eb89ecf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java @@ -68,13 +68,15 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("DepthwiseConvolution2D", getLayerName(), layerIndex, getNIn(), getNOut()); - DepthwiseConvolution2DLayer ret = new DepthwiseConvolution2DLayer(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + DepthwiseConvolution2DLayer ret = new DepthwiseConvolution2DLayer(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java index fa20692be..573b6c617 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -46,7 +47,9 @@ import java.util.Map; public class DropoutLayer extends FeedForwardLayer { private DropoutLayer(Builder builder) { + super(builder); + setType(LayerType.DO); } public DropoutLayer(double activationRetainProb){ @@ -66,13 +69,17 @@ public class DropoutLayer extends FeedForwardLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - org.deeplearning4j.nn.layers.DropoutLayer ret = new org.deeplearning4j.nn.layers.DropoutLayer(conf, networkDataType); + setNetConfiguration(conf); + + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.DropoutLayer ret = new org.deeplearning4j.nn.layers.DropoutLayer(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java index 67199aa64..3ef26352b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java @@ -27,7 +27,6 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; -import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.EmbeddingLayerParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.embeddings.ArrayEmbeddingInitializer; @@ -58,14 +57,16 @@ public class EmbeddingLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret = - new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index ea7b4e6bf..133b0b6c1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -65,14 +65,16 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer ret = - new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java index 8e8fd62a3..3728e55bb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.conf.DataFormat; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -43,9 +44,11 @@ public abstract class FeedForwardLayer extends BaseLayer { super(builder); this.nIn = builder.nIn; this.nOut = builder.nOut; + setType(LayerType.FC); } + @Override public InputType getOutputType(int layerIndex, InputType inputType) { if (inputType == null || (inputType.getType() != InputType.Type.FF diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java index cdf92720a..1cd9e6c91 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java @@ -69,14 +69,16 @@ public class GlobalPoolingLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer ret = - new org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -189,7 +191,7 @@ public class GlobalPoolingLayer extends NoParamLayer { @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * Pooling type for global pooling diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java index 76a943509..ac6242e9a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java @@ -59,7 +59,7 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer { } @Override - protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Builder builder) { + protected void initializeConstraints(LayerConfiguration.Builder builder) { super.initializeConstraints(builder); if (((Builder) builder).recurrentConstraints != null) { if (constraints == null) { @@ -79,14 +79,16 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM ret = - new org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM(conf, networkDataType); + new org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java index e12d6df22..bb84cedae 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java @@ -59,7 +59,7 @@ public class GravesLSTM extends AbstractLSTM { } @Override - protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Builder builder) { + protected void initializeConstraints(LayerConfiguration.Builder builder) { super.initializeConstraints(builder); if (((Builder) builder).recurrentConstraints != null) { if (constraints == null) { @@ -77,14 +77,16 @@ public class GravesLSTM extends AbstractLSTM { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("GravesLSTM", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.recurrent.GravesLSTM ret = - new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(conf, networkDataType); + new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java index 0f0d61fc3..8474d3089 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java @@ -57,7 +57,7 @@ public class LSTM extends AbstractLSTM { } @Override - protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Builder builder) { + protected void initializeConstraints(LayerConfiguration.Builder builder) { super.initializeConstraints(builder); if (((Builder) builder).recurrentConstraints != null) { if (constraints == null) { @@ -75,13 +75,14 @@ public class LSTM extends AbstractLSTM { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("LSTM", getLayerName(), layerIndex, getNIn(), getNOut()); - org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java similarity index 90% rename from cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java index 66f48dd14..a41870c3d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Layer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java @@ -20,10 +20,20 @@ package org.deeplearning4j.nn.conf.layers; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import java.io.Serializable; +import java.lang.reflect.Field; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; import lombok.Data; import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.Setter; +import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.api.layers.LayerConstraint; @@ -34,35 +44,49 @@ import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; -import com.fasterxml.jackson.annotation.JsonTypeInfo; - -import java.io.Serializable; -import java.lang.reflect.Field; -import java.util.*; /** * A neural network layer. + * */ @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @Data @NoArgsConstructor -public abstract class Layer implements TrainingConfig, Serializable, Cloneable { + +public abstract class LayerConfiguration implements TrainingConfig, Serializable, Cloneable { protected String layerName; + @Getter + protected List variables = new ArrayList<>(); + public void addVariable(String s) {variables.add(s);} + protected IDropout iDropout; protected List constraints; + /** + * The type of the layer, basically defines the base class and its properties + */ + @Getter @Setter @NonNull + private LayerType type = LayerType.UNKNOWN; - public Layer(Builder builder) { + @Getter @Setter + private NeuralNetConfiguration netConfiguration; + + public LayerConfiguration(Builder builder) { this.layerName = builder.layerName; this.iDropout = builder.iDropout; } + public String toJson() { + throw new RuntimeException("toJson is not implemented for LayerConfiguration"); + } + /** * Initialize the weight constraints. Should be called last, in the outer-most constructor */ @@ -113,10 +137,19 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable { this.constraints = null; } + /** + * Migration workaround //TODO To be removed + * + * @return a layer configuration + */ + @Deprecated + public LayerConfiguration getLayer() { + return this; + } @Override - public Layer clone() { + public LayerConfiguration clone() { try { - Layer ret = (Layer) super.clone(); + LayerConfiguration ret = (LayerConfiguration) super.clone(); //Let's check for any INDArray fields and dup them (in case cloned layer will be used in different threads on CUDA... // we don't want it being relocated contantly between devices) Class c = getClass(); @@ -150,7 +183,7 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable { } } - public abstract org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, + public abstract org.deeplearning4j.nn.api.Layer instantiate( @NonNull NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType); @@ -239,7 +272,14 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable { */ public abstract LayerMemoryReport getMemoryReport(InputType inputType); - @SuppressWarnings("unchecked") + public void clearVariables() { + this.variables.clear(); + } + + @Getter + public IActivation activationFn; + + @SuppressWarnings("unchecked") @Getter @Setter public abstract static class Builder> { @@ -344,6 +384,6 @@ public abstract class Layer implements TrainingConfig, Serializable, Cloneable { return (T) this; } - public abstract E build(); + public abstract E build(); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java index 571f884e3..a125d4ffc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java @@ -74,7 +74,7 @@ public class LayerValidation { } } - public static void generalValidation(String layerName, Layer layer, IDropout iDropout, List regularization, + public static void generalValidation(String layerName, LayerConfiguration layer, IDropout iDropout, List regularization, List regularizationBias, List allParamConstraints, List weightConstraints, List biasConstraints) { @@ -82,8 +82,8 @@ public class LayerValidation { if (layer instanceof BaseLayer) { BaseLayer bLayer = (BaseLayer) layer; configureBaseLayer(layerName, bLayer, iDropout, regularization, regularizationBias); - } else if (layer instanceof FrozenLayer && ((FrozenLayer) layer).getLayer() instanceof BaseLayer) { - BaseLayer bLayer = (BaseLayer) ((FrozenLayer) layer).getLayer(); + } else if (layer instanceof FrozenLayer && ((FrozenLayer) layer).getInnerConfiguration() instanceof BaseLayer) { + BaseLayer bLayer = (BaseLayer) ((FrozenLayer) layer).getInnerConfiguration(); configureBaseLayer(layerName, bLayer, iDropout, regularization, regularizationBias); } else if (layer instanceof Bidirectional) { Bidirectional l = (Bidirectional) layer; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java index 98d7fa093..77483640c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java @@ -44,7 +44,7 @@ import java.util.Map; @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) -public class LocalResponseNormalization extends Layer { +public class LocalResponseNormalization extends LayerConfiguration { // Defaults as per http://www.cs.toronto.edu/~fritz/absps/imagenet.pdf //Set defaults here as well as in builder, in case users use no-arg constructor instead of builder @@ -75,14 +75,16 @@ public class LocalResponseNormalization extends Layer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization ret = - new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(conf, networkDataType); + new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -154,7 +156,7 @@ public class LocalResponseNormalization extends Layer { @AllArgsConstructor @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { // defaults based on AlexNet model @@ -275,7 +277,7 @@ public class LocalResponseNormalization extends Layer { * Set the data format for the CNN activations - NCHW (channels first) or NHWC (channels last). * See {@link CNN2DFormat} for more details.
* Default: NCHW - * @param format Format for activations (in and out) + * @param dataFormat Format for activations (in and out) */ public Builder dataFormat(CNN2DFormat dataFormat){ this.dataFormat = dataFormat; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 921d0f9ea..2a8afacb7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -212,12 +212,13 @@ public class LocallyConnected1D extends SameDiffLayer { } @Override - public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) { + public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { + NeuralNetConfiguration global_conf = globalConfig.build(); if (activation == null) { - activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivationFn()); } if (cm == null) { - cm = globalConfig.getConvolutionMode(); + cm = global_conf.getConvolutionMode(); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index b44055332..a33445ce7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -229,12 +229,13 @@ public class LocallyConnected2D extends SameDiffLayer { } @Override - public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) { + public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { + NeuralNetConfiguration gconf = globalConfig.build(); if (activation == null) { - activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(gconf.getActivationFn()); } if (cm == null) { - cm = globalConfig.getConvolutionMode(); + cm = gconf.getConvolutionMode(); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java index e88a66298..2e89f7ee7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java @@ -57,13 +57,15 @@ public class LossLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - org.deeplearning4j.nn.layers.LossLayer ret = new org.deeplearning4j.nn.layers.LossLayer(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.LossLayer ret = new org.deeplearning4j.nn.layers.LossLayer(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java index 227650a5f..7d0c181f8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.NoArgsConstructor; +import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -30,10 +31,12 @@ import org.nd4j.linalg.learning.regularization.Regularization; import java.util.List; @NoArgsConstructor -public abstract class NoParamLayer extends Layer { +public abstract class NoParamLayer extends LayerConfiguration { protected NoParamLayer(Builder builder) { + super(builder); + setType(LayerType.POOL); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java index d31ff854a..2616ed8d9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java @@ -53,14 +53,15 @@ public class OutputLayer extends BaseOutputLayer { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("OutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); - org.deeplearning4j.nn.layers.OutputLayer ret = new org.deeplearning4j.nn.layers.OutputLayer(conf, networkDataType); + org.deeplearning4j.nn.layers.OutputLayer ret = new org.deeplearning4j.nn.layers.OutputLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java index 289009ad7..e44f7f709 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java @@ -59,13 +59,14 @@ public class PReLULayer extends BaseLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - org.deeplearning4j.nn.layers.feedforward.PReLU ret = new org.deeplearning4j.nn.layers.feedforward.PReLU(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.feedforward.PReLU ret = new org.deeplearning4j.nn.layers.feedforward.PReLU(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java index 4d3f56a84..fc0c256f7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PrimaryCapsules.java @@ -417,7 +417,7 @@ public class PrimaryCapsules extends SameDiffLayer { } @Override - public E build() { + public E build() { return (E) new PrimaryCapsules(this); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java index 161acc44e..10924fd90 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java @@ -150,9 +150,9 @@ public class RecurrentAttentionLayer extends SameDiffLayer { } @Override - public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) { + public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { if (activation == null) { - activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivationFn()); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java index 376886cc4..1127d0be0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java @@ -59,14 +59,17 @@ public class RnnLossLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.recurrent.RnnLossLayer ret = - new org.deeplearning4j.nn.layers.recurrent.RnnLossLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.recurrent.RnnLossLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java index 9f17c2cee..629e70da6 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java @@ -59,15 +59,16 @@ public class RnnOutputLayer extends BaseOutputLayer { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("RnnOutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer ret = - new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java index a6efb86b1..34bc03086 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java @@ -79,7 +79,7 @@ public class SeparableConvolution2D extends ConvolutionLayer { } @Override - protected void initializeConstraints(org.deeplearning4j.nn.conf.layers.Layer.Builder builder) { + protected void initializeConstraints(LayerConfiguration.Builder builder) { super.initializeConstraints(builder); if (((Builder) builder).pointWiseConstraints != null) { if (constraints == null) { @@ -117,15 +117,16 @@ public class SeparableConvolution2D extends ConvolutionLayer { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("SeparableConvolution2D", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer ret = - new org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java index b8de7a4e4..ff4082075 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java @@ -67,14 +67,16 @@ public class SpaceToBatchLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret = - new org.deeplearning4j.nn.layers.convolution.SpaceToBatch(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.SpaceToBatch(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -130,7 +132,7 @@ public class SpaceToBatchLayer extends NoParamLayer { @NoArgsConstructor @Getter @Setter - public static class Builder> extends Layer.Builder { + public static class Builder> extends LayerConfiguration.Builder { /** * Block size for SpaceToBatch layer. Should be a length 2 array for the height and width diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java index b35092359..110d127b0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java @@ -73,14 +73,16 @@ public class SpaceToDepthLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.SpaceToDepth ret = - new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -134,7 +136,7 @@ public class SpaceToDepthLayer extends NoParamLayer { @NoArgsConstructor @Getter @Setter - public static class Builder> extends Layer.Builder { + public static class Builder> extends LayerConfiguration.Builder { protected int blockSize; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 267e67005..5d48dfa6b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -61,14 +61,16 @@ public class Subsampling1DLayer extends SubsamplingLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer ret = - new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index cb643cd7b..d201c88b2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -113,14 +113,16 @@ public class Subsampling3DLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling3DLayer ret = - new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling3DLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling3DLayer(lconf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -338,7 +340,7 @@ public class Subsampling3DLayer extends NoParamLayer { @Setter @NoArgsConstructor protected static abstract class BaseSubsamplingBuilder> - extends Layer.Builder { + extends LayerConfiguration.Builder { protected org.deeplearning4j.nn.conf.layers.PoolingType poolingType = org.deeplearning4j.nn.conf.layers.PoolingType.MAX; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index f1d546234..32983b01c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -134,14 +134,16 @@ public class SubsamplingLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer ret = - new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -390,7 +392,7 @@ public class SubsamplingLayer extends NoParamLayer { @Getter @Setter protected static abstract class BaseSubsamplingBuilder> - extends Layer.Builder { + extends LayerConfiguration.Builder { protected org.deeplearning4j.nn.conf.layers.PoolingType poolingType = org.deeplearning4j.nn.conf.layers.PoolingType.MAX; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java index 6a012ed15..6f7a7c091 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java @@ -56,14 +56,17 @@ public class Upsampling1D extends BaseUpsamplingLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling1D ret = - new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling1D(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling1D(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java index bdbbb0c73..61693091a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java @@ -63,14 +63,16 @@ public class Upsampling2D extends BaseUpsamplingLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D ret = - new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java index ef5d832b4..f4d5fa280 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java @@ -61,14 +61,18 @@ public class Upsampling3D extends BaseUpsamplingLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D ret = - new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D(lconf, networkDataType); + + ret.setListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java index 98f6f8077..aa0268be1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java @@ -66,13 +66,15 @@ public class ZeroPadding1DLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.ZeroPadding1DLayer ret = - new org.deeplearning4j.nn.layers.convolution.ZeroPadding1DLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.ZeroPadding1DLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -125,7 +127,7 @@ public class ZeroPadding1DLayer extends NoParamLayer { @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * Padding value for left and right. Must be length 2 array diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java index f6b97cfcc..21d77ae03 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java @@ -53,13 +53,15 @@ public class ZeroPadding3DLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.ZeroPadding3DLayer ret = - new org.deeplearning4j.nn.layers.convolution.ZeroPadding3DLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.ZeroPadding3DLayer(lconf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -114,7 +116,7 @@ public class ZeroPadding3DLayer extends NoParamLayer { @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * [padLeftD, padRightD, padLeftH, padRightH, padLeftW, padRightW] diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java index 459205609..0d0e85d56 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java @@ -70,13 +70,15 @@ public class ZeroPaddingLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret = - new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -118,7 +120,7 @@ public class ZeroPaddingLayer extends NoParamLayer { @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * Padding value for top, bottom, left, and right. Must be length 4 array diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java index fd2546019..2124e9eb9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java @@ -25,7 +25,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.NoParamLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.convolution.Cropping1DLayer; @@ -76,12 +76,14 @@ public class Cropping1D extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - Cropping1DLayer ret = new Cropping1DLayer(conf, networkDataType); + setNetConfiguration(conf); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + Cropping1DLayer ret = new Cropping1DLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -113,7 +115,7 @@ public class Cropping1D extends NoParamLayer { @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * Cropping amount for top/bottom (in that order). Must be length 1 or 2 array. */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java index 29aad71bd..604a269cb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java @@ -26,7 +26,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.NoParamLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.convolution.Cropping2DLayer; @@ -92,12 +92,14 @@ public class Cropping2D extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - Cropping2DLayer ret = new Cropping2DLayer(conf, networkDataType); + setNetConfiguration(conf); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + Cropping2DLayer ret = new Cropping2DLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -131,7 +133,7 @@ public class Cropping2D extends NoParamLayer { @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * Cropping amount for top/bottom/left/right (in that order). A length 4 array. diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java index 1ab34b17b..c22c8f429 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java @@ -25,7 +25,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.InputTypeUtil; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.NoParamLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.convolution.Cropping3DLayer; @@ -84,12 +84,14 @@ public class Cropping3D extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - Cropping3DLayer ret = new Cropping3DLayer(conf, networkDataType); + setNetConfiguration(conf); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + Cropping3DLayer ret = new Cropping3DLayer(lconf, networkDataType); ret.setListeners(iterationListeners); ret.setIndex(layerIndex); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -121,7 +123,7 @@ public class Cropping3D extends NoParamLayer { @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * Cropping amount, a length 6 array, i.e. crop left depth, crop right depth, crop left height, crop right height, crop left width, crop right width diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java index 9eea40cfc..dc7e9b93d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.ElementWiseParamInitializer; @@ -58,18 +59,21 @@ public class ElementWiseMultiplicationLayer extends org.deeplearning4j.nn.conf.l @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + setNetConfiguration(conf); if (this.nIn != this.nOut) { throw new IllegalStateException("Element wise layer must have the same input and output size. Got nIn=" + nIn + ", nOut=" + nOut); } + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer ret = - new org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer(lconf, networkDataType); + ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java index ba5674bbb..35a4cae8d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.FrozenLayerParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -38,36 +38,32 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; import com.fasterxml.jackson.annotation.JsonProperty; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import java.util.Collection; import java.util.List; @EqualsAndHashCode(callSuper = false) -public class FrozenLayer extends Layer { +public class FrozenLayer extends LayerConfiguration { - @Getter - protected Layer layer; + /** + * A layer configuration, only if this layer config has been created from another one + */ + @Getter @Setter + private LayerConfiguration innerConfiguration; private FrozenLayer(Builder builder) { super(builder); - this.layer = builder.layer; + this.innerConfiguration = builder.layer; } - public FrozenLayer(@JsonProperty("layer") Layer layer) { - this.layer = layer; - } - - public NeuralNetConfiguration getInnerConf(NeuralNetConfiguration conf) { - NeuralNetConfiguration nnc = conf.clone(); - nnc.setLayer(layer); - return nnc; + public FrozenLayer(@JsonProperty("layer") LayerConfiguration layer) { + this.innerConfiguration = layer; } @Override - public Layer clone() { + public LayerConfiguration clone() { FrozenLayer l = (FrozenLayer) super.clone(); - l.layer = layer.clone(); + l.innerConfiguration = innerConfiguration.clone(); return l; } @@ -77,17 +73,17 @@ public class FrozenLayer extends Layer { boolean initializeParams, DataType networkDataType) { //Need to be able to instantiate a layer, from a config - for JSON -> net type situations - org.deeplearning4j.nn.api.Layer underlying = layer.instantiate(getInnerConf(conf), trainingListeners, + org.deeplearning4j.nn.api.Layer underlying = innerConfiguration.instantiate(getNetConfiguration(), trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType); - NeuralNetConfiguration nncUnderlying = underlying.conf(); - if (nncUnderlying.variables() != null) { - List vars = nncUnderlying.variables(true); - nncUnderlying.clearVariables(); - conf.clearVariables(); + NeuralNetConfiguration nncUnderlying = underlying.getNetConfiguration(); + if (nncUnderlying.netWideVariables() != null) { + List vars = nncUnderlying.netWideVariables(true); + nncUnderlying.clearNetWideVariable(); + conf.clearNetWideVariable(); for (String s : vars) { - conf.variables(false).add(s); - nncUnderlying.variables(false).add(s); + conf.netWideVariables(false).add(s); + nncUnderlying.netWideVariables(false).add(s); } } @@ -101,17 +97,17 @@ public class FrozenLayer extends Layer { @Override public InputType getOutputType(int layerIndex, InputType inputType) { - return layer.getOutputType(layerIndex, inputType); + return innerConfiguration.getOutputType(layerIndex, inputType); } @Override public void setNIn(InputType inputType, boolean override) { - layer.setNIn(inputType, override); + innerConfiguration.setNIn(inputType, override); } @Override public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return layer.getPreProcessorForInputType(inputType); + return innerConfiguration.getPreProcessorForInputType(inputType); } @Override @@ -131,38 +127,38 @@ public class FrozenLayer extends Layer { @Override public GradientNormalization getGradientNormalization() { - return layer.getGradientNormalization(); + return innerConfiguration.getGradientNormalization(); } @Override public double getGradientNormalizationThreshold() { - return layer.getGradientNormalizationThreshold(); + return innerConfiguration.getGradientNormalizationThreshold(); } @Override public LayerMemoryReport getMemoryReport(InputType inputType) { - return layer.getMemoryReport(inputType); + return innerConfiguration.getMemoryReport(inputType); } @Override public void setLayerName(String layerName) { super.setLayerName(layerName); - layer.setLayerName(layerName); + innerConfiguration.setLayerName(layerName); } @Override public void setConstraints(List constraints) { this.constraints = constraints; - this.layer.setConstraints(constraints); + this.innerConfiguration.setConstraints(constraints); } @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { - private Layer layer; + private LayerConfiguration layer; - public Builder layer(Layer layer) { + public Builder layer(LayerConfiguration layer) { this.setLayer(layer); return this; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java index 53d7ff914..ae438958f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java @@ -25,7 +25,7 @@ import lombok.EqualsAndHashCode; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.params.FrozenLayerWithBackpropParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -42,18 +42,19 @@ import java.util.List; @EqualsAndHashCode(callSuper = false) public class FrozenLayerWithBackprop extends BaseWrapperLayer { - public FrozenLayerWithBackprop(@JsonProperty("layer") Layer layer) { + public FrozenLayerWithBackprop(@JsonProperty("layer") LayerConfiguration layer) { super(layer); + underlying = layer; } public NeuralNetConfiguration getInnerConf(NeuralNetConfiguration conf) { NeuralNetConfiguration nnc = conf.clone(); - nnc.setLayer(underlying); + nnc.getLayerConfigurations().add(0, underlying); return nnc; } @Override - public Layer clone() { + public LayerConfiguration clone() { FrozenLayerWithBackprop l = (FrozenLayerWithBackprop) super.clone(); l.underlying = underlying.clone(); return l; @@ -65,18 +66,18 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { boolean initializeParams, DataType networkDataType) { //Need to be able to instantiate a layer, from a config - for JSON -> net type situations - org.deeplearning4j.nn.api.Layer underlying = getUnderlying().instantiate(getInnerConf(conf), trainingListeners, + org.deeplearning4j.nn.api.Layer underlying = getUnderlying().instantiate(conf, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType); - NeuralNetConfiguration nncUnderlying = underlying.conf(); + NeuralNetConfiguration nncUnderlying = underlying.getNetConfiguration(); - if (nncUnderlying.variables() != null) { - List vars = nncUnderlying.variables(true); - nncUnderlying.clearVariables(); - conf.clearVariables(); + if (nncUnderlying.netWideVariables() != null) { + List vars = nncUnderlying.netWideVariables(true); + nncUnderlying.clearNetWideVariable(); + conf.clearNetWideVariable(); for (String s : vars) { - conf.variables(false).add(s); - nncUnderlying.variables(false).add(s); + conf.netWideVariables(false).add(s); + nncUnderlying.netWideVariables(false).add(s); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java index 127502b68..ba85f879c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; @@ -65,13 +66,15 @@ public class RepeatVector extends FeedForwardLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - org.deeplearning4j.nn.layers.RepeatVector ret = new org.deeplearning4j.nn.layers.RepeatVector(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + + org.deeplearning4j.nn.layers.RepeatVector ret = new org.deeplearning4j.nn.layers.RepeatVector(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java index 1229e8cfd..d2d4bec81 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor; import org.deeplearning4j.nn.params.EmptyParamInitializer; @@ -51,7 +52,7 @@ import java.util.Map; @Data @EqualsAndHashCode(callSuper = false) -public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { +public class Yolo2OutputLayer extends LayerConfiguration { private double lambdaCoord; private double lambdaNoObj; @@ -79,14 +80,16 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret = - new org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -156,7 +159,7 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { @Getter @Setter - public static class Builder extends org.deeplearning4j.nn.conf.layers.Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { /** * Loss function coefficient for position and size/scale components of the loss function. Default (as per diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java index 388e131cd..5eda741e4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; @@ -47,13 +47,12 @@ import java.util.List; import java.util.Map; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; -import static org.nd4j.linalg.indexing.NDArrayIndex.point; @NoArgsConstructor @Data @EqualsAndHashCode(callSuper = true, exclude = {"initializer"}) @JsonIgnoreProperties({"initializer"}) -public class Bidirectional extends Layer { +public class Bidirectional extends LayerConfiguration { /** * This Mode enumeration defines how the activations for the forward and backward networks should be combined.
@@ -68,8 +67,8 @@ public class Bidirectional extends Layer { ADD, MUL, AVERAGE, CONCAT } - private Layer fwd; - private Layer bwd; + private LayerConfiguration fwd; + private LayerConfiguration bwd; private Mode mode; private transient BidirectionalParamInitializer initializer; @@ -82,7 +81,7 @@ public class Bidirectional extends Layer { * * @param layer layer to wrap */ - public Bidirectional(@NonNull Layer layer) { + public Bidirectional(@NonNull LayerConfiguration layer) { this(Mode.CONCAT, layer); } @@ -92,7 +91,7 @@ public class Bidirectional extends Layer { * @param mode Mode to use to combine activations. See {@link Mode} for details * @param layer layer to wrap */ - public Bidirectional(@NonNull Mode mode, @NonNull Layer layer) { + public Bidirectional(@NonNull Mode mode, @NonNull LayerConfiguration layer) { if (!(layer instanceof BaseRecurrentLayer || layer instanceof LastTimeStep || layer instanceof BaseWrapperLayer)) { throw new IllegalArgumentException("Cannot wrap a non-recurrent layer: " @@ -128,6 +127,7 @@ public class Bidirectional extends Layer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); NeuralNetConfiguration c1 = conf.clone(); NeuralNetConfiguration c2 = conf.clone(); c1.setLayer(fwd); @@ -140,10 +140,10 @@ public class Bidirectional extends Layer { org.deeplearning4j.nn.api.Layer b = bwd.instantiate(c2, trainingListeners, layerIndex, bp, initializeParams, networkDataType); - BidirectionalLayer ret = new BidirectionalLayer(conf, f, b, layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + BidirectionalLayer ret = new BidirectionalLayer(lconf, f, b, layerParamsView); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } @@ -238,12 +238,12 @@ public class Bidirectional extends Layer { @AllArgsConstructor @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { private Mode mode; - private Layer layer; + private LayerConfiguration layer; - public void setLayer(Layer layer) { + public void setLayer(LayerConfiguration layer) { rnnLayer(layer); } @@ -252,7 +252,7 @@ public class Bidirectional extends Layer { return this; } - public Builder rnnLayer(Layer layer) { + public Builder rnnLayer(LayerConfiguration layer) { if (!(layer instanceof BaseRecurrentLayer || layer instanceof LastTimeStep || layer instanceof BaseWrapperLayer)) { throw new IllegalArgumentException("Cannot wrap a non-recurrent layer: " diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java index ce87b8051..a869999dc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.conf.layers.recurrent; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -35,12 +35,12 @@ public class LastTimeStep extends BaseWrapperLayer { private LastTimeStep() {} - public LastTimeStep(Layer underlying) { + public LastTimeStep(LayerConfiguration underlying) { super(underlying); this.layerName = underlying.getLayerName(); // needed for keras import to match names } - public Layer getUnderlying() { + public LayerConfiguration getUnderlying() { return underlying; } @@ -49,8 +49,9 @@ public class LastTimeStep extends BaseWrapperLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); NeuralNetConfiguration conf2 = conf.clone(); - conf2.setLayer(((LastTimeStep) conf2.getLayer()).getUnderlying()); + conf2.setLayer(((LastTimeStep) lconf).getUnderlying()); return new LastTimeStepLayer(underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType)); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java index 7cbebeaf2..bda494c1d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerValidation; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; @@ -55,15 +56,16 @@ public class SimpleRnn extends BaseRecurrentLayer { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("SimpleRnn", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.recurrent.SimpleRnn ret = - new org.deeplearning4j.nn.layers.recurrent.SimpleRnn(conf, networkDataType); + new org.deeplearning4j.nn.layers.recurrent.SimpleRnn(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java index 54a93b904..7ab6370b7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java @@ -27,7 +27,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.layers.recurrent.TimeDistributedLayer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -46,20 +46,22 @@ public class TimeDistributed extends BaseWrapperLayer { /** * @param underlying Underlying (internal) layer - should be a feed forward type such as DenseLayerConfiguration */ - public TimeDistributed(@JsonProperty("underlying") @NonNull Layer underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { + public TimeDistributed(@JsonProperty("underlying") @NonNull LayerConfiguration underlying, @JsonProperty("rnnDataFormat") RNNFormat rnnDataFormat) { super(underlying); this.rnnDataFormat = rnnDataFormat; } - public TimeDistributed(Layer underlying){ + public TimeDistributed(LayerConfiguration underlying){ super(underlying); } @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + NeuralNetConfiguration conf2 = conf.clone(); - conf2.setLayer(((TimeDistributed) conf2.getLayer()).getUnderlying()); + conf2.setLayer(((TimeDistributed) lconf).getUnderlying()); return new TimeDistributedLayer(underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType), rnnDataFormat); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java index 71bb2a95a..18c4601c8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.SameDiffParamInitializer; import org.deeplearning4j.nn.weights.WeightInit; @@ -54,7 +54,7 @@ import java.util.Map; @Slf4j @Data @EqualsAndHashCode(callSuper = true, doNotUseGetters = true) -public abstract class AbstractSameDiffLayer extends Layer { +public abstract class AbstractSameDiffLayer extends LayerConfiguration { protected List regularization; protected List regularizationBias; @@ -121,7 +121,7 @@ public abstract class AbstractSameDiffLayer extends Layer { } - public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) { + public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { //Default implementation: no op } @@ -187,24 +187,25 @@ public abstract class AbstractSameDiffLayer extends Layer { WeightInitUtil.initWeights(fanIn, fanOut, array.shape(), weightInit, null, paramReshapeOrder(null), array); } - public void applyGlobalConfig(NeuralNetConfiguration.Builder b) { + public void applyGlobalConfig(NeuralNetConfiguration.NeuralNetConfigurationBuilder b) { + NeuralNetConfiguration bConf = b.build(); if (regularization == null || regularization.isEmpty()) { - regularization = b.getRegularization(); + regularization = bConf.getRegularization(); } if (regularizationBias == null || regularizationBias.isEmpty()) { - regularizationBias = b.getRegularizationBias(); + regularizationBias = bConf.getRegularizationBias(); } if (updater == null) { - updater = b.getIUpdater(); + updater = bConf.getIUpdater(); } if (biasUpdater == null) { - biasUpdater = b.getBiasUpdater(); + biasUpdater = bConf.getBiasUpdater(); } if (gradientNormalization == null) { - gradientNormalization = b.getGradientNormalization(); + gradientNormalization = bConf.getGradientNormalization(); } if (Double.isNaN(gradientNormalizationThreshold)) { - gradientNormalizationThreshold = b.getGradientNormalizationThreshold(); + gradientNormalizationThreshold = bConf.getGradientNormalizationThreshold(); } applyGlobalConfigToLayer(b); @@ -234,7 +235,7 @@ public abstract class AbstractSameDiffLayer extends Layer { @Getter @Setter - public static abstract class Builder> extends Layer.Builder { + public static abstract class Builder> extends LayerConfiguration.Builder { protected List regularization = new ArrayList<>(); protected List regularizationBias = new ArrayList<>(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java index ea8fc2b09..cb16d2f26 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffLayer.java @@ -24,6 +24,7 @@ import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.api.TrainingListener; @@ -85,16 +86,19 @@ public abstract class SameDiffLayer extends AbstractSameDiffLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.samediff.SameDiffLayer ret = - new org.deeplearning4j.nn.layers.samediff.SameDiffLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.samediff.SameDiffLayer(lconf, networkDataType); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } + @SuppressWarnings("unchecked") @Getter @Setter diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java index d781dd244..8fa7fd4d0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffOutputLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers.samediff; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; @@ -74,13 +75,15 @@ public abstract class SameDiffOutputLayer extends AbstractSameDiffLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer ret = - new org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.samediff.SameDiffOutputLayer(lconf, networkDataType); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java index 94a13ffec..cfec8d653 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java @@ -147,30 +147,31 @@ public abstract class SameDiffVertex extends GraphVertex implements TrainingConf } - public void applyGlobalConfig(NeuralNetConfiguration.Builder b) { + public void applyGlobalConfig(NeuralNetConfiguration b_conf) { + if(regularization == null || regularization.isEmpty()){ - regularization = b.getRegularization(); + regularization = b_conf.getRegularization(); } if(regularizationBias == null || regularizationBias.isEmpty()){ - regularizationBias = b.getRegularizationBias(); + regularizationBias = b_conf.getRegularizationBias(); } if (updater == null) { - updater = b.getIUpdater(); + updater = b_conf.getIUpdater(); } if (biasUpdater == null) { - biasUpdater = b.getBiasUpdater(); + biasUpdater = b_conf.getBiasUpdater(); } if (gradientNormalization == null) { - gradientNormalization = b.getGradientNormalization(); + gradientNormalization = b_conf.getGradientNormalization(); } if (Double.isNaN(gradientNormalizationThreshold)) { - gradientNormalizationThreshold = b.getGradientNormalizationThreshold(); + gradientNormalizationThreshold = b_conf.getGradientNormalizationThreshold(); } - applyGlobalConfigToLayer(b); + applyGlobalConfigToLayer(b_conf); } - public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) { + public void applyGlobalConfigToLayer(NeuralNetConfiguration globalConfig) { //Default implementation: no op } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java index 181d32b4c..bd39eb828 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskLayer.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.NoParamLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.EmptyParamInitializer; @@ -43,12 +44,13 @@ public class MaskLayer extends NoParamLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - org.deeplearning4j.nn.layers.util.MaskLayer ret = new org.deeplearning4j.nn.layers.util.MaskLayer(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + org.deeplearning4j.nn.layers.util.MaskLayer ret = new org.deeplearning4j.nn.layers.util.MaskLayer(lconf, networkDataType); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java index 8a3d309a5..7f11874e8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java @@ -24,7 +24,7 @@ import lombok.*; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.optimize.api.TrainingListener; @@ -49,7 +49,7 @@ public class MaskZeroLayer extends BaseWrapperLayer { } - public MaskZeroLayer(@JsonProperty("underlying") Layer underlying, @JsonProperty("maskingValue") double maskingValue) { + public MaskZeroLayer(@JsonProperty("underlying") LayerConfiguration underlying, @JsonProperty("maskingValue") double maskingValue) { this.underlying = underlying; this.maskingValue = maskingValue; } @@ -61,7 +61,7 @@ public class MaskZeroLayer extends BaseWrapperLayer { boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf2 = conf.clone(); - conf2.setLayer(((BaseWrapperLayer) conf2.getLayer()).getUnderlying()); + conf2.setLayer(((BaseWrapperLayer) this).getUnderlying()); org.deeplearning4j.nn.api.Layer underlyingLayer = underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType); @@ -102,12 +102,12 @@ public class MaskZeroLayer extends BaseWrapperLayer { @NoArgsConstructor @Getter @Setter - public static class Builder extends Layer.Builder { + public static class Builder extends LayerConfiguration.Builder { - private Layer underlying; + private LayerConfiguration underlying; private double maskValue; - public Builder setUnderlying(Layer underlying) { + public Builder setUnderlying(LayerConfiguration underlying) { this.underlying = underlying; return this; } @@ -117,7 +117,7 @@ public class MaskZeroLayer extends BaseWrapperLayer { return this; } - public Builder underlying(Layer underlying){ + public Builder underlying(LayerConfiguration underlying){ setUnderlying(underlying); return this; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java index ca1f10bd0..4e6a0c41c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerValidation; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; @@ -68,15 +69,16 @@ public class VariationalAutoencoder extends BasePretrainNetwork { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("VariationalAutoencoder", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret = - new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(conf, networkDataType); + new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java index ca90ee7a1..2495fbd56 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java @@ -22,11 +22,13 @@ package org.deeplearning4j.nn.conf.layers.wrapper; import lombok.Data; import lombok.EqualsAndHashCode; +import lombok.NonNull; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.WrapperLayerParamInitializer; import org.nd4j.linalg.learning.regularization.Regularization; @@ -35,9 +37,24 @@ import java.util.List; @Data @EqualsAndHashCode(callSuper = false) -public abstract class BaseWrapperLayer extends Layer { +public abstract class BaseWrapperLayer extends LayerConfiguration { - protected Layer underlying; + /** + * Set the net configuration for this configuration as well as for the underlying layer + * (if not null there) + * + * @param netConfiguration the neural net configuration + */ + @Override + public void setNetConfiguration(NeuralNetConfiguration netConfiguration) { + super.setNetConfiguration(netConfiguration); + if(getUnderlying().getNetConfiguration() == null) { + getUnderlying().setNetConfiguration( + netConfiguration); //also set netconf for underlying if not set + } + } + + protected LayerConfiguration underlying; protected BaseWrapperLayer(Builder builder) { super(builder); @@ -45,8 +62,9 @@ public abstract class BaseWrapperLayer extends Layer { protected BaseWrapperLayer() {} - public BaseWrapperLayer(Layer underlying) { + public BaseWrapperLayer(LayerConfiguration underlying) { this.underlying = underlying; + this.setNetConfiguration(underlying.getNetConfiguration()); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java index 696d63f5d..8469c6f62 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerValidation; import org.deeplearning4j.nn.layers.ocnn.OCNNParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -102,15 +103,16 @@ public class OCNNOutputLayer extends BaseOutputLayer { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("OCNNOutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer ret = - new org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer(conf, networkDataType); + new org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(lconf); ret.setActivation(activationFn); if (lastEpochSinceRUpdated == 0 && configureR) { paramTable.get(OCNNParamInitializer.R_KEY).putScalar(0, initialRValue); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java index abd52c0c3..c6a2cbb26 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java @@ -26,11 +26,10 @@ import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.*; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; -import org.nd4j.linalg.activations.impl.*; import org.nd4j.linalg.learning.config.*; import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.Regularization; @@ -38,7 +37,6 @@ import org.nd4j.linalg.learning.regularization.WeightDecay; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.impl.*; import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.JsonDeserializer; import com.fasterxml.jackson.databind.JsonMappingException; @@ -66,8 +64,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im public abstract T deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException; - protected boolean requiresIUpdaterFromLegacy(Layer[] layers){ - for(Layer l : layers){ + protected boolean requiresIUpdaterFromLegacy(LayerConfiguration[] layers){ + for(LayerConfiguration l : layers){ if(l instanceof BaseLayer){ BaseLayer bl = (BaseLayer)l; if(bl.getIUpdater() == null && bl.initializer().numParams(bl) > 0){ @@ -78,8 +76,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im return false; } - protected boolean requiresDropoutFromLegacy(Layer[] layers){ - for(Layer l : layers){ + protected boolean requiresDropoutFromLegacy(LayerConfiguration[] layers){ + for(LayerConfiguration l : layers){ if(l.getIDropout() != null){ return false; } @@ -87,8 +85,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im return true; } - protected boolean requiresRegularizationFromLegacy(Layer[] layers){ - for(Layer l : layers){ + protected boolean requiresRegularizationFromLegacy(LayerConfiguration[] layers){ + for(LayerConfiguration l : layers){ if(l instanceof BaseLayer && ((BaseLayer)l).getRegularization() == null){ return true; } @@ -96,8 +94,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im return false; } - protected boolean requiresWeightInitFromLegacy(Layer[] layers){ - for(Layer l : layers){ + protected boolean requiresWeightInitFromLegacy(LayerConfiguration[] layers){ + for(LayerConfiguration l : layers){ if(l instanceof BaseLayer && ((BaseLayer)l).getWeightInitFn() == null){ return true; } @@ -105,8 +103,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im return false; } - protected boolean requiresActivationFromLegacy(Layer[] layers){ - for(Layer l : layers){ + protected boolean requiresActivationFromLegacy(LayerConfiguration[] layers){ + for(LayerConfiguration l : layers){ if(l instanceof BaseLayer && ((BaseLayer)l).getActivationFn() == null){ return true; } @@ -114,8 +112,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im return false; } - protected boolean requiresLegacyLossHandling(Layer[] layers){ - for(Layer l : layers){ + protected boolean requiresLegacyLossHandling(LayerConfiguration[] layers){ + for(LayerConfiguration l : layers){ if(l instanceof BaseOutputLayer && ((BaseOutputLayer)l).getLossFn() == null){ return true; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java index edd9cbef8..cf9282771 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.weightnoise.DropConnect; import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import com.fasterxml.jackson.core.JsonLocation; @@ -65,16 +65,16 @@ public class ComputationGraphConfigurationDeserializer //Previously: enumerations and fields. Now: classes //Here, we manually create the appropriate Updater instances, if the IUpdater field is empty - List layerList = new ArrayList<>(); + List layerList = new ArrayList<>(); Map vertices = conf.getVertices(); for (Map.Entry entry : vertices.entrySet()) { if (entry.getValue() instanceof LayerVertex) { LayerVertex lv = (LayerVertex) entry.getValue(); - layerList.add(lv.getLayerConf().getLayer()); + layerList.add(lv.getLayerConfiguration()); } } - Layer[] layers = layerList.toArray(new Layer[layerList.size()]); + LayerConfiguration[] layers = layerList.toArray(new LayerConfiguration[layerList.size()]); //Now, check if we need to manually handle IUpdater deserialization from legacy format boolean attemptIUpdaterFromLegacy = requiresIUpdaterFromLegacy(layers); boolean requireLegacyRegularizationHandling = requiresRegularizationFromLegacy(layers); @@ -171,9 +171,9 @@ public class ComputationGraphConfigurationDeserializer // but, as there is no useLogStdev=false property for legacy batchnorm JSON, the 'real' value (useLogStdev=false) // is not set to override the default, unless we do it manually here for(GraphVertex gv : conf.getVertices().values()){ - if(gv instanceof LayerVertex && ((LayerVertex) gv).getLayerConf().getLayer() instanceof BatchNormalization){ - BatchNormalization bn = (BatchNormalization) ((LayerVertex) gv).getLayerConf().getLayer(); - List vars = ((LayerVertex) gv).getLayerConf().getVariables(); + if(gv instanceof LayerVertex && ((LayerVertex) gv).getLayerConfiguration() instanceof BatchNormalization){ + BatchNormalization bn = (BatchNormalization) ((LayerVertex) gv).getLayerConfiguration(); + List vars = ((LayerVertex) gv).getNetConfiguration().getNetWideVariables(); boolean isVariance = vars.contains(BatchNormalizationParamInitializer.GLOBAL_VAR); bn.setUseLogStd(!isVariance); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java index 8097111d6..0b6871524 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/JsonMappers.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.conf.serde; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.serde.legacy.LegacyJsonFormat; import com.fasterxml.jackson.databind.*; import com.fasterxml.jackson.databind.deser.BeanDeserializerModifier; @@ -76,8 +76,8 @@ public class JsonMappers { public JsonDeserializer modifyDeserializer(DeserializationConfig config, BeanDescription beanDesc, JsonDeserializer deserializer) { //Use our custom deserializers to handle backward compatibility for updaters -> IUpdater - if (beanDesc.getBeanClass() == MultiLayerConfiguration.class) { - return new MultiLayerConfigurationDeserializer(deserializer); + if (beanDesc.getBeanClass() == NeuralNetConfiguration.class) { + return new NeuralNetConfigurationDeserializer(deserializer); } else if (beanDesc.getBeanClass() == ComputationGraphConfiguration.class) { return new ComputationGraphConfigurationDeserializer(deserializer); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java similarity index 89% rename from cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java index 36f4a9b45..17a474e78 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/MultiLayerConfigurationDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java @@ -21,13 +21,12 @@ package org.deeplearning4j.nn.conf.serde; import org.apache.commons.io.IOUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.weightnoise.DropConnect; import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer; import com.fasterxml.jackson.core.JsonLocation; @@ -43,21 +42,19 @@ import java.io.IOException; import java.io.StringReader; import java.util.List; -public class MultiLayerConfigurationDeserializer extends BaseNetConfigDeserializer { +public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserializer { - public MultiLayerConfigurationDeserializer(JsonDeserializer defaultDeserializer) { - super(defaultDeserializer, MultiLayerConfiguration.class); + public NeuralNetConfigurationDeserializer(JsonDeserializer defaultDeserializer) { + super(defaultDeserializer, NeuralNetConfiguration.class); } @Override - public MultiLayerConfiguration deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException { + public NeuralNetConfiguration deserialize(JsonParser jp, DeserializationContext ctxt) throws IOException { long charOffsetStart = jp.getCurrentLocation().getCharOffset(); - MultiLayerConfiguration conf = (MultiLayerConfiguration) defaultDeserializer.deserialize(jp, ctxt); - Layer[] layers = new Layer[conf.getConfs().size()]; - for (int i = 0; i < layers.length; i++) { - layers[i] = conf.getConf(i).getLayer(); - } + NeuralNetConfiguration conf = (NeuralNetConfiguration) defaultDeserializer.deserialize(jp, ctxt); + + LayerConfiguration[] layers = conf.getFlattenedLayerConfigurations().toArray(new LayerConfiguration[0]); //Now, check if we need to manually handle IUpdater deserialization from legacy format boolean attemptIUpdaterFromLegacy = requiresIUpdaterFromLegacy(layers); @@ -162,11 +159,11 @@ public class MultiLayerConfigurationDeserializer extends BaseNetConfigDeserializ //JSON deserialization uses public BatchNormalization() constructor which defaults to log10stdev now // but, as there is no useLogStdev=false property for legacy batchnorm JSON, the 'real' value (useLogStdev=false) // is not set to override the default, unless we do it manually here - for(NeuralNetConfiguration nnc : conf.getConfs()){ - Layer l = nnc.getLayer(); + for(NeuralNetConfiguration nnc : conf.getNetConfigurations()){ + LayerConfiguration l = nnc.getLayerConfigurations().get(0); if(l instanceof BatchNormalization){ BatchNormalization bn = (BatchNormalization)l; - List vars = nnc.getVariables(); + List vars = nnc.getNetWideVariables(); boolean isVariance = vars.contains(BatchNormalizationParamInitializer.GLOBAL_VAR); bn.setUseLogStd(!isVariance); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java index c654b2698..ceb645be7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/legacy/LegacyJsonFormat.java @@ -61,7 +61,7 @@ public class LegacyJsonFormat { om.addMixIn(InputPreProcessor.class, InputPreProcessorMixin.class); om.addMixIn(GraphVertex.class, GraphVertexMixin.class); - om.addMixIn(Layer.class, LayerMixin.class); + om.addMixIn(LayerConfiguration.class, LayerMixin.class); om.addMixIn(ReconstructionDistribution.class, ReconstructionDistributionMixin.class); om.addMixIn(IActivation.class, IActivationMixin.class); om.addMixIn(ILossFunction.class, ILossFunctionMixin.class); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/DropConnect.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/DropConnect.java index cabb01843..926d2017d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/DropConnect.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/DropConnect.java @@ -78,7 +78,7 @@ public class DropConnect implements IWeightNoise { @Override public INDArray getParameter(Layer layer, String paramKey, int iteration, int epoch, boolean train, LayerWorkspaceMgr workspaceMgr) { - ParamInitializer init = layer.conf().getLayer().initializer(); + ParamInitializer init = layer.getLayerConfiguration().initializer(); INDArray param = layer.getParam(paramKey); double p; @@ -88,8 +88,8 @@ public class DropConnect implements IWeightNoise { p = weightRetainProbSchedule.valueAt(iteration, epoch); } - if (train && init.isWeightParam(layer.conf().getLayer(), paramKey) - || (applyToBiases && init.isBiasParam(layer.conf().getLayer(), paramKey))) { + if (train && init.isWeightParam(layer.getLayerConfiguration(), paramKey) + || (applyToBiases && init.isBiasParam(layer.getLayerConfiguration(), paramKey))) { INDArray out = workspaceMgr.createUninitialized(ArrayType.INPUT, param.dataType(), param.shape(), param.ordering()); Nd4j.getExecutioner().exec(new DropOut(param, out, p)); return out; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java index 0e789749b..fdf01ad66 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/weightnoise/WeightNoise.java @@ -71,10 +71,10 @@ public class WeightNoise implements IWeightNoise { @Override public INDArray getParameter(Layer layer, String paramKey, int iteration, int epoch, boolean train, LayerWorkspaceMgr workspaceMgr) { - ParamInitializer init = layer.conf().getLayer().initializer(); + ParamInitializer init = layer.getLayerConfiguration().initializer(); INDArray param = layer.getParam(paramKey); - if (train && init.isWeightParam(layer.conf().getLayer(), paramKey) || - (applyToBias && init.isBiasParam(layer.conf().getLayer(), paramKey))) { + if (train && init.isWeightParam(layer.getLayerConfiguration(), paramKey) || + (applyToBias && init.isBiasParam(layer.getLayerConfiguration(), paramKey))) { org.nd4j.linalg.api.rng.distribution.Distribution dist = Distributions.createDistribution(distribution); INDArray noise = dist.sample(param.ulike()); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 4a080bb28..34d9b8c50 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -25,12 +25,13 @@ import lombok.NonNull; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import lombok.val; -import net.brutex.ai.dnn.api.INeuralNetwork; +import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.networks.ArtificialNeuralNetwork; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.bytedeco.javacpp.Pointer; import org.deeplearning4j.exception.DL4JInvalidConfigException; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.util.*; import org.nd4j.adapters.OutputAdapter; import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator; @@ -105,8 +106,7 @@ import java.util.*; import java.util.concurrent.atomic.AtomicLong; @Slf4j -public class ComputationGraph extends ArtificialNeuralNetwork implements Serializable, Model, - INeuralNetwork { +public class ComputationGraph extends ArtificialNeuralNetwork implements Serializable { /** * This method returns configuration of this ComputationGraph @@ -220,6 +220,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali public ComputationGraph(ComputationGraphConfiguration computationGraphConfiguration) { + super(computationGraphConfiguration.getDefaultConfiguration()); this.computationGraphConfiguration = computationGraphConfiguration; this.numInputArrays = computationGraphConfiguration.getNetworkInputs().size(); this.numOutputArrays = computationGraphConfiguration.getNetworkOutputs().size(); @@ -543,7 +544,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali //Set RNG seed, for repeatability between initializations when set if (initializeParams) { - Nd4j.getRandom().setSeed(conf().getSeed()); + Nd4j.getRandom().setSeed(getNetConfiguration().getSeed()); } //Given the topological ordering: work out the subset of the parameters array used for each layer @@ -564,8 +565,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali int numLayers = 0; List tempLayerList = new ArrayList<>(); - defaultConfiguration.clearVariables(); - List variables = defaultConfiguration.variables(false); + defaultConfiguration.clearNetWideVariable(); + List variables = defaultConfiguration.netWideVariables(false); i = computationGraphConfiguration.getNetworkInputs().size(); for(; i layerVariables = l.conf().variables(); + List layerVariables = l.getNetConfiguration().netWideVariables(); if (layerVariables != null) { for (String s : layerVariables) { variables.add(gv.getVertexName() + "_" + s); @@ -689,7 +690,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali // now we init solver & optimizer if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); solver.initOptimizer(); } } @@ -710,7 +711,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } for(Layer l : layers){ - String layerName = l.conf().getLayer().getLayerName(); + String layerName = l.getLayerConfiguration().getLayerName(); List inputs = computationGraphConfiguration.getVertexInputs().get(layerName); String in = inputs.get(0); //For now: layers should have exactly 1 input @@ -1158,7 +1159,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } else { if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); } } @@ -2381,8 +2382,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali //Standard feed-forward case if(i > 0 && current.hasLayer() && prev.hasLayer() && - ConvolutionUtils.layerHasConvolutionLayout(prev.getLayer().conf().getLayer()) - && ConvolutionUtils.layerHasConvolutionLayout(current.getLayer().conf().getLayer())) { + ConvolutionUtils.layerHasConvolutionLayout(prev.getLayer().getLayerConfiguration()) + && ConvolutionUtils.layerHasConvolutionLayout(current.getLayer().getLayerConfiguration())) { /** * Not QUITE the proper fix, but getting close. @@ -2390,8 +2391,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali * Need to play with output sizes a bit to make sure we put the right parameters in there to get * correct behavior. */ - CNN2DFormat preLayerFormat = ConvolutionUtils.getFormatForLayer(prev.getLayer().conf().getLayer()); - CNN2DFormat currLayerFormat = ConvolutionUtils.getFormatForLayer(current.getLayer().conf().getLayer()); + CNN2DFormat preLayerFormat = ConvolutionUtils.getFormatForLayer(prev.getLayer().getLayerConfiguration()); + CNN2DFormat currLayerFormat = ConvolutionUtils.getFormatForLayer(current.getLayer().getLayerConfiguration()); if(preLayerFormat != currLayerFormat) { int inputIdx = -1; for(int inputVertex = 0; inputVertex < current.getInputVertices().length; inputVertex++) { @@ -2417,10 +2418,10 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali else out = current.doForward(train, workspaceMgr); } else if(i > 0 && current.hasLayer() && prev.hasLayer() && - Convolution1DUtils.hasRnnDataFormat(prev.getLayer().conf().getLayer()) - && Convolution1DUtils.hasRnnDataFormat(current.getLayer().conf().getLayer())) { - RNNFormat preLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(prev.getLayer().conf().getLayer()); - RNNFormat currLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(current.getLayer().conf().getLayer()); + Convolution1DUtils.hasRnnDataFormat(prev.getLayer().getLayerConfiguration()) + && Convolution1DUtils.hasRnnDataFormat(current.getLayer().getLayerConfiguration())) { + RNNFormat preLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(prev.getLayer().getLayerConfiguration()); + RNNFormat currLayerFormat = Convolution1DUtils.getRnnFormatFromLayer(current.getLayer().getLayerConfiguration()); int inputIdx = -1; for(int inputVertex = 0; inputVertex < current.getInputVertices().length; inputVertex++) { if(current.getInputVertices()[inputVertex].getVertexIndex() == prev.getVertexIndex()) { @@ -2923,7 +2924,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali init(); for (Layer l : layers) { - l.setListeners(listeners); + l.setListeners(listeners.toArray(new TrainingListener[]{})); } if (solver != null) { @@ -2936,6 +2937,28 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } } + /** + * The param table + * + * @return + */ + + public Map getParamTable() { + return null; + } + + /** + * Table of parameters by key, for backprop. For many models (dense layers, etc) - all parameters + * are backprop parameters + * + * @param backpropParamsOnly If true, return backprop params only. If false: return all params + * (equivalent to paramsTable()) + */ + + public Map getParamTable(boolean backpropParamsOnly) { + return null; + } + /** * Set the trainingListeners for the ComputationGraph (and all layers in the network) */ @@ -2994,7 +3017,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali */ public ComputationGraphUpdater getUpdater(boolean initializeIfAbsent){ if (solver == null && initializeIfAbsent) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this)); } if(solver != null) { @@ -3008,7 +3031,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali */ public void setUpdater(ComputationGraphUpdater updater) { if (solver == null) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); } solver.getOptimizer().setUpdaterComputationGraph(updater); } @@ -3399,14 +3422,10 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } @Override - public NeuralNetConfiguration conf() { + public NeuralNetConfiguration getNetConfiguration() { return defaultConfiguration; } - @Override - public void setConf(NeuralNetConfiguration conf) { - throw new UnsupportedOperationException(); - } @Override public INDArray input() { @@ -3434,16 +3453,11 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } - @Override - public Map paramTable() { - return paramTable(false); - } - public Map paramTable(boolean backpropParamsOnly) { //Get all parameters from all layers/vertices Map allParams = new LinkedHashMap<>(); for(GraphVertex gv : vertices){ - Map paramMap = gv.paramTable(backpropParamsOnly); + Map paramMap = gv.getParamTable(backpropParamsOnly); for (Map.Entry entry : paramMap.entrySet()) { String newKey = gv.getVertexName() + "_" + entry.getKey(); allParams.put(newKey, entry.getValue()); @@ -3452,11 +3466,11 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali return allParams; } - @Override + public void setParamTable(@NonNull Map paramTable) { - Map m = paramTable(); + Map m = getParamTable(); Preconditions.checkArgument(paramTable.keySet().equals(m.keySet()), "Cannot set param table: parameter set keys are not equal"); - Map current = paramTable(); + Map current = getParamTable(); //Check shapes before doing partial assigment to avoid leaving net in incorrect state for(String s : current.keySet()){ INDArray arrCurrent = current.get(s); @@ -3580,7 +3594,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali * @return Hidden state, or null if layer is not an RNN layer */ public Map rnnGetPreviousState(int layer) { - return rnnGetPreviousState(layers[layer].conf().getLayer().getLayerName()); + return rnnGetPreviousState(layers[layer].getLayerConfiguration().getLayerName()); } /** @@ -3613,7 +3627,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying(); } if (l instanceof RecurrentLayer) { - states.put(l.conf().getLayer().getLayerName(), ((RecurrentLayer) l).rnnGetPreviousState()); + states.put(l.getLayerConfiguration().getLayerName(), ((RecurrentLayer) l).rnnGetPreviousState()); } } return states; @@ -3626,7 +3640,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali * @param state The state to set the specified layer to */ public void rnnSetPreviousState(int layer, Map state) { - rnnSetPreviousState(layers[layer].conf().getLayer().getLayerName(), state); + rnnSetPreviousState(layers[layer].getLayerConfiguration().getLayerName(), state); } /** @@ -3729,7 +3743,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) .build(); } } @@ -3975,7 +3989,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali Layer outputLayer = getOutputLayer(0); if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), Evaluation.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation(labelsList, topN))[0]; @@ -3993,7 +4007,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali public T evaluate(MultiDataSetIterator iterator, List labelsList, int topN) { Layer outputLayer = getOutputLayer(0); if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), Evaluation.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), Evaluation.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.Evaluation(labelsList, topN))[0]; } @@ -4058,7 +4072,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali public T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(0); if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), ROC.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0]; } @@ -4081,7 +4095,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali public T evaluateROC(MultiDataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(0); if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROC.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), ROC.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0]; } @@ -4104,7 +4118,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali public T evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(0); if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), ROCMultiClass.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0]; } @@ -4119,7 +4133,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali public T evaluateROCMultiClass(MultiDataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(0); if(this.getComputationGraphConfiguration().isValidateOutputLayerConfig()){ - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), ROCMultiClass.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), ROCMultiClass.class); } return (T)doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0]; } @@ -4396,19 +4410,19 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali paramShape = ""; if (currentLayer instanceof BidirectionalLayer) { // Bidirectional layer is not an FFL BidirectionalLayer bi = (BidirectionalLayer) currentLayer; - in = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNIn()); - out = String.valueOf(((Bidirectional)bi.conf().getLayer()).getNOut()); + in = String.valueOf(((Bidirectional)bi.getLayerConfiguration()).getNIn()); + out = String.valueOf(((Bidirectional)bi.getLayerConfiguration()).getNOut()); } else { try { - in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn()); - out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut()); + in = String.valueOf(((FeedForwardLayer) currentLayer.getLayerConfiguration()).getNIn()); + out = String.valueOf(((FeedForwardLayer) currentLayer.getLayerConfiguration()).getNOut()); } catch (Exception e) { // Some layers, like PReLU, are just BaseLayers (but have parameters) } } - List paraNames = currentLayer.conf().variables(); + List paraNames = currentLayer.getNetConfiguration().netWideVariables(); for (String aP : paraNames) { - String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape()); + String paramS = ArrayUtils.toString(currentLayer.getParamTable().get(aP).shape()); paramShape += aP + ":" + paramS + ", "; } paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString(); @@ -4738,7 +4752,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali throw new IllegalArgumentException("Invalid layer index: " + layer + ". ILayer index must be between 0 and " + (layers.length - 1) + " inclusive"); } - return layerSize(layers[layer].conf().getLayer().getLayerName()); + return layerSize(layers[layer].getLayerConfiguration().getLayerName()); } /** @@ -4757,7 +4771,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali throw new IllegalArgumentException("Invalid layer index: " + layer + ". ILayer index must be between 0 and " + (layers.length - 1) + " inclusive"); } - return layerInputSize(layers[layer].conf().getLayer().getLayerName()); + return layerInputSize(layers[layer].getLayerConfiguration().getLayerName()); } /** @@ -4775,7 +4789,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali if(l == null){ throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists"); } - org.deeplearning4j.nn.conf.layers.Layer conf = l.conf().getLayer(); + LayerConfiguration conf = l.getLayerConfiguration(); if (conf == null || !(conf instanceof FeedForwardLayer)) { return 0; } @@ -4800,7 +4814,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali if(l == null){ throw new IllegalArgumentException("No layer with name \"" + layerName + "\" exists"); } - org.deeplearning4j.nn.conf.layers.Layer conf = l.conf().getLayer(); + LayerConfiguration conf = l.getLayerConfiguration(); if (conf == null || !(conf instanceof FeedForwardLayer)) { return 0; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java index cdb124d75..759f214bc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java @@ -38,6 +38,16 @@ public abstract class BaseGraphVertex implements GraphVertex { protected ComputationGraph graph; + public BaseGraphVertex(){}; + @Override + public Map getParamTable() { + return null; + } + + public void setParamTable(Map params) { + throw new RuntimeException("Not implemented."); + } + protected String vertexName; /** The index of this vertex */ @@ -197,7 +207,7 @@ public abstract class BaseGraphVertex implements GraphVertex { } @Override - public Map paramTable(boolean backpropOnly) { + public Map getParamTable(boolean backpropOnly) { return Collections.emptyMap(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java index 949ee0f7e..0d2a3a26d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java @@ -179,8 +179,8 @@ public abstract class BaseWrapperVertex implements GraphVertex { } @Override - public Map paramTable(boolean backpropOnly) { - return underlying.paramTable(backpropOnly); + public Map getParamTable(boolean backpropOnly) { + return underlying.getParamTable(backpropOnly); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java index 61136e0db..96ac34c19 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java @@ -156,5 +156,5 @@ public interface GraphVertex extends Trainable, Serializable { * @param backpropOnly If true: exclude unsupervised training parameters * @return Parameter table */ - Map paramTable(boolean backpropOnly); + Map getParamTable(boolean backpropOnly); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java index 77107c6ee..c0b5999ac 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java @@ -20,6 +20,7 @@ package org.deeplearning4j.nn.graph.vertex.impl; +import java.util.Map; import lombok.AllArgsConstructor; import lombok.EqualsAndHashCode; import org.deeplearning4j.nn.api.TrainingConfig; @@ -46,4 +47,24 @@ public class FrozenVertex extends BaseWrapperVertex { } return config; } + + /** + * The param table + * + * @return + */ + @Override + public Map getParamTable() { + return null; + } + + /** + * Setter for the param table + * + * @param paramTable + */ + @Override + public void setParamTable(Map paramTable) { + + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java index 60f3dad0b..5f9ebdb35 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java @@ -85,12 +85,12 @@ public class LayerVertex extends BaseGraphVertex { return; this.layer = new FrozenLayer(this.layer); - this.layer.conf().getLayer().setLayerName(vertexName); + this.layer.getLayerConfiguration().setLayerName(vertexName); } @Override - public Map paramTable(boolean backpropOnly) { - return layer.paramTable(backpropOnly); + public Map getParamTable(boolean backpropOnly) { + return layer.getParamTable(backpropOnly); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index 5c4c8ee16..edaa3fb80 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -20,9 +20,19 @@ package org.deeplearning4j.nn.layers; +import java.lang.ref.Cleaner; +import java.lang.ref.PhantomReference; +import java.lang.ref.Reference; +import java.lang.ref.WeakReference; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; import lombok.AccessLevel; import lombok.Data; +import lombok.Getter; import lombok.NoArgsConstructor; +import lombok.NonNull; import lombok.Setter; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; @@ -30,411 +40,784 @@ import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.common.primitives.Pair; +import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.primitives.Pair; - -import java.util.*; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; /** * A layer with input and output, no parameters or gradients */ @Data @NoArgsConstructor -public abstract class AbstractLayer implements Layer { +public abstract class AbstractLayer implements Layer { - @Setter(AccessLevel.NONE) - protected INDArray input; - protected INDArray preOutput; - protected NeuralNetConfiguration conf; - protected boolean dropoutApplied = false; - protected Collection trainingListeners = new ArrayList<>(); - protected int index = 0; - protected INDArray maskArray; - protected MaskState maskState; - protected CacheMode cacheMode = CacheMode.NONE; - protected boolean inputModificationAllowed = false; - protected DataType dataType; + @Setter(AccessLevel.NONE) + protected INDArray input; + protected INDArray preOutput; + @Getter + @NonNull + protected LayerConf_T layerConfiguration; + protected boolean dropoutApplied = false; + @Getter @Setter @NonNull + protected Collection trainingListeners = new ArrayList<>(); + @Deprecated public Collection getListeners() {return getTrainingListeners();} + @Deprecated public void setListeners(TrainingListener ... listeners) { setTrainingListeners(List.of(listeners));} + /** + * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, + * they will be replaced by this method + * + * @param listeners + */ + @Deprecated + public void setListeners(Collection listeners) { + setTrainingListeners(listeners); + } - protected int iterationCount; - protected int epochCount; - public AbstractLayer(NeuralNetConfiguration conf, DataType dataType) { - this.conf = conf; - if (conf != null) - cacheMode = conf.getCacheMode(); - this.dataType = dataType; + protected int index = 0; + protected INDArray maskArray; + protected MaskState maskState; + protected CacheMode cacheMode = CacheMode.NONE; + protected boolean inputModificationAllowed = false; + protected DataType dataType; + protected int iterationCount; + protected int epochCount; + private List variables = new ArrayList<>(); + public AbstractLayer(LayerConfiguration layerConfiguration, DataType dataType) { + this.layerConfiguration = (LayerConf_T) layerConfiguration; + if (layerConfiguration != null) { + cacheMode = layerConfiguration.getNetConfiguration().getCacheMode(); + } + this.dataType = dataType; + } + + /** + * @param backpropOnly If true: return only parameters that are not exclusively used for layerwise + * pretraining + * @return Parameter table + */ + @Override + public Map getParamTable(boolean backpropOnly) { + return null; + } + + public void setParamTable(Map map) { + throw new RuntimeException("Not implemented."); + } + /** + * @return 1D gradients view array + */ + @Override + public INDArray getGradientsViewArray() { + return null; + } + + /** + * Creates and returns a copy of this object. The precise meaning of "copy" may depend on the + * class of the object. The general intent is that, for any object {@code x}, the expression: + *
+ *
+   * x.clone() != x
+ * will be true, and that the expression: + *
+ *
+   * x.clone().getClass() == x.getClass()
+ * will be {@code true}, but these are not absolute requirements. While it is typically the case + * that: + *
+ *
+   * x.clone().equals(x)
+ * will be {@code true}, this is not an absolute requirement. + *

+ * By convention, the returned object should be obtained by calling {@code super.clone}. If a + * class and all of its superclasses (except {@code Object}) obey this convention, it will be the + * case that {@code x.clone().getClass() == x.getClass()}. + *

+ * By convention, the object returned by this method should be independent of this object (which + * is being cloned). To achieve this independence, it may be necessary to modify one or more + * fields of the object returned by {@code super.clone} before returning it. Typically, this + * means copying any mutable objects that comprise the internal "deep structure" of the object + * being cloned and replacing the references to these objects with references to the copies. If a + * class contains only primitive fields or references to immutable objects, then it is usually the + * case that no fields in the object returned by {@code super.clone} need to be modified. + *

+ * The method {@code clone} for class {@code Object} performs a specific cloning operation. First, + * if the class of this object does not implement the interface {@code Cloneable}, then a + * {@code CloneNotSupportedException} is thrown. Note that all arrays are considered to implement + * the interface {@code Cloneable} and that the return type of the {@code clone} method of an + * array type {@code T[]} is {@code T[]} where T is any reference or primitive type. Otherwise, + * this method creates a new instance of the class of this object and initializes all its fields + * with exactly the contents of the corresponding fields of this object, as if by assignment; the + * contents of the fields are not themselves cloned. Thus, this method performs a "shallow copy" + * of this object, not a "deep copy" operation. + *

+ * The class {@code Object} does not itself implement the interface {@code Cloneable}, so calling + * the {@code clone} method on an object whose class is {@code Object} will result in throwing an + * exception at run time. + * + * @return a clone of this instance. + * @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable} + * interface. Subclasses that override the {@code clone} method + * can also throw this exception to indicate that an instance + * cannot be cloned. + * @see Cloneable + */ + @Override + protected Object clone() throws CloneNotSupportedException { + return super.clone(); + } + + /** + * Called by the garbage collector on an object when garbage collection determines that there are + * no more references to the object. A subclass overrides the {@code finalize} method to dispose + * of system resources or to perform other cleanup. + *

+ * The general contract of {@code finalize} is that it is invoked if and when the Java™ + * virtual machine has determined that there is no longer any means by which this object can be + * accessed by any thread that has not yet died, except as a result of an action taken by the + * finalization of some other object or class which is ready to be finalized. The {@code finalize} + * method may take any action, including making this object available again to other threads; the + * usual purpose of {@code finalize}, however, is to perform cleanup actions before the object is + * irrevocably discarded. For example, the finalize method for an object that represents an + * input/output connection might perform explicit I/O transactions to break the connection before + * the object is permanently discarded. + *

+ * The {@code finalize} method of class {@code Object} performs no special action; it simply + * returns normally. Subclasses of {@code Object} may override this definition. + *

+ * The Java programming language does not guarantee which thread will invoke the {@code finalize} + * method for any given object. It is guaranteed, however, that the thread that invokes finalize + * will not be holding any user-visible synchronization locks when finalize is invoked. If an + * uncaught exception is thrown by the finalize method, the exception is ignored and finalization + * of that object terminates. + *

+ * After the {@code finalize} method has been invoked for an object, no further action is taken + * until the Java virtual machine has again determined that there is no longer any means by which + * this object can be accessed by any thread that has not yet died, including possible actions by + * other objects or classes which are ready to be finalized, at which point the object may be + * discarded. + *

+ * The {@code finalize} method is never invoked more than once by a Java virtual machine for any + * given object. + *

+ * Any exception thrown by the {@code finalize} method causes the finalization of this object to + * be halted, but is otherwise ignored. + * + * @throws Throwable the {@code Exception} raised by this method + * @apiNote Classes that embed non-heap resources have many options for cleanup of those + * resources. The class must ensure that the lifetime of each instance is longer than that of any + * resource it embeds. {@link Reference#reachabilityFence} can be used to ensure that objects + * remain reachable while resources embedded in the object are in use. + *

+ * A subclass should avoid overriding the {@code finalize} method unless the subclass embeds + * non-heap resources that must be cleaned up before the instance is collected. Finalizer + * invocations are not automatically chained, unlike constructors. If a subclass overrides + * {@code finalize} it must invoke the superclass finalizer explicitly. To guard against + * exceptions prematurely terminating the finalize chain, the subclass should use a + * {@code try-finally} block to ensure {@code super.finalize()} is always invoked. For example, + *

{@code      @Override
+   *     protected void finalize() throws Throwable {
+   *         try {
+   *             ... // cleanup subclass state
+   *         } finally {
+   *             super.finalize();
+   *         }
+   *     }
+   * }
+ * @jls 12.6 Finalization of Class Instances + * @see WeakReference + * @see PhantomReference + * @deprecated The finalization mechanism is inherently problematic. Finalization can lead to + * performance issues, deadlocks, and hangs. Errors in finalizers can lead to resource leaks; + * there is no way to cancel finalization if it is no longer necessary; and no ordering is + * specified among calls to {@code finalize} methods of different objects. Furthermore, there are + * no guarantees regarding the timing of finalization. The {@code finalize} method might be called + * on a finalizable object only after an indefinite delay, if at all. + *

+ * Classes whose instances hold non-heap resources should provide a method to enable explicit + * release of those resources, and they should also implement {@link AutoCloseable} if + * appropriate. The {@link Cleaner} and {@link PhantomReference} provide more flexible and + * efficient ways to release resources when an object becomes unreachable. + */ + @Override + protected void finalize() throws Throwable { + super.finalize(); + } + + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return null; + } + + /** + * This method returns Optimizer used for training + * + * @return + */ + @Override + public ConvexOptimizer getOptimizer() { + return null; + } + + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + @Override + public void fit(DataSet dataSet) { + + } + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + @Override + public void fit(MultiDataSet dataSet) { + + } + + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + @Override + public void fit(DataSetIterator iterator) { + + } + + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + @Override + public void fit(MultiDataSetIterator iterator) { + + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(DataSetIterator iterator, T... evaluations) { + return null; + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations) { + return null; + } + + /** + * @param netConfiguration + */ + @Override + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { + + } + + /** + * Init the model + */ + @Override + public void init() { + + } + + /** + * This method ADDS additional TrainingListener to existing listeners + * + * @param listener + */ + @Override + public void addListeners(TrainingListener... listener) { + this.trainingListeners.addAll(List.of(listener)); + } + + /** + * Update layer weights and biases with gradient change + * + * @param gradient + */ + @Override + public void update(Gradient gradient) { + + } + + /** + * Perform one update applying the gradient + * + * @param gradient the gradient to apply + * @param paramType + */ + @Override + public void update(INDArray gradient, String paramType) { + + } + + /** + * Update the score + * + * @param workspaceMgr + */ + @Override + public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { + + } + + /** + * the number of parameters for the model + * + * @param backwards + * @return the number of parameters for the model + */ + @Override + public long numParams(boolean backwards) { + return 0; + } + + /** + * Set the parameters for this model. This expects a linear ndarray which then be unpacked + * internally relative to the expected ordering of the model + * + * @param params the parameters for the model + */ + @Override + public void setParams(INDArray params) { + + } + + /** + * Set the initial parameters array as a view of the full (backprop) network parameters NOTE: this + * is intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users. + * + * @param params a 1 x nParams row vector that is a view of the larger (MLN/CG) parameters array + */ + @Override + public void setParamsViewArray(INDArray params) { + + } + + /** + * Set the gradients array as a view of the full (backprop) network parameters NOTE: this is + * intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users. + * + * @param gradients a 1 x nParams row vector that is a view of the larger (MLN/CG) gradients + * array + */ + @Override + public void setBackpropGradientsViewArray(INDArray gradients) { + + } + + /** + * The current inputs batch size + * + * @return the current inputs batch size + */ + @Override + public int batchSize() { + return 0; + } + + /** + * The input/feature matrix for the model + * + * @return the input/feature matrix for the model + */ + @Override + public INDArray input() { + return null; + } + + /** + * Get a parameter array for a given parameter type key + * + * @param param the key of the parameter + * @return ndarray of parameters + */ + @Override + public INDArray getParam(String param) { + return null; + } + + + /** + * The param table + * + * @return + */ + @Override + public Map getParamTable() { + return null; + } + + /** + * Set the parameters for a given parameter type. + * + * @param key the param type key to set + * @param val the new parameters ndarray + */ + @Override + public void setParam(String key, INDArray val) { + + } + + /** + * + */ + @Override + public void close() { + + } + + /** + * Calculate the gradient relative to the error in the next layer + * + * @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where + * C is cost function a=sigma(z) is activation. + * @param workspaceMgr Workspace manager + * @return Pair where Gradient is gradient for this layer, INDArray is epsilon + * (activation gradient) needed by next layer, but before element-wise multiply by sigmaPrime(z). + * So for standard feed-forward layer, if this layer is L, then return.getSecond() == dL/dIn = + * (w^(L)*(delta^(L))^T)^T. Note that the returned array should be placed in the + * {@link ArrayType#ACTIVATION_GRAD} workspace via the workspace manager + */ + @Override + public Pair backpropGradient(INDArray epsilon, + LayerWorkspaceMgr workspaceMgr) { + return null; + } + + /** + * Perform forward pass and return the activations array with the last set input + * + * @param training training or test mode + * @param workspaceMgr Workspace manager + * @return the activation (layer output) of the last specified input. Note that the returned array + * should be placed in the {@link ArrayType#ACTIVATIONS} workspace via the workspace manager + */ + @Override + public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { + return null; + } + + /** + * Returns true if the layer can be trained in an unsupervised/pretrain manner (AE, VAE, etc) + * + * @return true if the layer can be pretrained (using fit(INDArray), false otherwise + */ + @Override + public boolean isPretrainLayer() { + return false; + } + + /** + * + */ + @Override + public void clearNoiseWeightParams() { + + } + + public List variables() { + return variables; + } + + public List variables(boolean copy) { + if (copy) { + return variables(); + } + return variables; + } + + /** + * The configuration for the neural network + * + * @return the configuration for the neural network + */ + @Override + public NeuralNetConfiguration getNetConfiguration() { + return layerConfiguration.getNetConfiguration(); + } + + public void addVariable(String variable) { + if (!variables.contains(variable)) { + variables.add(variable); + } + } + + /** + * Return the configuration of this layer + * + * @return the configuration + */ + @Override + public LayerConfiguration getLayerConfiguration() { + return layerConf(); + } + + public void setLayerConfiguration(LayerConfiguration layerConfiguration) { + this.layerConfiguration = (LayerConf_T) layerConfiguration; + } + + @Override + public void setCacheMode(CacheMode mode) { + if (mode == null) { + mode = CacheMode.NONE; } - @Override - public void setCacheMode(CacheMode mode) { - if (mode == null) - mode = CacheMode.NONE; + this.cacheMode = mode; + } - this.cacheMode = mode; + public LayerConf_T layerConf() { + return this.layerConfiguration; + } + + @Override + public TrainingConfig getConfig() { + return layerConfiguration; + } + + protected String layerId() { + String name = this.layerConfiguration.getLayerName(); + return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + index + + ", layer type: " + + getClass().getSimpleName() + ")"; + } + + public INDArray getInput() { + return input; + } + + public int getEpochCount() { + return epochCount; + } + + public void setEpochCount(int epochCount) { + this.epochCount = epochCount; + } + + @Override + public void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr) { + this.input = workspaceMgr.leverageTo(ArrayType.INPUT, input); + dropoutApplied = false; + } + + @Override + public int getIndex() { + return index; + } + + @Override + public void setIndex(int index) { + this.index = index; + } + + /** + * Returns the parameters of the neural network as a flattened row vector + * + * @return the parameters of the neural network + */ + @Override + public INDArray params() { + return null; + } + + protected void setParams(INDArray params, char order) { + throw new UnsupportedOperationException("Not supported"); + } + + /** + * @return Number of parameters + */ + @Override + public long numParams() { + return 0; + } + + protected void applyMask(INDArray to) { + to.muliColumnVector(maskArray.castTo(to.dataType())); + } + + @Override + public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { + setInput(input, workspaceMgr); + return activate(training, workspaceMgr); + } + + @Override + public double calcRegularizationScore(boolean backpropParamsOnly) { + return 0.0; + } + + + @Deprecated + public void clear() { + input = null; + maskArray = null; + maskState = null; + if (layerConf().getIDropout() != null) { + layerConf().getIDropout().clear(); } + } - public LayerConfT layerConf() { - return (LayerConfT) this.conf.getLayer(); + protected void applyDropOutIfNecessary(boolean training, LayerWorkspaceMgr workspaceMgr) { + if (training && !dropoutApplied && layerConf().getIDropout() != null) { + INDArray result; + if (inputModificationAllowed) { + result = input; + } else { + result = workspaceMgr.createUninitialized(ArrayType.INPUT, input.dataType(), input.shape(), + input.ordering()); + } + + input = layerConf().getIDropout() + .applyDropout(input, result, getIterationCount(), getEpochCount(), workspaceMgr); + dropoutApplied = true; } + } - @Override - public TrainingConfig getConfig(){ - return conf.getLayer(); + protected INDArray backpropDropOutIfPresent(INDArray epsilon) { + if (layerConf().getIDropout() != null) { + layerConf().getIDropout().backprop(epsilon, epsilon, getIterationCount(), getEpochCount()); } + return epsilon; + } - protected String layerId() { - String name = this.conf().getLayer().getLayerName(); - return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + index + ", layer type: " + - getClass().getSimpleName() + ")"; + + @Override + public Type type() { + return Type.FEED_FORWARD; + } + + + public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) { + throw new UnsupportedOperationException("Not supported"); + } + + + public Pair gradientAndScore() { + return new Pair<>(gradient(), score()); + } + + @Override + public int getInputMiniBatchSize() { + return (int) input.size(0); + } + + @Override + public void setInputMiniBatchSize(int size) { + } + + @Override + public INDArray getMaskArray() { + return maskArray; + } + + @Override + public void setMaskArray(INDArray maskArray) { + this.maskArray = maskArray; + } + + @Override + public Pair feedForwardMaskArray(INDArray maskArray, + MaskState currentMaskState, int minibatchSize) { + //Most layers: CNN, dense, activation, etc - set mask array, mask state and then leave the mask unmodified + + this.maskArray = maskArray; + this.maskState = currentMaskState; + + return new Pair<>(maskArray, currentMaskState); + } + + + public Gradient gradient() { + throw new UnsupportedOperationException( + "Not supported for this layer, or should be overridden for layers requiring it"); + } + + + public void fit() { + throw new UnsupportedOperationException( + "Not supported for this layer, or should be overridden for layers requiring it"); + } + + + public double score() { + throw new UnsupportedOperationException( + "Not supported for this layer, or should be overridden for layers requiring it"); + } + + + public void applyConstraints(int iteration, int epoch) { + if (layerConf().getConstraints() != null) { + for (LayerConstraint lc : layerConf().getConstraints()) { + lc.applyConstraint(this, iteration, epoch); + } } + } - public INDArray getInput() { - return input; + public void assertInputSet(boolean backprop) { + if (input == null) { + if (backprop) { + throw new IllegalStateException( + "Cannot perform backprop in layer " + getClass().getSimpleName() + + ": layer input field is not set"); + } else { + throw new IllegalStateException( + "Cannot perform forward pass in layer " + getClass().getSimpleName() + + ": layer input field is not set"); + } } + } - public int getEpochCount() { - return epochCount; - } + @Override + public void allowInputModification(boolean allow) { + inputModificationAllowed = allow; + } - public void setEpochCount(int epochCount) { - this.epochCount = epochCount; - } + @Override + public LayerHelper getHelper() { + //Layers with helpers should override this method! + return null; + } - /** - * Init the model - */ - @Override - public void init() { + @Override + public boolean updaterDivideByMinibatch(String paramName) { + //Majority of params's gradients should be... Exception: batch norm mean/variance estimate + return true; + } - } - - @Override - public void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr) { - this.input = workspaceMgr.leverageTo(ArrayType.INPUT, input); - dropoutApplied = false; - } - - @Override - public int getIndex() { - return index; - } - - @Override - public void setIndex(int index) { - this.index = index; - } - - - @Override - public Collection getListeners() { - return trainingListeners; - } - - @Override - public void setListeners(Collection listeners) { - this.trainingListeners = listeners != null ? listeners : new ArrayList(); - } - - /** - * This method ADDS additional TrainingListener to existing listeners - * - * @param listeners - */ - @Override - public void addListeners(TrainingListener... listeners) { - if (this.trainingListeners == null) { - setListeners(listeners); - return; - } - - Collections.addAll(trainingListeners, listeners); - } - - @Override - public void setListeners(TrainingListener... listeners) { - setListeners(Arrays.asList(listeners)); - } - - @Override - public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { - throw new UnsupportedOperationException("Not supported"); - } - - @Override - public void update(Gradient gradient) { - throw new UnsupportedOperationException(); - } - - @Override - public void update(INDArray gradient, String paramType) { - throw new UnsupportedOperationException(); - } - - - @Override - public ConvexOptimizer getOptimizer() { - throw new UnsupportedOperationException("Not supported"); - } - - @Override - public void setConf(NeuralNetConfiguration conf) { - this.conf = conf; - } - - /**Returns the parameters of the neural network as a flattened row vector - * @return the parameters of the neural network - */ - @Override - public INDArray params() { - return null; - } - - @Override - public INDArray getParam(String param) { - throw new UnsupportedOperationException("Not supported"); - } - - @Override - public void setParam(String key, INDArray val) { - throw new UnsupportedOperationException("Not supported"); - } - - @Override - public void setParams(INDArray params) { - if (params != null) { - throw new UnsupportedOperationException("Not supported"); - } - } - - protected void setParams(INDArray params, char order) { - throw new UnsupportedOperationException("Not supported"); - } - - @Override - public void setParamsViewArray(INDArray params) { - if (params != null) { - throw new UnsupportedOperationException("Not supported"); - } - } - - @Override - public INDArray getGradientsViewArray() { - return null; - } - - @Override - public void setBackpropGradientsViewArray(INDArray gradients) { - if (gradients != null) { - throw new UnsupportedOperationException("Not supported"); - } - } - - @Override - public void setParamTable(Map paramTable) { - if (paramTable != null && !paramTable.isEmpty()) { - throw new UnsupportedOperationException("Not supported"); - } - } - - @Override - public Map paramTable() { - return paramTable(false); - } - - @Override - public Map paramTable(boolean backpropParamsOnly) { - return Collections.emptyMap(); - } - - protected void applyMask(INDArray to) { - to.muliColumnVector(maskArray.castTo(to.dataType())); - } - - @Override - public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { - setInput(input, workspaceMgr); - return activate(training, workspaceMgr); - } - - @Override - public double calcRegularizationScore(boolean backpropParamsOnly){ - return 0.0; - } - - @Override - public int batchSize() { - return (int) input.size(0); - } - - @Override - public NeuralNetConfiguration conf() { - return conf; - } - - - @Override - public void clear() { - input = null; - maskArray = null; - maskState = null; - if(layerConf().getIDropout() != null){ - layerConf().getIDropout().clear(); - } - } - - protected void applyDropOutIfNecessary(boolean training, LayerWorkspaceMgr workspaceMgr){ - if(training && !dropoutApplied && layerConf().getIDropout() != null ){ - INDArray result; - if(inputModificationAllowed){ - result = input; - } else { - result = workspaceMgr.createUninitialized(ArrayType.INPUT, input.dataType(), input.shape(), input.ordering()); - } - - input = layerConf().getIDropout().applyDropout(input, result, getIterationCount(), getEpochCount(), workspaceMgr); - dropoutApplied = true; - } - } - - protected INDArray backpropDropOutIfPresent(INDArray epsilon){ - if(layerConf().getIDropout() != null ){ - layerConf().getIDropout().backprop(epsilon, epsilon, getIterationCount(), getEpochCount()); - } - return epsilon; - } - - - @Override - public Type type() { - return Type.FEED_FORWARD; - } - - /** - * The number of parameters for the model - * - * @return the number of parameters for the model - */ - @Override - public long numParams() { - return 0; - } - - @Override - public long numParams(boolean backwards) { - return numParams(); - } - - @Override - public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) { - throw new UnsupportedOperationException("Not supported"); - } - - - @Override - public Pair gradientAndScore() { - return new Pair<>(gradient(), score()); - } - - @Override - public INDArray input() { - return input; - } - - @Override - public void setInputMiniBatchSize(int size) {} - - @Override - public int getInputMiniBatchSize() { - return (int) input.size(0); - } - - @Override - public void setMaskArray(INDArray maskArray) { - this.maskArray = maskArray; - } - - @Override - public INDArray getMaskArray() { - return maskArray; - } - - - @Override - public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { - //Most layers: CNN, dense, activation, etc - set mask array, mask state and then leave the mask unmodified - - this.maskArray = maskArray; - this.maskState = currentMaskState; - - return new Pair<>(maskArray, currentMaskState); - } - - - @Override - public Gradient gradient() { - throw new UnsupportedOperationException( - "Not supported for this layer, or should be overridden for layers requiring it"); - } - - @Override - public void fit() { - throw new UnsupportedOperationException( - "Not supported for this layer, or should be overridden for layers requiring it"); - } - - @Override - public double score() { - throw new UnsupportedOperationException( - "Not supported for this layer, or should be overridden for layers requiring it"); - } - - - @Override - public void applyConstraints(int iteration, int epoch){ - if(layerConf().getConstraints() != null){ - for(LayerConstraint lc : layerConf().getConstraints()){ - lc.applyConstraint(this, iteration, epoch); - } - } - } - - public void assertInputSet(boolean backprop){ - if(input == null){ - if(backprop){ - throw new IllegalStateException("Cannot perform backprop in layer " + getClass().getSimpleName() - + ": layer input field is not set"); - } else { - throw new IllegalStateException("Cannot perform forward pass in layer " + getClass().getSimpleName() - + ": layer input field is not set"); - } - } - } - - @Override - public void allowInputModification(boolean allow){ - inputModificationAllowed = allow; - } - - @Override - public LayerHelper getHelper() { - //Layers with helpers should override this method! - return null; - } - - @Override - public boolean updaterDivideByMinibatch(String paramName) { - //Majority of params's gradients should be... Exception: batch norm mean/variance estimate - return true; - } - - @Override - public void close(){ - //No-op for individual layers - } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java index f83b1cf31..7043275a0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java @@ -22,6 +22,7 @@ package org.deeplearning4j.nn.layers; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.nd4j.linalg.api.buffer.DataType; @@ -33,7 +34,7 @@ import org.deeplearning4j.nn.workspace.ArrayType; public class ActivationLayer extends AbstractLayer { - public ActivationLayer(NeuralNetConfiguration conf, DataType dataType) { + public ActivationLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java index ed1176133..68de26b7c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java @@ -20,10 +20,21 @@ package org.deeplearning4j.nn.layers; +import java.lang.reflect.Constructor; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.DefaultParamInitializer; @@ -31,421 +42,650 @@ import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.Solver; import org.deeplearning4j.optimize.api.ConvexOptimizer; +import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.common.primitives.Pair; +import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.regularization.Regularization; -import org.nd4j.common.primitives.Pair; - -import java.lang.reflect.Constructor; -import java.util.*; /** * A layer with parameters + * * @author Adam Gibson */ @Slf4j public abstract class BaseLayer - extends AbstractLayer { + extends AbstractLayer { - protected INDArray paramsFlattened; - protected INDArray gradientsFlattened; - protected Map params; - protected transient Map gradientViews; - protected double score = 0.0; - protected ConvexOptimizer optimizer; - protected Gradient gradient; - protected Solver solver; + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(DataSetIterator iterator, T... evaluations) { + return null; + } - protected Map weightNoiseParams = new HashMap<>(); + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations) { + return null; + } - public BaseLayer(NeuralNetConfiguration conf, DataType dataType) { - super(conf, dataType); + /** + * @param netConfiguration + */ + @Override + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { + + } + + /** + * Init the model + */ + @Override + public void init() { + + } + + /** + * This method ADDS additional TrainingListener to existing listeners + * + * @param listener + */ + @Override + public void addListeners(TrainingListener... listener) { + + } + + /** + * Update layer weights and biases with gradient change + * + * @param gradient + */ + @Override + public void update(Gradient gradient) { + + } + + /** + * Perform one update applying the gradient + * + * @param gradient the gradient to apply + * @param paramType + */ + @Override + public void update(INDArray gradient, String paramType) { + + } + + /** + * the number of parameters for the model + * + * @param backwards + * @return the number of parameters for the model + */ + @Override + public long numParams(boolean backwards) { + return 0; + } + + /** + * Set the parameters for this model. This expects a linear ndarray which then be unpacked + * internally relative to the expected ordering of the model + * + * @param params the parameters for the model + */ + @Override + public void setParams(INDArray params) { + + } + + /** + * The current inputs batch size + * + * @return the current inputs batch size + */ + @Override + public int batchSize() { + return 0; + } + + /** + * The input/feature matrix for the model + * + * @return the input/feature matrix for the model + */ + @Override + public INDArray input() { + return null; + } + + /** + * Get a parameter array for a given parameter type key + * + * @param param the key of the parameter + * @return ndarray of parameters + */ + @Override + public INDArray getParam(String param) { + return null; + } + + /** + * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, + * they will be replaced by this method + * + * @param listeners + */ + @Override + public void setListeners(TrainingListener... listeners) { + + } + + /** + * Set the parameters for a given parameter type. + * + * @param key the param type key to set + * @param val the new parameters ndarray + */ + @Override + public void setParam(String key, INDArray val) { + + } + + /** + * + */ + @Override + public void close() { + + } + + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + @Override + public void fit(DataSet dataSet) { + + } + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + @Override + public void fit(MultiDataSet dataSet) { + + } + + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + @Override + public void fit(DataSetIterator iterator) { + + } + + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + @Override + public void fit(MultiDataSetIterator iterator) { + + } + + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return null; + } + + protected double score = 0.0; + protected ConvexOptimizer optimizer; + protected Gradient gradient; + protected Solver solver; + protected Map weightNoiseParams = new HashMap<>(); + protected INDArray paramsFlattened; + protected INDArray gradientsFlattened; + /** + * Full table of parameters + */ + protected Map paramsTable; + @Getter protected transient Map gradientViews; + + public BaseLayer(LayerConfiguration conf, DataType dataType) { + super(conf, dataType); + } + + + /** + * and others even use \epsilon (epsilon) + * http://web.cs.swarthmore.edu/~meeden/cs81/s10/BackPropDeriv.pdf + * + * @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where + * C is cost function a=sigma(z) is activation. + * @param workspaceMgr Workspace manager + * @return + */ + @Override + public Pair backpropGradient(INDArray epsilon, + LayerWorkspaceMgr workspaceMgr) { + assertInputSet(true); + //If this layer is layer L, then epsilon is (w^(L+1)*(d^(L+1))^T) (or equivalent) + Pair zAndPreNorm = preOutputWithPreNorm(true, true, workspaceMgr); + INDArray z = zAndPreNorm.getFirst(); //Note: using preOutput(INDArray) can't be used as this does a setInput(input) and resets the 'appliedDropout' flag + INDArray preNorm = zAndPreNorm.getSecond(); + INDArray delta = layerConf().getActivationFn().backprop(z, epsilon) + .getFirst(); //TODO handle activation function params + + if (maskArray != null) { + applyMask(delta); } - public LayerConfT layerConf() { - return (LayerConfT) this.conf.getLayer(); + Gradient ret = new DefaultGradient(); + + if (hasBias()) { + INDArray biasGrad = gradientViews.get(DefaultParamInitializer.BIAS_KEY); + delta.sum(biasGrad, 0); //biasGrad is initialized/zeroed first + ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad); } - @Override - public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { - assertInputSet(true); - //If this layer is layer L, then epsilon is (w^(L+1)*(d^(L+1))^T) (or equivalent) - Pair zAndPreNorm = preOutputWithPreNorm(true, true, workspaceMgr); - INDArray z = zAndPreNorm.getFirst(); //Note: using preOutput(INDArray) can't be used as this does a setInput(input) and resets the 'appliedDropout' flag - INDArray preNorm = zAndPreNorm.getSecond(); - INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //TODO handle activation function params + INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr); - if (maskArray != null) { - applyMask(delta); + INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, + delta.dataType(), new long[]{W.size(0), delta.size(0)}, 'f'); + if (hasLayerNorm()) { + INDArray g = getParam(DefaultParamInitializer.GAIN_KEY); + + INDArray dldg = gradientViews.get(DefaultParamInitializer.GAIN_KEY); + Nd4j.getExecutioner().exec(new LayerNormBp(preNorm, g, delta, delta, dldg, true, 1)); + ret.gradientForVariable().put(DefaultParamInitializer.GAIN_KEY, dldg); + + } + + epsilonNext = W.mmuli(delta.transpose(), epsilonNext) + .transpose(); //W.mmul(delta.transpose()).transpose(); + + INDArray weightGrad = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); //f order + Nd4j.gemm(input.castTo(weightGrad.dataType()), delta, weightGrad, true, false, 1.0, + 0.0); //TODO avoid castTo? + ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad); + + weightNoiseParams.clear(); + + epsilonNext = backpropDropOutIfPresent(epsilonNext); + return new Pair<>(ret, epsilonNext); + } + + + public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { + if (this.input == null) { + log.warn("There is no input for this layer '{}'", layerConfiguration); + return; + } + INDArray output = activate(true, workspaceMgr); + setScoreWithZ(output); + } + + + protected void setScoreWithZ(INDArray z) { + } + + /** + * Objective function: the specified objective + * + * @return the score for the objective + */ + + @Override + public double score() { + return score; + } + + @Override + public Gradient gradient() { + return gradient; + } + + + @Override + public ConvexOptimizer getOptimizer() { + if (optimizer == null) { + Solver solver = new Solver.Builder().model(this).configure(getNetConfiguration()).build(); + this.optimizer = solver.getOptimizer(); + } + return optimizer; + } + + /** + * Returns the parameters of the neural network as a flattened row vector + * + * @return the parameters of the neural network + */ + @Override + public INDArray params() { + return paramsFlattened; + } + + + public void setParamsTable(INDArray paramsTable) { + if (paramsTable == paramsFlattened) { + return; //no op + } + setParams(paramsTable, 'f'); + } + + protected void setParams(INDArray params, char order) { + List parameterList = layerConfiguration.getVariables(); //netWideVariables(); + int length = 0; + for (String s : parameterList) { + length += getParam(s).length(); + } + if (params.length() != length) { + throw new IllegalArgumentException("Unable to set parameters: must be of length " + length + + ", got params of length " + params.length() + " - " + layerId()); + } + int idx = 0; + Set paramKeySet = this.getParamTable().keySet(); + for (String s : paramKeySet) { + INDArray param = getParam(s); + INDArray get = params.get(NDArrayIndex.point(0), + NDArrayIndex.interval(idx, idx + param.length())); + if (param.length() != get.length()) { + throw new IllegalStateException( + "Parameter " + s + " should have been of length " + param.length() + + " but was " + get.length() + " - " + layerId()); } + param.assign(get.reshape(order, + param.shape())); //Use assign due to backprop params being a view of a larger array + idx += param.length(); + } + } - Gradient ret = new DefaultGradient(); + @Override + public void setParamsViewArray(INDArray params) { + if (this.paramsTable != null && params.length() != numParams()) { + throw new IllegalArgumentException("Invalid input: expect params of length " + numParams() + + ", got params of length " + params.length() + " - " + layerId()); + } - if(hasBias()){ - INDArray biasGrad = gradientViews.get(DefaultParamInitializer.BIAS_KEY); - delta.sum(biasGrad, 0); //biasGrad is initialized/zeroed first - ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad); + this.paramsFlattened = params; + } + + @Override + public INDArray getGradientsViewArray() { + return gradientsFlattened; + } + + @Override + public void setBackpropGradientsViewArray(INDArray gradients) { + if (this.paramsTable != null && gradients.length() != numParams()) { + throw new IllegalArgumentException( + "Invalid input: expect gradients array of length " + numParams(true) + + ", got array of length " + gradients.length() + " - " + layerId()); + } + + this.gradientsFlattened = gradients; + this.gradientViews = layerConfiguration.initializer() + .getGradientsFromFlattened(layerConfiguration, gradients); + } + + /** + * Get the parameter, after applying any weight noise (such as DropConnect) if necessary. Note + * that during training, this will store the post-noise parameters, as these should be used for + * both forward pass and backprop, for a single iteration. Consequently, the parameters (post + * noise) should be cleared after each training iteration + * + * @param param Parameter key + * @param training If true: during training + * @return The parameter, after applying any noise + */ + protected INDArray getParamWithNoise(String param, boolean training, + LayerWorkspaceMgr workspaceMgr) { + INDArray p; + if (layerConf().getWeightNoise() != null) { + if (training && weightNoiseParams.size() > 0 && weightNoiseParams.containsKey(param)) { + //Re-use these weights for both forward pass and backprop - don't want to use 2 different params here + //These should be cleared during backprop + return weightNoiseParams.get(param); + } else { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + p = layerConf().getWeightNoise() + .getParameter(this, param, getIterationCount(), getEpochCount(), training, + workspaceMgr); } + } - INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, true, workspaceMgr); - - INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, delta.dataType(), new long[]{W.size(0), delta.size(0)}, 'f'); - if(hasLayerNorm()) { - INDArray g = getParam(DefaultParamInitializer.GAIN_KEY); - - INDArray dldg = gradientViews.get(DefaultParamInitializer.GAIN_KEY); - Nd4j.getExecutioner().exec(new LayerNormBp(preNorm, g, delta, delta, dldg, true, 1)); - ret.gradientForVariable().put(DefaultParamInitializer.GAIN_KEY, dldg); - - } - - epsilonNext = W.mmuli(delta.transpose(),epsilonNext).transpose(); //W.mmul(delta.transpose()).transpose(); - - INDArray weightGrad = gradientViews.get(DefaultParamInitializer.WEIGHT_KEY); //f order - Nd4j.gemm(input.castTo(weightGrad.dataType()), delta, weightGrad, true, false, 1.0, 0.0); //TODO avoid castTo? - ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad); - - weightNoiseParams.clear(); - - epsilonNext = backpropDropOutIfPresent(epsilonNext); - return new Pair<>(ret, epsilonNext); + if (training) { + //Store for re-use in backprop + weightNoiseParams.put(param, p); + } + } else { + return getParam(param); } - public void fit() { - throw new UnsupportedOperationException("Not supported"); + return p; + } + + protected INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) { + return preOutputWithPreNorm(training, false, workspaceMgr).getFirst(); + } + + protected Pair preOutputWithPreNorm(boolean training, boolean forBackprop, + LayerWorkspaceMgr workspaceMgr) { + assertInputSet(forBackprop); + applyDropOutIfNecessary(training, workspaceMgr); + INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr); + INDArray b = getParamWithNoise(DefaultParamInitializer.BIAS_KEY, training, workspaceMgr); + INDArray g = (hasLayerNorm() ? getParam(DefaultParamInitializer.GAIN_KEY) : null); + + INDArray input = this.input.castTo(dataType); + + //Input validation: + if (input.rank() != 2 || input.columns() != W.rows()) { + if (input.rank() != 2) { + throw new DL4JInvalidInputException( + "Input that is not a matrix; expected matrix (rank 2), got rank " + + input.rank() + " array with shape " + Arrays.toString(input.shape()) + + ". Missing preprocessor or wrong input type? " + layerId()); + } + throw new DL4JInvalidInputException( + "Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape()) + + ") is invalid: does not match layer input size (layer # inputs = " + + W.size(0) + ") " + layerId()); } - @Override - public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { - if (this.input == null) - return; + INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, W.dataType(), + input.size(0), W.size(1)); + input.castTo(ret.dataType()).mmuli(W, + ret); //TODO Can we avoid this cast? (It sohuld be a no op if not required, however) - INDArray output = activate(true, workspaceMgr); - setScoreWithZ(output); + INDArray preNorm = ret; + if (hasLayerNorm()) { + preNorm = (forBackprop ? ret.dup(ret.ordering()) : ret); + Nd4j.getExecutioner().exec(new LayerNorm(preNorm, g, ret, true, 1)); } - - protected void setScoreWithZ(INDArray z) {} - - /** - * Objective function: the specified objective - * @return the score for the objective - */ - - @Override - public double score() { - return score; + if (hasBias()) { + ret.addiRowVector(b); } - @Override - public Gradient gradient() { - return gradient; + if (maskArray != null) { + applyMask(ret); } - @Override - public void update(Gradient gradient) { - for (String paramType : gradient.gradientForVariable().keySet()) { - update(gradient.getGradientFor(paramType), paramType); - } + return new Pair<>(ret, preNorm); + } + + @Override + public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { + INDArray z = preOutput(training, workspaceMgr); + INDArray ret = layerConf().getActivationFn().getActivation(z, training); + + if (maskArray != null) { + applyMask(ret); } - @Override - public void update(INDArray gradient, String paramType) { - setParam(paramType, getParam(paramType).addi(gradient)); + return ret; + } + + @Override + public double calcRegularizationScore(boolean backpropParamsOnly) { + double scoreSum = 0.0; + for (Map.Entry e : paramsTable.entrySet()) { + List l = layerConf().getRegularizationByParam(e.getKey()); + if (l == null || l.isEmpty()) { + continue; + } + for (Regularization r : l) { + scoreSum += r.score(e.getValue(), getIterationCount(), getEpochCount()); + } } + return scoreSum; + } + + @Override + public Layer clone() { + Layer layer = null; + try { + Constructor c = getClass().getConstructor(NeuralNetConfiguration.class); + layer = (Layer) c.newInstance(layerConfiguration); + Map linkedTable = new LinkedHashMap<>(); + for (Map.Entry entry : paramsTable.entrySet()) { + linkedTable.put(entry.getKey(), entry.getValue().dup()); + } + layer.setParamTable(linkedTable); + } catch (Exception e) { + log.error("", e); + } + + return layer; + + } + + + @Override + public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) { + if (input != null) { + setInput(input, workspaceMgr); + applyDropOutIfNecessary(true, workspaceMgr); + } + if (solver == null) { + solver = new Solver.Builder().model(this).configure(getNetConfiguration()) + .listeners(getTrainingListeners()).build(); + } + this.optimizer = solver.getOptimizer(); + solver.optimize(workspaceMgr); + } @Override - public ConvexOptimizer getOptimizer() { - if (optimizer == null) { - Solver solver = new Solver.Builder().model(this).configure(conf()).build(); - this.optimizer = solver.getOptimizer(); - } - return optimizer; - } + public String toString() { + return getClass().getName() + "{" + "conf=" + layerConfiguration + ", score=" + score + + ", optimizer=" + optimizer + ", listeners=" + trainingListeners + '}'; + } - /**Returns the parameters of the neural network as a flattened row vector - * @return the parameters of the neural network - */ - @Override - public INDArray params() { - return paramsFlattened; - } + @Override + public void clear() { + super.clear(); + weightNoiseParams.clear(); + } - @Override - public INDArray getParam(String param) { - return params.get(param); - } + @Override + public void clearNoiseWeightParams() { + weightNoiseParams.clear(); + } - @Override - public void setParam(String key, INDArray val) { - if (params.containsKey(key)) - params.get(key).assign(val); - else - params.put(key, val); - } + /** + * Does this layer have no bias term? Many layers (dense, convolutional, output, embedding) have + * biases by default, but no-bias versions are possible via configuration + * + * @return True if a bias term is present, false otherwise + */ + public boolean hasBias() { + //Overridden by layers supporting no bias mode: dense, output, convolutional, embedding + return true; + } - @Override - public void setParams(INDArray params) { - if (params == paramsFlattened) - return; //no op - setParams(params, 'f'); - } + /** + * Does this layer support and is it enabled layer normalization? Only Dense and SimpleRNN Layers + * support layer normalization. + * + * @return True if layer normalization is enabled on this layer, false otherwise + */ + public boolean hasLayerNorm() { + // Overridden by layers supporting layer normalization. + return false; + } - protected void setParams(INDArray params, char order) { - List parameterList = conf.variables(); - int length = 0; - for (String s : parameterList) - length += getParam(s).length(); - if (params.length() != length) - throw new IllegalArgumentException("Unable to set parameters: must be of length " + length - + ", got params of length " + params.length() + " - " + layerId()); - int idx = 0; - Set paramKeySet = this.params.keySet(); - for (String s : paramKeySet) { - INDArray param = getParam(s); - INDArray get = params.get(NDArrayIndex.point(0), NDArrayIndex.interval(idx, idx + param.length())); - if (param.length() != get.length()) - throw new IllegalStateException("Parameter " + s + " should have been of length " + param.length() - + " but was " + get.length() + " - " + layerId()); - param.assign(get.reshape(order, param.shape())); //Use assign due to backprop params being a view of a larger array - idx += param.length(); - } - } + /** + * The number of parameters (all types) for the model + * + * @return the number of parameters for the model + */ + public long numParams() { + int ret = 0; + for (INDArray val : paramsTable.values()) { + ret += val.length(); + } + return ret; + } - @Override - public void setParamsViewArray(INDArray params) { - if (this.params != null && params.length() != numParams()) - throw new IllegalArgumentException("Invalid input: expect params of length " + numParams() - + ", got params of length " + params.length() + " - " + layerId()); + /** + * Return a map of all parameters in the network. Parameter names are as described in + * {@link #getParam(String)}. As per {@link #getParam(String)} the returned arrays are views - + * modifications to these will impact the underlying network parameters + * + * @return A map of all parameters in the network + */ + @Override + public Map getParamTable() { + return getParamTable(false); + } - this.paramsFlattened = params; - } + /** + * Set the full table of parameters (of all types) + * + * @param paramTable ndarray parameters table + */ + @Override + public void setParamTable(@NonNull Map paramTable) { + this.paramsTable = paramTable; + } - @Override - public INDArray getGradientsViewArray() { - return gradientsFlattened; - } - - @Override - public void setBackpropGradientsViewArray(INDArray gradients) { - if (this.params != null && gradients.length() != numParams()) - throw new IllegalArgumentException("Invalid input: expect gradients array of length " + numParams(true) - + ", got array of length " + gradients.length() + " - " + layerId()); - - this.gradientsFlattened = gradients; - this.gradientViews = conf.getLayer().initializer().getGradientsFromFlattened(conf, gradients); - } - - @Override - public void setParamTable(Map paramTable) { - this.params = paramTable; - } - - @Override - public Map paramTable() { - return paramTable(false); - } - - @Override - public Map paramTable(boolean backpropParamsOnly) { - return params; - } - - /** - * Get the parameter, after applying any weight noise (such as DropConnect) if necessary. - * Note that during training, this will store the post-noise parameters, as these should be used - * for both forward pass and backprop, for a single iteration. - * Consequently, the parameters (post noise) should be cleared after each training iteration - * - * @param param Parameter key - * @param training If true: during training - * @return The parameter, after applying any noise - */ - protected INDArray getParamWithNoise(String param, boolean training, LayerWorkspaceMgr workspaceMgr){ - INDArray p; - if(layerConf().getWeightNoise() != null){ - if(training && weightNoiseParams.size() > 0 && weightNoiseParams.containsKey(param) ){ - //Re-use these weights for both forward pass and backprop - don't want to use 2 different params here - //These should be cleared during backprop - return weightNoiseParams.get(param); - } else { - try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - p = layerConf().getWeightNoise().getParameter(this, param, getIterationCount(), getEpochCount(), training, workspaceMgr); - } - } - - if(training){ - //Store for re-use in backprop - weightNoiseParams.put(param, p); - } - } else { - return getParam(param); - } - - return p; - } - - protected INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) { - return preOutputWithPreNorm(training, false, workspaceMgr).getFirst(); - } - - protected Pair preOutputWithPreNorm(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { - assertInputSet(forBackprop); - applyDropOutIfNecessary(training, workspaceMgr); - INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr); - INDArray b = getParamWithNoise(DefaultParamInitializer.BIAS_KEY, training, workspaceMgr); - INDArray g = (hasLayerNorm() ? getParam(DefaultParamInitializer.GAIN_KEY) : null); - - INDArray input = this.input.castTo(dataType); - - //Input validation: - if (input.rank() != 2 || input.columns() != W.rows()) { - if (input.rank() != 2) { - throw new DL4JInvalidInputException("Input that is not a matrix; expected matrix (rank 2), got rank " - + input.rank() + " array with shape " + Arrays.toString(input.shape()) - + ". Missing preprocessor or wrong input type? " + layerId()); - } - throw new DL4JInvalidInputException( - "Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape()) - + ") is invalid: does not match layer input size (layer # inputs = " - + W.size(0) + ") " + layerId()); - } - - - INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, W.dataType(), input.size(0), W.size(1)); - input.castTo(ret.dataType()).mmuli(W, ret); //TODO Can we avoid this cast? (It sohuld be a no op if not required, however) - - INDArray preNorm = ret; - if(hasLayerNorm()){ - preNorm = (forBackprop ? ret.dup(ret.ordering()) : ret); - Nd4j.getExecutioner().exec(new LayerNorm(preNorm, g, ret, true, 1)); - } - - if(hasBias()){ - ret.addiRowVector(b); - } - - if (maskArray != null) { - applyMask(ret); - } - - return new Pair<>(ret, preNorm); - } - - @Override - public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { - INDArray z = preOutput(training, workspaceMgr); - INDArray ret = layerConf().getActivationFn().getActivation(z, training); - - if (maskArray != null) { - applyMask(ret); - } - - return ret; - } - - @Override - public double calcRegularizationScore(boolean backpropParamsOnly){ - double scoreSum = 0.0; - for (Map.Entry e : paramTable().entrySet()) { - List l = layerConf().getRegularizationByParam(e.getKey()); - if(l == null || l.isEmpty()){ - continue; - } - for(Regularization r : l){ - scoreSum += r.score(e.getValue(), getIterationCount(), getEpochCount()); - } - } - return scoreSum; - } - - @Override - public Layer clone() { - Layer layer = null; - try { - Constructor c = getClass().getConstructor(NeuralNetConfiguration.class); - layer = (Layer) c.newInstance(conf); - Map linkedTable = new LinkedHashMap<>(); - for (Map.Entry entry : params.entrySet()) { - linkedTable.put(entry.getKey(), entry.getValue().dup()); - } - layer.setParamTable(linkedTable); - } catch (Exception e) { - log.error("",e); - } - - return layer; - - } - - /** - * The number of parameters for the model - * - * @return the number of parameters for the model - */ - @Override - public long numParams() { - int ret = 0; - for (INDArray val : params.values()) - ret += val.length(); - return ret; - } - - @Override - public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) { - if (input != null) { - setInput(input, workspaceMgr); - applyDropOutIfNecessary(true, workspaceMgr); - } - if (solver == null) { - solver = new Solver.Builder().model(this).configure(conf()).listeners(getListeners()).build(); - } - this.optimizer = solver.getOptimizer(); - solver.optimize(workspaceMgr); - } - - @Override - public String toString() { - return getClass().getName() + "{" + "conf=" + conf + ", score=" + score - + ", optimizer=" + optimizer + ", listeners=" + trainingListeners + '}'; - } - - @Override - public void clear(){ - super.clear(); - weightNoiseParams.clear(); - } - - @Override - public void clearNoiseWeightParams(){ - weightNoiseParams.clear(); - } - - /** - * Does this layer have no bias term? Many layers (dense, convolutional, output, embedding) have biases by - * default, but no-bias versions are possible via configuration - * - * @return True if a bias term is present, false otherwise - */ - public boolean hasBias(){ - //Overridden by layers supporting no bias mode: dense, output, convolutional, embedding - return true; - } - - /** - * Does this layer support and is it enabled layer normalization? Only Dense and SimpleRNN Layers support - * layer normalization. - * - * @return True if layer normalization is enabled on this layer, false otherwise - */ - public boolean hasLayerNorm(){ - // Overridden by layers supporting layer normalization. - return false; - } + @Override + public Map getParamTable(boolean backpropParamsOnly) { + return paramsTable; + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java index 1f317eee6..b06c9a3ed 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseOutputLayer.java @@ -23,6 +23,7 @@ package org.deeplearning4j.nn.layers; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.DefaultParamInitializer; @@ -58,7 +59,7 @@ public abstract class BaseOutputLayer { - public BasePretrainNetwork(NeuralNetConfiguration conf, DataType dataType) { + public BasePretrainNetwork(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -119,52 +120,47 @@ public abstract class BasePretrainNetwork paramTable(boolean backpropParamsOnly) { + public Map getParamTable(boolean backpropParamsOnly) { if (!backpropParamsOnly) - return params; + return getParamTable(); Map map = new LinkedHashMap<>(); - map.put(PretrainParamInitializer.WEIGHT_KEY, params.get(PretrainParamInitializer.WEIGHT_KEY)); - map.put(PretrainParamInitializer.BIAS_KEY, params.get(PretrainParamInitializer.BIAS_KEY)); + map.put(PretrainParamInitializer.WEIGHT_KEY, super.getParamTable().get(PretrainParamInitializer.WEIGHT_KEY)); + map.put(PretrainParamInitializer.BIAS_KEY, super.getParamTable().get(PretrainParamInitializer.BIAS_KEY)); return map; } - - public INDArray params() { - return paramsFlattened; - } - /**The number of parameters for the model, for backprop (i.e., excluding visible bias) * @return the number of parameters for the model (ex. visible bias) */ public long numParams() { int ret = 0; - for (Map.Entry entry : params.entrySet()) { + for (Map.Entry entry : getParamTable().entrySet()) { ret += entry.getValue().length(); } return ret; } @Override - public void setParams(INDArray params) { - if (params == paramsFlattened) + public void setParamsTable(INDArray paramsTable) { + if (paramsTable == paramsFlattened) return; //No op //SetParams has two different uses: during pretrain vs. backprop. //pretrain = 3 sets of params (inc. visible bias); backprop = 2 - List parameterList = conf.variables(); + List parameterList = layerConfiguration.getVariables(); long paramLength = 0; for (String s : parameterList) { val len = getParam(s).length(); paramLength += len; } - if (params.length() != paramLength) { + if (paramsTable.length() != paramLength) { throw new IllegalArgumentException("Unable to set parameters: must be of length " + paramLength - + ", got params of length " + params.length() + " " + layerId()); + + ", got params of length " + paramsTable.length() + " " + layerId()); } // Set for backprop and only W & hb - paramsFlattened.assign(params); + paramsFlattened.assign(paramsTable); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/DropoutLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/DropoutLayer.java index 743186706..c89165431 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/DropoutLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/DropoutLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.layers; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.nd4j.linalg.api.buffer.DataType; @@ -31,7 +32,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; public class DropoutLayer extends BaseLayer { - public DropoutLayer(NeuralNetConfiguration conf, DataType dataType) { + public DropoutLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayer.java index 1e6c60add..75bf8ae01 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayer.java @@ -49,8 +49,8 @@ public class FrozenLayer extends BaseWrapperLayer { throw new IllegalArgumentException("Output Layers are not allowed to be frozen " + layerId()); } this.zeroGradient = new DefaultGradient(insideLayer.params()); - if (insideLayer.paramTable() != null) { - for (String paramType : insideLayer.paramTable().keySet()) { + if (insideLayer.getParamTable() != null) { + for (String paramType : insideLayer.getParamTable().keySet()) { //save memory?? zeroGradient.setGradientFor(paramType, null); } @@ -63,7 +63,7 @@ public class FrozenLayer extends BaseWrapperLayer { } protected String layerId() { - String name = underlying.conf().getLayer().getLayerName(); + String name = underlying.getLayerConfiguration().getLayerName(); return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + underlying.getIndex() + ")"; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java index 918a21a4a..425ec454f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java @@ -46,7 +46,7 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { } protected String layerId() { - String name = underlying.conf().getLayer().getLayerName(); + String name = underlying.getLayerConfiguration().getLayerName(); return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + underlying.getIndex() + ")"; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java index e53fc6619..e13a06219 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java @@ -24,6 +24,7 @@ package org.deeplearning4j.nn.layers; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.optimize.Solver; @@ -54,7 +55,7 @@ public class LossLayer extends BaseLayer { - public OutputLayer(NeuralNetConfiguration conf, DataType dataType) { + public OutputLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java index 84dd1fd1f..c36bf9366 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/RepeatVector.java @@ -25,6 +25,7 @@ import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.workspace.ArrayType; @@ -38,7 +39,7 @@ import java.util.Arrays; public class RepeatVector extends AbstractLayer { - public RepeatVector(NeuralNetConfiguration conf, DataType dataType) { + public RepeatVector(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java index b63978b10..2fcdcd17a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cnn3DLossLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Convolution3D; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -49,7 +50,19 @@ public class Cnn3DLossLayer extends BaseLayer { private final int[] cropping; //[padTop, padBottom] - public Cropping1DLayer(NeuralNetConfiguration conf, DataType dataType) { + public Cropping1DLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); - this.cropping = ((org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D) conf.getLayer()).getCropping(); + this.cropping = layerConfiguration.getCropping(); } @Override @@ -79,7 +80,8 @@ public class Cropping1DLayer extends AbstractLayer { @Override public Layer clone() { - return new Cropping2DLayer(conf.clone(), dataType); + + return new Cropping2DLayer(layerConfiguration.clone(), dataType); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java index 3d6beac05..d72d2f3eb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java @@ -24,6 +24,7 @@ import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -40,11 +41,12 @@ public class Cropping2DLayer extends AbstractLayer { - public Deconvolution3DLayer(NeuralNetConfiguration conf, DataType dataType) { + public Deconvolution3DLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -135,7 +136,7 @@ public class Deconvolution3DLayer extends BaseLayer { boolean ncdhw = layerConf().getDataFormat() == Convolution3D.DataFormat.NCDHW; int chDim = ncdhw ? 1 : 4; if (input.size(chDim) != layerConf().getNIn() ) { - String layerName = conf.getLayer().getLayerName(); + String layerName = getLayerConfiguration().getLayerName(); if (layerName == null) layerName = "(not named)"; throw new DL4JInvalidInputException("Cannot do forward pass in Deconvolution3D layer (layer name = " + layerName diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java index c63aeb3f9..888875129 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.DepthwiseConvolutionParamInitializer; @@ -45,7 +46,7 @@ import java.util.Arrays; public class DepthwiseConvolution2DLayer extends ConvolutionLayer { - public DepthwiseConvolution2DLayer(NeuralNetConfiguration conf, DataType dataType) { + public DepthwiseConvolution2DLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -152,7 +153,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { getParamWithNoise(DepthwiseConvolutionParamInitializer.WEIGHT_KEY, training, workspaceMgr); if (input.rank() != 4) { - String layerName = conf.getLayer().getLayerName(); + String layerName = layerConfiguration.getLayerName(); if (layerName == null) layerName = "(not named)"; throw new DL4JInvalidInputException("Got rank " + input.rank() @@ -174,7 +175,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { long outDepth = depthMultiplier * inDepth; if (input.size(nchw ? 1 : 3) != inDepth) { - String layerName = conf.getLayer().getLayerName(); + String layerName = layerConfiguration.getLayerName(); if (layerName == null) layerName = "(not named)"; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java index d5d0ebf0f..d205017bf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java @@ -26,6 +26,7 @@ import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.ConvolutionParamInitializer; @@ -46,7 +47,7 @@ import java.util.Arrays; public class SeparableConvolution2DLayer extends ConvolutionLayer { - public SeparableConvolution2DLayer(NeuralNetConfiguration conf, DataType dataType) { + public SeparableConvolution2DLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -176,7 +177,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { int wIdx = 3; if (input.rank() != 4) { - String layerName = conf.getLayer().getLayerName(); + String layerName = getLayerConfiguration().getLayerName(); if (layerName == null) layerName = "(not named)"; throw new DL4JInvalidInputException("Got rank " + input.rank() @@ -193,7 +194,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { long outDepth = pointWiseWeights.size(0); if (input.size(chIdx) != inDepth) { - String layerName = conf.getLayer().getLayerName(); + String layerName = getLayerConfiguration().getLayerName(); if (layerName == null) layerName = "(not named)"; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java index 6abd39baa..fb824dfa3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java @@ -24,6 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -42,7 +43,7 @@ import java.util.Arrays; @Slf4j public class SpaceToBatch extends AbstractLayer { - public SpaceToBatch(NeuralNetConfiguration conf, DataType dataType) { + public SpaceToBatch(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java index aa0fd2ebb..32a74d80d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepth.java @@ -24,6 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -43,7 +44,7 @@ import java.util.Arrays; @Slf4j public class SpaceToDepth extends AbstractLayer { - public SpaceToDepth(NeuralNetConfiguration conf, DataType dataType) { + public SpaceToDepth(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java index 386c312e6..f6bfcb3cf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ZeroPadding1DLayer.java @@ -23,6 +23,7 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -38,9 +39,9 @@ public class ZeroPadding1DLayer extends AbstractLayer { - public ZeroPaddingLayer(NeuralNetConfiguration conf, DataType dataType) { + public ZeroPaddingLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -115,7 +116,7 @@ public class ZeroPaddingLayer extends AbstractLayer { - public Upsampling2D(NeuralNetConfiguration conf, DataType dataType) { + public Upsampling2D(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling3D.java index 0df9431c7..9fa3e6365 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling3D.java @@ -24,6 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -44,7 +45,7 @@ import java.util.Arrays; public class Upsampling3D extends AbstractLayer { - public Upsampling3D(NeuralNetConfiguration conf, DataType dataType) { + public Upsampling3D(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java index fcf624932..8a715f100 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/PReLU.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.layers.feedforward; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -38,7 +39,7 @@ public class PReLU extends BaseLayer { - public AutoEncoder(NeuralNetConfiguration conf, DataType dataType) { + public AutoEncoder(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java index d2aa10406..362fb8db8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.layers.feedforward.dense; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.layers.BaseLayer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,7 +32,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; */ public class DenseLayer extends BaseLayer { - public DenseLayer(NeuralNetConfiguration conf, DataType dataType) { + public DenseLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/elementwise/ElementWiseMultiplicationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/elementwise/ElementWiseMultiplicationLayer.java index b65172652..f799b7373 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/elementwise/ElementWiseMultiplicationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/elementwise/ElementWiseMultiplicationLayer.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.layers.feedforward.elementwise; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.params.ElementWiseParamInitializer; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -38,7 +39,7 @@ import java.util.Arrays; public class ElementWiseMultiplicationLayer extends BaseLayer { - public ElementWiseMultiplicationLayer(NeuralNetConfiguration conf, DataType dataType){ + public ElementWiseMultiplicationLayer(LayerConfiguration conf, DataType dataType){ super(conf, dataType); } @@ -68,7 +69,7 @@ public class ElementWiseMultiplicationLayer extends BaseLayer { private static final int[] DIM_1 = new int[]{1}; - public EmbeddingLayer(NeuralNetConfiguration conf, DataType dataType) { + public EmbeddingLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java index 762407264..6760caa1a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingSequenceLayer.java @@ -25,6 +25,7 @@ import lombok.val; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -45,7 +46,7 @@ import static org.nd4j.linalg.api.shape.Shape.hasDefaultStridesForShape; public class EmbeddingSequenceLayer extends BaseLayer { private static final int[] WEIGHT_DIM = new int[]{1}; - public EmbeddingSequenceLayer(NeuralNetConfiguration conf, DataType dataType) { + public EmbeddingSequenceLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java index a8803eda2..c5b159fb9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNLSTMHelper.java @@ -118,7 +118,7 @@ public class MKLDNNLSTMHelper implements LSTMHelper { if(prevMemCellState != null) args.add(prevMemCellState); - IActivation a = ((LSTM)conf.getLayer()).getActivationFn(); + IActivation a = ((LSTM)layer.getLayerConfiguration()).getActivationFn(); DynamicCustomOp op = DynamicCustomOp.builder("lstmLayer") .addInputs(args.toArray(new INDArray[0])) diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java index 6d4b65b0b..75bd4866d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/BatchNormalization.java @@ -25,6 +25,7 @@ import lombok.val; import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -47,7 +48,6 @@ import org.nd4j.linalg.exception.ND4JOpProfilerException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.common.primitives.Pair; -import org.nd4j.common.util.OneTimeLogger; import java.util.*; @@ -63,7 +63,7 @@ public class BatchNormalization extends BaseLayer getListeners() { - return listeners; - } @Override public void setListeners(TrainingListener... listeners) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java index 9bfa02687..957a56632 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/normalization/LocalResponseNormalization.java @@ -26,6 +26,7 @@ import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -54,11 +55,12 @@ public class LocalResponseNormalization @Override public Layer clone() { - return new LocalResponseNormalization(conf.clone(), dataType); + return new LocalResponseNormalization(getLayerConfiguration().clone(), dataType); } - public LocalResponseNormalization(NeuralNetConfiguration conf, DataType dataType) { + public LocalResponseNormalization(LayerConfiguration conf, DataType dataType) { super(conf, dataType); + layerConfiguration = (org.deeplearning4j.nn.conf.layers.LocalResponseNormalization) conf; initializeHelper(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java index 69016dca4..e5f0fbf1e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java @@ -25,6 +25,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -68,7 +69,7 @@ public class Yolo2OutputLayer extends AbstractLayer pair = getGradientsAndDelta(preOutput2d(true, workspaceMgr), workspaceMgr); //Returns Gradient and delta^(this), not Gradient and epsilon^(this-1) //150 - long inputShape = (( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) this.getConf().getLayer()).getNIn(); + long inputShape = (( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) getLayerConfiguration()).getNIn(); INDArray delta = pair.getSecond(); //4 x 150 INDArray epsilonNext = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), new long[]{inputShape, delta.length()}, 'f'); @@ -125,7 +126,7 @@ public class OCNNOutputLayer extends BaseOutputLayer paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { return PARAM_KEYS; } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return WEIGHT_KEYS; } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return WEIGHT_KEYS.contains(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return false; } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) conf.getLayer(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) conf; Map params = Collections.synchronizedMap(new LinkedHashMap()); val nIn = ocnnOutputLayer.getNIn(); int hiddenLayer = ocnnOutputLayer.getHiddenSize(); @@ -133,8 +127,8 @@ public class OCNNParamInitializer extends DefaultParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { - org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) conf.getLayer(); + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { + org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) conf; Map params = Collections.synchronizedMap(new LinkedHashMap()); val nIn = ocnnOutputLayer.getNIn(); val hiddenLayer = ocnnOutputLayer.getHiddenSize(); @@ -155,11 +149,11 @@ public class OCNNParamInitializer extends DefaultParamInitializer { } - protected INDArray createWeightMatrix(NeuralNetConfiguration configuration, + protected INDArray createWeightMatrix(LayerConfiguration configuration, INDArray weightParamView, boolean initializeParameters) { - org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) configuration.getLayer(); + org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) configuration; IWeightInit weightInit = ocnnOutputLayer.getWeightInitFn(); if (initializeParameters) { INDArray ret = weightInit.init(weightParamView.size(0), //Fan in diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingLayer.java index e8b8ae9a3..b342f5032 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/pooling/GlobalPoolingLayer.java @@ -20,11 +20,13 @@ package org.deeplearning4j.nn.layers.pooling; +import java.util.Map; import lombok.val; import org.apache.commons.lang3.ArrayUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.PoolingType; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -54,11 +56,11 @@ public class GlobalPoolingLayer extends AbstractLayer paramTable) { + throw new RuntimeException("Not implemented."); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java index 6b02b6c6e..1216c9d8c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/BaseRecurrentLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.buffer.DataType; @@ -50,7 +51,7 @@ public abstract class BaseRecurrentLayer gradientViews; @@ -65,14 +75,25 @@ public class BidirectionalLayer implements RecurrentLayer { private INDArray outFwd; private INDArray outBwd; - public BidirectionalLayer(@NonNull NeuralNetConfiguration conf, @NonNull Layer fwd, @NonNull Layer bwd, @NonNull INDArray paramsView) { - this.conf = conf; + public BidirectionalLayer(@NonNull LayerConfiguration conf, @NonNull Layer fwd, @NonNull Layer bwd, @NonNull INDArray paramsView) { + this.layerConfiguration = conf; + this.conf = conf.getNetConfiguration(); this.fwd = fwd; this.bwd = bwd; - this.layerConf = (Bidirectional) conf.getLayer(); + this.layerConf = (Bidirectional) layerConfiguration; this.paramsView = paramsView; } + /** + * Return the configuration of this layer + * + * @return the configuration + */ + @Override + public LayerConfiguration getLayerConfiguration() { + return layerConf; + } + private RNNFormat getRNNDataFormat(){ return layerConf.getRNNDataFormat(); } @@ -283,7 +304,7 @@ public class BidirectionalLayer implements RecurrentLayer { @Override public TrainingConfig getConfig() { - return conf.getLayer(); + return layerConfiguration; } @Override @@ -349,13 +370,23 @@ public class BidirectionalLayer implements RecurrentLayer { } @Override - public NeuralNetConfiguration conf() { + public NeuralNetConfiguration getNetConfiguration() { return conf; } + /** + * @param netConfiguration + */ @Override - public void setConf(NeuralNetConfiguration conf) { - this.conf = conf; + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { + + } + + + public void setLayerConfiguration(LayerConfiguration layerConfiguration) { + this.layerConfiguration = layerConfiguration; + this.layerConf = (Bidirectional) layerConfiguration; + this.conf = layerConfiguration.getNetConfiguration(); } @Override @@ -363,11 +394,86 @@ public class BidirectionalLayer implements RecurrentLayer { return input; } + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return null; + } + @Override public ConvexOptimizer getOptimizer() { return null; } + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + @Override + public void fit(DataSet dataSet) { + + } + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + @Override + public void fit(MultiDataSet dataSet) { + + } + + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + @Override + public void fit(DataSetIterator iterator) { + + } + + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + @Override + public void fit(MultiDataSetIterator iterator) { + + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(DataSetIterator iterator, T... evaluations) { + return null; + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, + T... evaluations) { + return null; + } + @Override public INDArray getParam(String param) { String sub = param.substring(1); @@ -379,17 +485,17 @@ public class BidirectionalLayer implements RecurrentLayer { } @Override - public Map paramTable() { - return paramTable(false); + public Map getParamTable() { + return getParamTable(false); } @Override - public Map paramTable(boolean backpropParamsOnly) { + public Map getParamTable(boolean backpropParamsOnly) { Map m = new LinkedHashMap<>(); - for(Map.Entry e : fwd.paramTable(backpropParamsOnly).entrySet()){ + for(Map.Entry e : fwd.getParamTable(backpropParamsOnly).entrySet()){ m.put(BidirectionalParamInitializer.FORWARD_PREFIX + e.getKey(), e.getValue()); } - for(Map.Entry e : bwd.paramTable(backpropParamsOnly).entrySet()){ + for(Map.Entry e : bwd.getParamTable(backpropParamsOnly).entrySet()){ m.put(BidirectionalParamInitializer.BACKWARD_PREFIX + e.getKey(), e.getValue()); } return m; @@ -442,10 +548,9 @@ public class BidirectionalLayer implements RecurrentLayer { //No op } - @Override public void setListeners(Collection listeners) { - fwd.setListeners(listeners); - bwd.setListeners(listeners); + fwd.setListeners(listeners.toArray(new TrainingListener[]{})); + bwd.setListeners(listeners.toArray(new TrainingListener[]{})); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java index 99a2081dc..ac5c57165 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java @@ -24,6 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.GravesBidirectionalLSTMParamInitializer; @@ -41,7 +42,7 @@ public class GravesBidirectionalLSTM protected FwdPassReturn cachedPassForward; protected FwdPassReturn cachedPassBackward; - public GravesBidirectionalLSTM(NeuralNetConfiguration conf, DataType dataType) { + public GravesBidirectionalLSTM(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -74,7 +75,7 @@ public class GravesBidirectionalLSTM final FwdPassReturn fwdPass = activateHelperDirectional(true, null, null, true, true, workspaceMgr); fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput); final Pair forwardsGradient = LSTMHelpers.backpropGradientHelper(this, - this.conf, + this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), permuteIfNWC(epsilon), @@ -89,7 +90,7 @@ public class GravesBidirectionalLSTM final FwdPassReturn backPass = activateHelperDirectional(true, null, null, true, false, workspaceMgr); final Pair backwardsGradient = LSTMHelpers.backpropGradientHelper(this, - this.conf, + this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), permuteIfNWC(epsilon), @@ -117,7 +118,7 @@ public class GravesBidirectionalLSTM final Gradient correctOrderedGradient = new DefaultGradient(); - for (final String key : params.keySet()) { + for (final String key : paramsTable.keySet()) { correctOrderedGradient.setGradientFor(key, combinedGradient.getGradientFor(key)); } @@ -155,7 +156,7 @@ public class GravesBidirectionalLSTM cachedPassForward = null; } else { - forwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), + forwardsEval = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), training, null, null, @@ -163,7 +164,7 @@ public class GravesBidirectionalLSTM GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, maskArray, true, null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); - backwardsEval = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), + backwardsEval = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), @@ -215,7 +216,7 @@ public class GravesBidirectionalLSTM biasKey = GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS; } - FwdPassReturn ret = LSTMHelpers.activateHelper(this, this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + FwdPassReturn ret = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), getParam(recurrentKey), getParam(inputKey), getParam(biasKey), training, prevOutputActivations, prevMemCellState, forBackprop, forwards, inputKey, maskArray, true, null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java index 1e37cfe32..5aedd780b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java @@ -24,6 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; import org.nd4j.common.base.Preconditions; @@ -40,7 +41,7 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this, - this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, GravesLSTMParamInitializer.BIAS_KEY, gradientViews, maskArray, true, null, @@ -128,7 +129,7 @@ public class GravesLSTM extends BaseRecurrentLayer { @@ -45,7 +45,7 @@ public class LSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this, - this.conf, this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, LSTMParamInitializer.INPUT_WEIGHT_KEY, LSTMParamInitializer.RECURRENT_WEIGHT_KEY, LSTMParamInitializer.BIAS_KEY, gradientViews, null, false, helper, workspaceMgr, @@ -161,7 +161,7 @@ public class LSTM extends BaseRecurrentLayer= endIdx; iTimeIndex--) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java index 4656ce9d1..da5f0b782 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/LastTimeStepLayer.java @@ -57,7 +57,7 @@ public class LastTimeStepLayer extends BaseWrapperLayer { public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { long[] newEpsShape = origOutputShape; - boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC; + boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.getLayerConfiguration()) == RNNFormat.NWC; INDArray newEps = Nd4j.create(epsilon.dataType(), newEpsShape, 'f'); if(lastTimeStepIdxs == null){ //no mask case @@ -119,7 +119,7 @@ public class LastTimeStepLayer extends BaseWrapperLayer { "rank " + in.rank() + " with shape " + Arrays.toString(in.shape())); } origOutputShape = in.shape(); - boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.conf().getLayer()) == RNNFormat.NWC; + boolean nwc = TimeSeriesUtils.getFormatFromRnnLayer(underlying.getLayerConfiguration()) == RNNFormat.NWC; // underlying instanceof BaseRecurrentLayer && ((BaseRecurrentLayer)underlying).getDataFormat() == RNNFormat.NWC)|| // underlying instanceof MaskZeroLayer && ((MaskZeroLayer)underlying).getUnderlying() instanceof BaseRecurrentLayer && // ((BaseRecurrentLayer)((MaskZeroLayer)underlying).getUnderlying()).getDataFormat() == RNNFormat.NWC; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java index fb2117b9b..de9d75928 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -47,7 +48,7 @@ import java.util.List; public class RnnLossLayer extends BaseLayer implements IOutputLayer { @Setter @Getter protected INDArray labels; - public RnnLossLayer(NeuralNetConfiguration conf, DataType dataType) { + public RnnLossLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java index edede1c1f..63f64e95e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnOutputLayer.java @@ -24,6 +24,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.params.DefaultParamInitializer; @@ -40,7 +41,7 @@ import java.util.Arrays; public class RnnOutputLayer extends BaseOutputLayer { - public RnnOutputLayer(NeuralNetConfiguration conf, DataType dataType) { + public RnnOutputLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java index 0176ce720..b0437fe1a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java @@ -22,6 +22,7 @@ package org.deeplearning4j.nn.layers.recurrent; import lombok.val; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.SimpleRnnParamInitializer; @@ -47,7 +48,7 @@ public class SimpleRnn extends BaseRecurrentLayer paramTable(boolean backpropOnly) { + public Map getParamTable(boolean backpropOnly) { return paramTable; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 021c1a5aa..748f91564 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -24,6 +24,7 @@ import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -64,7 +65,7 @@ public class SameDiffLayer extends AbstractLayer { protected Map gradTable; - public SameDiffLayer(NeuralNetConfiguration conf, DataType dataType){ + public SameDiffLayer(LayerConfiguration conf, DataType dataType){ super(conf, dataType); } @@ -271,7 +272,7 @@ public class SameDiffLayer extends AbstractLayer { @Override public void setBackpropGradientsViewArray(INDArray gradients) { this.gradients = gradients; - this.gradTable = layerConf().initializer().getGradientsFromFlattened(conf(), gradients); + this.gradTable = layerConf().initializer().getGradientsFromFlattened(this.getLayerConfiguration(), gradients); } @Override @@ -286,12 +287,12 @@ public class SameDiffLayer extends AbstractLayer { } @Override - public Map paramTable() { - return paramTable(false); + public Map getParamTable() { + return getParamTable(false); } @Override - public Map paramTable(boolean backpropParamsOnly) { + public Map getParamTable(boolean backpropParamsOnly) { return paramTable; } @@ -301,7 +302,7 @@ public class SameDiffLayer extends AbstractLayer { sameDiff = SameDiff.create(); //Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe) sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); - Map p = paramTable(); + Map p = getParamTable(); long[] inputShape = input.shape().clone(); inputShape[0] = -1; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java index 67cf9e648..d3cc93049 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java @@ -26,6 +26,7 @@ import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -67,7 +68,7 @@ public class SameDiffOutputLayer extends AbstractLayer gradTable; - public SameDiffOutputLayer(NeuralNetConfiguration conf, DataType dataType){ + public SameDiffOutputLayer(LayerConfiguration conf, DataType dataType){ super(conf, dataType); } @@ -277,7 +278,7 @@ public class SameDiffOutputLayer extends AbstractLayer paramTable() { - return paramTable(false); + public Map getParamTable() { + return getParamTable(false); } @Override - public Map paramTable(boolean backpropParamsOnly) { + public Map getParamTable(boolean backpropParamsOnly) { return paramTable; } @@ -307,7 +308,7 @@ public class SameDiffOutputLayer extends AbstractLayer p = paramTable(); + Map p = getParamTable(); long[] inputShape = input.shape().clone(); inputShape[0] = -1; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.java index 4f85849e8..683f8b4d9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/training/CenterLossOutputLayer.java @@ -22,6 +22,7 @@ package org.deeplearning4j.nn.layers.training; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseOutputLayer; @@ -39,7 +40,7 @@ public class CenterLossOutputLayer extends BaseOutputLayer { private final Gradient emptyGradient = new DefaultGradient(); - public MaskLayer(NeuralNetConfiguration conf, DataType dataType) { + public MaskLayer(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java index 75df1dfad..b3ff49db5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.variational.CompositeReconstructionDistribution; import org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper; import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution; @@ -39,12 +40,17 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.Solver; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.blas.Level1; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.regularization.Regularization; @@ -65,7 +71,7 @@ public class VariationalAutoencoder implements Layer { protected Map params; @Getter protected transient Map gradientViews; - protected NeuralNetConfiguration conf; + protected double score = 0.0; protected ConvexOptimizer optimizer; protected Gradient gradient; @@ -91,27 +97,50 @@ public class VariationalAutoencoder implements Layer { @Getter @Setter protected int epochCount; - public VariationalAutoencoder(NeuralNetConfiguration conf, DataType dataType) { - this.conf = conf; + @Getter @Setter @NonNull + private LayerConfiguration layerConfiguration; + + public VariationalAutoencoder(@NonNull LayerConfiguration layerConfiguration, DataType dataType) { + this.layerConfiguration = layerConfiguration; this.dataType = dataType; this.encoderLayerSizes = - ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) conf.getLayer()) + ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) layerConfiguration) .getEncoderLayerSizes(); this.decoderLayerSizes = - ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) conf.getLayer()) + ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) layerConfiguration) .getDecoderLayerSizes(); this.reconstructionDistribution = - ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) conf.getLayer()) + ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) layerConfiguration) .getOutputDistribution(); - this.pzxActivationFn = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) conf.getLayer()) + this.pzxActivationFn = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) layerConfiguration) .getPzxActivationFn(); - this.numSamples = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) conf.getLayer()) + this.numSamples = ((org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) layerConfiguration) .getNumSamples(); } protected org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder layerConf() { - return (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) conf().getLayer(); + return (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) layerConfiguration; + } + + /** + * Return the configuration of this layer + * + * @return the configuration + */ + @Override + public LayerConfiguration getLayerConfiguration() { + return layerConf(); + } + + /** + * Set a new layer configuration, new init() needs to be called afterwards. + * + * @param lconf layer configuration + */ + @Override + public void setLayerConfiguration(LayerConfiguration lconf) { + } @Override @@ -123,7 +152,7 @@ public class VariationalAutoencoder implements Layer { } protected String layerId() { - String name = this.conf().getLayer().getLayerName(); + String name = this.getLayerConfiguration().getLayerName(); return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + index + ")"; } @@ -470,9 +499,19 @@ public class VariationalAutoencoder implements Layer { return paramsFlattened; } + /** + * The param table + * + * @return + */ + @Override + public Map getParamTable() { + return null; + } + @Override public TrainingConfig getConfig() { - return conf.getLayer(); + return layerConfiguration; } @Override @@ -522,7 +561,7 @@ public class VariationalAutoencoder implements Layer { } this.gradientsFlattened = gradients; - this.gradientViews = conf.getLayer().initializer().getGradientsFromFlattened(conf, gradients); + this.gradientViews = layerConfiguration.initializer().getGradientsFromFlattened(this.layerConfiguration, gradients); } @Override @@ -548,14 +587,22 @@ public class VariationalAutoencoder implements Layer { return (int) input.size(0); } + /** + * The configuration for the neural network + * + * @return the configuration for the neural network + */ @Override - public NeuralNetConfiguration conf() { - return conf; + public NeuralNetConfiguration getNetConfiguration() { + return this.layerConfiguration.getNetConfiguration(); } + /** + * @param netConfiguration + */ @Override - public void setConf(NeuralNetConfiguration conf) { - this.conf = conf; + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { + } @Override @@ -563,23 +610,94 @@ public class VariationalAutoencoder implements Layer { return input; } + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return null; + } + @Override public ConvexOptimizer getOptimizer() { return optimizer; } + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + @Override + public void fit(DataSet dataSet) { + + } + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + @Override + public void fit(MultiDataSet dataSet) { + + } + + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + @Override + public void fit(DataSetIterator iterator) { + + } + + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + @Override + public void fit(MultiDataSetIterator iterator) { + + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(DataSetIterator iterator, T... evaluations) { + return null; + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, + T... evaluations) { + return null; + } + @Override public INDArray getParam(String param) { return params.get(param); } - @Override - public Map paramTable() { - return new LinkedHashMap<>(params); - } @Override - public Map paramTable(boolean backpropParamsOnly) { + public Map getParamTable(boolean backpropParamsOnly) { Map map = new LinkedHashMap<>(); for (Map.Entry e : params.entrySet()) { if (!backpropParamsOnly || !isPretrainParam(e.getKey())) { @@ -601,8 +719,8 @@ public class VariationalAutoencoder implements Layer { @Override public void setParam(String key, INDArray val) { - if (paramTable().containsKey(key)) { - paramTable().get(key).assign(val); + if (getParamTable().containsKey(key)) { + getParamTable().get(key).assign(val); } else { throw new IllegalArgumentException("Unknown parameter: " + key + " - " + layerId()); } @@ -630,7 +748,7 @@ public class VariationalAutoencoder implements Layer { @Override public double calcRegularizationScore(boolean backpropParamsOnly){ double scoreSum = 0.0; - for (Map.Entry e : paramTable().entrySet()) { + for (Map.Entry e : getParamTable().entrySet()) { if(backpropParamsOnly && isPretrainParam(e.getKey())) continue; List l = layerConf().getRegularizationByParam(e.getKey()); @@ -799,7 +917,6 @@ public class VariationalAutoencoder implements Layer { setListeners(Arrays.asList(listeners)); } - @Override public void setListeners(Collection listeners) { if (trainingListeners == null) trainingListeners = new ArrayList<>(); @@ -905,7 +1022,7 @@ public class VariationalAutoencoder implements Layer { if (solver == null) { try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().model(this).configure(conf()).listeners(getListeners()).build(); + solver = new Solver.Builder().model(this).configure(getNetConfiguration()).listeners(getListeners()).build(); } } this.optimizer = solver.getOptimizer(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/wrapper/BaseWrapperLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/wrapper/BaseWrapperLayer.java index 80439cbc5..d27d9cfbb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/wrapper/BaseWrapperLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/wrapper/BaseWrapperLayer.java @@ -20,6 +20,8 @@ package org.deeplearning4j.nn.layers.wrapper; +import java.util.Collection; +import java.util.Map; import lombok.Data; import lombok.NonNull; import org.deeplearning4j.nn.api.Layer; @@ -27,308 +29,321 @@ import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.TrainingListener; -import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; - -import java.util.Collection; -import java.util.Map; +import org.nd4j.linalg.api.ndarray.INDArray; @Data -public abstract class BaseWrapperLayer implements Layer { +public abstract class BaseWrapperLayer extends AbstractLayer { - protected Layer underlying; + protected Layer underlying; - public BaseWrapperLayer(@NonNull Layer underlying){ - this.underlying = underlying; - } - @Override - public void setCacheMode(CacheMode mode) { - underlying.setCacheMode(mode); - } + public BaseWrapperLayer(@NonNull Layer underlying) { + this.underlying = underlying; + } - @Override - public double calcRegularizationScore(boolean backpropParamsOnly){ - return underlying.calcRegularizationScore(backpropParamsOnly); - } + /** + * Return the configuration of this layer (which is the configuration of the underlying layer in + * this case + * + * @return the underlying layer configuration + */ + @Override + public LayerConfiguration getLayerConfiguration() { + return underlying.getLayerConfiguration(); + } - @Override - public Type type() { - return underlying.type(); - } + @Override + public void setLayerConfiguration(LayerConfiguration layerConfiguration) { + underlying.setLayerConfiguration(layerConfiguration); + } - @Override - public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { - return underlying.backpropGradient(epsilon, workspaceMgr); - } + @Override + public void setCacheMode(CacheMode mode) { + underlying.setCacheMode(mode); + } - @Override - public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { - return underlying.activate(training, workspaceMgr); - } + @Override + public double calcRegularizationScore(boolean backpropParamsOnly) { + return underlying.calcRegularizationScore(backpropParamsOnly); + } - @Override - public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { - return underlying.activate(input, training, workspaceMgr); - } + @Override + public Type type() { + return underlying.type(); + } - @Override - public Collection getListeners() { - return underlying.getListeners(); - } + @Override + public Pair backpropGradient(INDArray epsilon, + LayerWorkspaceMgr workspaceMgr) { + return underlying.backpropGradient(epsilon, workspaceMgr); + } - @Override - public void setListeners(TrainingListener... listeners) { - underlying.setListeners(listeners); - } + @Override + public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { + return underlying.activate(training, workspaceMgr); + } - @Override - public void addListeners(TrainingListener... listener) { - underlying.addListeners(listener); - } + @Override + public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { + return underlying.activate(input, training, workspaceMgr); + } - @Override - public void fit() { - underlying.fit(); - } + @Override + public Collection getListeners() { + return underlying.getListeners(); + } - @Override - public void update(Gradient gradient) { - underlying.update(gradient); - } + @Override + public void setListeners(TrainingListener... listeners) { + underlying.setListeners(listeners); + } - @Override - public void update(INDArray gradient, String paramType) { - underlying.update(gradient, paramType); - } + @Override + public void addListeners(TrainingListener... listener) { + underlying.addListeners(listener); + } - @Override - public double score() { - return underlying.score(); - } + @Override + public void fit() { + underlying.fit(); + } - @Override - public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { - underlying.computeGradientAndScore(workspaceMgr); - } + @Override + public void update(Gradient gradient) { + underlying.update(gradient); + } - @Override - public INDArray params() { - return underlying.params(); - } + @Override + public void update(INDArray gradient, String paramType) { + underlying.update(gradient, paramType); + } - @Override - public long numParams() { - return underlying.numParams(); - } + @Override + public double score() { + return underlying.score(); + } - @Override - public long numParams(boolean backwards) { - return underlying.numParams(); - } + @Override + public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { + underlying.computeGradientAndScore(workspaceMgr); + } - @Override - public void setParams(INDArray params) { - underlying.setParams(params); - } + @Override + public INDArray params() { + return underlying.params(); + } - @Override - public void setParamsViewArray(INDArray params) { - underlying.setParamsViewArray(params); - } + @Override + public long numParams() { + return underlying.numParams(); + } - @Override - public INDArray getGradientsViewArray() { - return underlying.getGradientsViewArray(); - } + @Override + public long numParams(boolean backwards) { + return underlying.numParams(); + } - @Override - public void setBackpropGradientsViewArray(INDArray gradients) { - underlying.setBackpropGradientsViewArray(gradients); - } + @Override + public void setParams(INDArray params) { + underlying.setParams(params); + } - @Override - public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) { - underlying.fit(data, workspaceMgr); - } + @Override + public void setParamsViewArray(INDArray params) { + underlying.setParamsViewArray(params); + } - @Override - public Gradient gradient() { - return underlying.gradient(); - } + @Override + public INDArray getGradientsViewArray() { + return underlying.getGradientsViewArray(); + } - @Override - public Pair gradientAndScore() { - return underlying.gradientAndScore(); - } + @Override + public void setBackpropGradientsViewArray(INDArray gradients) { + underlying.setBackpropGradientsViewArray(gradients); + } - @Override - public int batchSize() { - return underlying.batchSize(); - } + @Override + public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) { + underlying.fit(data, workspaceMgr); + } - @Override - public NeuralNetConfiguration conf() { - return underlying.conf(); - } + @Override + public Gradient gradient() { + return underlying.gradient(); + } - @Override - public void setConf(NeuralNetConfiguration conf) { - underlying.setConf(conf); - } + @Override + public Pair gradientAndScore() { + return underlying.gradientAndScore(); + } - @Override - public INDArray input() { - return underlying.input(); - } + @Override + public int batchSize() { + return underlying.batchSize(); + } - @Override - public ConvexOptimizer getOptimizer() { - return underlying.getOptimizer(); - } + @Override + public NeuralNetConfiguration getNetConfiguration() { + return underlying.getNetConfiguration(); + } - @Override - public INDArray getParam(String param) { - return underlying.getParam(param); - } + @Override + public INDArray input() { + return underlying.input(); + } - @Override - public Map paramTable() { - return underlying.paramTable(); - } + @Override + public ConvexOptimizer getOptimizer() { + return underlying.getOptimizer(); + } - @Override - public Map paramTable(boolean backpropParamsOnly) { - return underlying.paramTable(backpropParamsOnly); - } + @Override + public INDArray getParam(String param) { + return underlying.getParam(param); + } - @Override - public void setParamTable(Map paramTable) { - underlying.setParamTable(paramTable); - } + @Override + public Map getParamTable() { + return underlying.getParamTable(); + } - @Override - public void setParam(String key, INDArray val) { - underlying.setParam(key, val); - } + /** + * Setter for the param table + * + * @param paramTable Map<String, INDArray> + */ + @Override + public void setParamTable(Map paramTable) { + underlying.setParamTable(paramTable); + } - @Override - public void clear() { - underlying.clear(); - } + @Override + public Map getParamTable(boolean backpropParamsOnly) { + return underlying.getParamTable(backpropParamsOnly); + } - @Override - public void applyConstraints(int iteration, int epoch) { - underlying.applyConstraints(iteration, epoch); - } + @Override + public void setParam(String key, INDArray val) { + underlying.setParam(key, val); + } - @Override - public void init() { - underlying.init(); - } + @Override + public void clear() { + underlying.clear(); + } - @Override - public void setListeners(Collection listeners) { - underlying.setListeners(listeners); - } + @Override + public void applyConstraints(int iteration, int epoch) { + underlying.applyConstraints(iteration, epoch); + } - @Override - public void setIndex(int index) { - underlying.setIndex(index); - } + @Override + public void init() { + underlying.init(); + } - @Override - public int getIndex() { - return underlying.getIndex(); - } + @Override + public int getIndex() { + return underlying.getIndex(); + } - @Override - public int getIterationCount() { - return underlying.getIterationCount(); - } + @Override + public void setIndex(int index) { + underlying.setIndex(index); + } - @Override - public int getEpochCount() { - return underlying.getEpochCount(); - } + @Override + public int getIterationCount() { + return underlying.getIterationCount(); + } - @Override - public void setIterationCount(int iterationCount) { - underlying.setIterationCount(iterationCount); - } + @Override + public void setIterationCount(int iterationCount) { + underlying.setIterationCount(iterationCount); + } - @Override - public void setEpochCount(int epochCount) { - underlying.setEpochCount(epochCount); - } + @Override + public int getEpochCount() { + return underlying.getEpochCount(); + } - @Override - public void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr) { - underlying.setInput(input, workspaceMgr); - } + @Override + public void setEpochCount(int epochCount) { + underlying.setEpochCount(epochCount); + } - @Override - public void setInputMiniBatchSize(int size) { - underlying.setInputMiniBatchSize(size); - } + @Override + public void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr) { + underlying.setInput(input, workspaceMgr); + } - @Override - public int getInputMiniBatchSize() { - return underlying.getInputMiniBatchSize(); - } + @Override + public int getInputMiniBatchSize() { + return underlying.getInputMiniBatchSize(); + } - @Override - public void setMaskArray(INDArray maskArray) { - underlying.setMaskArray(maskArray); - } + @Override + public void setInputMiniBatchSize(int size) { + underlying.setInputMiniBatchSize(size); + } - @Override - public INDArray getMaskArray() { - return underlying.getMaskArray(); - } + @Override + public INDArray getMaskArray() { + return underlying.getMaskArray(); + } - @Override - public boolean isPretrainLayer() { - return underlying.isPretrainLayer(); - } + @Override + public void setMaskArray(INDArray maskArray) { + underlying.setMaskArray(maskArray); + } - @Override - public void clearNoiseWeightParams() { - underlying.clearNoiseWeightParams(); - } + @Override + public boolean isPretrainLayer() { + return underlying.isPretrainLayer(); + } - @Override - public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { - return underlying.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); - } + @Override + public void clearNoiseWeightParams() { + underlying.clearNoiseWeightParams(); + } - @Override - public void allowInputModification(boolean allow) { - underlying.allowInputModification(allow); - } + @Override + public Pair feedForwardMaskArray(INDArray maskArray, + MaskState currentMaskState, int minibatchSize) { + return underlying.feedForwardMaskArray(maskArray, currentMaskState, minibatchSize); + } - @Override - public LayerHelper getHelper() { - return underlying.getHelper(); - } + @Override + public void allowInputModification(boolean allow) { + underlying.allowInputModification(allow); + } - @Override - public TrainingConfig getConfig() { - return underlying.getConfig(); - } + @Override + public LayerHelper getHelper() { + return underlying.getHelper(); + } - @Override - public boolean updaterDivideByMinibatch(String paramName) { - return underlying.updaterDivideByMinibatch(paramName); - } + @Override + public TrainingConfig getConfig() { + return underlying.getConfig(); + } - @Override - public void close(){ - //No-op for individual layers - } + @Override + public boolean updaterDivideByMinibatch(String paramName) { + return underlying.updaterDivideByMinibatch(paramName); + } + + @Override + public void close() { + //No-op for individual layers + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 0f81392f9..7da7f837c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -21,12 +21,27 @@ package org.deeplearning4j.nn.multilayer; +import java.io.File; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; import lombok.Getter; import lombok.NonNull; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import lombok.val; -import net.brutex.ai.dnn.api.INeuralNetwork; import net.brutex.ai.dnn.networks.ArtificialNeuralNetwork; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; @@ -34,13 +49,26 @@ import org.bytedeco.javacpp.Pointer; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.api.Classifier; +import org.deeplearning4j.nn.api.FwdPassType; +import org.deeplearning4j.nn.api.Layer; +import org.deeplearning4j.nn.api.MaskState; +import org.deeplearning4j.nn.api.ModelAdapter; +import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.api.Updater; -import org.deeplearning4j.nn.api.*; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.RecurrentLayer; -import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.CNN2DFormat; +import org.deeplearning4j.nn.conf.CacheMode; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetBaseBuilderConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.RNNFormat; +import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -57,9 +85,18 @@ import org.deeplearning4j.optimize.Solver; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator; -import org.deeplearning4j.util.*; +import org.deeplearning4j.util.Convolution1DUtils; +import org.deeplearning4j.util.ConvolutionUtils; +import org.deeplearning4j.util.CrashReportingUtil; +import org.deeplearning4j.util.ModelSerializer; +import org.deeplearning4j.util.NetworkUtils; +import org.deeplearning4j.util.OutputLayerUtil; +import org.jetbrains.annotations.NotNull; import org.nd4j.adapters.OutputAdapter; import org.nd4j.common.base.Preconditions; +import org.nd4j.common.primitives.Pair; +import org.nd4j.common.primitives.Triple; +import org.nd4j.common.util.OneTimeLogger; import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.ROC; @@ -88,16 +125,10 @@ import org.nd4j.linalg.heartbeat.reports.Task; import org.nd4j.linalg.heartbeat.utils.EnvironmentUtils; import org.nd4j.linalg.heartbeat.utils.TaskUtils; import org.nd4j.linalg.indexing.NDArrayIndex; -import org.nd4j.common.primitives.Pair; -import org.nd4j.common.primitives.Triple; import org.nd4j.linalg.schedule.ISchedule; import org.nd4j.linalg.util.FeatureUtil; import org.nd4j.linalg.workspace.ND4JWorkspaceException; import org.nd4j.linalg.workspace.WorkspaceUtils; -import org.nd4j.common.util.OneTimeLogger; - -import java.io.*; -import java.util.*; /** * Artificial Neural Network An artificial neural network (1) takes some input data, and (2) @@ -115,8 +146,8 @@ import java.util.*; * weights (or parameters) so that predictions get more accurate. */ @Slf4j -public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serializable, Classifier, Layer, - INeuralNetwork { +public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serializable, Classifier, + Layer { /** * Workspace for working memory for a single layer: forward pass and backward pass Note that this @@ -155,15 +186,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial .initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT) .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) .policyLearning(LearningPolicy.FIRST_LOOP).build(); + //the hidden neural network layers (including output layer) protected Layer[] layers; - protected LinkedHashMap layerMap = new LinkedHashMap<>(); + //Current training data: input features and labels protected INDArray input, labels; protected boolean initCalled = false; protected Collection trainingListeners = new ArrayList<>(); - protected NeuralNetConfiguration defaultConfiguration; - protected MultiLayerConfiguration layerWiseConfigurations; protected Gradient gradient; protected double score; @Setter @@ -174,7 +204,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial protected boolean clearTbpttState = true; //Mainly for unit testing (should be enabled otherwise) protected transient ThreadLocal lastEtlTime = new ThreadLocal<>(); protected INDArray mask; - protected int layerIndex; //For Layer.get/setIndex() + protected int layerIndex; //For LayerConfiguration.get/setIndex() protected transient Solver solver; //Used to call optimizers during backprop //Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers @Getter @@ -183,27 +213,34 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG; - public MultiLayerNetwork(MultiLayerConfiguration conf) { - this.layerWiseConfigurations = conf; - this.defaultConfiguration = conf.getConf(0).clone(); + public MultiLayerNetwork(@NotNull NeuralNetConfiguration conf) { + super(conf); //Working memory: should learn over course of: (a) full forward pass, and (b) full backward pass //Working memory should be opened once per layer and once per preprocessor, for each of forward and backward passes - int numWorkingMem = 2 * (layerWiseConfigurations.getConfs().size() - + layerWiseConfigurations.getInputPreProcessors().size()); + int numWorkingMem = 2 * (conf.getFlattenedLayerConfigurations().size() + + conf.getInputPreProcessors().size()); WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); - WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(layerWiseConfigurations.getConfs().size()); + WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig( + conf.getFlattenedLayerConfigurations().size()); + + init(); } + public MultiLayerNetwork(@NotNull NeuralNetBaseBuilderConfiguration conf) { + this(( NeuralNetConfiguration) conf); + } + + /** - * Initialize the network based on the configuration (a MultiLayerConfiguration in JSON format) - * and parameters array + * Initialize the network based on the configuration (a NeuralNetConfiguration in JSON format) and + * parameters array * * @param conf the configuration json * @param params the parameters for the network */ public MultiLayerNetwork(String conf, INDArray params) { - this(MultiLayerConfiguration.fromJson(conf)); + this(NeuralNetConfiguration.fromJson(conf)); init(); setParameters(params); } @@ -214,7 +251,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * @param conf the configuration * @param params the parameters */ - public MultiLayerNetwork(MultiLayerConfiguration conf, INDArray params) { + public MultiLayerNetwork(NeuralNetConfiguration conf, INDArray params) { this(conf); init(); setParameters(params); @@ -261,6 +298,28 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return ModelSerializer.restoreMultiLayerNetwork(f, loadUpdater); } + /** + * Return the configuration of this layer + * + * @return the configuration + */ + @Override + public LayerConfiguration getLayerConfiguration() { + //TODO + throw new RuntimeException( + "getLayerConfiguration cannot be called on a MultiLayerNetwork. This function is here because of inheritance from Layer (which should be fixed)."); + } + + /** + * Set a new layer configuration, new init() needs to be called afterwards. + * + * @param lconf layer configuration + */ + @Override + public void setLayerConfiguration(LayerConfiguration lconf) { + throw new RuntimeException("setLayerConfiguration has no effect on a MultiLayerNetwork"); + } + /** * This method sets specified CacheMode for all layers within network * @@ -299,20 +358,6 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial lastEtlTime.set(time); } - protected void intializeConfigurations() { - if (layerWiseConfigurations == null) { - layerWiseConfigurations = new MultiLayerConfiguration.Builder().build(); - } - - if (layers == null) { - layers = new Layer[getnLayers()]; - } - - if (defaultConfiguration == null) { - defaultConfiguration = new NeuralNetConfiguration.Builder().build(); - } - } - /** * Perform layerwise pretraining for one epoch - see {@link #pretrain(DataSetIterator, int)} */ @@ -396,8 +441,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - int ec = getLayer(layerIdx).conf().getEpochCount() + 1; - getLayer(layerIdx).conf().setEpochCount(ec); + int ec = getLayer(layerIdx).getNetConfiguration().getEpochCount() + 1; + getLayer(layerIdx).getNetConfiguration().setEpochCount(ec); } /** @@ -422,7 +467,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } LayerWorkspaceMgr workspaceMgr; - if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { workspaceMgr = LayerWorkspaceMgr.builder() @@ -450,12 +495,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } try (MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { - if (layerWiseConfigurations.getInputPreProcess(layerIdx) != null) { + if (getNetConfiguration().getInputPreProcess(layerIdx) != null) { if (input.size(0) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - outputOfPrevLayer = layerWiseConfigurations.getInputPreProcess(layerIdx) + outputOfPrevLayer = getNetConfiguration().getInputPreProcess(layerIdx) .preProcess(outputOfPrevLayer, (int) input.size(0), LayerWorkspaceMgr.noWorkspaces(helperWorkspaces)); } @@ -475,16 +520,6 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return labels == null ? (int) input.size(0) : (int) labels.size(0); } - @Override - public NeuralNetConfiguration conf() { - return defaultConfiguration; - } - - @Override - public void setConf(NeuralNetConfiguration conf) { - throw new UnsupportedOperationException(); - } - @Override public INDArray input() { return input; @@ -498,14 +533,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Get one parameter array for the network.
In MultiLayerNetwork, parameters are keyed like * "0_W" and "0_b" to mean "weights of layer index 0" and "biases of layer index 0" respectively. - * Numbers increment sequentially, and the suffixes ("W", "b" etc) depend on the layer type, and + * Numbers increment sequentially, and the suffixes ("W", "b" etc.) depend on the layer type, and * are defined in the relevant parameter initializers for each layer.
Note that the returned * INDArrays are views of the underlying network parameters, so modifications of the returned * arrays will impact the parameters of the network. * * @param param the key of the parameter * @return The specified parameter array for the network - * @see #paramTable() paramTable() method, for a map of all parameters + * @see #getParamTable() paramTable() method, for a map of all parameters */ @Override public INDArray getParam(String param) { @@ -521,20 +556,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return layers[layerIdx].getParam(newKey); } - /** - * Return a map of all parameters in the network. Parameter names are as described in - * {@link #getParam(String)}. As per {@link #getParam(String)} the returned arrays are views - - * modifications to these will impact the underlying network parameters - * - * @return A map of all parameters in the network - */ - @Override - public Map paramTable() { - return paramTable(false); - } /** - * Returns a map of all parameters in the network as per {@link #paramTable()}.
Optionally + * Returns a map of all parameters in the network as per {@link #getParamTable()}.
Optionally * (with backpropParamsOnly=true) only the 'backprop' parameters are returned - that is, any * parameters involved only in unsupervised layerwise pretraining not standard inference/backprop * are excluded from the returned list. @@ -546,7 +570,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial //Get all parameters from all layers Map allParams = new LinkedHashMap<>(); for (int i = 0; i < layers.length; i++) { - Map paramMap = layers[i].paramTable(backpropParamsOnly); + Map paramMap = layers[i].getParamTable(backpropParamsOnly); for (Map.Entry entry : paramMap.entrySet()) { String newKey = i + "_" + entry.getKey(); allParams.put(newKey, entry.getValue()); @@ -568,7 +592,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Set the parameters of the netowrk. Note that the parameter keys must match the format as - * described in {@link #getParam(String)} and {@link #paramTable()}. Note that the values of the + * described in {@link #getParam(String)} and {@link #getParamTable()}. Note that the values of the * parameters used as an argument to this method are copied - i.e., it is safe to later * modify/reuse the values in the provided paramTable without this impacting the network. * @@ -576,7 +600,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial */ @Override public void setParamTable(Map paramTable) { - Map currParamTable = paramTable(); + Map currParamTable = getParamTable(); if (!currParamTable.keySet().equals(paramTable.keySet())) { throw new IllegalArgumentException( "Cannot set param table: parameter keys do not match.\n" + "Current: " @@ -623,22 +647,6 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial layers[layerIdx].setParam(newKey, val); } - /** - * Get the configuration for the network - * - * @return Network configuration - */ - public MultiLayerConfiguration getLayerWiseConfigurations() { - return layerWiseConfigurations; - } - - /** - * This method is intended for internal/developer use only. - */ - public void setLayerWiseConfigurations(MultiLayerConfiguration layerWiseConfigurations) { - this.layerWiseConfigurations = layerWiseConfigurations; - } - /** * Initialize the MultiLayerNetwork. This should be called once before the network is used. This * is functionally equivalent to calling {@code init(null, false)}. @@ -660,20 +668,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * directly */ public void init(INDArray parameters, boolean cloneParametersArray) { - if (layerWiseConfigurations == null || layers == null) { - intializeConfigurations(); - } if (initCalled) { return; } - DataType netDtype = getLayerWiseConfigurations().getDataType(); + DataType netDtype = getNetConfiguration().getDataType(); if (parameters != null && parameters.dataType() != netDtype) { Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters); if (cloneParametersArray) { - try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { parameters = parameters.castTo(netDtype); } } else { @@ -685,29 +690,25 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - if (layerMap == null) { - layerMap = new LinkedHashMap<>(); + if (getNetConfiguration().getTrainingWorkspaceMode() == null) { + getNetConfiguration().setTrainingWorkspaceMode(WorkspaceMode.NONE); } - if (layerWiseConfigurations.getTrainingWorkspaceMode() == null) { - layerWiseConfigurations.setTrainingWorkspaceMode(WorkspaceMode.NONE); + if (getNetConfiguration().getInferenceWorkspaceMode() == null) { + getNetConfiguration().setInferenceWorkspaceMode(WorkspaceMode.NONE); } - if (layerWiseConfigurations.getInferenceWorkspaceMode() == null) { - layerWiseConfigurations.setInferenceWorkspaceMode(WorkspaceMode.NONE); - } - - if (layerWiseConfigurations.getCacheMode() == null) { - layerWiseConfigurations.setCacheMode(CacheMode.NONE); + if (getNetConfiguration().getCacheMode() == null) { + getNetConfiguration().setCacheMode(CacheMode.NONE); } OneTimeLogger.info(log, "Starting MultiLayerNetwork with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", - layerWiseConfigurations.getTrainingWorkspaceMode(), - layerWiseConfigurations.getInferenceWorkspaceMode(), - layerWiseConfigurations.getCacheMode()); + getNetConfiguration().getTrainingWorkspaceMode(), + getNetConfiguration().getInferenceWorkspaceMode(), + getNetConfiguration().getCacheMode()); - int nLayers = getnLayers(); + int nLayers = getNetConfiguration().getFlattenedLayerConfigurations().size(); if (nLayers < 1) { throw new IllegalStateException("Unable to create network: number of layers is less than 1"); @@ -722,9 +723,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial long paramLength = 0; val nParamsPerLayer = new long[nLayers]; for (int i = 0; i < nLayers; i++) { - NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); - conf.getLayer().setDataType(netDtype); - nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf); + LayerConfiguration layer_conf = getNetConfiguration().getFlattenedLayerConfigurations().get(i); + layer_conf.setDataType(netDtype); + nParamsPerLayer[i] = layer_conf.initializer().numParams(layer_conf); paramLength += nParamsPerLayer[i]; } @@ -757,7 +758,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial //Set RNG seed, for repeatability between initializations when set if (initializeParams) { - Nd4j.getRandom().setSeed(getDefaultConfiguration().getSeed()); + Nd4j.getRandom().setSeed(getNetConfiguration().getSeed()); } // construct multi-layer @@ -771,28 +772,27 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial paramsView = null; } paramCountSoFar += nParamsPerLayer[i]; - - NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); - layers[i] = conf.getLayer() - .instantiate(conf, trainingListeners, i, paramsView, initializeParams, netDtype); - layerMap.put(conf.getLayer().getLayerName(), layers[i]); + @NonNull + LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i); + layers[i] = lc.instantiate(lc.getNetConfiguration(), trainingListeners, i, paramsView, initializeParams, + netDtype); } initCalled = true; } - //Set parameters in MultiLayerNetwork.defaultConfiguration for later use in BaseOptimizer.setupSearchState() etc - defaultConfiguration.clearVariables(); - List variables = defaultConfiguration.variables(false); + //Set parameters in MultiLayerNetwork.getNetConfiguration() for later use in BaseOptimizer.setupSearchState() etc + getNetConfiguration().clearNetWideVariable(); + List variables = getNetConfiguration().netWideVariables(false); for (int i = 0; i < layers.length; i++) { if (layers[i] == null) { throw new IllegalStateException( "Encountered null layer during initialization for layer " + i + - ": " + layerWiseConfigurations.getConf(i).getLayer().getClass().getSimpleName() + ": " + layers[i].getClass().getSimpleName() + " initialization " + "returned null layer?"); } - for (String s : layers[i].conf().variables()) { + for (String s : layers[i].getNetConfiguration().netWideVariables()) { variables.add(i + "_" + s); } } @@ -800,7 +800,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial // now we init solver & optimizer if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) .build(); solver.initOptimizer(); } @@ -832,7 +832,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) .build(); } } @@ -861,8 +861,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial long paramLength = 0; val nParamsPerLayer = new long[nLayers]; for (int i = 0; i < nLayers; i++) { - NeuralNetConfiguration conf = layerWiseConfigurations.getConf(i); - nParamsPerLayer[i] = conf.getLayer().initializer().numParams(conf); + LayerConfiguration layerConfiguration = getNetConfiguration().getFlattenedLayerConfigurations().get(i); + nParamsPerLayer[i] = layerConfiguration.initializer().numParams(layerConfiguration); //TODO better initialisation paramLength += nParamsPerLayer[i]; } @@ -886,8 +886,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial protected INDArray activationFromPrevLayer(int curr, INDArray input, boolean training, LayerWorkspaceMgr mgr) { - if (getLayerWiseConfigurations().getInputPreProcess(curr) != null) { - input = getLayerWiseConfigurations().getInputPreProcess(curr) + if (getNetConfiguration().getInputPreProcess(curr) != null) { + input = getNetConfiguration().getInputPreProcess(curr) .preProcess(input, getInputMiniBatchSize(), mgr); } @@ -1060,10 +1060,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial try { mgr.validateArrayLocation(arrayType, array, false, layerIdx > 0); } catch (ND4JWorkspaceException e) { - String layerName = layers[layerIdx].conf().getLayer().getLayerName(); + String layerName = layers[layerIdx].getLayerConfiguration().getLayerName(); String clazz; if (isPreprocessor) { - clazz = layerWiseConfigurations.getInputPreProcess(layerIdx).getClass().getName(); + clazz = getNetConfiguration().getInputPreProcess(layerIdx).getClass().getName(); } else { clazz = layers[layerIdx].getClass().getName(); } @@ -1106,8 +1106,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial "Expected no workspace active in ffToLayerActivationsDetached"); LayerWorkspaceMgr workspaceMgr; - WorkspaceMode wsm = (train ? layerWiseConfigurations.getTrainingWorkspaceMode() - : layerWiseConfigurations.getInferenceWorkspaceMode()); + WorkspaceMode wsm = (train ? getNetConfiguration().getTrainingWorkspaceMode() + : getNetConfiguration().getInferenceWorkspaceMode()); if (wsm == WorkspaceMode.NONE) { workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { @@ -1137,8 +1137,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial for (int i = 0; i <= layerIndex; i++) { try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered( ArrayType.FF_WORKING_MEM)) { - if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { - input = getLayerWiseConfigurations().getInputPreProcess(i) + if (getNetConfiguration().getInputPreProcess(i) != null) { + input = getNetConfiguration().getInputPreProcess(i) .preProcess(input, getInputMiniBatchSize(), workspaceMgr); //Validation: Exception if invalid (bad preprocessor implementation) validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, @@ -1207,7 +1207,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial setLayerMaskArrays(fMask, lMask); LayerWorkspaceMgr workspaceMgr; - if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { WorkspaceUtils.assertNoWorkspacesOpen( "Expected no workspace active in ffToLayerActivationsInWs when training workspace is set to NONE"); workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); @@ -1225,7 +1225,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); } - if (layerWiseConfigurations.getCacheMode() != CacheMode.NONE) { + if (getNetConfiguration().getCacheMode() != CacheMode.NONE) { //For now: store cache mode activations in activations workspace workspaceMgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); workspaceMgr.setWorkspace(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, @@ -1245,8 +1245,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial for (int i = 0; i <= layerIndex; i++) { try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered( ArrayType.FF_WORKING_MEM)) { - if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { - input = getLayerWiseConfigurations().getInputPreProcess(i) + if (getNetConfiguration().getInputPreProcess(i) != null) { + input = getNetConfiguration().getInputPreProcess(i) .preProcess(input, getInputMiniBatchSize(), workspaceMgr); //Validation: Exception if invalid (bad preprocessor implementation) validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, @@ -1280,7 +1280,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } if (input == null) { - throw new IllegalStateException("Layer " + i + " returned null activations"); + throw new IllegalStateException("LayerConfiguration " + i + " returned null activations"); } //Validation: Exception if invalid (bad layer implementation) @@ -1355,8 +1355,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial LayerWorkspaceMgr mgrEven; LayerWorkspaceMgr mgrOdd; - WorkspaceMode wsm = train ? layerWiseConfigurations.getTrainingWorkspaceMode() - : layerWiseConfigurations.getInferenceWorkspaceMode(); + WorkspaceMode wsm = train ? getNetConfiguration().getTrainingWorkspaceMode() + : getNetConfiguration().getInferenceWorkspaceMode(); if (wsm == WorkspaceMode.NONE) { mgrEven = LayerWorkspaceMgr.noWorkspaces(); mgrOdd = mgrEven; @@ -1368,7 +1368,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial : "inference") + " workspace mode is set to NONE. Cannot put output activations into the specified workspace if" + - "workspaces are disabled for the network. use getConfiguration().setTraining/InferenceWorkspaceMode(WorkspaceMode.ENABLED)"); + "workspaces are disabled for the network. use getNetConfiguration().setTraining/InferenceWorkspaceMode(WorkspaceMode.ENABLED)"); } } else { mgrEven = LayerWorkspaceMgr.builder() @@ -1430,8 +1430,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); } - if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { - input = getLayerWiseConfigurations().getInputPreProcess(i) + if (getNetConfiguration().getInputPreProcess(i) != null) { + input = getNetConfiguration().getInputPreProcess(i) .preProcess(input, getInputMiniBatchSize(), mgr); //Validation: Exception if invalid (bad preprocessor implementation) validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, @@ -1451,13 +1451,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (fwdPassType == FwdPassType.STANDARD) { //Standard feed-forward case - if (i > 0 && ConvolutionUtils.layerHasConvolutionLayout(layers[i - 1].conf().getLayer()) - && ConvolutionUtils.layerHasConvolutionLayout(layers[i].conf().getLayer())) { + if (i > 0 && ConvolutionUtils.layerHasConvolutionLayout( + layers[i - 1].getLayerConfiguration()) + && ConvolutionUtils.layerHasConvolutionLayout(layers[i].getLayerConfiguration())) { CNN2DFormat preLayerFormat = ConvolutionUtils.getFormatForLayer( - layers[i - 1].conf().getLayer()); + layers[i - 1].getLayerConfiguration()); CNN2DFormat currLayerFormat = ConvolutionUtils.getFormatForLayer( - layers[i].conf().getLayer()); + layers[i].getLayerConfiguration()); if (preLayerFormat != currLayerFormat) { //NHWC case if (preLayerFormat == CNN2DFormat.NCHW) { @@ -1474,12 +1475,13 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } input = layers[i].activate(input, train, mgr); - } else if (i > 0 && Convolution1DUtils.hasRnnDataFormat(layers[i - 1].conf().getLayer()) - && Convolution1DUtils.hasRnnDataFormat(layers[i].conf().getLayer())) { + } else if (i > 0 && Convolution1DUtils.hasRnnDataFormat( + layers[i - 1].getLayerConfiguration()) + && Convolution1DUtils.hasRnnDataFormat(layers[i].getLayerConfiguration())) { RNNFormat preLayerFormat = Convolution1DUtils.getRnnFormatFromLayer( - layers[i - 1].conf().getLayer()); + layers[i - 1].getLayerConfiguration()); RNNFormat currLayerFormat = Convolution1DUtils.getRnnFormatFromLayer( - layers[i].conf().getLayer()); + layers[i].getLayerConfiguration()); //permute for next layer if (preLayerFormat != currLayerFormat) { input = input.permute(0, 2, 1); @@ -1653,7 +1655,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (!initCalled) { init(); } - MultiLayerConfiguration conf = this.layerWiseConfigurations.clone(); + NeuralNetConfiguration conf = this.getNetConfiguration().clone(); MultiLayerNetwork ret = new MultiLayerNetwork(conf); ret.init(this.params().dup(), false); @@ -1698,7 +1700,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Returns a 1 x m vector where the vector is composed of a flattened vector of all of the - * parameters in the network.
See {@link #getParam(String)} and {@link #paramTable()} for a + * parameters in the network.
See {@link #getParam(String)} and {@link #getParamTable()} for a * more useful/interpretable representation of the parameters.
Note that the parameter vector * is not a copy, and changes to the returned INDArray will impact the network parameters. * @@ -1709,6 +1711,28 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return flattenedParams; } + /** + * The param table + * + * @return + */ + @Override + public Map getParamTable() { + return null; + } + + /** + * Table of parameters by key, for backprop. For many models (dense layers, etc) - all parameters + * are backprop parameters + * + * @param backpropParamsOnly If true, return backprop params only. If false: return all params + * (equivalent to paramsTable()) + */ + @Override + public Map getParamTable(boolean backpropParamsOnly) { + return null; + } + /** * Set the parameters for this model. This expects a linear ndarray which then be unpacked * internally relative to the expected ordering of the model.
See also: @@ -1868,7 +1892,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } LayerWorkspaceMgr workspaceMgr; - if (getLayerWiseConfigurations().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { workspaceMgr = LayerWorkspaceMgr.builder() @@ -1908,7 +1932,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial boolean hasMaskArrays = next.hasMaskArrays(); - if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) { + if (getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT) { doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(), next.getLabelsMaskArray(), workspaceMgr); } else { @@ -1921,7 +1945,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) .build(); } } @@ -1983,7 +2007,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial setLayerMaskArrays(fMask, labelMask); LayerWorkspaceMgr mgr; - if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { mgr = LayerWorkspaceMgr.noWorkspaces(); } else { mgr = LayerWorkspaceMgr.builder() @@ -1997,7 +2021,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial WS_RNN_LOOP_WORKING_MEM_CONFIG) .build(); - if (layerWiseConfigurations.getCacheMode() != null) { + if (getNetConfiguration().getCacheMode() != null) { //For now: store cache mode activations in activations workspace mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); } @@ -2018,8 +2042,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } INDArray inputToOutputLayer = activations.get(activations.size() - 1); - if (layerWiseConfigurations.getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = layerWiseConfigurations.getInputPreProcess(layers.length - 1) + if (getNetConfiguration().getInputPreProcess(layers.length - 1) != null) { + inputToOutputLayer = getNetConfiguration().getInputPreProcess(layers.length - 1) .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); //Validate activations location } @@ -2059,7 +2083,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial LayerWorkspaceMgr mgrEven; LayerWorkspaceMgr mgrOdd; - if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { mgrEven = LayerWorkspaceMgr.noWorkspaces(); mgrOdd = mgrEven; WorkspaceUtils.assertNoWorkspacesOpen( @@ -2188,7 +2212,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial //TBPTT gradient if (layers[i] instanceof RecurrentLayer) { currPair = ((RecurrentLayer) layers[i]).tbpttBackpropGradient(currPair.getSecond(), - layerWiseConfigurations.getTbpttBackLength(), workspaceMgr); + getNetConfiguration().getTbpttBackLength(), workspaceMgr); } else { currPair = layers[i].backpropGradient(currPair.getSecond(), workspaceMgr); } @@ -2208,9 +2232,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial gradientList.addLast(new Triple<>(multiGradientKey, entry.getValue(), currPair.getFirst().flatteningOrderForVariable(origName))); } - if (getLayerWiseConfigurations().getInputPreProcess(i) != null) { + if (getNetConfiguration().getInputPreProcess(i) != null) { currPair = new Pair<>(currPair.getFirst(), - this.layerWiseConfigurations.getInputPreProcess(i) + this.getNetConfiguration().getInputPreProcess(i) .backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr)); if (i > 0 && currPair.getSecond() != null) { validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, @@ -2276,7 +2300,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { WorkspaceUtils.assertNoWorkspacesOpen( "Expected no workspace active in calcBackpropGradients when " + "training workspace is set to none"); @@ -2313,7 +2337,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return; } - int fwdLen = layerWiseConfigurations.getTbpttFwdLength(); + int fwdLen = getNetConfiguration().getTbpttFwdLength(); update(TaskUtils.buildTask(input, labels)); val timeSeriesLength = input.size(2); long nSubsets = timeSeriesLength / fwdLen; @@ -2342,7 +2366,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) .build(); } } @@ -2401,7 +2425,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } @Override - public void setListeners(Collection listeners) { + public void setListeners(TrainingListener ... listeners) { if (layers == null) { init(); } @@ -2410,30 +2434,15 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } if (solver != null) { - solver.setListeners(listeners); + solver.setListeners(List.of(listeners)); } this.trainingListeners.clear(); if (listeners != null) { - this.trainingListeners.addAll(listeners); + this.trainingListeners.addAll(List.of(listeners)); } } - @Override - public void setListeners(TrainingListener... listeners) { - Collection cListeners = new ArrayList<>(); - //Check: user might have done setListeners(null) thinking this would clear the current listeners. - //This results in an TrainingListener[1] with a single null value -> results in a NPE later - if (listeners != null && listeners.length > 0) { - for (TrainingListener i : listeners) { - if (i != null) { - cListeners.add(i); - } - } - } - setListeners(cListeners); - } - /** * @deprecated Use {@link #getListeners()} */ @@ -2542,7 +2551,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial update(TaskUtils.buildTask(features, labels)); LayerWorkspaceMgr workspaceMgr; - if (layerWiseConfigurations.getTrainingWorkspaceMode() == null) { + if (getNetConfiguration().getTrainingWorkspaceMode() == null) { workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { workspaceMgr = LayerWorkspaceMgr.builder() @@ -2556,12 +2565,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); - if (layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT) { + if (getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT) { doTruncatedBPTT(features, labels, featuresMask, labelsMask, workspaceMgr); } else { if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) .build(); } } @@ -2599,7 +2608,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial @Override public void fit(INDArray examples, int[] labels) { org.deeplearning4j.nn.conf.layers.OutputLayer layerConf = - (org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().conf().getLayer(); + (org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().getLayerConfiguration(); if (layerConf.getNOut() > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); @@ -2861,8 +2870,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial "Final layer is of type: " + getOutputLayer().getClass()); } - WorkspaceMode wsm = (training ? layerWiseConfigurations.getTrainingWorkspaceMode() - : layerWiseConfigurations.getInferenceWorkspaceMode()); + WorkspaceMode wsm = (training ? getNetConfiguration().getTrainingWorkspaceMode() + : getNetConfiguration().getInferenceWorkspaceMode()); LayerWorkspaceMgr mgr; if (wsm == WorkspaceMode.NONE) { mgr = LayerWorkspaceMgr.noWorkspaces(); @@ -2886,8 +2895,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial throw new ND4JArraySizeException(); } IOutputLayer ol = (IOutputLayer) getOutputLayer(); - if (getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = getLayerWiseConfigurations().getInputPreProcess(layers.length - 1) + if (getNetConfiguration().getInputPreProcess(layers.length - 1) != null) { + inputToOutputLayer = getNetConfiguration().getInputPreProcess(layers.length - 1) .preProcess(inputToOutputLayer, (int) data.getFeatures().size(0), mgr); } ol.setInput(inputToOutputLayer, mgr); //Feedforward doesn't include output layer for efficiency @@ -2953,12 +2962,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial INDArray out; if (getOutputLayer() instanceof IOutputLayer) { IOutputLayer ol = (IOutputLayer) getOutputLayer(); - if (layerWiseConfigurations.getInputPreProcess(layers.length - 1) != null) { + if (getNetConfiguration().getInputPreProcess(layers.length - 1) != null) { if (data.getFeatures().size(0) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - inputLast = layerWiseConfigurations.getInputPreProcess(layers.length - 1) + inputLast = getNetConfiguration().getInputPreProcess(layers.length - 1) .preProcess(inputLast, (int) data.getFeatures().size(0), mgr); } @@ -3023,7 +3032,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial //Note: Workspace manager is only ose here for score calculation... other workspace managers are used in the // various FF/backprop methds LayerWorkspaceMgr mgr; - if (layerWiseConfigurations.getTrainingWorkspaceMode() == WorkspaceMode.NONE) { + if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { mgr = LayerWorkspaceMgr.noWorkspaces(); } else { mgr = LayerWorkspaceMgr.builder() @@ -3037,13 +3046,13 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial WS_RNN_LOOP_WORKING_MEM_CONFIG) .build(); - if (layerWiseConfigurations.getCacheMode() != null) { + if (getNetConfiguration().getCacheMode() != null) { //For now: store cache mode activations in activations workspace mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); } } - boolean tbptt = layerWiseConfigurations.getBackpropType() == BackpropType.TruncatedBPTT; + boolean tbptt = getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; FwdPassType fwdType = (tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD); synchronizeIterEpochCounts(); @@ -3062,8 +3071,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } INDArray inputToOutputLayer = activations.get(activations.size() - 1); - if (layerWiseConfigurations.getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = layerWiseConfigurations.getInputPreProcess(layers.length - 1) + if (getNetConfiguration().getInputPreProcess(layers.length - 1) != null) { + inputToOutputLayer = getNetConfiguration().getInputPreProcess(layers.length - 1) .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); //Validate activations location } @@ -3138,12 +3147,6 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial setParams(params); } - /** - * Intended for internal/developer use - */ - public NeuralNetConfiguration getDefaultConfiguration() { - return defaultConfiguration; - } public INDArray getLabels() { return labels; @@ -3189,7 +3192,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * @return the number of layers in the network */ public int getnLayers() { - return layerWiseConfigurations.getConfs().size(); + return getNetConfiguration().getFlattenedLayerConfigurations().size(); } /** @@ -3210,12 +3213,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return layers[i]; } - public Layer getLayer(String name) { - return layerMap.get(name); + public Layer getLayer(@NotNull String name) { + return Arrays.stream(layers) + .filter(l -> !l.getLayerConfiguration().getLayerName().equals(name)) + .findFirst() + .get(); } public List getLayerNames() { - return new ArrayList<>(layerMap.keySet()); + return Arrays.stream(layers) + .map(l -> l.getLayerConfiguration().getLayerName()) + .collect(Collectors.toList()); } public INDArray getMask() { @@ -3253,7 +3261,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } //========== - //Layer methods + //LayerConfiguration methods @Override public Pair feedForwardMaskArray(INDArray maskArray, @@ -3266,7 +3274,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } else { //Do a forward pass through each preprocessor and layer for (int i = 0; i < layers.length; i++) { - InputPreProcessor preProcessor = getLayerWiseConfigurations().getInputPreProcess(i); + InputPreProcessor preProcessor = getNetConfiguration().getInputPreProcess(i); if (preProcessor != null) { Pair p = @@ -3342,22 +3350,22 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial @Override public int getIterationCount() { - return getLayerWiseConfigurations().getIterationCount(); + return getNetConfiguration().getIterationCount(); } @Override public void setIterationCount(int iterationCount) { - getLayerWiseConfigurations().setIterationCount(iterationCount); + getNetConfiguration().setIterationCount(iterationCount); } @Override public int getEpochCount() { - return getLayerWiseConfigurations().getEpochCount(); + return getNetConfiguration().getEpochCount(); } @Override public void setEpochCount(int epochCount) { - getLayerWiseConfigurations().setEpochCount(epochCount); + getNetConfiguration().setEpochCount(epochCount); } @Override @@ -3407,7 +3415,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial @Override public int getInputMiniBatchSize() { - if (!conf().isMiniBatch()) { + if (!getNetConfiguration().isMiniBatch()) { return 1; } @@ -3498,7 +3506,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying(); } if (!(l instanceof RecurrentLayer)) { - throw new IllegalArgumentException("Layer is not an RNN layer"); + throw new IllegalArgumentException("LayerConfiguration is not an RNN layer"); } return ((RecurrentLayer) l).rnnGetPreviousState(); } @@ -3518,7 +3526,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial l = ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying(); } if (!(l instanceof RecurrentLayer)) { - throw new IllegalArgumentException("Layer is not an RNN layer"); + throw new IllegalArgumentException("LayerConfiguration is not an RNN layer"); } RecurrentLayer r = (RecurrentLayer) l; r.rnnSetPreviousState(state); @@ -3575,7 +3583,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial */ public void setUpdater(Updater updater) { if (solver == null) { - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); } solver.getOptimizer().setUpdater(updater); } @@ -3584,7 +3592,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (solver == null && initializeIfReq) { synchronized (this) { if (solver == null) { //May have been created while waiting for lock - solver = new Solver.Builder().configure(conf()).listeners(getListeners()).model(this) + solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) .build(); solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this)); } @@ -3605,9 +3613,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * [miniBatchSize,timeSeriesLength] and contain values 0 or 1 at each element (to specify whether * a given input/example is present - or merely padding - at a given time step).
* NOTE: This method is not usually used directly. Instead, methods such as - * {@link #feedForward(INDArray, INDArray, INDArray)} - * and {@link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking - * internally. + * {@link #feedForward(INDArray, INDArray, INDArray)} and + * {@link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally. * * @param featuresMaskArray Mask array for features (input) * @param labelsMaskArray Mask array for labels (output) @@ -3723,8 +3730,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial */ public T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(); - if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) { - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), + if (getNetConfiguration().isValidateOutputLayerConfig()) { + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), ROC.class); } return (T) doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0]; @@ -3749,8 +3756,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial public T evaluateROCMultiClass(DataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(); - if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) { - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), + if (getNetConfiguration().isValidateOutputLayerConfig()) { + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), ROCMultiClass.class); } return (T) doEvaluation(iterator, @@ -3780,9 +3787,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial DataSetIterator iter = iterator.asyncSupported() ? new AsyncDataSetIterator(iterator, 2, true) : iterator; - WorkspaceMode cMode = layerWiseConfigurations.getTrainingWorkspaceMode(); - layerWiseConfigurations.setTrainingWorkspaceMode( - layerWiseConfigurations.getInferenceWorkspaceMode()); + WorkspaceMode cMode = getNetConfiguration().getTrainingWorkspaceMode(); + getNetConfiguration().setTrainingWorkspaceMode( + getNetConfiguration().getInferenceWorkspaceMode()); //First: let's determine if we should do 'split feed forward' for long time series //The idea: RNN 20k time steps. Train using TBPTT length 100 -> 200 segments of length 100. If we naively @@ -3790,11 +3797,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial // evaluation in segments... //Only do this if TBPTT is enabled - if not, it means we can train without TBPTT and hence should be able // to test without splitting also - boolean useRnnSegments = (layerWiseConfigurations.getBackpropType() + boolean useRnnSegments = (getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT); MemoryWorkspace outputWs; - if (getLayerWiseConfigurations().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED) { + if (getNetConfiguration().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED) { outputWs = Nd4j.getWorkspaceManager() .getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM); } else { @@ -3830,7 +3837,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial rnnClearPreviousState(); //Get subset of features and labels: - val fwdLen = layerWiseConfigurations.getTbpttFwdLength(); + val fwdLen = getNetConfiguration().getTbpttFwdLength(); val tsLength = features.size(2); long nSubsets = tsLength / fwdLen; if (tsLength % fwdLen != 0) { @@ -3867,7 +3874,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial ((AsyncDataSetIterator) iter).shutdown(); } - layerWiseConfigurations.setTrainingWorkspaceMode(cMode); + getNetConfiguration().setTrainingWorkspaceMode(cMode); return evaluations; } @@ -3973,8 +3980,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } Layer outputLayer = getOutputLayer(); - if (getLayerWiseConfigurations().isValidateOutputLayerConfig()) { - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.conf().getLayer(), + if (getNetConfiguration().isValidateOutputLayerConfig()) { + OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), Evaluation.class); } @@ -4034,7 +4041,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial int frozenParams = 0; for (org.deeplearning4j.nn.api.Layer currentLayer : getLayers()) { - String name = currentLayer.conf().getLayer().getLayerName(); + String name = currentLayer.getLayerConfiguration().getLayerName(); if (name == null) { name = String.valueOf(currentLayer.getIndex()); } @@ -4049,13 +4056,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial InputPreProcessor preProcessor; InputType outType; if (inputType != null) { - preProcessor = getLayerWiseConfigurations().getInputPreProcess(currentLayer.getIndex()); + preProcessor = getNetConfiguration().getInputPreProcess(currentLayer.getIndex()); inShape = inputType.toString(); if (preProcessor != null) { inputType = preProcessor.getOutputType(inputType); inShape += "--> " + inputType.toString(); } - outType = currentLayer.conf().getLayer().getOutputType(currentLayer.getIndex(), inputType); + outType = currentLayer.getLayerConfiguration() + .getOutputType(currentLayer.getIndex(), inputType); outShape = outType.toString(); inputType = outType; } @@ -4063,19 +4071,20 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial paramShape = ""; if (currentLayer instanceof BidirectionalLayer) { // Bidirectional layer is not an FFL BidirectionalLayer bi = (BidirectionalLayer) currentLayer; - in = String.valueOf(((Bidirectional) bi.conf().getLayer()).getNIn()); - out = String.valueOf(((Bidirectional) bi.conf().getLayer()).getNOut()); + in = String.valueOf(((Bidirectional) bi.getLayerConfiguration()).getNIn()); + out = String.valueOf(((Bidirectional) bi.getLayerConfiguration()).getNOut()); } else { try { - in = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNIn()); - out = String.valueOf(((FeedForwardLayer) currentLayer.conf().getLayer()).getNOut()); + in = String.valueOf(((FeedForwardLayer) currentLayer.getLayerConfiguration()).getNIn()); + out = String.valueOf( + ((FeedForwardLayer) currentLayer.getLayerConfiguration()).getNOut()); } catch ( Exception e) { // Some layers, like PReLU, are just BaseLayers (but have parameters) } } - Set paraNames = currentLayer.paramTable().keySet(); + Set paraNames = currentLayer.getParamTable().keySet(); for (String aP : paraNames) { - String paramS = ArrayUtils.toString(currentLayer.paramTable().get(aP).shape()); + String paramS = ArrayUtils.toString(currentLayer.getParamTable().get(aP).shape()); paramShape += aP + ":" + paramS + ", "; } paramShape = paramShape.subSequence(0, paramShape.lastIndexOf(",")).toString(); @@ -4168,7 +4177,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Increment the epoch count (in the underlying {@link MultiLayerConfiguration} by 1). Note that + * Increment the epoch count (in the underlying {@link NeuralNetConfiguration} by 1). Note that * this is done automatically when using iterator-based fitting methods, such as * {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, * INDArray/INDArray etc), the network has no way to know when one epoch ends and another starts. @@ -4176,10 +4185,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * epoch counter is used for situations such as some learning rate schedules, and the like. *

* The current epoch count can be obtained using - * {@code MultiLayerConfiguration.getLayerwiseConfiguration().getEpochCount()} + * {@code NeuralNetConfiguration.getLayerwiseConfiguration().getEpochCount()} */ public void incrementEpochCount() { - layerWiseConfigurations.setEpochCount(layerWiseConfigurations.getEpochCount() + 1); + getNetConfiguration().setEpochCount(getNetConfiguration().getEpochCount() + 1); synchronizeIterEpochCounts(); } @@ -4246,8 +4255,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { INDArray newParams = params().castTo(dataType); - String jsonConfig = getLayerWiseConfigurations().toJson(); - MultiLayerConfiguration newConf = MultiLayerConfiguration.fromJson(jsonConfig); + String jsonConfig = getNetConfiguration().toJson(); + NeuralNetConfiguration newConf = NeuralNetConfiguration.fromJson(jsonConfig); newConf.setDataType(dataType); MultiLayerNetwork newNet = new MultiLayerNetwork(newConf); newNet.init(newParams, false); @@ -4267,8 +4276,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * (fixed) learning rate.
*
* Note: This method not free from a performance point of view: a proper learning - * rate schedule - * should be used in preference to calling this method at every iteration. + * rate schedule should be used in preference to calling this method at every iteration. * * @param newLr New learning rate for all layers * @see #setLearningRate(ISchedule) @@ -4282,8 +4290,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Set the learning rate schedule for all layers in the network to the specified schedule. This * schedule will replace any/all existing schedules, and also any fixed learning rate values.
* Note that the iteration/epoch counts will not be reset. Use - * {@link MultiLayerConfiguration#setIterationCount(int)} and - * {@link MultiLayerConfiguration#setEpochCount(int)} if this is required + * {@link NeuralNetConfiguration#setIterationCount(int)} and + * {@link NeuralNetConfiguration#setEpochCount(int)} if this is required * * @param newLr New learning rate schedule for all layers * @see #setLearningRate(ISchedule) @@ -4299,10 +4307,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * (fixed) learning rate.
*
* Note: This method not free from a performance point of view: a proper learning - * rate schedule - * should be used in preference to calling this method at every iteration. Note also that - * {@link #setLearningRate(double)} should also be used in preference, when all layers need to be - * set to a new LR + * rate schedule should be used in preference to calling this method at every iteration. Note also + * that {@link #setLearningRate(double)} should also be used in preference, when all layers need + * to be set to a new LR * * @param layerNumber Number of the layer to set the LR for * @param newLr New learning rate for a single layer @@ -4318,8 +4325,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Note also that {@link #setLearningRate(ISchedule)} should also be used in preference, when all * layers need to be set to a new LR schedule.
This schedule will replace any/all existing * schedules, and also any fixed learning rate values.
Note also that the iteration/epoch - * counts will not be reset. Use {@link MultiLayerConfiguration#setIterationCount(int)} and - * {@link MultiLayerConfiguration#setEpochCount(int)} if this is required + * counts will not be reset. Use {@link NeuralNetConfiguration#setIterationCount(int)} and + * {@link NeuralNetConfiguration#setEpochCount(int)} if this is required * * @param layerNumber Number of the layer to set the LR schedule for * @param newLr New learning rate for a single layer @@ -4335,7 +4342,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * has no learning rate (no parameters, or an updater without a learning rate) then null is * returned * - * @param layerNumber Layer number to get the learning rate for + * @param layerNumber LayerConfiguration number to get the learning rate for * @return Learning rate for the specified layer, or null */ public Double getLearningRate(int layerNumber) { @@ -4355,10 +4362,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial public int layerSize(int layer) { if (layer < 0 || layer > layers.length) { throw new IllegalArgumentException( - "Invalid layer index: " + layer + ". Layer index must be between 0 and " + "Invalid layer index: " + layer + ". LayerConfiguration index must be between 0 and " + (layers.length - 1) + " inclusive"); } - org.deeplearning4j.nn.conf.layers.Layer conf = layers[layer].conf().getLayer(); + LayerConfiguration conf = layers[layer].getLayerConfiguration(); if (conf == null || !(conf instanceof FeedForwardLayer)) { return 0; } @@ -4384,10 +4391,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial public int layerInputSize(int layer) { if (layer < 0 || layer > layers.length) { throw new IllegalArgumentException( - "Invalid layer index: " + layer + ". Layer index must be between 0 and " + "Invalid layer index: " + layer + ". LayerConfiguration index must be between 0 and " + (layers.length - 1) + " inclusive"); } - org.deeplearning4j.nn.conf.layers.Layer conf = layers[layer].conf().getLayer(); + LayerConfiguration conf = layers[layer].getLayerConfiguration(); if (conf == null || !(conf instanceof FeedForwardLayer)) { return 0; } @@ -4451,8 +4458,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (obj instanceof MultiLayerNetwork) { MultiLayerNetwork network = (MultiLayerNetwork) obj; boolean paramsEquals = network.params().equals(params()); - boolean confEquals = getLayerWiseConfigurations().equals( - network.getLayerWiseConfigurations()); + boolean confEquals = getNetConfiguration().equals( + network.getNetConfiguration()); boolean updaterEquals = getUpdater().equals(network.getUpdater()); return paramsEquals && confEquals && updaterEquals; } @@ -4466,15 +4473,15 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { val mln = ModelSerializer.restoreMultiLayerNetwork(ois, true); - this.defaultConfiguration = mln.defaultConfiguration.clone(); - this.layerWiseConfigurations = mln.layerWiseConfigurations.clone(); + this.setNetConfiguration( mln.getNetConfiguration().clone() ); this.init(); this.flattenedParams.assign(mln.flattenedParams); - int numWorkingMem = 2 * (layerWiseConfigurations.getConfs().size() - + layerWiseConfigurations.getInputPreProcessors().size()); + int numWorkingMem = 2 * (getNetConfiguration().getFlattenedLayerConfigurations().size() + + getNetConfiguration().getInputPreProcessors().size()); WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); - WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig(layerWiseConfigurations.getConfs().size()); + WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig( + getNetConfiguration().getFlattenedLayerConfigurations().size()); if (mln.getUpdater() != null && mln.getUpdater(false).getStateViewArray() != null) { this.getUpdater(true).getStateViewArray().assign(mln.getUpdater(false).getStateViewArray()); @@ -4508,4 +4515,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); System.gc(); } + + /** + * Returns a string representation of the underlying configuration. + * + * @return a string representation of the configuration. + */ + @Override + public String toString() { + return getNetConfiguration().toString(); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java index 5215e2276..c68403835 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BatchNormalizationParamInitializer.java @@ -21,16 +21,15 @@ package org.deeplearning4j.nn.params; import lombok.val; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; -public class BatchNormalizationParamInitializer implements ParamInitializer { +public class BatchNormalizationParamInitializer extends AbstractParamInitializer { private static final BatchNormalizationParamInitializer INSTANCE = new BatchNormalizationParamInitializer(); @@ -45,12 +44,7 @@ public class BatchNormalizationParamInitializer implements ParamInitializer { public static final String GLOBAL_LOG_STD = "log10stdev"; @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { BatchNormalization layer = (BatchNormalization) l; //Parameters in batch norm: //gamma, beta, global mean estimate, global variance estimate @@ -66,7 +60,7 @@ public class BatchNormalizationParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { if(((BatchNormalization)layer).isUseLogStd()){ return Arrays.asList(GAMMA, BETA, GLOBAL_MEAN, GLOBAL_LOG_STD); } else { @@ -75,30 +69,30 @@ public class BatchNormalizationParamInitializer implements ParamInitializer { } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return false; } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return false; } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramView, boolean initializeParams) { Map params = Collections.synchronizedMap(new LinkedHashMap()); // TODO setup for RNN - BatchNormalization layer = (BatchNormalization) conf.getLayer(); + BatchNormalization layer = (BatchNormalization) conf; val nOut = layer.getNOut(); long meanOffset = 0; @@ -107,9 +101,9 @@ public class BatchNormalizationParamInitializer implements ParamInitializer { INDArray betaView = paramView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, 2 * nOut)); params.put(GAMMA, createGamma(conf, gammaView, initializeParams)); - conf.addVariable(GAMMA); + conf.getNetConfiguration().addNetWideVariable(GAMMA); params.put(BETA, createBeta(conf, betaView, initializeParams)); - conf.addVariable(BETA); + conf.getNetConfiguration().addNetWideVariable(BETA); meanOffset = 2 * nOut; } @@ -131,21 +125,21 @@ public class BatchNormalizationParamInitializer implements ParamInitializer { } params.put(GLOBAL_MEAN, globalMeanView); - conf.addVariable(GLOBAL_MEAN); + conf.getNetConfiguration().addNetWideVariable(GLOBAL_MEAN); if(layer.isUseLogStd()){ params.put(GLOBAL_LOG_STD, globalVarView); - conf.addVariable(GLOBAL_LOG_STD); + conf.getNetConfiguration().addNetWideVariable(GLOBAL_LOG_STD); } else { params.put(GLOBAL_VAR, globalVarView); - conf.addVariable(GLOBAL_VAR); + conf.getNetConfiguration().addNetWideVariable(GLOBAL_VAR); } return params; } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { - BatchNormalization layer = (BatchNormalization) conf.getLayer(); + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { + BatchNormalization layer = (BatchNormalization) conf; val nOut = layer.getNOut(); Map out = new LinkedHashMap<>(); @@ -171,15 +165,15 @@ public class BatchNormalizationParamInitializer implements ParamInitializer { return out; } - private INDArray createBeta(NeuralNetConfiguration conf, INDArray betaView, boolean initializeParams) { - BatchNormalization layer = (BatchNormalization) conf.getLayer(); + private INDArray createBeta(LayerConfiguration conf, INDArray betaView, boolean initializeParams) { + BatchNormalization layer = (BatchNormalization) conf; if (initializeParams) betaView.assign(layer.getBeta()); return betaView; } - private INDArray createGamma(NeuralNetConfiguration conf, INDArray gammaView, boolean initializeParams) { - BatchNormalization layer = (BatchNormalization) conf.getLayer(); + private INDArray createGamma(LayerConfiguration conf, INDArray gammaView, boolean initializeParams) { + BatchNormalization layer = (BatchNormalization) conf; if (initializeParams) gammaView.assign(layer.getGamma()); return gammaView; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BidirectionalParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BidirectionalParamInitializer.java index a75128790..27905d60f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BidirectionalParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/BidirectionalParamInitializer.java @@ -21,12 +21,10 @@ package org.deeplearning4j.nn.params; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; -import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.nd4j.linalg.api.ndarray.INDArray; @@ -36,14 +34,13 @@ import java.util.List; import java.util.Map; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; -import static org.nd4j.linalg.indexing.NDArrayIndex.point; -public class BidirectionalParamInitializer implements ParamInitializer { +public class BidirectionalParamInitializer extends AbstractParamInitializer { public static final String FORWARD_PREFIX = "f"; public static final String BACKWARD_PREFIX = "b"; private final Bidirectional layer; - private final Layer underlying; + private final LayerConfiguration underlying; private List paramKeys; private List weightKeys; @@ -55,19 +52,14 @@ public class BidirectionalParamInitializer implements ParamInitializer { } @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer layer) { + public long numParams(LayerConfiguration layer) { return 2 * underlying(layer).initializer().numParams(underlying(layer)); } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { if(paramKeys == null) { - Layer u = underlying(layer); + LayerConfiguration u = underlying(layer); List orig = u.initializer().paramKeys(u); paramKeys = withPrefixes(orig); } @@ -75,9 +67,9 @@ public class BidirectionalParamInitializer implements ParamInitializer { } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { if(weightKeys == null) { - Layer u = underlying(layer); + LayerConfiguration u = underlying(layer); List orig = u.initializer().weightKeys(u); weightKeys = withPrefixes(orig); } @@ -85,9 +77,9 @@ public class BidirectionalParamInitializer implements ParamInitializer { } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { if(biasKeys == null) { - Layer u = underlying(layer); + LayerConfiguration u = underlying(layer); List orig = u.initializer().weightKeys(u); biasKeys = withPrefixes(orig); } @@ -95,27 +87,27 @@ public class BidirectionalParamInitializer implements ParamInitializer { } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return weightKeys(this.layer).contains(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return biasKeys(this.layer).contains(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { val n = paramsView.length()/2; INDArray forwardView = paramsView.get(interval(0,0,true), interval(0, n)); INDArray backwardView = paramsView.get(interval(0,0,true), interval(n, 2*n)); conf.clearVariables(); - NeuralNetConfiguration c1 = conf.clone(); - NeuralNetConfiguration c2 = conf.clone(); - c1.setLayer(underlying); - c2.setLayer(underlying); + LayerConfiguration c1 = conf.clone(); + LayerConfiguration c2 = conf.clone(); + //c1.setLayer(underlying); + //c2.setLayer(underlying); Map origFwd = underlying.initializer().init(c1, forwardView, initializeParams); Map origBwd = underlying.initializer().init(c2, backwardView, initializeParams); List variables = addPrefixes(c1.getVariables(), c2.getVariables()); @@ -156,7 +148,7 @@ public class BidirectionalParamInitializer implements ParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { val n = gradientView.length()/2; INDArray forwardView = gradientView.get(interval(0,0,true), interval(0, n)); INDArray backwardView = gradientView.get(interval(0,0,true), interval(n, 2*n)); @@ -175,7 +167,7 @@ public class BidirectionalParamInitializer implements ParamInitializer { return out; } - private Layer underlying(Layer layer){ + private LayerConfiguration underlying(LayerConfiguration layer){ Bidirectional b = (Bidirectional)layer; return b.getFwd(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/CenterLossParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/CenterLossParamInitializer.java index 65df1fea5..8a02c397e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/CenterLossParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/CenterLossParamInitializer.java @@ -22,7 +22,9 @@ package org.deeplearning4j.nn.params; import lombok.val; +import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -43,20 +45,20 @@ public class CenterLossParamInitializer extends DefaultParamInitializer { public final static String CENTER_KEY = "cL"; @Override - public long numParams(NeuralNetConfiguration conf) { + public long numParams(LayerConfiguration conf) { org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); // also equal to numClasses return nIn * nOut + nOut + nIn * nOut; //weights + bias + embeddings } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { Map params = Collections.synchronizedMap(new LinkedHashMap()); org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer layerConf = - (org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) conf; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); // also equal to numClasses @@ -81,9 +83,9 @@ public class CenterLossParamInitializer extends DefaultParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer layerConf = - (org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) conf; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); // also equal to numClasses @@ -107,10 +109,10 @@ public class CenterLossParamInitializer extends DefaultParamInitializer { } - protected INDArray createCenterLossMatrix(NeuralNetConfiguration conf, INDArray centerLossView, + protected INDArray createCenterLossMatrix(LayerConfiguration conf, INDArray centerLossView, boolean initializeParameters) { org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer layerConf = - (org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.CenterLossOutputLayer) conf; if (initializeParameters) { centerLossView.assign(0.0); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java index 8ebabb433..745e77a69 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java @@ -24,7 +24,7 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Convolution3D; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -44,13 +44,9 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer { public final static String WEIGHT_KEY = DefaultParamInitializer.WEIGHT_KEY; public final static String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; - @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { Convolution3D layerConf = (Convolution3D) l; @@ -62,13 +58,13 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer { @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - Convolution3D layer = (Convolution3D) conf.getLayer(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + Convolution3D layer = (Convolution3D) conf; if (layer.getKernelSize().length != 3) throw new IllegalArgumentException("Filter size must be == 3"); Map params = Collections.synchronizedMap(new LinkedHashMap()); - Convolution3D layerConf = (Convolution3D) conf.getLayer(); + Convolution3D layerConf = (Convolution3D) conf; val nOut = layerConf.getNOut(); if (layer.hasBias()) { @@ -88,9 +84,9 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { - Convolution3D layerConf = (Convolution3D) conf.getLayer(); + Convolution3D layerConf = (Convolution3D) conf; int[] kernel = layerConf.getKernelSize(); val nIn = layerConf.getNIn(); @@ -112,7 +108,7 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer { } - protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) { + protected INDArray createWeightMatrix(LayerConfiguration conf, INDArray weightView, boolean initializeParams) { /* Create a 5d weight matrix of: (number of kernels, num input channels, kernel depth, kernel height, kernel width) @@ -120,7 +116,7 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer { Inputs to the convolution layer are: (batch size, num input feature maps, image depth, image height, image width) */ - Convolution3D layerConf = (Convolution3D) conf.getLayer(); + Convolution3D layerConf = (Convolution3D) conf; if (initializeParams) { int[] kernel = layerConf.getKernelSize(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java index 4618e2c3e..a8b3ce7aa 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java @@ -22,17 +22,16 @@ package org.deeplearning4j.nn.params; import lombok.val; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; -public class ConvolutionParamInitializer implements ParamInitializer { +public class ConvolutionParamInitializer extends AbstractParamInitializer { private static final ConvolutionParamInitializer INSTANCE = new ConvolutionParamInitializer(); @@ -45,12 +44,7 @@ public class ConvolutionParamInitializer implements ParamInitializer { public final static String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { org.deeplearning4j.nn.conf.layers.ConvolutionLayer layerConf = (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) l; @@ -61,7 +55,7 @@ public class ConvolutionParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { org.deeplearning4j.nn.conf.layers.ConvolutionLayer layerConf = (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) layer; if(layerConf.hasBias()){ @@ -72,12 +66,12 @@ public class ConvolutionParamInitializer implements ParamInitializer { } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return Collections.singletonList(WEIGHT_KEY); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { org.deeplearning4j.nn.conf.layers.ConvolutionLayer layerConf = (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) layer; if(layerConf.hasBias()){ @@ -88,24 +82,24 @@ public class ConvolutionParamInitializer implements ParamInitializer { } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return WEIGHT_KEY.equals(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return BIAS_KEY.equals(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - ConvolutionLayer layer = (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf.getLayer(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + ConvolutionLayer layer = (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf; if (layer.getKernelSize().length != 2) throw new IllegalArgumentException("Filter size must be == 2"); Map params = Collections.synchronizedMap(new LinkedHashMap()); org.deeplearning4j.nn.conf.layers.ConvolutionLayer layerConf = - (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf; val nOut = layerConf.getNOut(); @@ -115,23 +109,23 @@ public class ConvolutionParamInitializer implements ParamInitializer { INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nOut, numParams(conf))); params.put(BIAS_KEY, createBias(conf, biasView, initializeParams)); params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); - conf.addVariable(WEIGHT_KEY); - conf.addVariable(BIAS_KEY); - conf.addVariable(BIAS_KEY); + conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); + conf.getNetConfiguration().addNetWideVariable(BIAS_KEY); + conf.getNetConfiguration().addNetWideVariable(BIAS_KEY); } else { INDArray weightView = paramsView; params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); - conf.addVariable(WEIGHT_KEY); + conf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); } return params; } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { org.deeplearning4j.nn.conf.layers.ConvolutionLayer layerConf = - (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf; int[] kernel = layerConf.getKernelSize(); val nIn = layerConf.getNIn(); @@ -154,17 +148,17 @@ public class ConvolutionParamInitializer implements ParamInitializer { } //1 bias per feature map - protected INDArray createBias(NeuralNetConfiguration conf, INDArray biasView, boolean initializeParams) { + protected INDArray createBias(LayerConfiguration conf, INDArray biasView, boolean initializeParams) { //the bias is a 1D tensor -- one bias per output feature map org.deeplearning4j.nn.conf.layers.ConvolutionLayer layerConf = - (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf; if (initializeParams) biasView.assign(layerConf.getBiasInit()); return biasView; } - protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) { + protected INDArray createWeightMatrix(LayerConfiguration conf, INDArray weightView, boolean initializeParams) { /* Create a 4d weight matrix of: (number of kernels, num input channels, kernel height, kernel width) @@ -173,7 +167,7 @@ public class ConvolutionParamInitializer implements ParamInitializer { (batch size, num input feature maps, image height, image width) */ org.deeplearning4j.nn.conf.layers.ConvolutionLayer layerConf = - (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf; if (initializeParams) { int[] kernel = layerConf.getKernelSize(); int[] stride = layerConf.getStride(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Deconvolution3DParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Deconvolution3DParamInitializer.java index 8169fec5f..7f8b8e9e6 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Deconvolution3DParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Deconvolution3DParamInitializer.java @@ -24,7 +24,7 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Deconvolution3D; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -45,12 +45,7 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer public final static String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { Deconvolution3D layerConf = (Deconvolution3D) l; int[] kernel = layerConf.getKernelSize(); @@ -61,13 +56,13 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - Deconvolution3D layer = (Deconvolution3D) conf.getLayer(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + Deconvolution3D layer = (Deconvolution3D) conf; if (layer.getKernelSize().length != 3) throw new IllegalArgumentException("Filter size must be == 3"); Map params = Collections.synchronizedMap(new LinkedHashMap()); - Deconvolution3D layerConf = (Deconvolution3D) conf.getLayer(); + Deconvolution3D layerConf = (Deconvolution3D) conf; val nOut = layerConf.getNOut(); if (layer.hasBias()) { @@ -87,9 +82,9 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { - Deconvolution3D layerConf = (Deconvolution3D) conf.getLayer(); + Deconvolution3D layerConf = (Deconvolution3D) conf; int[] kernel = layerConf.getKernelSize(); val nIn = layerConf.getNIn(); @@ -111,7 +106,7 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer } - protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) { + protected INDArray createWeightMatrix(LayerConfiguration conf, INDArray weightView, boolean initializeParams) { /* Create a 5d weight matrix of: (number of kernels, num input channels, kernel depth, kernel height, kernel width) @@ -119,7 +114,7 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer Inputs to the convolution layer are: (batch size, num input feature maps, image depth, image height, image width) */ - Deconvolution3D layerConf = (Deconvolution3D) conf.getLayer(); + Deconvolution3D layerConf = (Deconvolution3D) conf; if (initializeParams) { int[] kernel = layerConf.getKernelSize(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java index a39a1f454..1c7ac91d9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java @@ -22,6 +22,7 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -38,7 +39,7 @@ public class DeconvolutionParamInitializer extends ConvolutionParamInitializer { } @Override - protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) { + protected INDArray createWeightMatrix(LayerConfiguration conf, INDArray weightView, boolean initializeParams) { /* Create a 4d weight matrix of: (number of kernels, num input channels, kernel height, kernel width) @@ -47,7 +48,7 @@ public class DeconvolutionParamInitializer extends ConvolutionParamInitializer { (batch size, num input feature maps, image height, image width) */ org.deeplearning4j.nn.conf.layers.Deconvolution2D layerConf = - (org.deeplearning4j.nn.conf.layers.Deconvolution2D) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.Deconvolution2D) conf; if (initializeParams) { int[] kernel = layerConf.getKernelSize(); int[] stride = layerConf.getStride(); @@ -76,10 +77,10 @@ public class DeconvolutionParamInitializer extends ConvolutionParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { org.deeplearning4j.nn.conf.layers.Deconvolution2D layerConf = - (org.deeplearning4j.nn.conf.layers.Deconvolution2D) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.Deconvolution2D) conf; int[] kernel = layerConf.getKernelSize(); val nIn = layerConf.getNIn(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java index b41f05b4e..c20562223 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java @@ -20,18 +20,20 @@ package org.deeplearning4j.nn.params; +import lombok.NonNull; +import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; +import org.deeplearning4j.nn.weights.WeightInitXavier; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; - -public class DefaultParamInitializer implements ParamInitializer { +@Slf4j +public class DefaultParamInitializer extends AbstractParamInitializer { private static final DefaultParamInitializer INSTANCE = new DefaultParamInitializer(); @@ -44,12 +46,7 @@ public class DefaultParamInitializer implements ParamInitializer { public final static String GAIN_KEY = "g"; @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { FeedForwardLayer layerConf = (FeedForwardLayer) l; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); @@ -57,7 +54,7 @@ public class DefaultParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { final ArrayList keys = new ArrayList<>(3); keys.addAll(weightKeys(layer)); keys.addAll(biasKeys(layer)); @@ -65,7 +62,7 @@ public class DefaultParamInitializer implements ParamInitializer { } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { if(hasLayerNorm(layer)){ return Arrays.asList(WEIGHT_KEY, GAIN_KEY); } @@ -73,7 +70,7 @@ public class DefaultParamInitializer implements ParamInitializer { } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { if(hasBias(layer)){ return Collections.singletonList(BIAS_KEY); } else { @@ -83,19 +80,19 @@ public class DefaultParamInitializer implements ParamInitializer { @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return WEIGHT_KEY.equals(key) || (hasLayerNorm(layer) && GAIN_KEY.equals(key)); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return BIAS_KEY.equals(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - if (!(conf.getLayer() instanceof org.deeplearning4j.nn.conf.layers.FeedForwardLayer)) - throw new IllegalArgumentException("unsupported layer type: " + conf.getLayer().getClass().getName()); + public Map init(@NonNull LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + if (!(conf instanceof org.deeplearning4j.nn.conf.layers.FeedForwardLayer)) + throw new IllegalArgumentException("unsupported layer type: " + conf.getClass().getName()); Map params = Collections.synchronizedMap(new LinkedHashMap()); @@ -105,22 +102,22 @@ public class DefaultParamInitializer implements ParamInitializer { "Expected params view of length " + length + ", got length " + paramsView.length()); org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); val nWeightParams = nIn * nOut; INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)); - params.put(WEIGHT_KEY, createWeightMatrix(conf, weightView, initializeParams)); - conf.addVariable(WEIGHT_KEY); + params.put(WEIGHT_KEY, createWeightMatrix(layerConf, weightView, initializeParams)); + layerConf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); long offset = nWeightParams; if(hasBias(layerConf)){ INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(offset, offset + nOut)); - params.put(BIAS_KEY, createBias(conf, biasView, initializeParams)); - conf.addVariable(BIAS_KEY); + params.put(BIAS_KEY, createBias(layerConf, biasView, initializeParams)); + layerConf.getNetConfiguration().addNetWideVariable(BIAS_KEY); offset += nOut; } @@ -128,16 +125,16 @@ public class DefaultParamInitializer implements ParamInitializer { INDArray gainView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(offset, offset + nOut)); params.put(GAIN_KEY, createGain(conf, gainView, initializeParams)); - conf.addVariable(GAIN_KEY); + conf.getNetConfiguration().addNetWideVariable(GAIN_KEY); } return params; } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); val nWeightParams = nIn * nOut; @@ -166,9 +163,9 @@ public class DefaultParamInitializer implements ParamInitializer { } - protected INDArray createBias(NeuralNetConfiguration conf, INDArray biasParamView, boolean initializeParameters) { + protected INDArray createBias(LayerConfiguration conf, INDArray biasParamView, boolean initializeParameters) { org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; return createBias(layerConf.getNOut(), layerConf.getBiasInit(), biasParamView, initializeParameters); } @@ -179,9 +176,9 @@ public class DefaultParamInitializer implements ParamInitializer { return biasParamView; } - protected INDArray createGain(NeuralNetConfiguration conf, INDArray gainParamView, boolean initializeParameters) { + protected INDArray createGain(LayerConfiguration conf, INDArray gainParamView, boolean initializeParameters) { org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; return createGain(layerConf.getNOut(), layerConf.getGainInit(), gainParamView, initializeParameters); } @@ -193,12 +190,18 @@ public class DefaultParamInitializer implements ParamInitializer { } - protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightParamView, + protected INDArray createWeightMatrix(LayerConfiguration conf, INDArray weightParamView, boolean initializeParameters) { org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; if (initializeParameters) { + if( layerConf.getWeightInitFn() == null) { + // set a default and set warning + layerConf.setWeightInitFn(new WeightInitXavier()); + log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getLayerName(), + conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName()); + } return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInitFn(), weightParamView, true); } else { @@ -206,7 +209,8 @@ public class DefaultParamInitializer implements ParamInitializer { } } - protected INDArray createWeightMatrix(long nIn, long nOut, IWeightInit weightInit, + protected INDArray createWeightMatrix(long nIn, long nOut, + @NonNull IWeightInit weightInit, INDArray weightParamView, boolean initializeParameters) { val shape = new long[] {nIn, nOut}; @@ -220,7 +224,7 @@ public class DefaultParamInitializer implements ParamInitializer { } } - protected boolean hasBias(Layer layer){ + protected boolean hasBias(LayerConfiguration layer){ if(layer instanceof BaseOutputLayer ) { return ((BaseOutputLayer) layer).hasBias(); } else if(layer instanceof DenseLayer){ @@ -233,7 +237,7 @@ public class DefaultParamInitializer implements ParamInitializer { return true; } - protected boolean hasLayerNorm(Layer layer){ + protected boolean hasLayerNorm(LayerConfiguration layer){ if(layer instanceof DenseLayer){ return ((DenseLayer) layer).hasLayerNorm(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java index b9f682818..72f2ac6ba 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java @@ -22,17 +22,18 @@ package org.deeplearning4j.nn.params; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; -public class DepthwiseConvolutionParamInitializer implements ParamInitializer { +public class DepthwiseConvolutionParamInitializer extends AbstractParamInitializer { private static final DepthwiseConvolutionParamInitializer INSTANCE = new DepthwiseConvolutionParamInitializer(); @@ -44,12 +45,7 @@ public class DepthwiseConvolutionParamInitializer implements ParamInitializer { public final static String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) l; val depthWiseParams = numDepthWiseParams(layerConf); @@ -79,7 +75,7 @@ public class DepthwiseConvolutionParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) layer; if(layerConf.hasBias()){ @@ -90,12 +86,12 @@ public class DepthwiseConvolutionParamInitializer implements ParamInitializer { } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return Collections.singletonList(WEIGHT_KEY); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) layer; if(layerConf.hasBias()){ @@ -106,23 +102,23 @@ public class DepthwiseConvolutionParamInitializer implements ParamInitializer { } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return WEIGHT_KEY.equals(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return BIAS_KEY.equals(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - DepthwiseConvolution2D layer = (DepthwiseConvolution2D) conf.getLayer(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + DepthwiseConvolution2D layer = (DepthwiseConvolution2D) conf; if (layer.getKernelSize().length != 2) throw new IllegalArgumentException("Filter size must be == 2"); Map params = Collections.synchronizedMap(new LinkedHashMap()); - DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) conf.getLayer(); + DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) conf; val depthWiseParams = numDepthWiseParams(layerConf); val biasParams = numBiasParams(layerConf); @@ -143,9 +139,9 @@ public class DepthwiseConvolutionParamInitializer implements ParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { - DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) conf.getLayer(); + DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) conf; int[] kernel = layerConf.getKernelSize(); val nIn = layerConf.getNIn(); @@ -169,21 +165,21 @@ public class DepthwiseConvolutionParamInitializer implements ParamInitializer { return out; } - protected INDArray createBias(NeuralNetConfiguration conf, INDArray biasView, boolean initializeParams) { - DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) conf.getLayer(); + protected INDArray createBias(LayerConfiguration conf, INDArray biasView, boolean initializeParams) { + DepthwiseConvolution2D layerConf = (DepthwiseConvolution2D) conf; if (initializeParams) biasView.assign(layerConf.getBiasInit()); return biasView; } - protected INDArray createDepthWiseWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) { + protected INDArray createDepthWiseWeightMatrix(LayerConfiguration conf, INDArray weightView, boolean initializeParams) { /* Create a 4d weight matrix of: (channels multiplier, num input channels, kernel height, kernel width) Inputs to the convolution layer are: (batch size, num input feature maps, image height, image width) */ DepthwiseConvolution2D layerConf = - (DepthwiseConvolution2D) conf.getLayer(); + (DepthwiseConvolution2D) conf; int depthMultiplier = layerConf.getDepthMultiplier(); if (initializeParams) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ElementWiseParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ElementWiseParamInitializer.java index 7245d6dab..665a47d7f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ElementWiseParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ElementWiseParamInitializer.java @@ -23,7 +23,7 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -41,7 +41,7 @@ public class ElementWiseParamInitializer extends DefaultParamInitializer{ } @Override - public long numParams(Layer layer) { + public long numParams(LayerConfiguration layer) { FeedForwardLayer layerConf = (FeedForwardLayer) layer; val nIn = layerConf.getNIn(); return nIn*2; //weights + bias @@ -57,9 +57,9 @@ public class ElementWiseParamInitializer extends DefaultParamInitializer{ * @return Map of parameters keyed by type (view of the 'paramsView' array) */ @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - if (!(conf.getLayer() instanceof org.deeplearning4j.nn.conf.layers.FeedForwardLayer)) - throw new IllegalArgumentException("unsupported layer type: " + conf.getLayer().getClass().getName()); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + if (!(conf instanceof org.deeplearning4j.nn.conf.layers.FeedForwardLayer)) + throw new IllegalArgumentException("unsupported layer type: " + conf.getClass().getName()); Map params = Collections.synchronizedMap(new LinkedHashMap()); @@ -69,7 +69,7 @@ public class ElementWiseParamInitializer extends DefaultParamInitializer{ "Expected params view of length " + length + ", got length " + paramsView.length()); org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; val nIn = layerConf.getNIn(); val nWeightParams = nIn ; @@ -96,9 +96,9 @@ public class ElementWiseParamInitializer extends DefaultParamInitializer{ * @return A map containing an array by parameter type, that is a view of the full network gradients array */ @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); val nWeightParams = nIn ; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/EmptyParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/EmptyParamInitializer.java index 7ec9ea885..28d458e78 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/EmptyParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/EmptyParamInitializer.java @@ -20,9 +20,10 @@ package org.deeplearning4j.nn.params; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collections; @@ -32,7 +33,7 @@ import java.util.Map; /** * @author Adam Gibson */ -public class EmptyParamInitializer implements ParamInitializer { +public class EmptyParamInitializer extends AbstractParamInitializer { private static final EmptyParamInitializer INSTANCE = new EmptyParamInitializer(); @@ -41,47 +42,42 @@ public class EmptyParamInitializer implements ParamInitializer { } @Override - public long numParams(NeuralNetConfiguration conf) { + public long numParams(LayerConfiguration layer) { return 0; } @Override - public long numParams(Layer layer) { - return 0; - } - - @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return false; } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return false; } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { return Collections.EMPTY_MAP; } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { return Collections.emptyMap(); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/FrozenLayerParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/FrozenLayerParamInitializer.java index 71bff7702..580d07402 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/FrozenLayerParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/FrozenLayerParamInitializer.java @@ -20,81 +20,75 @@ package org.deeplearning4j.nn.params; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; -import org.nd4j.linalg.api.ndarray.INDArray; - import java.util.Collections; import java.util.List; import java.util.Map; +import org.deeplearning4j.nn.api.AbstractParamInitializer; +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; +import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; +import org.nd4j.linalg.api.ndarray.INDArray; -public class FrozenLayerParamInitializer implements ParamInitializer { +public class FrozenLayerParamInitializer extends AbstractParamInitializer { - private static final FrozenLayerParamInitializer INSTANCE = new FrozenLayerParamInitializer(); + private static final FrozenLayerParamInitializer INSTANCE = new FrozenLayerParamInitializer(); - public static FrozenLayerParamInitializer getInstance() { - return INSTANCE; - } + public static FrozenLayerParamInitializer getInstance() { + return INSTANCE; + } - @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } + @Override + public long numParams(LayerConfiguration layer) { + FrozenLayer fl = (FrozenLayer) layer; + ParamInitializer initializer = fl.getInnerConfiguration().initializer(); + return initializer.numParams(fl.getInnerConfiguration()); + } - @Override - public long numParams(Layer layer) { - FrozenLayer fl = (FrozenLayer) layer; - ParamInitializer initializer = fl.getLayer().initializer(); - return initializer.numParams(fl.getLayer()); - } + @Override + public List paramKeys(LayerConfiguration layer) { + return Collections.emptyList(); + } - @Override - public List paramKeys(Layer layer) { - return Collections.emptyList(); - } + @Override + public List weightKeys(LayerConfiguration layer) { + return Collections.emptyList(); + } - @Override - public List weightKeys(Layer layer) { - return Collections.emptyList(); - } + @Override + public List biasKeys(LayerConfiguration layer) { + return Collections.emptyList(); + } - @Override - public List biasKeys(Layer layer) { - return Collections.emptyList(); - } + @Override + public boolean isWeightParam(LayerConfiguration layer, String key) { + return false; + } - @Override - public boolean isWeightParam(Layer layer, String key) { - return false; - } + @Override + public boolean isBiasParam(LayerConfiguration layer, String key) { + return false; + } - @Override - public boolean isBiasParam(Layer layer, String key) { - return false; - } + @Override + public Map init(LayerConfiguration conf, INDArray paramsView, + boolean initializeParams) { + FrozenLayer fl_conf = (FrozenLayer) conf; + LayerConfiguration innerLayer = fl_conf.getInnerConfiguration(); + ParamInitializer initializer = innerLayer.initializer(); + fl_conf.setInnerConfiguration(innerLayer); + Map m = initializer.init(conf, paramsView, initializeParams); + return m; + } - @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - FrozenLayer fl = (FrozenLayer) conf.getLayer(); - Layer innerLayer = fl.getLayer(); - ParamInitializer initializer = innerLayer.initializer(); - conf.setLayer(innerLayer); - Map m = initializer.init(conf, paramsView, initializeParams); - conf.setLayer(fl); - - return m; - } - - @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { - FrozenLayer fl = (FrozenLayer) conf.getLayer(); - Layer innerLayer = fl.getLayer(); - ParamInitializer initializer = innerLayer.initializer(); - conf.setLayer(innerLayer); - Map m = initializer.getGradientsFromFlattened(conf, gradientView); - conf.setLayer(fl); - return m; - } + @Override + public Map getGradientsFromFlattened(LayerConfiguration conf, + INDArray gradientView) { + FrozenLayer fl = (FrozenLayer) conf; + LayerConfiguration innerLayer = fl.getInnerConfiguration(); + ParamInitializer initializer = innerLayer.initializer(); + fl.setInnerConfiguration(innerLayer); + Map m = initializer.getGradientsFromFlattened(conf, gradientView); + return m; + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/FrozenLayerWithBackpropParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/FrozenLayerWithBackpropParamInitializer.java index 5aa01b3e4..1328e28d9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/FrozenLayerWithBackpropParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/FrozenLayerWithBackpropParamInitializer.java @@ -20,10 +20,10 @@ package org.deeplearning4j.nn.params; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; -import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,7 +31,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; -public class FrozenLayerWithBackpropParamInitializer implements ParamInitializer { +public class FrozenLayerWithBackpropParamInitializer extends AbstractParamInitializer { private static final FrozenLayerWithBackpropParamInitializer INSTANCE = new FrozenLayerWithBackpropParamInitializer(); @@ -40,62 +40,54 @@ public class FrozenLayerWithBackpropParamInitializer implements ParamInitializer } @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer layer) { + public long numParams(LayerConfiguration layer) { FrozenLayerWithBackprop fl = (FrozenLayerWithBackprop) layer; ParamInitializer initializer = fl.getUnderlying().initializer(); return initializer.numParams(fl.getUnderlying()); } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return false; } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return false; } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - FrozenLayerWithBackprop fl = (FrozenLayerWithBackprop) conf.getLayer(); - Layer innerLayer = fl.getUnderlying(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + FrozenLayerWithBackprop fl = (FrozenLayerWithBackprop) conf; + LayerConfiguration innerLayer = fl.getUnderlying(); ParamInitializer initializer = innerLayer.initializer(); - conf.setLayer(innerLayer); + fl.setUnderlying(innerLayer); Map m = initializer.init(conf, paramsView, initializeParams); - conf.setLayer(fl); - return m; } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { - FrozenLayerWithBackprop fl = (FrozenLayerWithBackprop) conf.getLayer(); - Layer innerLayer = fl.getUnderlying(); + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { + FrozenLayerWithBackprop fl = (FrozenLayerWithBackprop) conf; + LayerConfiguration innerLayer = fl.getUnderlying(); ParamInitializer initializer = innerLayer.initializer(); - conf.setLayer(innerLayer); + fl.setUnderlying(innerLayer); Map m = initializer.getGradientsFromFlattened(conf, gradientView); - conf.setLayer(fl); return m; } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java index de437ee6d..5239a6c2c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java @@ -21,9 +21,10 @@ package org.deeplearning4j.nn.params; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,7 +34,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; -public class GravesBidirectionalLSTMParamInitializer implements ParamInitializer { +public class GravesBidirectionalLSTMParamInitializer extends AbstractParamInitializer { private static final GravesBidirectionalLSTMParamInitializer INSTANCE = new GravesBidirectionalLSTMParamInitializer(); @@ -61,12 +62,7 @@ public class GravesBidirectionalLSTMParamInitializer implements ParamInitializer BIAS_KEY_BACKWARDS)); @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM layerConf = (org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) l; @@ -81,37 +77,37 @@ public class GravesBidirectionalLSTMParamInitializer implements ParamInitializer } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { return ALL_PARAM_KEYS; } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return WEIGHT_KEYS; } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return BIAS_KEYS; } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return RECURRENT_WEIGHT_KEY_FORWARDS.equals(key) || INPUT_WEIGHT_KEY_FORWARDS.equals(key) || RECURRENT_WEIGHT_KEY_BACKWARDS.equals(key) || INPUT_WEIGHT_KEY_BACKWARDS.equals(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return BIAS_KEY_FORWARDS.equals(key) || BIAS_KEY_BACKWARDS.equals(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { Map params = Collections.synchronizedMap(new LinkedHashMap()); org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM layerConf = - (org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) conf; double forgetGateInit = layerConf.getForgetGateBiasInit(); val nL = layerConf.getNOut(); //i.e., n neurons in this layer @@ -187,9 +183,9 @@ public class GravesBidirectionalLSTMParamInitializer implements ParamInitializer @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM layerConf = - (org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) conf; val nL = layerConf.getNOut(); //i.e., n neurons in this layer val nLast = layerConf.getNIn(); //i.e., n neurons in previous layer diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java index 37e4d1cdf..5c59e5f7e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java @@ -21,9 +21,10 @@ package org.deeplearning4j.nn.params; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -33,7 +34,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; -public class GravesLSTMParamInitializer implements ParamInitializer { +public class GravesLSTMParamInitializer extends AbstractParamInitializer { private static final GravesLSTMParamInitializer INSTANCE = new GravesLSTMParamInitializer(); @@ -47,12 +48,7 @@ public class GravesLSTMParamInitializer implements ParamInitializer { public final static String INPUT_WEIGHT_KEY = LSTMParamInitializer.INPUT_WEIGHT_KEY; @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { org.deeplearning4j.nn.conf.layers.GravesLSTM layerConf = (org.deeplearning4j.nn.conf.layers.GravesLSTM) l; val nL = layerConf.getNOut(); //i.e., n neurons in this layer @@ -66,35 +62,35 @@ public class GravesLSTMParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { return Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY, BIAS_KEY); } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return Arrays.asList(INPUT_WEIGHT_KEY, RECURRENT_WEIGHT_KEY); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return Collections.singletonList(BIAS_KEY); } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return RECURRENT_WEIGHT_KEY.equals(key) || INPUT_WEIGHT_KEY.equals(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return BIAS_KEY.equals(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { Map params = Collections.synchronizedMap(new LinkedHashMap()); org.deeplearning4j.nn.conf.layers.GravesLSTM layerConf = - (org.deeplearning4j.nn.conf.layers.GravesLSTM) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.GravesLSTM) conf; double forgetGateInit = layerConf.getForgetGateBiasInit(); val nL = layerConf.getNOut(); //i.e., n neurons in this layer @@ -157,9 +153,9 @@ public class GravesLSTMParamInitializer implements ParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { org.deeplearning4j.nn.conf.layers.GravesLSTM layerConf = - (org.deeplearning4j.nn.conf.layers.GravesLSTM) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.GravesLSTM) conf; val nL = layerConf.getNOut(); //i.e., n neurons in this layer val nLast = layerConf.getNIn(); //i.e., n neurons in previous layer diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java index 2a7418957..04f12ea32 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java @@ -20,11 +20,16 @@ package org.deeplearning4j.nn.params; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LSTM; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,9 +37,7 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.INDArrayIndex; import org.nd4j.linalg.indexing.NDArrayIndex; -import java.util.*; - -public class LSTMParamInitializer implements ParamInitializer { +public class LSTMParamInitializer extends AbstractParamInitializer { private static final LSTMParamInitializer INSTANCE = new LSTMParamInitializer(); @@ -54,12 +57,7 @@ public class LSTMParamInitializer implements ParamInitializer { private static final List BIAS_KEYS = Collections.unmodifiableList(Collections.singletonList(BIAS_KEY)); @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { LSTM layerConf = (LSTM) l; val nL = layerConf.getNOut(); //i.e., n neurons in this layer @@ -73,34 +71,34 @@ public class LSTMParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { return LAYER_PARAM_KEYS; } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return WEIGHT_KEYS; } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return BIAS_KEYS; } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return RECURRENT_WEIGHT_KEY.equals(key) || INPUT_WEIGHT_KEY.equals(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return BIAS_KEY.equals(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { Map params = Collections.synchronizedMap(new LinkedHashMap()); - org.deeplearning4j.nn.conf.layers.LSTM layerConf = (org.deeplearning4j.nn.conf.layers.LSTM) conf.getLayer(); + org.deeplearning4j.nn.conf.layers.LSTM layerConf = (org.deeplearning4j.nn.conf.layers.LSTM) conf; double forgetGateInit = layerConf.getForgetGateBiasInit(); val nL = layerConf.getNOut(); //i.e., n neurons in this layer @@ -162,8 +160,8 @@ public class LSTMParamInitializer implements ParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { - org.deeplearning4j.nn.conf.layers.LSTM layerConf = (org.deeplearning4j.nn.conf.layers.LSTM) conf.getLayer(); + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { + org.deeplearning4j.nn.conf.layers.LSTM layerConf = (org.deeplearning4j.nn.conf.layers.LSTM) conf; val nL = layerConf.getNOut(); //i.e., n neurons in this layer val nLast = layerConf.getNIn(); //i.e., n neurons in previous layer diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java index d0a93e368..3de33be57 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java @@ -21,10 +21,11 @@ package org.deeplearning4j.nn.params; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.PReLULayer; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; @@ -36,7 +37,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -public class PReLUParamInitializer implements ParamInitializer { +public class PReLUParamInitializer extends AbstractParamInitializer { public final static String WEIGHT_KEY = "W"; private final long[] weightShape; @@ -58,14 +59,8 @@ public class PReLUParamInitializer implements ParamInitializer { return new PReLUParamInitializer(shape, sharedAxes); } - @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { return numParams(weightShape); } @@ -78,34 +73,34 @@ public class PReLUParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { return weightKeys(layer); } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return Collections.singletonList(WEIGHT_KEY); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return Collections.emptyList(); } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return WEIGHT_KEY.equals(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return false; } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - if (!(conf.getLayer() instanceof BaseLayer)) - throw new IllegalArgumentException("unsupported layer type: " + conf.getLayer().getClass().getName()); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + if (!(conf instanceof BaseLayer)) + throw new IllegalArgumentException("unsupported layer type: " + conf.getClass().getName()); Map params = Collections.synchronizedMap(new LinkedHashMap()); @@ -123,7 +118,7 @@ public class PReLUParamInitializer implements ParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { val length = numParams(conf); INDArray weightGradientView = gradientView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, length)) @@ -135,10 +130,10 @@ public class PReLUParamInitializer implements ParamInitializer { } - protected INDArray createWeightMatrix(NeuralNetConfiguration conf, INDArray weightParamView, + protected INDArray createWeightMatrix(LayerConfiguration conf, INDArray weightParamView, boolean initializeParameters) { - PReLULayer layerConf = (PReLULayer) conf.getLayer(); + PReLULayer layerConf = (PReLULayer) conf; if (initializeParameters) { return layerConf.getWeightInitFn().init(layerConf.getNIn(), layerConf.getNOut(), weightShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PretrainParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PretrainParamInitializer.java index c794a452c..4eb87427a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PretrainParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PretrainParamInitializer.java @@ -22,6 +22,7 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; @@ -45,18 +46,18 @@ public class PretrainParamInitializer extends DefaultParamInitializer { public final static String VISIBLE_BIAS_KEY = "v" + DefaultParamInitializer.BIAS_KEY; @Override - public long numParams(NeuralNetConfiguration conf) { + public long numParams(LayerConfiguration conf) { org.deeplearning4j.nn.conf.layers.BasePretrainNetwork layerConf = - (org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) conf; return super.numParams(conf) + layerConf.getNIn(); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { Map params = super.init(conf, paramsView, initializeParams); org.deeplearning4j.nn.conf.layers.BasePretrainNetwork layerConf = - (org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) conf; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); val nWeightParams = nIn * nOut; @@ -69,10 +70,10 @@ public class PretrainParamInitializer extends DefaultParamInitializer { return params; } - protected INDArray createVisibleBias(NeuralNetConfiguration conf, INDArray visibleBiasView, + protected INDArray createVisibleBias(LayerConfiguration conf, INDArray visibleBiasView, boolean initializeParameters) { org.deeplearning4j.nn.conf.layers.BasePretrainNetwork layerConf = - (org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.BasePretrainNetwork) conf; if (initializeParameters) { INDArray ret = Nd4j.valueArrayOf(new long[]{1, layerConf.getNIn()}, layerConf.getVisibleBiasInit()); visibleBiasView.assign(ret); @@ -82,10 +83,10 @@ public class PretrainParamInitializer extends DefaultParamInitializer { @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { Map out = super.getGradientsFromFlattened(conf, gradientView); org.deeplearning4j.nn.conf.layers.FeedForwardLayer layerConf = - (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf.getLayer(); + (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; val nIn = layerConf.getNIn(); val nOut = layerConf.getNOut(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SameDiffParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SameDiffParamInitializer.java index 2b9c3484c..0846e0bf5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SameDiffParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SameDiffParamInitializer.java @@ -22,9 +22,10 @@ package org.deeplearning4j.nn.params; import lombok.extern.slf4j.Slf4j; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex; import org.nd4j.linalg.api.ndarray.INDArray; @@ -38,7 +39,7 @@ import java.util.Map; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; @Slf4j -public class SameDiffParamInitializer implements ParamInitializer { +public class SameDiffParamInitializer extends AbstractParamInitializer { private static final SameDiffParamInitializer INSTANCE = new SameDiffParamInitializer(); @@ -47,12 +48,7 @@ public class SameDiffParamInitializer implements ParamInitializer { } @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer layer) { + public long numParams(LayerConfiguration layer) { AbstractSameDiffLayer sd = (AbstractSameDiffLayer)layer; Map m = sd.getLayerParams().getParamShapes(); int n = 0; @@ -63,36 +59,36 @@ public class SameDiffParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { AbstractSameDiffLayer sd = (AbstractSameDiffLayer)layer; return sd.getLayerParams().getParameterKeys(); } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { AbstractSameDiffLayer sd = (AbstractSameDiffLayer)layer; return sd.getLayerParams().getWeightParameterKeys(); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { AbstractSameDiffLayer sd = (AbstractSameDiffLayer)layer; return sd.getLayerParams().getBiasParameterKeys(); } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return weightKeys(layer).contains(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return biasKeys(layer).contains(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - AbstractSameDiffLayer sd = (AbstractSameDiffLayer) conf.getLayer(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + AbstractSameDiffLayer sd = (AbstractSameDiffLayer) conf; Map out = subsetAndReshape(sd.getLayerParams().getParameterKeys(), sd.getLayerParams().getParamShapes(), paramsView, sd); if(initializeParams){ @@ -107,8 +103,8 @@ public class SameDiffParamInitializer implements ParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { - AbstractSameDiffLayer sd = (AbstractSameDiffLayer) conf.getLayer(); + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { + AbstractSameDiffLayer sd = (AbstractSameDiffLayer) conf; return subsetAndReshape(sd.getLayerParams().getParameterKeys(), sd.getLayerParams().getParamShapes(), gradientView, sd); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java index bb7dabb4e..9df032560 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java @@ -22,9 +22,10 @@ package org.deeplearning4j.nn.params; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -32,7 +33,7 @@ import org.nd4j.linalg.indexing.NDArrayIndex; import java.util.*; -public class SeparableConvolutionParamInitializer implements ParamInitializer { +public class SeparableConvolutionParamInitializer extends AbstractParamInitializer { private static final SeparableConvolutionParamInitializer INSTANCE = new SeparableConvolutionParamInitializer(); @@ -45,12 +46,7 @@ public class SeparableConvolutionParamInitializer implements ParamInitializer { public final static String BIAS_KEY = DefaultParamInitializer.BIAS_KEY; @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer l) { + public long numParams(LayerConfiguration l) { SeparableConvolution2D layerConf = (SeparableConvolution2D) l; val depthWiseParams = numDepthWiseParams(layerConf); @@ -96,7 +92,7 @@ public class SeparableConvolutionParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { SeparableConvolution2D layerConf = (SeparableConvolution2D) layer; if(layerConf.hasBias()){ @@ -107,12 +103,12 @@ public class SeparableConvolutionParamInitializer implements ParamInitializer { } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { return Arrays.asList(DEPTH_WISE_WEIGHT_KEY, POINT_WISE_WEIGHT_KEY); } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { SeparableConvolution2D layerConf = (SeparableConvolution2D) layer; if(layerConf.hasBias()){ @@ -123,23 +119,23 @@ public class SeparableConvolutionParamInitializer implements ParamInitializer { } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return DEPTH_WISE_WEIGHT_KEY.equals(key) || POINT_WISE_WEIGHT_KEY.equals(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return BIAS_KEY.equals(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - SeparableConvolution2D layer = (SeparableConvolution2D) conf.getLayer(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + SeparableConvolution2D layer = (SeparableConvolution2D) conf; if (layer.getKernelSize().length != 2) throw new IllegalArgumentException("Filter size must be == 2"); Map params = Collections.synchronizedMap(new LinkedHashMap()); - SeparableConvolution2D layerConf = (SeparableConvolution2D) conf.getLayer(); + SeparableConvolution2D layerConf = (SeparableConvolution2D) conf; val depthWiseParams = numDepthWiseParams(layerConf); val biasParams = numBiasParams(layerConf); @@ -164,10 +160,10 @@ public class SeparableConvolutionParamInitializer implements ParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { SeparableConvolution2D layerConf = - (SeparableConvolution2D) conf.getLayer(); + (SeparableConvolution2D) conf; int[] kernel = layerConf.getKernelSize(); val nIn = layerConf.getNIn(); @@ -195,22 +191,22 @@ public class SeparableConvolutionParamInitializer implements ParamInitializer { return out; } - protected INDArray createBias(NeuralNetConfiguration conf, INDArray biasView, boolean initializeParams) { + protected INDArray createBias(LayerConfiguration conf, INDArray biasView, boolean initializeParams) { SeparableConvolution2D layerConf = - (SeparableConvolution2D) conf.getLayer(); + (SeparableConvolution2D) conf; if (initializeParams) biasView.assign(layerConf.getBiasInit()); return biasView; } - protected INDArray createDepthWiseWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, boolean initializeParams) { + protected INDArray createDepthWiseWeightMatrix(LayerConfiguration conf, INDArray weightView, boolean initializeParams) { /* Create a 4d weight matrix of: (channels multiplier, num input channels, kernel height, kernel width) Inputs to the convolution layer are: (batch size, num input feature maps, image height, image width) */ SeparableConvolution2D layerConf = - (SeparableConvolution2D) conf.getLayer(); + (SeparableConvolution2D) conf; int depthMultiplier = layerConf.getDepthMultiplier(); if (initializeParams) { @@ -233,14 +229,14 @@ public class SeparableConvolutionParamInitializer implements ParamInitializer { } } - protected INDArray createPointWiseWeightMatrix(NeuralNetConfiguration conf, INDArray weightView, + protected INDArray createPointWiseWeightMatrix(LayerConfiguration conf, INDArray weightView, boolean initializeParams) { /* Create a 4d weight matrix of: (num output channels, channels multiplier * num input channels, kernel height, kernel width) */ SeparableConvolution2D layerConf = - (SeparableConvolution2D) conf.getLayer(); + (SeparableConvolution2D) conf; int depthMultiplier = layerConf.getDepthMultiplier(); if (initializeParams) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java index f3fbf1e11..603492afa 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java @@ -21,9 +21,10 @@ package org.deeplearning4j.nn.params; import lombok.val; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,9 +32,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; import java.util.*; import static org.nd4j.linalg.indexing.NDArrayIndex.interval; -import static org.nd4j.linalg.indexing.NDArrayIndex.point; -public class SimpleRnnParamInitializer implements ParamInitializer { +public class SimpleRnnParamInitializer extends AbstractParamInitializer { private static final SimpleRnnParamInitializer INSTANCE = new SimpleRnnParamInitializer(); @@ -51,12 +51,7 @@ public class SimpleRnnParamInitializer implements ParamInitializer { @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer layer) { + public long numParams(LayerConfiguration layer) { SimpleRnn c = (SimpleRnn)layer; val nIn = c.getNIn(); val nOut = c.getNOut(); @@ -64,7 +59,7 @@ public class SimpleRnnParamInitializer implements ParamInitializer { } @Override - public List paramKeys(Layer layer) { + public List paramKeys(LayerConfiguration layer) { final ArrayList keys = new ArrayList<>(3); keys.addAll(weightKeys(layer)); keys.addAll(biasKeys(layer)); @@ -72,7 +67,7 @@ public class SimpleRnnParamInitializer implements ParamInitializer { } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { final ArrayList keys = new ArrayList<>(WEIGHT_KEYS); if(hasLayerNorm(layer)){ @@ -83,23 +78,23 @@ public class SimpleRnnParamInitializer implements ParamInitializer { } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { return BIAS_KEYS; } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return WEIGHT_KEY.equals(key) || RECURRENT_WEIGHT_KEY.equals(key) || GAIN_KEY.equals(key); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return BIAS_KEY.equals(key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - SimpleRnn c = (SimpleRnn)conf.getLayer(); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + SimpleRnn c = (SimpleRnn)conf; val nIn = c.getNIn(); val nOut = c.getNOut(); @@ -140,8 +135,8 @@ public class SimpleRnnParamInitializer implements ParamInitializer { } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { - SimpleRnn c = (SimpleRnn)conf.getLayer(); + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { + SimpleRnn c = (SimpleRnn)conf; val nIn = c.getNIn(); val nOut = c.getNOut(); @@ -172,7 +167,7 @@ public class SimpleRnnParamInitializer implements ParamInitializer { return m; } - protected boolean hasLayerNorm(Layer layer){ + protected boolean hasLayerNorm(LayerConfiguration layer){ if(layer instanceof SimpleRnn){ return ((SimpleRnn) layer).hasLayerNorm(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java index 399bf3a47..9284843d5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.weights.IWeightInit; import org.nd4j.linalg.api.ndarray.INDArray; @@ -71,8 +71,8 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali @Override - public long numParams(NeuralNetConfiguration conf) { - VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer(); + public long numParams(LayerConfiguration conf) { + VariationalAutoencoder layer = (VariationalAutoencoder) conf; val nIn = layer.getNIn(); val nOut = layer.getNOut(); @@ -116,7 +116,7 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali } @Override - public List paramKeys(Layer l) { + public List paramKeys(LayerConfiguration l) { VariationalAutoencoder layer = (VariationalAutoencoder) l; int[] encoderLayerSizes = layer.getEncoderLayerSizes(); int[] decoderLayerSizes = layer.getDecoderLayerSizes(); @@ -154,7 +154,7 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali } @Override - public List weightKeys(Layer layer) { + public List weightKeys(LayerConfiguration layer) { List out = new ArrayList<>(); for(String s : paramKeys(layer)){ if(isWeightParam(layer, s)){ @@ -165,7 +165,7 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali } @Override - public List biasKeys(Layer layer) { + public List biasKeys(LayerConfiguration layer) { List out = new ArrayList<>(); for(String s : paramKeys(layer)){ if(isBiasParam(layer, s)){ @@ -176,24 +176,24 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali } @Override - public boolean isWeightParam(Layer layer, String key) { + public boolean isWeightParam(LayerConfiguration layer, String key) { return key.endsWith(WEIGHT_KEY_SUFFIX); } @Override - public boolean isBiasParam(Layer layer, String key) { + public boolean isBiasParam(LayerConfiguration layer, String key) { return key.endsWith(BIAS_KEY_SUFFIX); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { if (paramsView.length() != numParams(conf)) { throw new IllegalArgumentException("Incorrect paramsView length: Expected length " + numParams(conf) + ", got length " + paramsView.length()); } Map ret = new LinkedHashMap<>(); - VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer(); + VariationalAutoencoder layer = (VariationalAutoencoder) conf; val nIn = layer.getNIn(); val nOut = layer.getNOut(); @@ -316,9 +316,9 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { Map ret = new LinkedHashMap<>(); - VariationalAutoencoder layer = (VariationalAutoencoder) conf.getLayer(); + VariationalAutoencoder layer = (VariationalAutoencoder) conf; val nIn = layer.getNIn(); val nOut = layer.getNOut(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/WrapperLayerParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/WrapperLayerParamInitializer.java index 234226eb4..7cb7059c8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/WrapperLayerParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/WrapperLayerParamInitializer.java @@ -20,16 +20,17 @@ package org.deeplearning4j.nn.params; +import org.deeplearning4j.nn.api.AbstractParamInitializer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.List; import java.util.Map; -public class WrapperLayerParamInitializer implements ParamInitializer { +public class WrapperLayerParamInitializer extends AbstractParamInitializer { private static final WrapperLayerParamInitializer INSTANCE = new WrapperLayerParamInitializer(); @@ -42,67 +43,62 @@ public class WrapperLayerParamInitializer implements ParamInitializer { } @Override - public long numParams(NeuralNetConfiguration conf) { - return numParams(conf.getLayer()); - } - - @Override - public long numParams(Layer layer) { - Layer l = underlying(layer); + public long numParams(LayerConfiguration layer) { + LayerConfiguration l = underlying(layer); return l.initializer().numParams(l); } @Override - public List paramKeys(Layer layer) { - Layer l = underlying(layer); + public List paramKeys(LayerConfiguration layer) { + LayerConfiguration l = underlying(layer); return l.initializer().paramKeys(l); } @Override - public List weightKeys(Layer layer) { - Layer l = underlying(layer); + public List weightKeys(LayerConfiguration layer) { + LayerConfiguration l = underlying(layer); return l.initializer().weightKeys(l); } @Override - public List biasKeys(Layer layer) { - Layer l = underlying(layer); + public List biasKeys(LayerConfiguration layer) { + LayerConfiguration l = underlying(layer); return l.initializer().biasKeys(l); } @Override - public boolean isWeightParam(Layer layer, String key) { - Layer l = underlying(layer); + public boolean isWeightParam(LayerConfiguration layer, String key) { + LayerConfiguration l = underlying(layer); return l.initializer().isWeightParam(layer, key); } @Override - public boolean isBiasParam(Layer layer, String key) { - Layer l = underlying(layer); + public boolean isBiasParam(LayerConfiguration layer, String key) { + LayerConfiguration l = underlying(layer); return l.initializer().isBiasParam(layer, key); } @Override - public Map init(NeuralNetConfiguration conf, INDArray paramsView, boolean initializeParams) { - Layer orig = conf.getLayer(); - Layer l = underlying(conf.getLayer()); - conf.setLayer(l); + public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { + LayerConfiguration orig = conf; + LayerConfiguration l = underlying(conf); + Map m = l.initializer().init(conf, paramsView, initializeParams); - conf.setLayer(orig); + return m; } @Override - public Map getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) { - Layer orig = conf.getLayer(); - Layer l = underlying(conf.getLayer()); - conf.setLayer(l); + public Map getGradientsFromFlattened(LayerConfiguration conf, INDArray gradientView) { + LayerConfiguration orig = conf; + LayerConfiguration l = underlying(conf); + Map m = l.initializer().getGradientsFromFlattened(conf, gradientView); - conf.setLayer(orig); + return m; } - private Layer underlying(Layer layer){ + private LayerConfiguration underlying(LayerConfiguration layer){ while (layer instanceof BaseWrapperLayer) { layer = ((BaseWrapperLayer)layer).getUnderlying(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java index 3f2ddd88b..73a31b96b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java @@ -20,23 +20,40 @@ package org.deeplearning4j.nn.transferlearning; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.core.JsonProcessingException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor; import lombok.ToString; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.layers.LayerConstraint; -import org.deeplearning4j.nn.conf.*; +import org.deeplearning4j.nn.conf.BackpropType; +import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; +import org.deeplearning4j.nn.conf.ConvolutionMode; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.Updater; +import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; -import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerValidation; +import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.util.NetworkUtils; +import org.nd4j.common.primitives.Optional; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.learning.config.IUpdater; @@ -44,14 +61,6 @@ import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.L2Regularization; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.WeightDecay; -import org.nd4j.common.primitives.Optional; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonTypeInfo; -import com.fasterxml.jackson.core.JsonProcessingException; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "type") @JsonInclude(JsonInclude.Include.NON_NULL) @@ -60,738 +69,794 @@ import java.util.List; @Data public class FineTuneConfiguration { - protected IActivation activationFn; - protected IWeightInit weightInitFn; - protected Double biasInit; - protected List regularization; - protected List regularizationBias; + protected IActivation activationFn; + protected IWeightInit weightInitFn; + protected Double biasInit; + protected List regularization; + protected List regularizationBias; + protected boolean removeL2 = false; //For: .l2(0.0) -> user means "no l2" so we should remove it if it is present in the original model... + protected boolean removeL2Bias = false; + protected boolean removeL1 = false; + protected boolean removeL1Bias = false; + protected boolean removeWD = false; + protected boolean removeWDBias = false; + protected Optional dropout; + protected Optional weightNoise; + protected IUpdater updater; + protected IUpdater biasUpdater; + protected Boolean miniBatch; + protected Integer maxNumLineSearchIterations; + protected Long seed; + protected OptimizationAlgorithm optimizationAlgo; + protected StepFunction stepFunction; + protected Boolean minimize; + protected Optional gradientNormalization; + protected Double gradientNormalizationThreshold; + protected ConvolutionMode convolutionMode; + protected ConvolutionLayer.AlgoMode cudnnAlgoMode; + protected Optional> constraints; + + protected Boolean pretrain; + protected Boolean backprop; + protected BackpropType backpropType; + protected Integer tbpttFwdLength; + protected Integer tbpttBackLength; + + protected WorkspaceMode trainingWorkspaceMode; + protected WorkspaceMode inferenceWorkspaceMode; + + public static Builder builder() { + return new Builder(); + } + + private static T get(Optional optional) { + if (optional == null) { + return null; + } + return optional.orElse(null); + } + + public static FineTuneConfiguration fromJson(String json) { + try { + return NeuralNetConfiguration.mapper().readValue(json, FineTuneConfiguration.class); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static FineTuneConfiguration fromYaml(String yaml) { + try { + return NeuralNetConfiguration.mapperYaml().readValue(yaml, FineTuneConfiguration.class); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + /** + * public NeuralNetConfiguration appliedNeuralNetConfiguration(NeuralNetConfiguration nnc) { + * applyToNeuralNetConfiguration(nnc); nnc = new + * NeuralNetConfiguration.NeuralNetConfigurationBuilder(nnc.clone()).build(); return nnc; } + **/ + + public void applyToLayerConfiguration(LayerConfiguration layerConfiguration) { + + Updater originalUpdater = null; + WeightInit origWeightInit = null; + + if (layerConfiguration != null) { + //As per NeuralNetConfiguration.configureLayer and LayerValidation.configureBaseLayer: only copy dropout to base layers + // this excludes things like subsampling and activation layers + if (dropout != null && layerConfiguration instanceof BaseLayer) { + IDropout d = dropout.orElse(null); + if (d != null) { + d = d.clone(); //Clone to avoid shared state between layers + } + layerConfiguration.setIDropout(d); + } + if (constraints != null) { + layerConfiguration.setConstraints(constraints.orElse(null)); + } + } + + if (layerConfiguration != null && layerConfiguration instanceof BaseLayer) { + BaseLayer bl = (BaseLayer) layerConfiguration; + if (activationFn != null) { + bl.setActivationFn(activationFn); + } + if (weightInitFn != null) { + bl.setWeightInitFn(weightInitFn); + } + if (biasInit != null) { + bl.setBiasInit(biasInit); + } + if (regularization != null && !regularization.isEmpty()) { + bl.setRegularization(regularization); + } + if (regularizationBias != null && !regularizationBias.isEmpty()) { + bl.setRegularizationBias(regularizationBias); + } + if (removeL2) { + NetworkUtils.removeInstances(bl.getRegularization(), L2Regularization.class); + } + if (removeL2Bias) { + NetworkUtils.removeInstances(bl.getRegularizationBias(), L2Regularization.class); + } + if (removeL1) { + NetworkUtils.removeInstances(bl.getRegularization(), L1Regularization.class); + } + if (removeL1Bias) { + NetworkUtils.removeInstances(bl.getRegularizationBias(), L1Regularization.class); + } + if (removeWD) { + NetworkUtils.removeInstances(bl.getRegularization(), WeightDecay.class); + } + if (removeWDBias) { + NetworkUtils.removeInstances(bl.getRegularizationBias(), WeightDecay.class); + } + if (gradientNormalization != null) { + bl.setGradientNormalization(gradientNormalization.orElse(null)); + } + if (gradientNormalizationThreshold != null) { + bl.setGradientNormalizationThreshold(gradientNormalizationThreshold); + } + if (updater != null) { + bl.setIUpdater(updater); + } + if (biasUpdater != null) { + bl.setBiasUpdater(biasUpdater); + } + if (weightNoise != null) { + bl.setWeightNoise(weightNoise.orElse(null)); + } + } + NeuralNetConfiguration nnc = layerConfiguration.getNetConfiguration(); + if (miniBatch != null) { + nnc.setMiniBatch(miniBatch); + } + if (maxNumLineSearchIterations != null) { + nnc.setMaxNumLineSearchIterations(maxNumLineSearchIterations); + } + if (seed != null) { + nnc.setSeed(seed); + } + if (optimizationAlgo != null) { + nnc.setOptimizationAlgo(optimizationAlgo); + } + if (stepFunction != null) { + nnc.setStepFunction(stepFunction); + } + if (minimize != null) { + nnc.setMinimize(minimize); + } + + if (convolutionMode != null && layerConfiguration instanceof ConvolutionLayer) { + ((ConvolutionLayer) layerConfiguration).setConvolutionMode(convolutionMode); + } + if (cudnnAlgoMode != null && layerConfiguration instanceof ConvolutionLayer) { + ((ConvolutionLayer) layerConfiguration).setCudnnAlgoMode(cudnnAlgoMode); + } + if (convolutionMode != null && layerConfiguration instanceof SubsamplingLayer) { + ((SubsamplingLayer) layerConfiguration).setConvolutionMode(convolutionMode); + } + + //Perform validation + if (layerConfiguration != null) { + LayerValidation.generalValidation(layerConfiguration.getLayerName(), layerConfiguration, get(dropout), regularization, + regularizationBias, + get(constraints), null, null); + } + } + + + public void applyToComputationGraphConfiguration(ComputationGraphConfiguration conf) { + if (backpropType != null) { + conf.setBackpropType(backpropType); + } + if (tbpttFwdLength != null) { + conf.setTbpttFwdLength(tbpttFwdLength); + } + if (tbpttBackLength != null) { + conf.setTbpttBackLength(tbpttBackLength); + } + } + + public NeuralNetConfiguration appliedNeuralNetConfigurationBuilder() { + NeuralNetConfiguration.NeuralNetConfigurationBuilder confBuilder = NeuralNetConfiguration.builder(); + + if (activationFn != null) { + confBuilder.activationFn(activationFn); + } + if (weightInitFn != null) { + confBuilder.weightInitFn(weightInitFn); + } + if (biasInit != null) { + confBuilder.biasInit(biasInit); + } + if (regularization != null) { + confBuilder.regularization(regularization); + } + if (regularizationBias != null) { + confBuilder.regularizationBias(regularizationBias); + } + if (dropout != null) { + confBuilder.idropOut(dropout.orElse(null)); + } + if (updater != null) { + confBuilder.updater(updater); + } + if (biasUpdater != null) { + confBuilder.biasUpdater(biasUpdater); + } + if (miniBatch != null) { + confBuilder.miniBatch(miniBatch); + } + if (maxNumLineSearchIterations != null) { + confBuilder.maxNumLineSearchIterations(maxNumLineSearchIterations); + } + if (seed != null) { + confBuilder.seed(seed); + } + if (optimizationAlgo != null) { + confBuilder.optimizationAlgo(optimizationAlgo); + } + if (stepFunction != null) { + confBuilder.stepFunction(stepFunction); + } + if (minimize != null) { + confBuilder.minimize(minimize); + } + if (gradientNormalization != null) { + confBuilder.gradientNormalization(gradientNormalization.orElse(null)); + } + if (gradientNormalizationThreshold != null) { + confBuilder.gradientNormalizationThreshold(gradientNormalizationThreshold); + } + if (trainingWorkspaceMode != null) { + confBuilder.trainingWorkspaceMode(trainingWorkspaceMode); + } + if (inferenceWorkspaceMode != null) { + confBuilder.inferenceWorkspaceMode(inferenceWorkspaceMode); + } + return confBuilder.build(); + } + + public String toJson() { + try { + return NeuralNetConfiguration.mapper().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + public String toYaml() { + try { + return NeuralNetConfiguration.mapperYaml().writeValueAsString(this); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + + /* + * Can't use Lombok @Builder annotation due to optionals (otherwise we have a bunch of ugly .x(Optional value) + * methods - lombok builder doesn't support excluding fields? :( + * Note the use of optional here: gives us 3 states... + * 1. Null: not set + * 2. Optional (empty): set to null + * 3. Optional (not empty): set to specific value + * + * Obviously, having null only makes sense for some things (dropout, etc) whereas null for other things doesn't + * make sense + */ + @ToString + public static class Builder { + + protected List regularization = new ArrayList<>(); + protected List regularizationBias = new ArrayList<>(); protected boolean removeL2 = false; //For: .l2(0.0) -> user means "no l2" so we should remove it if it is present in the original model... protected boolean removeL2Bias = false; protected boolean removeL1 = false; protected boolean removeL1Bias = false; protected boolean removeWD = false; protected boolean removeWDBias = false; - protected Optional dropout; - protected Optional weightNoise; - protected IUpdater updater; - protected IUpdater biasUpdater; - protected Boolean miniBatch; - protected Integer maxNumLineSearchIterations; - protected Long seed; - protected OptimizationAlgorithm optimizationAlgo; - protected StepFunction stepFunction; - protected Boolean minimize; - protected Optional gradientNormalization; - protected Double gradientNormalizationThreshold; - protected ConvolutionMode convolutionMode; - protected ConvolutionLayer.AlgoMode cudnnAlgoMode; - protected Optional> constraints; + private IActivation activation; + private IWeightInit weightInitFn; + private Double biasInit; + private Optional dropout; + private Optional weightNoise; + private IUpdater updater; + private IUpdater biasUpdater; + private Boolean miniBatch; + private Integer maxNumLineSearchIterations; + private Long seed; + private OptimizationAlgorithm optimizationAlgo; + private StepFunction stepFunction; + private Boolean minimize; + private Optional gradientNormalization; + private Double gradientNormalizationThreshold; + private ConvolutionMode convolutionMode; + private ConvolutionLayer.AlgoMode cudnnAlgoMode; + private Optional> constraints; + private Boolean pretrain; + private Boolean backprop; + private BackpropType backpropType; + private Integer tbpttFwdLength; + private Integer tbpttBackLength; + private WorkspaceMode trainingWorkspaceMode; + private WorkspaceMode inferenceWorkspaceMode; - protected Boolean pretrain; - protected Boolean backprop; - protected BackpropType backpropType; - protected Integer tbpttFwdLength; - protected Integer tbpttBackLength; + public Builder() { - protected WorkspaceMode trainingWorkspaceMode; - protected WorkspaceMode inferenceWorkspaceMode; - - public static Builder builder() { - return new Builder(); } - /* - * Can't use Lombok @Builder annotation due to optionals (otherwise we have a bunch of ugly .x(Optional value) - * methods - lombok builder doesn't support excluding fields? :( - * Note the use of optional here: gives us 3 states... - * 1. Null: not set - * 2. Optional (empty): set to null - * 3. Optional (not empty): set to specific value - * - * Obviously, having null only makes sense for some things (dropout, etc) whereas null for other things doesn't - * make sense + /** + * Activation function / neuron non-linearity */ - @ToString - public static class Builder { - private IActivation activation; - private IWeightInit weightInitFn; - private Double biasInit; - protected List regularization = new ArrayList<>(); - protected List regularizationBias = new ArrayList<>(); - protected boolean removeL2 = false; //For: .l2(0.0) -> user means "no l2" so we should remove it if it is present in the original model... - protected boolean removeL2Bias = false; - protected boolean removeL1 = false; - protected boolean removeL1Bias = false; - protected boolean removeWD = false; - protected boolean removeWDBias = false; - private Optional dropout; - private Optional weightNoise; - private IUpdater updater; - private IUpdater biasUpdater; - private Boolean miniBatch; - private Integer maxNumLineSearchIterations; - private Long seed; - private OptimizationAlgorithm optimizationAlgo; - private StepFunction stepFunction; - private Boolean minimize; - private Optional gradientNormalization; - private Double gradientNormalizationThreshold; - private ConvolutionMode convolutionMode; - private ConvolutionLayer.AlgoMode cudnnAlgoMode; - private Optional> constraints; - private Boolean pretrain; - private Boolean backprop; - private BackpropType backpropType; - private Integer tbpttFwdLength; - private Integer tbpttBackLength; - private WorkspaceMode trainingWorkspaceMode; - private WorkspaceMode inferenceWorkspaceMode; + public Builder activation(IActivation activationFn) { + this.activation = activationFn; + return this; + } - public Builder() { + /** + * Activation function / neuron non-linearity + */ + public Builder activation(Activation activation) { + this.activation = activation.getActivationFunction(); + return this; + } - } + /** + * Weight initialization scheme to use, for initial weight values + * + * @see IWeightInit + */ + public Builder weightInit(IWeightInit weightInit) { + this.weightInitFn = weightInit; + return this; + } - /** - * Activation function / neuron non-linearity - */ - public Builder activation(IActivation activationFn) { - this.activation = activationFn; - return this; - } + /** + * Weight initialization scheme to use, for initial weight values + * + * @see WeightInit + */ + public Builder weightInit(WeightInit weightInit) { + if (weightInit == WeightInit.DISTRIBUTION) { + throw new UnsupportedOperationException( + "Not supported!, User weightInit(Distribution distribution) instead!"); + } - /** - * Activation function / neuron non-linearity - */ - public Builder activation(Activation activation) { - this.activation = activation.getActivationFunction(); - return this; - } - - /** - * Weight initialization scheme to use, for initial weight values - * - * @see IWeightInit - */ - public Builder weightInit(IWeightInit weightInit) { - this.weightInitFn = weightInit; - return this; - } - - /** - * Weight initialization scheme to use, for initial weight values - * - * @see WeightInit - */ - public Builder weightInit(WeightInit weightInit) { - if(weightInit == WeightInit.DISTRIBUTION) { - throw new UnsupportedOperationException("Not supported!, User weightInit(Distribution distribution) instead!"); - } - - this.weightInitFn = weightInit.getWeightInitFunction(); - return this; - } - - - /** - * Set weight initialization scheme to random sampling via the specified distribution. - * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))} - * - * @param distribution Distribution to use for weight initialization - */ - public Builder weightInit(Distribution distribution){ - return weightInit(new WeightInitDistribution(distribution)); - } - - /** - * Constant for bias initialization. Default: 0.0 - * - * @param biasInit Constant for bias initialization - */ - public Builder biasInit(double biasInit) { - this.biasInit = biasInit; - return this; - } - - /** - * Distribution to sample initial weights from. - * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))} - */ - @Deprecated - public Builder dist(Distribution dist) { - return weightInit(dist); - } - - /** - * L1 regularization coefficient for the weights (excluding biases) - */ - public Builder l1(double l1) { - NetworkUtils.removeInstances(regularization, L1Regularization.class); - if(l1 > 0.0) { - regularization.add(new L1Regularization(l1)); - } - return this; - } - - /** - * L2 regularization coefficient for the weights (excluding biases)
- * Note: Generally, {@link WeightDecay} (set via {@link #weightDecay(double,boolean)} should be preferred to - * L2 regularization. See {@link WeightDecay} javadoc for further details.
- */ - public Builder l2(double l2) { - NetworkUtils.removeInstances(regularization, L2Regularization.class); - if(l2 > 0.0) { - NetworkUtils.removeInstancesWithWarning(regularization, WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization"); - regularization.add(new L2Regularization(l2)); - } else { - removeL2 = true; - } - return this; - } - - /** - * L1 regularization coefficient for the bias parameters - */ - public Builder l1Bias(double l1Bias) { - NetworkUtils.removeInstances(regularizationBias, L1Regularization.class); - if(l1Bias > 0.0) { - regularizationBias.add(new L1Regularization(l1Bias)); - } else { - removeL1Bias = true; - } - return this; - } - - /** - * L2 regularization coefficient for the bias parameters
- * Note: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double,boolean)} should be preferred to - * L2 regularization. See {@link WeightDecay} javadoc for further details.
- */ - public Builder l2Bias(double l2Bias) { - NetworkUtils.removeInstances(regularizationBias, L2Regularization.class); - if(l2Bias > 0.0) { - NetworkUtils.removeInstancesWithWarning(regularizationBias, WeightDecay.class, "WeightDecay bias regularization removed: incompatible with added L2 regularization"); - regularizationBias.add(new L2Regularization(l2Bias)); - } else { - removeL2Bias = true; - } - return this; - } - - /** - * Add weight decay regularization for the network parameters (excluding biases).
- * This applies weight decay with multiplying the learning rate - see {@link WeightDecay} for more details.
- * - * @param coefficient Weight decay regularization coefficient - * @see #weightDecay(double, boolean) - */ - public Builder weightDecay(double coefficient) { - return weightDecay(coefficient, true); - } - - /** - * Add weight decay regularization for the network parameters (excluding biases). See {@link WeightDecay} for more details.
- * - * @param coefficient Weight decay regularization coefficient - * @param applyLR Whether the learning rate should be multiplied in when performing weight decay updates. See {@link WeightDecay} for more details. - * @see #weightDecay(double, boolean) - */ - public Builder weightDecay(double coefficient, boolean applyLR) { - //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both - NetworkUtils.removeInstances(this.regularization, WeightDecay.class); - if(coefficient > 0.0) { - NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization"); - this.regularization.add(new WeightDecay(coefficient, applyLR)); - } else { - removeWD = true; - } - return this; - } - - /** - * Weight decay for the biases only - see {@link #weightDecay(double)} for more details. - * This applies weight decay with multiplying the learning rate.
- * - * @param coefficient Weight decay regularization coefficient - * @see #weightDecayBias(double, boolean) - */ - public Builder weightDecayBias(double coefficient) { - return weightDecayBias(coefficient, true); - } - - /** - * Weight decay for the biases only - see {@link #weightDecay(double)} for more details
- * - * @param coefficient Weight decay regularization coefficient - */ - public Builder weightDecayBias(double coefficient, boolean applyLR) { - //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both - NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class); - if(coefficient > 0) { - NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization"); - this.regularizationBias.add(new WeightDecay(coefficient, applyLR)); - } else { - removeWDBias = true; - } - return this; - } - - /** - * Set the dropout - * - * @param dropout Dropout, such as {@link Dropout}, {@link org.deeplearning4j.nn.conf.dropout.GaussianDropout}, - * {@link org.deeplearning4j.nn.conf.dropout.GaussianNoise} etc - */ - public Builder dropout(IDropout dropout) { - this.dropout = Optional.ofNullable(dropout); - return this; - } - - /** - * Dropout probability. This is the probability of retaining each input activation value for a layer. - * dropOut(x) will keep an input activation with probability x, and set to 0 with probability 1-x.
- * dropOut(0.0) is a special value / special case - when set to 0.0., dropout is disabled (not applied). Note - * that a dropout value of 1.0 is functionally equivalent to no dropout: i.e., 100% probability of retaining - * each input activation.
- *

- * Note 1: Dropout is applied at training time only - and is automatically not applied at test time - * (for evaluation, etc)
- * Note 2: This sets the probability per-layer. Care should be taken when setting lower values for - * complex networks (too much information may be lost with aggressive (very low) dropout values).
- * Note 3: Frequently, dropout is not applied to (or, has higher retain probability for) input (first layer) - * layers. Dropout is also often not applied to output layers. This needs to be handled MANUALLY by the user - * - set .dropout(0) on those layers when using global dropout setting.
- * Note 4: Implementation detail (most users can ignore): DL4J uses inverted dropout, as described here: - * http://cs231n.github.io/neural-networks-2/ - *

- * - * @param inputRetainProbability Dropout probability (probability of retaining each input activation value for a layer) - * @see #dropout(IDropout) - */ - public Builder dropOut(double inputRetainProbability){ - if(inputRetainProbability == 0.0){ - return dropout(null); - } - return dropout(new Dropout(inputRetainProbability)); - } - - /** - * Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and - * {@link org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) - * - * @param weightNoise Weight noise instance to use - */ - public Builder weightNoise(IWeightNoise weightNoise) { - this.weightNoise = Optional.ofNullable(weightNoise); - return this; - } - - /** - * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} - * or {@link org.nd4j.linalg.learning.config.Nesterovs} - * - * @param updater Updater to use - */ - public Builder updater(IUpdater updater) { - this.updater = updater; - return this; - } - - /** - * @deprecated Use {@link #updater(IUpdater)} - */ - @Deprecated - public Builder updater(Updater updater) { - return updater(updater.getIUpdaterWithDefaultConfig()); - } - - /** - * Gradient updater configuration, for the biases only. If not set, biases will use the updater as - * set by {@link #updater(IUpdater)} - * - * @param biasUpdater Updater to use for bias parameters - */ - public Builder biasUpdater(IUpdater biasUpdater) { - this.biasUpdater = biasUpdater; - return this; - } - - /** - * Whether scores and gradients should be divided by the minibatch size.
- * Most users should leave this ast he default value of true. - */ - public Builder miniBatch(boolean miniBatch) { - this.miniBatch = miniBatch; - return this; - } - - public Builder maxNumLineSearchIterations(int maxNumLineSearchIterations) { - this.maxNumLineSearchIterations = maxNumLineSearchIterations; - return this; - } - - /** - * RNG seed for reproducibility - * @param seed RNG seed to use - */ - public Builder seed(long seed) { - this.seed = seed; - return this; - } - - /** - * RNG seed for reproducibility - * @param seed RNG seed to use - */ - public Builder seed(int seed){ - return seed((long)seed); - } - - public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) { - this.optimizationAlgo = optimizationAlgo; - return this; - } - - public Builder stepFunction(StepFunction stepFunction) { - this.stepFunction = stepFunction; - return this; - } - - public Builder minimize(boolean minimize) { - this.minimize = minimize; - return this; - } - - /** - * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping etc. - * See {@link GradientNormalization} for details - * - * @param gradientNormalization Type of normalization to use. Defaults to None. - * @see GradientNormalization - */ - public Builder gradientNormalization(GradientNormalization gradientNormalization) { - this.gradientNormalization = Optional.ofNullable(gradientNormalization); - return this; - } - - /** - * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, - * GradientNormalization.ClipL2PerParamType, and GradientNormalization.ClipElementWiseAbsoluteValue
- * Not used otherwise.
- * L2 threshold for first two types of clipping, or absolute value threshold for last type of clipping - */ - public Builder gradientNormalizationThreshold(double gradientNormalizationThreshold) { - this.gradientNormalizationThreshold = gradientNormalizationThreshold; - return this; - } - - /** - * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. - * See {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE
- * @param convolutionMode Convolution mode to use - */ - public Builder convolutionMode(ConvolutionMode convolutionMode) { - this.convolutionMode = convolutionMode; - return this; - } - - /** - * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN. - * See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. - */ - public Builder cudnnAlgoMode(ConvolutionLayer.AlgoMode cudnnAlgoMode) { - this.cudnnAlgoMode = cudnnAlgoMode; - return this; - } - - /** - * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated. - * - * @param constraints Constraints to apply to all parameters of all layers - */ - public Builder constraints(List constraints) { - this.constraints = Optional.ofNullable(constraints); - return this; - } - - public Builder pretrain(boolean pretrain) { - this.pretrain = pretrain; - return this; - } - - public Builder backprop(boolean backprop) { - this.backprop = backprop; - return this; - } - - /** - * The type of backprop. Default setting is used for most networks (MLP, CNN etc), - * but optionally truncated BPTT can be used for training recurrent neural networks. - * If using TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() - * - * @param backpropType Type of backprop. Default: BackpropType.Standard - */ - public Builder backpropType(BackpropType backpropType) { - this.backpropType = backpropType; - return this; - } - - /** - * When doing truncated BPTT: how many steps of forward pass should we do - * before doing (truncated) backprop?
- * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
- * Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, - * but may be larger than it in some circumstances (but never smaller)
- * Ideally your training data time series length should be divisible by this - * This is the k1 parameter on pg23 of - * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param tbpttFwdLength Forward length > 0, >= backwardLength - */ - public Builder tbpttFwdLength(int tbpttFwdLength) { - this.tbpttFwdLength = tbpttFwdLength; - return this; - } - - /** - * When doing truncated BPTT: how many steps of backward should we do?
- * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
- * This is the k2 parameter on pg23 of - * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf - * - * @param tbpttBackLength <= forwardLength - */ - public Builder tbpttBackLength(int tbpttBackLength) { - this.tbpttBackLength = tbpttBackLength; - return this; - } - - /** - * This method defines Workspace mode being used during training: - * NONE: workspace won't be used - * ENABLED: workspaces will be used for training (reduced memory and better performance) - * - * @param trainingWorkspaceMode Workspace mode for training - * @return Builder - */ - public Builder trainingWorkspaceMode(WorkspaceMode trainingWorkspaceMode) { - this.trainingWorkspaceMode = trainingWorkspaceMode; - return this; - } - - /** - * This method defines Workspace mode being used during inference:
- * NONE: workspace won't be used
- * ENABLED: workspaces will be used for inference (reduced memory and better performance) - * - * @param inferenceWorkspaceMode Workspace mode for inference - * @return Builder - */ - public Builder inferenceWorkspaceMode(WorkspaceMode inferenceWorkspaceMode) { - this.inferenceWorkspaceMode = inferenceWorkspaceMode; - return this; - } - - public FineTuneConfiguration build() { - return new FineTuneConfiguration(activation, weightInitFn, biasInit, regularization, regularizationBias, - removeL2, removeL2Bias, removeL1, removeL1Bias, removeWD, removeWDBias, dropout, - weightNoise, updater, biasUpdater, miniBatch, maxNumLineSearchIterations, seed, optimizationAlgo, stepFunction, - minimize, gradientNormalization, gradientNormalizationThreshold, convolutionMode, cudnnAlgoMode, constraints, - pretrain, backprop, backpropType, tbpttFwdLength, tbpttBackLength, trainingWorkspaceMode, inferenceWorkspaceMode); - } + this.weightInitFn = weightInit.getWeightInitFunction(); + return this; } - public NeuralNetConfiguration appliedNeuralNetConfiguration(NeuralNetConfiguration nnc) { - applyToNeuralNetConfiguration(nnc); - nnc = new NeuralNetConfiguration.Builder(nnc.clone()).build(); - return nnc; + /** + * Set weight initialization scheme to random sampling via the specified distribution. + * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))} + * + * @param distribution Distribution to use for weight initialization + */ + public Builder weightInit(Distribution distribution) { + return weightInit(new WeightInitDistribution(distribution)); } - public void applyToNeuralNetConfiguration(NeuralNetConfiguration nnc) { - - Layer l = nnc.getLayer(); - Updater originalUpdater = null; - WeightInit origWeightInit = null; - - if (l != null) { - //As per NeuralNetConfiguration.configureLayer and LayerValidation.configureBaseLayer: only copy dropout to base layers - // this excludes things like subsampling and activation layers - if (dropout != null && l instanceof BaseLayer) { - IDropout d = dropout.orElse(null); - if(d != null) - d = d.clone(); //Clone to avoid shared state between layers - l.setIDropout(d); - } - if(constraints != null) - l.setConstraints(constraints.orElse(null)); - } - - if (l != null && l instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) l; - if (activationFn != null) - bl.setActivationFn(activationFn); - if (weightInitFn != null) - bl.setWeightInitFn(weightInitFn); - if (biasInit != null) - bl.setBiasInit(biasInit); - if (regularization != null && !regularization.isEmpty()) - bl.setRegularization(regularization); - if (regularizationBias != null && !regularizationBias.isEmpty()) - bl.setRegularizationBias(regularizationBias); - if (removeL2) - NetworkUtils.removeInstances(bl.getRegularization(), L2Regularization.class); - if (removeL2Bias) - NetworkUtils.removeInstances(bl.getRegularizationBias(), L2Regularization.class); - if (removeL1) - NetworkUtils.removeInstances(bl.getRegularization(), L1Regularization.class); - if (removeL1Bias) - NetworkUtils.removeInstances(bl.getRegularizationBias(), L1Regularization.class); - if (removeWD) - NetworkUtils.removeInstances(bl.getRegularization(), WeightDecay.class); - if (removeWDBias) - NetworkUtils.removeInstances(bl.getRegularizationBias(), WeightDecay.class); - if (gradientNormalization != null) - bl.setGradientNormalization(gradientNormalization.orElse(null)); - if (gradientNormalizationThreshold != null) - bl.setGradientNormalizationThreshold(gradientNormalizationThreshold); - if (updater != null){ - bl.setIUpdater(updater); - } - if (biasUpdater != null){ - bl.setBiasUpdater(biasUpdater); - } - if (weightNoise != null){ - bl.setWeightNoise(weightNoise.orElse(null)); - } - } - if (miniBatch != null) - nnc.setMiniBatch(miniBatch); - if (maxNumLineSearchIterations != null) - nnc.setMaxNumLineSearchIterations(maxNumLineSearchIterations); - if (seed != null) - nnc.setSeed(seed); - if (optimizationAlgo != null) - nnc.setOptimizationAlgo(optimizationAlgo); - if (stepFunction != null) - nnc.setStepFunction(stepFunction); - if (minimize != null) - nnc.setMinimize(minimize); - - if (convolutionMode != null && l instanceof ConvolutionLayer) { - ((ConvolutionLayer) l).setConvolutionMode(convolutionMode); - } - if (cudnnAlgoMode != null && l instanceof ConvolutionLayer) { - ((ConvolutionLayer) l).setCudnnAlgoMode(cudnnAlgoMode); - } - if (convolutionMode != null && l instanceof SubsamplingLayer) { - ((SubsamplingLayer) l).setConvolutionMode(convolutionMode); - } - - //Perform validation - if (l != null) { - LayerValidation.generalValidation(l.getLayerName(), l, get(dropout), regularization, regularizationBias, - get(constraints), null, null); - } + /** + * Constant for bias initialization. Default: 0.0 + * + * @param biasInit Constant for bias initialization + */ + public Builder biasInit(double biasInit) { + this.biasInit = biasInit; + return this; } - private static T get(Optional optional){ - if(optional == null){ - return null; - } - return optional.orElse(null); + /** + * Distribution to sample initial weights from. Equivalent to: + * {@code .weightInit(new WeightInitDistribution(distribution))} + */ + @Deprecated + public Builder dist(Distribution dist) { + return weightInit(dist); } - public void applyToMultiLayerConfiguration(MultiLayerConfiguration conf) { - if (backpropType != null) - conf.setBackpropType(backpropType); - if (tbpttFwdLength != null) - conf.setTbpttFwdLength(tbpttFwdLength); - if (tbpttBackLength != null) - conf.setTbpttBackLength(tbpttBackLength); + /** + * L1 regularization coefficient for the weights (excluding biases) + */ + public Builder l1(double l1) { + NetworkUtils.removeInstances(regularization, L1Regularization.class); + if (l1 > 0.0) { + regularization.add(new L1Regularization(l1)); + } + return this; } - public void applyToComputationGraphConfiguration(ComputationGraphConfiguration conf) { - if (backpropType != null) - conf.setBackpropType(backpropType); - if (tbpttFwdLength != null) - conf.setTbpttFwdLength(tbpttFwdLength); - if (tbpttBackLength != null) - conf.setTbpttBackLength(tbpttBackLength); + /** + * L2 regularization coefficient for the weights (excluding biases)
+ * Note: Generally, {@link WeightDecay} (set via {@link #weightDecay(double, boolean)} + * should be preferred to + * L2 regularization. See {@link WeightDecay} javadoc for further details.
+ */ + public Builder l2(double l2) { + NetworkUtils.removeInstances(regularization, L2Regularization.class); + if (l2 > 0.0) { + NetworkUtils.removeInstancesWithWarning(regularization, WeightDecay.class, + "WeightDecay regularization removed: incompatible with added L2 regularization"); + regularization.add(new L2Regularization(l2)); + } else { + removeL2 = true; + } + return this; } - public NeuralNetConfiguration.Builder appliedNeuralNetConfigurationBuilder() { - NeuralNetConfiguration.Builder confBuilder = new NeuralNetConfiguration.Builder(); - if (activationFn != null) - confBuilder.setActivationFn(activationFn); - if (weightInitFn != null) - confBuilder.setWeightInitFn(weightInitFn); - if (biasInit != null) - confBuilder.setBiasInit(biasInit); - if (regularization != null) - confBuilder.setRegularization(regularization); - if (regularizationBias != null) - confBuilder.setRegularizationBias(regularizationBias); - if (dropout != null) - confBuilder.setIdropOut(dropout.orElse(null)); - if (updater != null) - confBuilder.updater(updater); - if(biasUpdater != null) - confBuilder.biasUpdater(biasUpdater); - if (miniBatch != null) - confBuilder.setMiniBatch(miniBatch); - if (maxNumLineSearchIterations != null) - confBuilder.setMaxNumLineSearchIterations(maxNumLineSearchIterations); - if (seed != null) - confBuilder.setSeed(seed); - if (optimizationAlgo != null) - confBuilder.setOptimizationAlgo(optimizationAlgo); - if (stepFunction != null) - confBuilder.setStepFunction(stepFunction); - if (minimize != null) - confBuilder.setMinimize(minimize); - if (gradientNormalization != null) - confBuilder.setGradientNormalization(gradientNormalization.orElse(null)); - if (gradientNormalizationThreshold != null) - confBuilder.setGradientNormalizationThreshold(gradientNormalizationThreshold); - if (trainingWorkspaceMode != null) - confBuilder.trainingWorkspaceMode(trainingWorkspaceMode); - if (inferenceWorkspaceMode != null) - confBuilder.inferenceWorkspaceMode(inferenceWorkspaceMode); - return confBuilder; + /** + * L1 regularization coefficient for the bias parameters + */ + public Builder l1Bias(double l1Bias) { + NetworkUtils.removeInstances(regularizationBias, L1Regularization.class); + if (l1Bias > 0.0) { + regularizationBias.add(new L1Regularization(l1Bias)); + } else { + removeL1Bias = true; + } + return this; } - - public String toJson() { - try { - return NeuralNetConfiguration.mapper().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } + /** + * L2 regularization coefficient for the bias parameters
+ * Note: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double, boolean)} + * should be preferred to + * L2 regularization. See {@link WeightDecay} javadoc for further details.
+ */ + public Builder l2Bias(double l2Bias) { + NetworkUtils.removeInstances(regularizationBias, L2Regularization.class); + if (l2Bias > 0.0) { + NetworkUtils.removeInstancesWithWarning(regularizationBias, WeightDecay.class, + "WeightDecay bias regularization removed: incompatible with added L2 regularization"); + regularizationBias.add(new L2Regularization(l2Bias)); + } else { + removeL2Bias = true; + } + return this; } - public String toYaml() { - try { - return NeuralNetConfiguration.mapperYaml().writeValueAsString(this); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } + /** + * Add weight decay regularization for the network parameters (excluding biases).
This + * applies weight decay with multiplying the learning rate - see {@link WeightDecay} for + * more details.
+ * + * @param coefficient Weight decay regularization coefficient + * @see #weightDecay(double, boolean) + */ + public Builder weightDecay(double coefficient) { + return weightDecay(coefficient, true); } - public static FineTuneConfiguration fromJson(String json) { - try { - return NeuralNetConfiguration.mapper().readValue(json, FineTuneConfiguration.class); - } catch (IOException e) { - throw new RuntimeException(e); - } + /** + * Add weight decay regularization for the network parameters (excluding biases). See + * {@link WeightDecay} for more details.
+ * + * @param coefficient Weight decay regularization coefficient + * @param applyLR Whether the learning rate should be multiplied in when performing weight + * decay updates. See {@link WeightDecay} for more details. + * @see #weightDecay(double, boolean) + */ + public Builder weightDecay(double coefficient, boolean applyLR) { + //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both + NetworkUtils.removeInstances(this.regularization, WeightDecay.class); + if (coefficient > 0.0) { + NetworkUtils.removeInstancesWithWarning(this.regularization, L2Regularization.class, + "L2 regularization removed: incompatible with added WeightDecay regularization"); + this.regularization.add(new WeightDecay(coefficient, applyLR)); + } else { + removeWD = true; + } + return this; } - public static FineTuneConfiguration fromYaml(String yaml) { - try { - return NeuralNetConfiguration.mapperYaml().readValue(yaml, FineTuneConfiguration.class); - } catch (IOException e) { - throw new RuntimeException(e); - } + /** + * Weight decay for the biases only - see {@link #weightDecay(double)} for more details. This + * applies weight decay with multiplying the learning rate.
+ * + * @param coefficient Weight decay regularization coefficient + * @see #weightDecayBias(double, boolean) + */ + public Builder weightDecayBias(double coefficient) { + return weightDecayBias(coefficient, true); } + + /** + * Weight decay for the biases only - see {@link #weightDecay(double)} for more details
+ * + * @param coefficient Weight decay regularization coefficient + */ + public Builder weightDecayBias(double coefficient, boolean applyLR) { + //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both + NetworkUtils.removeInstances(this.regularizationBias, WeightDecay.class); + if (coefficient > 0) { + NetworkUtils.removeInstancesWithWarning(this.regularizationBias, L2Regularization.class, + "L2 bias regularization removed: incompatible with added WeightDecay regularization"); + this.regularizationBias.add(new WeightDecay(coefficient, applyLR)); + } else { + removeWDBias = true; + } + return this; + } + + /** + * Set the dropout + * + * @param dropout Dropout, such as {@link Dropout}, + * {@link org.deeplearning4j.nn.conf.dropout.GaussianDropout}, + * {@link org.deeplearning4j.nn.conf.dropout.GaussianNoise} etc + */ + public Builder dropout(IDropout dropout) { + this.dropout = Optional.ofNullable(dropout); + return this; + } + + /** + * Dropout probability. This is the probability of retaining each input activation + * value for a layer. dropOut(x) will keep an input activation with probability x, and set to 0 + * with probability 1-x.
dropOut(0.0) is a special value / special case - when set to 0.0., + * dropout is disabled (not applied). Note that a dropout value of 1.0 is functionally + * equivalent to no dropout: i.e., 100% probability of retaining each input activation.
+ *

+ * Note 1: Dropout is applied at training time only - and is automatically not applied at test + * time (for evaluation, etc)
Note 2: This sets the probability per-layer. Care should be + * taken when setting lower values for complex networks (too much information may be lost with + * aggressive (very low) dropout values).
Note 3: Frequently, dropout is not applied to (or, + * has higher retain probability for) input (first layer) layers. Dropout is also often not + * applied to output layers. This needs to be handled MANUALLY by the user - set .dropout(0) on + * those layers when using global dropout setting.
Note 4: Implementation detail (most users + * can ignore): DL4J uses inverted dropout, as described here: + * http://cs231n.github.io/neural-networks-2/ + *

+ * + * @param inputRetainProbability Dropout probability (probability of retaining each input + * activation value for a layer) + * @see #dropout(IDropout) + */ + public Builder dropOut(double inputRetainProbability) { + if (inputRetainProbability == 0.0) { + return dropout(null); + } + return dropout(new Dropout(inputRetainProbability)); + } + + /** + * Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and + * {@link org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) + * + * @param weightNoise Weight noise instance to use + */ + public Builder weightNoise(IWeightNoise weightNoise) { + this.weightNoise = Optional.ofNullable(weightNoise); + return this; + } + + /** + * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or + * {@link org.nd4j.linalg.learning.config.Nesterovs} + * + * @param updater Updater to use + */ + public Builder updater(IUpdater updater) { + this.updater = updater; + return this; + } + + /** + * @deprecated Use {@link #updater(IUpdater)} + */ + @Deprecated + public Builder updater(Updater updater) { + return updater(updater.getIUpdaterWithDefaultConfig()); + } + + /** + * Gradient updater configuration, for the biases only. If not set, biases will use the updater + * as set by {@link #updater(IUpdater)} + * + * @param biasUpdater Updater to use for bias parameters + */ + public Builder biasUpdater(IUpdater biasUpdater) { + this.biasUpdater = biasUpdater; + return this; + } + + /** + * Whether scores and gradients should be divided by the minibatch size.
Most users should + * leave this ast he default value of true. + */ + public Builder miniBatch(boolean miniBatch) { + this.miniBatch = miniBatch; + return this; + } + + public Builder maxNumLineSearchIterations(int maxNumLineSearchIterations) { + this.maxNumLineSearchIterations = maxNumLineSearchIterations; + return this; + } + + /** + * RNG seed for reproducibility + * + * @param seed RNG seed to use + */ + public Builder seed(long seed) { + this.seed = seed; + return this; + } + + /** + * RNG seed for reproducibility + * + * @param seed RNG seed to use + */ + public Builder seed(int seed) { + return seed((long) seed); + } + + public Builder optimizationAlgo(OptimizationAlgorithm optimizationAlgo) { + this.optimizationAlgo = optimizationAlgo; + return this; + } + + public Builder stepFunction(StepFunction stepFunction) { + this.stepFunction = stepFunction; + return this; + } + + public Builder minimize(boolean minimize) { + this.minimize = minimize; + return this; + } + + /** + * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping + * etc. See {@link GradientNormalization} for details + * + * @param gradientNormalization Type of normalization to use. Defaults to None. + * @see GradientNormalization + */ + public Builder gradientNormalization(GradientNormalization gradientNormalization) { + this.gradientNormalization = Optional.ofNullable(gradientNormalization); + return this; + } + + /** + * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, + * GradientNormalization.ClipL2PerParamType, and + * GradientNormalization.ClipElementWiseAbsoluteValue
Not used otherwise.
L2 threshold + * for first two types of clipping, or absolute value threshold for last type of clipping + */ + public Builder gradientNormalizationThreshold(double gradientNormalizationThreshold) { + this.gradientNormalizationThreshold = gradientNormalizationThreshold; + return this; + } + + /** + * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. + * See {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE
+ * + * @param convolutionMode Convolution mode to use + */ + public Builder convolutionMode(ConvolutionMode convolutionMode) { + this.convolutionMode = convolutionMode; + return this; + } + + /** + * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage + * of cuDNN. See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", + * but "NO_WORKSPACE" uses less memory. + */ + public Builder cudnnAlgoMode(ConvolutionLayer.AlgoMode cudnnAlgoMode) { + this.cudnnAlgoMode = cudnnAlgoMode; + return this; + } + + /** + * Set constraints to be applied to all layers. Default: no constraints.
Constraints can be + * used to enforce certain conditions (non-negativity of parameters, max-norm regularization, + * etc). These constraints are applied at each iteration, after the parameters have been + * updated. + * + * @param constraints Constraints to apply to all parameters of all layers + */ + public Builder constraints(List constraints) { + this.constraints = Optional.ofNullable(constraints); + return this; + } + + public Builder pretrain(boolean pretrain) { + this.pretrain = pretrain; + return this; + } + + public Builder backprop(boolean backprop) { + this.backprop = backprop; + return this; + } + + /** + * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but + * optionally truncated BPTT can be used for training recurrent neural networks. If using + * TruncatedBPTT make sure you set both tBPTTForwardLength() and tBPTTBackwardLength() + * + * @param backpropType Type of backprop. Default: BackpropType.Standard + */ + public Builder backpropType(BackpropType backpropType) { + this.backpropType = backpropType; + return this; + } + + /** + * When doing truncated BPTT: how many steps of forward pass should we do before doing + * (truncated) backprop?
Only applicable when doing + * backpropType(BackpropType.TruncatedBPTT)
Typically tBPTTForwardLength parameter is same + * as the tBPTTBackwardLength parameter, but may be larger than it in some circumstances (but + * never smaller)
Ideally your training data time series length should be divisible by this + * This is the k1 parameter on pg23 of + * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param tbpttFwdLength Forward length > 0, >= backwardLength + */ + public Builder tbpttFwdLength(int tbpttFwdLength) { + this.tbpttFwdLength = tbpttFwdLength; + return this; + } + + /** + * When doing truncated BPTT: how many steps of backward should we do?
Only applicable when + * doing backpropType(BackpropType.TruncatedBPTT)
This is the k2 parameter on pg23 of + * http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf + * + * @param tbpttBackLength <= forwardLength + */ + public Builder tbpttBackLength(int tbpttBackLength) { + this.tbpttBackLength = tbpttBackLength; + return this; + } + + /** + * This method defines Workspace mode being used during training: NONE: workspace won't be used + * ENABLED: workspaces will be used for training (reduced memory and better performance) + * + * @param trainingWorkspaceMode Workspace mode for training + * @return Builder + */ + public Builder trainingWorkspaceMode(WorkspaceMode trainingWorkspaceMode) { + this.trainingWorkspaceMode = trainingWorkspaceMode; + return this; + } + + /** + * This method defines Workspace mode being used during inference:
NONE: workspace won't be + * used
ENABLED: workspaces will be used for inference (reduced memory and better + * performance) + * + * @param inferenceWorkspaceMode Workspace mode for inference + * @return Builder + */ + public Builder inferenceWorkspaceMode(WorkspaceMode inferenceWorkspaceMode) { + this.inferenceWorkspaceMode = inferenceWorkspaceMode; + return this; + } + + public FineTuneConfiguration build() { + return new FineTuneConfiguration(activation, weightInitFn, biasInit, regularization, + regularizationBias, + removeL2, removeL2Bias, removeL1, removeL1Bias, removeWD, removeWDBias, dropout, + weightNoise, updater, biasUpdater, miniBatch, maxNumLineSearchIterations, seed, + optimizationAlgo, stepFunction, + minimize, gradientNormalization, gradientNormalizationThreshold, convolutionMode, + cudnnAlgoMode, constraints, + pretrain, backprop, backpropType, tbpttFwdLength, tbpttBackLength, trainingWorkspaceMode, + inferenceWorkspaceMode); + } + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java index b941cf636..8cc50854b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java @@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.VertexIndices; import org.deeplearning4j.nn.graph.vertex.impl.FrozenVertex; @@ -51,7 +51,7 @@ import java.util.*; public class TransferLearning { public static class Builder { - private final MultiLayerConfiguration origConf; + private final NeuralNetConfiguration origConf; private final MultiLayerNetwork origModel; private MultiLayerNetwork editedModel; @@ -64,9 +64,9 @@ public class TransferLearning { new HashMap<>(); private final Map> nInEditedMap = new HashMap<>(); private final List editedParams = new ArrayList<>(); - private final List editedConfs = new ArrayList<>(); + private final List editedConfs = new ArrayList<>(); private final List appendParams = new ArrayList<>(); //these could be new arrays, and views from origParams - private final List appendConfs = new ArrayList<>(); + private final List appendConfs = new ArrayList<>(); private Map inputPreProcessors = new HashMap<>(); @@ -80,8 +80,8 @@ public class TransferLearning { */ public Builder(MultiLayerNetwork origModel) { this.origModel = origModel; - this.origConf = origModel.getLayerWiseConfigurations().clone(); - this.dataType = origModel.getLayerWiseConfigurations().getDataType(); + this.origConf = origModel.getNetConfiguration().clone(); + this.dataType = origModel.getNetConfiguration().getDataType(); this.inputPreProcessors = origConf.getInputPreProcessors(); } @@ -299,31 +299,31 @@ public class TransferLearning { * At the very least an outputLayer must be added (output layer should be added last - as per the note on order) * Learning configs (like updaters, learning rate etc) specified with the layer here will be honored * - * @param layer layer conf to add (similar to the NeuralNetConfiguration .list().layer(...) + * @param layerConf layer conf to add (similar to the NeuralNetConfiguration .list().layer(...) * @return Builder */ - public Builder addLayer(Layer layer) { + public Builder addLayer(LayerConfiguration layerConf) { if (!prepDone) { doPrep(); } - // Use the fineTune config to create the required NeuralNetConfiguration + Layer instances + // Use the fineTune config to create the required NeuralNetConfiguration + LayerConfiguration instances //instantiate dummy layer to get the params //Build a nn config builder with settings from finetune. Set layer with the added layer //Issue: fine tune config has .learningRate(x), then I add a layer with .learningRate(y)... //We don't want that to be overridden - NeuralNetConfiguration layerConf = - finetuneConfiguration.appliedNeuralNetConfigurationBuilder().layer(layer).build(); + NeuralNetConfiguration netConf = + finetuneConfiguration.appliedNeuralNetConfigurationBuilder(); - val numParams = layer.initializer().numParams(layerConf); + val numParams = layerConf.initializer().numParams(layerConf); INDArray params; if (numParams > 0) { - params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), 1, numParams); - org.deeplearning4j.nn.api.Layer someLayer = layer.instantiate(layerConf, null, 0, params, true, dataType); + params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); + org.deeplearning4j.nn.api.Layer someLayer = layerConf.instantiate(layerConf.getNetConfiguration(), null, 0, params, true, dataType); appendParams.add(someLayer.params()); - appendConfs.add(someLayer.conf()); + appendConfs.add(someLayer.getLayerConfiguration()); } else { appendConfs.add(layerConf); @@ -364,27 +364,27 @@ public class TransferLearning { if (frozenTill != -1) { org.deeplearning4j.nn.api.Layer[] layers = editedModel.getLayers(); for (int i = frozenTill; i >= 0; i--) { - //Complication here: inner Layer (implementation) NeuralNetConfiguration.layer (config) should keep + //Complication here: inner LayerConfiguration (implementation) NeuralNetConfiguration.layer (config) should keep // the original layer config. While network NNC should have the frozen layer, for to/from JSON etc - NeuralNetConfiguration origNNC = editedModel.getLayerWiseConfigurations().getConf(i); - NeuralNetConfiguration layerNNC = origNNC.clone(); - layers[i].setConf(layerNNC); + LayerConfiguration origNNC = editedModel.getNetConfiguration().getFlattenedLayerConfigurations().get(i); + LayerConfiguration layerNNC = origNNC.clone(); + layers[i].setLayerConfiguration(layerNNC); layers[i] = new FrozenLayer(layers[i]); if (origNNC.getVariables() != null) { - List vars = origNNC.variables(true); + List vars = origNNC.getVariables(); origNNC.clearVariables(); layerNNC.clearVariables(); for (String s : vars) { - origNNC.variables(false).add(s); - layerNNC.variables(false).add(s); + origNNC.addVariable(s); + layerNNC.addVariable(s); } } - Layer origLayerConf = editedModel.getLayerWiseConfigurations().getConf(i).getLayer(); - Layer newLayerConf = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(origLayerConf); + LayerConfiguration origLayerConf = editedModel.getNetConfiguration().getFlattenedLayerConfigurations().get(i); + LayerConfiguration newLayerConf = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(origLayerConf); newLayerConf.setLayerName(origLayerConf.getLayerName()); - editedModel.getLayerWiseConfigurations().getConf(i).setLayer(newLayerConf); + editedModel.getNetConfiguration().getNetConfigurations().get(i).setLayer(newLayerConf); } editedModel.setLayers(layers); } @@ -441,15 +441,14 @@ public class TransferLearning { private void fineTuneConfigurationBuild() { - - for (int i = 0; i < origConf.getConfs().size(); i++) { - NeuralNetConfiguration layerConf; + for (int i = 0; i < origConf.getFlattenedLayerConfigurations().size(); i++) { + LayerConfiguration layerConf; if (finetuneConfiguration != null) { - NeuralNetConfiguration nnc = origConf.getConf(i).clone(); - finetuneConfiguration.applyToNeuralNetConfiguration(nnc); + LayerConfiguration nnc = origConf.getFlattenedLayerConfigurations().get(i).clone(); + finetuneConfiguration.applyToLayerConfiguration(nnc); layerConf = nnc; } else { - layerConf = origConf.getConf(i).clone(); + layerConf = origConf.getFlattenedLayerConfigurations().get(i).clone(); } editedConfs.add(layerConf); } @@ -458,16 +457,16 @@ public class TransferLearning { private void nInReplaceBuild(int layerNum, int nIn, IWeightInit init) { Preconditions.checkArgument(layerNum >= 0 && layerNum < editedConfs.size(), "Invalid layer index: must be 0 to " + "numLayers-1 = %s includive, got %s", editedConfs.size(), layerNum); - NeuralNetConfiguration layerConf = editedConfs.get(layerNum); - Layer layerImpl = layerConf.getLayer(); //not a clone need to modify nOut in place + LayerConfiguration layerConf = editedConfs.get(layerNum); + LayerConfiguration layerImpl = layerConf; //not a clone need to modify nOut in place Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nInReplace can only be applide on FeedForward layers;" + "got layer of type %s", layerImpl.getClass().getSimpleName()); FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; layerImplF.setWeightInitFn(init); layerImplF.setNIn(nIn); long numParams = layerImpl.initializer().numParams(layerConf); - INDArray params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), 1, numParams); - org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType); + INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); + org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf.getNetConfiguration(), null, 0, params, true, dataType); editedParams.set(layerNum, someLayer.params()); } @@ -476,29 +475,29 @@ public class TransferLearning { Preconditions.checkArgument(layerNum >= 0 && layerNum < editedConfs.size(), "Invalid layer index: must be 0 to " + "numLayers-1 = %s includive, got %s", editedConfs.size(), layerNum); - NeuralNetConfiguration layerConf = editedConfs.get(layerNum); - Layer layerImpl = layerConf.getLayer(); //not a clone need to modify nOut in place + LayerConfiguration layerConf = editedConfs.get(layerNum); + LayerConfiguration layerImpl = layerConf; //not a clone need to modify nOut in place Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nOutReplace can only be applide on FeedForward layers;" + "got layer of type %s", layerImpl.getClass().getSimpleName()); FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; layerImplF.setWeightInitFn(scheme); layerImplF.setNOut(nOut); long numParams = layerImpl.initializer().numParams(layerConf); - INDArray params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), 1, numParams); - org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType); + INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); + org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf.getNetConfiguration(), null, 0, params, true, dataType); editedParams.set(layerNum, someLayer.params()); if (layerNum + 1 < editedConfs.size()) { layerConf = editedConfs.get(layerNum + 1); - layerImpl = layerConf.getLayer(); //modify in place + layerImpl = layerConf; //modify in place if(layerImpl instanceof FeedForwardLayer) { layerImplF = (FeedForwardLayer) layerImpl; layerImplF.setWeightInitFn(schemeNext); layerImplF.setNIn(nOut); numParams = layerImpl.initializer().numParams(layerConf); if (numParams > 0) { - params = Nd4j.create(origModel.getLayerWiseConfigurations().getDataType(), 1, numParams); - someLayer = layerImpl.instantiate(layerConf, null, 0, params, true, dataType); + params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); + someLayer = layerImpl.instantiate(layerConf.getNetConfiguration(), null, 0, params, true, dataType); editedParams.set(layerNum + 1, someLayer.params()); } } @@ -526,27 +525,27 @@ public class TransferLearning { } } - private MultiLayerConfiguration constructConf() { + private NeuralNetConfiguration constructConf() { //use the editedConfs list to make a new config - List allConfs = new ArrayList<>(); + List allConfs = new ArrayList<>(); allConfs.addAll(editedConfs); allConfs.addAll(appendConfs); //Set default layer names, if not set - as per NeuralNetConfiguration.ListBuilder.build() for (int i = 0; i < allConfs.size(); i++) { - if (allConfs.get(i).getLayer().getLayerName() == null) { - allConfs.get(i).getLayer().setLayerName("layer" + i); + if (allConfs.get(i).getLayerName() == null) { + allConfs.get(i).setLayerName("layer" + i); } } - MultiLayerConfiguration conf = new MultiLayerConfiguration.Builder().inputPreProcessors(inputPreProcessors) - .setInputType(this.inputType).confs(allConfs) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().inputPreProcessors(inputPreProcessors) + .inputType(this.inputType) + .layersFromList(allConfs) + .validateOutputLayerConfig(validateOutputLayerConfig == null || validateOutputLayerConfig) .dataType(origConf.getDataType()) .build(); - if (finetuneConfiguration != null) { - finetuneConfiguration.applyToMultiLayerConfiguration(conf); - } + return conf; } } @@ -590,10 +589,10 @@ public class TransferLearning { for (Map.Entry gv : vertices.entrySet()) { if (gv.getValue() instanceof LayerVertex) { LayerVertex lv = (LayerVertex) gv.getValue(); - NeuralNetConfiguration nnc = lv.getLayerConf().clone(); - fineTuneConfiguration.applyToNeuralNetConfiguration(nnc); + NeuralNetConfiguration nnc = lv.getNetConfiguration().clone(); + fineTuneConfiguration.applyToLayerConfiguration(lv.getLayerConfiguration()); vertices.put(gv.getKey(), new LayerVertex(nnc, lv.getPreProcessor())); - nnc.getLayer().setLayerName(gv.getKey()); + lv.getLayerConfiguration().setLayerName(gv.getKey()); } } @@ -725,14 +724,14 @@ public class TransferLearning { * @return GraphBuilder */ public GraphBuilder nInReplace(String layerName, int nIn, IWeightInit scheme) { - Preconditions.checkState(origGraph.getVertex(layerName) != null, "Layer with name %s not found", + Preconditions.checkState(origGraph.getVertex(layerName) != null, "LayerConfiguration with name %s not found", layerName); Preconditions.checkState(origGraph.getVertex(layerName).hasLayer(), "nInReplace can only be applied" + " on vertices with layers. Vertex %s does not have a layer", layerName); initBuilderIfReq(); - NeuralNetConfiguration layerConf = origGraph.getLayer(layerName).conf(); - Layer layerImpl = layerConf.getLayer().clone(); + LayerConfiguration layerConf = origGraph.getLayer(layerName).getLayerConfiguration(); + LayerConfiguration layerImpl = layerConf.clone(); Preconditions.checkState(layerImpl instanceof FeedForwardLayer, "Can only use nInReplace on FeedForward layers;" + "got layer of type %s for layer name %s", layerImpl.getClass().getSimpleName(), layerName); @@ -744,7 +743,7 @@ public class TransferLearning { if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex && nInFromNewConfig.containsKey(layerName)){ - Layer l = ((LayerVertex)editedConfigBuilder.getVertices().get(layerName)).getLayerConf().getLayer(); + LayerConfiguration l = ((LayerVertex)editedConfigBuilder.getVertices().get(layerName)).getLayerConfiguration(); if(l instanceof FeedForwardLayer){ layerImplF.setNIn(nInFromNewConfig.get(layerName)); } @@ -764,8 +763,8 @@ public class TransferLearning { if (origGraph.getVertex(layerName).hasLayer()) { - NeuralNetConfiguration layerConf = origGraph.getLayer(layerName).conf(); - Layer layerImpl = layerConf.getLayer().clone(); + LayerConfiguration layerConf = origGraph.getLayer(layerName).getLayerConfiguration(); + LayerConfiguration layerImpl = layerConf.clone(); layerImpl.resetLayerDefaultConfig(); FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; layerImplF.setWeightInitFn(scheme); @@ -773,7 +772,7 @@ public class TransferLearning { if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex && nInFromNewConfig.containsKey(layerName)){ - Layer l = ((LayerVertex)editedConfigBuilder.getVertices().get(layerName)).getLayerConf().getLayer(); + LayerConfiguration l = ((LayerVertex)editedConfigBuilder.getVertices().get(layerName)).getLayerConfiguration(); if(l instanceof FeedForwardLayer){ layerImplF.setNIn(nInFromNewConfig.get(layerName)); } @@ -802,10 +801,10 @@ public class TransferLearning { throw new UnsupportedOperationException( "Cannot modify nOut of a layer vertex that feeds non-layer vertices. Use removeVertexKeepConnections followed by addVertex instead"); } - layerConf = origGraph.getLayer(fanoutVertexName).conf(); - if(!(layerConf.getLayer() instanceof FeedForwardLayer)) + layerConf = origGraph.getLayer(fanoutVertexName).getLayerConfiguration(); + if(!(layerConf instanceof FeedForwardLayer)) continue; - layerImpl = layerConf.getLayer().clone(); + layerImpl = layerConf.clone(); layerImplF = (FeedForwardLayer) layerImpl; layerImplF.setWeightInitFn(schemeNext); layerImplF.setNIn(nOut); @@ -859,7 +858,7 @@ public class TransferLearning { * @param layerInputs * @return */ - public GraphBuilder addLayer(String layerName, Layer layer, String... layerInputs) { + public GraphBuilder addLayer(String layerName, LayerConfiguration layer, String... layerInputs) { initBuilderIfReq(); editedConfigBuilder.addLayer(layerName, layer, null, layerInputs); editedVertices.add(layerName); @@ -874,7 +873,7 @@ public class TransferLearning { * @param layerInputs * @return */ - public GraphBuilder addLayer(String layerName, Layer layer, InputPreProcessor preProcessor, + public GraphBuilder addLayer(String layerName, LayerConfiguration layer, InputPreProcessor preProcessor, String... layerInputs) { initBuilderIfReq(); editedConfigBuilder.addLayer(layerName, layer, preProcessor, layerInputs); @@ -1009,24 +1008,24 @@ public class TransferLearning { String layerName = gv.getVertexName(); LayerVertex currLayerVertex = (LayerVertex) newConfig.getVertices().get(layerName); - Layer origLayerConf = currLayerVertex.getLayerConf().getLayer(); - Layer newLayerConf = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(origLayerConf); + LayerConfiguration origLayerConf = currLayerVertex.getLayerConfiguration(); + LayerConfiguration newLayerConf = new org.deeplearning4j.nn.conf.layers.misc.FrozenLayer(origLayerConf); newLayerConf.setLayerName(origLayerConf.getLayerName()); - //Complication here(and reason for clone on next line): inner Layer (implementation) + //Complication here(and reason for clone on next line): inner LayerConfiguration (implementation) // NeuralNetConfiguration.layer (config) should keep the original layer config. While network // NNC should have the frozen layer - NeuralNetConfiguration newNNC = currLayerVertex.getLayerConf().clone(); - currLayerVertex.setLayerConf(newNNC); - currLayerVertex.getLayerConf().setLayer(newLayerConf); + NeuralNetConfiguration newNNC = currLayerVertex.getNetConfiguration().clone(); + currLayerVertex.setNetConfiguration(newNNC); + currLayerVertex.getNetConfiguration().setLayer(newLayerConf); //Make sure the underlying layer doesn't change: - List vars = currLayerVertex.getLayerConf().variables(true); - currLayerVertex.getLayerConf().clearVariables(); + List vars = currLayerVertex.getNetConfiguration().netWideVariables(true); + currLayerVertex.getNetConfiguration().clearNetWideVariable(); for (String s : vars) { - newNNC.variables(false).add(s); + newNNC.netWideVariables(false).add(s); } - //We also need to place the layer in the CompGraph Layer[] (replacing the old one) + //We also need to place the layer in the CompGraph LayerConfiguration[] (replacing the old one) //This could no doubt be done more efficiently org.deeplearning4j.nn.api.Layer[] layers = newGraph.getLayers(); for (int j = 0; j < layers.length; j++) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java index a6f7d6c4f..effc48ad4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java @@ -21,7 +21,6 @@ package org.deeplearning4j.nn.transferlearning; import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; @@ -179,7 +178,7 @@ public class TransferLearningHelper { org.deeplearning4j.nn.api.Layer l = gv.getLayer(); gv.setLayerAsFrozen(); - //We also need to place the layer in the CompGraph Layer[] (replacing the old one) + //We also need to place the layer in the CompGraph LayerConfiguration[] (replacing the old one) //This could no doubt be done more efficiently org.deeplearning4j.nn.api.Layer[] layers = origGraph.getLayers(); for (int j = 0; j < layers.length; j++) { @@ -282,16 +281,16 @@ public class TransferLearningHelper { } List allConfs = new ArrayList<>(); for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) { - allConfs.add(origMLN.getLayer(i).conf()); + allConfs.add(origMLN.getLayer(i).getNetConfiguration()); } - MultiLayerConfiguration c = origMLN.getLayerWiseConfigurations(); + NeuralNetConfiguration c = origMLN.getNetConfiguration(); - unFrozenSubsetMLN = new MultiLayerNetwork(new MultiLayerConfiguration.Builder() + unFrozenSubsetMLN = new MultiLayerNetwork(NeuralNetConfiguration.builder() .inputPreProcessors(c.getInputPreProcessors()) - .backpropType(c.getBackpropType()).tBPTTForwardLength(c.getTbpttFwdLength()) - .tBPTTBackwardLength(c.getTbpttBackLength()).confs(allConfs) - .dataType(origMLN.getLayerWiseConfigurations().getDataType()) + .backpropType(c.getBackpropType()).tbpttFwdLength(c.getTbpttFwdLength()) + .tbpttBackLength(c.getTbpttBackLength()).confs(allConfs) + .dataType(origMLN.getNetConfiguration().getDataType()) .build()); unFrozenSubsetMLN.init(); //copy over params diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java index 91d24de46..dfcef372c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java @@ -21,7 +21,7 @@ package org.deeplearning4j.nn.updater; import lombok.Getter; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.Trainable; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.GradientNormalization; @@ -44,7 +44,7 @@ import org.nd4j.linalg.learning.config.IUpdater; import java.util.*; @Getter -public abstract class BaseMultiLayerUpdater implements Updater { +public abstract class BaseMultiLayerUpdater implements Updater { protected final T network; protected Map layersByName; @@ -81,7 +81,7 @@ public abstract class BaseMultiLayerUpdater implements Updater int paramsViewSoFar = 0; int currentUpdaterOffset = 0; for (int i = 0; i < layers.length; i++) { - Map layerParamTable = layers[i].paramTable(false); + Map layerParamTable = layers[i].getParamTable(false); if (layerParamTable != null) { List variables = new ArrayList<>(layerParamTable.keySet()); //Is from a set, but iteration order should be fixed per layer as it's a from a LinkedHashSet for (int j = 0; j < variables.size(); j++) { @@ -351,8 +351,8 @@ public abstract class BaseMultiLayerUpdater implements Updater long currentStart = 0; long currentEnd = 0; for(Trainable t : getOrderedLayers()){ - Set layerParams = t.paramTable(false).keySet(); - Map paramTable = t.paramTable(false); + Set layerParams = t.getParamTable(false).keySet(); + Map paramTable = t.getParamTable(false); for(String s : layerParams) { if(t.updaterDivideByMinibatch(s)){ long l = paramTable.get(s).length(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java index dea50edd9..f27e7dcfa 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java @@ -42,7 +42,7 @@ public class LayerUpdater extends BaseMultiLayerUpdater { } layersByName = new HashMap<>(); - layersByName.put(layer.conf().getLayer().getLayerName(), layer); + layersByName.put(layer.getLayerConfiguration().getLayerName(), layer); } @Override @@ -62,7 +62,7 @@ public class LayerUpdater extends BaseMultiLayerUpdater { @Override protected boolean isMiniBatch() { - return network.conf().isMiniBatch(); + return network.getNetConfiguration().isMiniBatch(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/MultiLayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/MultiLayerUpdater.java index 58f64f66f..f43aa85d2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/MultiLayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/MultiLayerUpdater.java @@ -71,7 +71,7 @@ public class MultiLayerUpdater extends BaseMultiLayerUpdater @Override protected boolean isMiniBatch() { - return network.conf().isMiniBatch(); + return network.getNetConfiguration().isMiniBatch(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterCreator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterCreator.java index 3194de852..14850eafb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterCreator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterCreator.java @@ -21,7 +21,7 @@ package org.deeplearning4j.nn.updater; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater; @@ -35,7 +35,7 @@ public class UpdaterCreator { private UpdaterCreator() {} - public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) { + public static org.deeplearning4j.nn.api.Updater getUpdater(IModel layer) { if (layer instanceof MultiLayerNetwork) { return new MultiLayerUpdater((MultiLayerNetwork) layer); } else if (layer instanceof ComputationGraph) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.java index 6af2901d6..1c39f52e1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.java @@ -20,7 +20,6 @@ package org.deeplearning4j.nn.updater.graph; -import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Trainable; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; @@ -90,8 +89,10 @@ public class ComputationGraphUpdater extends BaseMultiLayerUpdater listeners; - private Model model; + private IModel model; private ConvexOptimizer optimizer; private StepFunction stepFunction; @@ -90,7 +90,7 @@ public class Solver { public static class Builder { private NeuralNetConfiguration conf; - private Model model; + private IModel model; private final List listeners = new ArrayList<>(); public Builder configure(NeuralNetConfiguration conf) { @@ -112,7 +112,7 @@ public class Solver { return this; } - public Builder model(Model model) { + public Builder model(IModel model) { this.model = model; return this; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/BaseTrainingListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/BaseTrainingListener.java index c7d755187..d72b836f5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/BaseTrainingListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/BaseTrainingListener.java @@ -20,7 +20,7 @@ package org.deeplearning4j.optimize.api; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.List; @@ -29,43 +29,43 @@ import java.util.Map; public abstract class BaseTrainingListener implements TrainingListener { @Override - public void onEpochStart(Model model) { + public void onEpochStart(IModel model) { //No op } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { //No op } @Override - public void onForwardPass(Model model, List activations) { + public void onForwardPass(IModel model, List activations) { //No op } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { //No op } @Override - public void onGradientCalculation(Model model) { + public void onGradientCalculation(IModel model) { //No op } @Override - public void onBackwardPass(Model model) { + public void onBackwardPass(IModel model) { //No op } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { //No op } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/ConvexOptimizer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/ConvexOptimizer.java index c32a0fac3..0d6999fce 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/ConvexOptimizer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/ConvexOptimizer.java @@ -20,7 +20,7 @@ package org.deeplearning4j.optimize.api; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; @@ -128,6 +128,6 @@ public interface ConvexOptimizer extends Serializable { * @param batchSize batchSize for update * @paramType paramType to update */ - void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize, LayerWorkspaceMgr workspaceMgr); + void updateGradientAccordingToParams(Gradient gradient, IModel model, int batchSize, LayerWorkspaceMgr workspaceMgr); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/IterationListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/IterationListener.java index 085d734b1..309f478fe 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/IterationListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/IterationListener.java @@ -21,7 +21,7 @@ package org.deeplearning4j.optimize.api; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import java.io.Serializable; @@ -33,6 +33,6 @@ public abstract class IterationListener extends BaseTrainingListener implements * @param iteration the iteration * @param model the model iterating */ - public abstract void iterationDone(Model model, int iteration, int epoch); + public abstract void iterationDone(IModel model, int iteration, int epoch); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/TrainingListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/TrainingListener.java index 81a2d8465..20fe978dc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/TrainingListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/api/TrainingListener.java @@ -20,7 +20,7 @@ package org.deeplearning4j.optimize.api; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; @@ -35,19 +35,19 @@ public interface TrainingListener { * @param iteration the iteration * @param model the model iterating */ - void iterationDone(Model model, int iteration, int epoch); + void iterationDone(IModel model, int iteration, int epoch); /** * Called once at the start of each epoch, when using methods such as {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork#fit(DataSetIterator)}, * {@link org.deeplearning4j.nn.graph.ComputationGraph#fit(DataSetIterator)} or {@link org.deeplearning4j.nn.graph.ComputationGraph#fit(MultiDataSetIterator)} */ - void onEpochStart(Model model); + void onEpochStart(IModel model); /** * Called once at the end of each epoch, when using methods such as {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork#fit(DataSetIterator)}, * {@link org.deeplearning4j.nn.graph.ComputationGraph#fit(DataSetIterator)} or {@link org.deeplearning4j.nn.graph.ComputationGraph#fit(MultiDataSetIterator)} */ - void onEpochEnd(Model model); + void onEpochEnd(IModel model); /** * Called once per iteration (forward pass) for activations (usually for a {@link org.deeplearning4j.nn.multilayer.MultiLayerNetwork}), @@ -56,7 +56,7 @@ public interface TrainingListener { * @param model Model * @param activations ILayer activations (including input) */ - void onForwardPass(Model model, List activations); + void onForwardPass(IModel model, List activations); /** * Called once per iteration (forward pass) for activations (usually for a {@link org.deeplearning4j.nn.graph.ComputationGraph}), @@ -65,30 +65,30 @@ public interface TrainingListener { * @param model Model * @param activations ILayer activations (including input) */ - void onForwardPass(Model model, Map activations); + void onForwardPass(IModel model, Map activations); /** * Called once per iteration (backward pass) before the gradients are updated - * Gradients are available via {@link Model#gradient()}. + * Gradients are available via {@link IModel#gradient()}. * Note that gradients will likely be updated in-place - thus they should be copied or processed synchronously * in this method. *

- * For updates (gradients post learning rate/momentum/rmsprop etc) see {@link #onBackwardPass(Model)} + * For updates (gradients post learning rate/momentum/rmsprop etc) see {@link #onBackwardPass(IModel)} * * @param model Model */ - void onGradientCalculation(Model model); + void onGradientCalculation(IModel model); /** * Called once per iteration (backward pass) after gradients have been calculated, and updated - * Gradients are available via {@link Model#gradient()}. + * Gradients are available via {@link IModel#gradient()}. *

- * Unlike {@link #onGradientCalculation(Model)} the gradients at this point will be post-update, rather than + * Unlike {@link #onGradientCalculation(IModel)} the gradients at this point will be post-update, rather than * raw (pre-update) gradients at that method call. * * @param model Model */ - void onBackwardPass(Model model); + void onBackwardPass(IModel model); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java index 550e4425b..120099da4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CheckpointListener.java @@ -23,8 +23,8 @@ package org.deeplearning4j.optimize.listeners; import com.google.common.io.Files; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.IOUtils; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.BaseTrainingListener; @@ -109,7 +109,7 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { int epochsDone = getEpoch(model) + 1; if(saveEveryNEpochs != null && epochsDone > 0 && epochsDone % saveEveryNEpochs == 0){ //Save: @@ -119,7 +119,7 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { if (startTime < 0) { startTime = System.currentTimeMillis(); startIter = iteration; @@ -164,7 +164,7 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ } } - private void saveCheckpoint(Model model) { + private void saveCheckpoint(IModel model) { try{ saveCheckpointHelper(model); } catch (Exception e){ @@ -172,7 +172,7 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ } } - private void saveCheckpointHelper(Model model) throws Exception { + private void saveCheckpointHelper(IModel model) throws Exception { if(!checkpointRecordFile.exists()){ checkpointRecordFile.createNewFile(); write(Checkpoint.getFileHeader() + "\n", checkpointRecordFile); @@ -243,27 +243,27 @@ public class CheckpointListener extends BaseTrainingListener implements Serializ return str; } - protected static int getIter(Model model) { + protected static int getIter(IModel model) { if (model instanceof MultiLayerNetwork) { - return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getIterationCount(); + return ((MultiLayerNetwork) model).getNetConfiguration().getIterationCount(); } else if (model instanceof ComputationGraph) { return ((ComputationGraph) model).getComputationGraphConfiguration().getIterationCount(); } else { - return model.conf().getIterationCount(); + return model.getNetConfiguration().getIterationCount(); } } - protected static int getEpoch(Model model) { + protected static int getEpoch(IModel model) { if (model instanceof MultiLayerNetwork) { - return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount(); + return ((MultiLayerNetwork) model).getNetConfiguration().getEpochCount(); } else if (model instanceof ComputationGraph) { return ((ComputationGraph) model).getComputationGraphConfiguration().getEpochCount(); } else { - return model.conf().getEpochCount(); + return model.getNetConfiguration().getEpochCount(); } } - protected static String getModelType(Model model){ + protected static String getModelType(IModel model){ if(model.getClass() == MultiLayerNetwork.class){ return "MultiLayerNetwork"; } else if(model.getClass() == ComputationGraph.class){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java index 51f798e26..0692387cf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresIterationListener.java @@ -20,7 +20,7 @@ package org.deeplearning4j.optimize.listeners; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.api.BaseTrainingListener; import java.io.File; @@ -132,7 +132,7 @@ public class CollectScoresIterationListener extends BaseTrainingListener { } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { if (++iterationCount % frequency == 0) { double score = model.score(); scoreVsIter.reallocateGuard(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresListener.java index 4f6d17b3c..558b3eb92 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/CollectScoresListener.java @@ -25,7 +25,7 @@ import it.unimi.dsi.fastutil.ints.IntArrayList; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.api.BaseTrainingListener; @@ -53,7 +53,7 @@ public class CollectScoresListener extends BaseTrainingListener implements Seria } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { if(iteration % frequency == 0){ double score = model.score(); listIteration.add(iteration); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ComposableIterationListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ComposableIterationListener.java index 3b82fc6b2..4b67fcede 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ComposableIterationListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ComposableIterationListener.java @@ -20,7 +20,7 @@ package org.deeplearning4j.optimize.listeners; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.deeplearning4j.optimize.api.TrainingListener; @@ -42,7 +42,7 @@ public class ComposableIterationListener extends BaseTrainingListener implements } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { for (TrainingListener listener : listeners) listener.iterationDone(model, iteration, epoch); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/EvaluativeListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/EvaluativeListener.java index a05d14a87..f98dd0aad 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/EvaluativeListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/EvaluativeListener.java @@ -24,8 +24,8 @@ import lombok.Getter; import lombok.NonNull; import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.exception.DL4JInvalidInputException; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.BaseTrainingListener; @@ -39,8 +39,6 @@ import org.nd4j.linalg.dataset.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import java.util.List; -import java.util.Map; import java.util.concurrent.atomic.AtomicLong; @Slf4j @@ -193,24 +191,24 @@ public class EvaluativeListener extends BaseTrainingListener { * @param iteration the iteration */ @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { if (invocationType == InvocationType.ITERATION_END) invokeListener(model); } @Override - public void onEpochStart(Model model) { + public void onEpochStart(IModel model) { if (invocationType == InvocationType.EPOCH_START) invokeListener(model); } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { if (invocationType == InvocationType.EPOCH_END) invokeListener(model); } - protected void invokeListener(Model model) { + protected void invokeListener(IModel model) { if (iterationCount.get() == null) iterationCount.set(new AtomicLong(0)); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/FailureTestingListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/FailureTestingListener.java index c05626511..d6ac11b41 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/FailureTestingListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/FailureTestingListener.java @@ -25,7 +25,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.TrainingListener; @@ -51,41 +51,41 @@ public class FailureTestingListener implements TrainingListener, Serializable { } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { call(CallType.ITER_DONE, model); } @Override - public void onEpochStart(Model model) { + public void onEpochStart(IModel model) { call(CallType.EPOCH_START, model); } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { call(CallType.EPOCH_END, model); } @Override - public void onForwardPass(Model model, List activations) { + public void onForwardPass(IModel model, List activations) { call(CallType.FORWARD_PASS, model); } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { call(CallType.FORWARD_PASS, model); } @Override - public void onGradientCalculation(Model model) { + public void onGradientCalculation(IModel model) { call(CallType.GRADIENT_CALC, model); } @Override - public void onBackwardPass(Model model) { + public void onBackwardPass(IModel model) { call(CallType.BACKWARD_PASS, model); } - protected void call(CallType callType, Model model){ + protected void call(CallType callType, IModel model){ if(!trigger.initialized()){ trigger.initialize(); } @@ -149,7 +149,7 @@ public class FailureTestingListener implements TrainingListener, Serializable { * @param model Model * @return */ - public abstract boolean triggerFailure(CallType callType, int iteration, int epoch, Model model); + public abstract boolean triggerFailure(CallType callType, int iteration, int epoch, IModel model); public boolean initialized(){ return initialized; @@ -170,7 +170,7 @@ public class FailureTestingListener implements TrainingListener, Serializable { } @Override - public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { + public boolean triggerFailure(CallType callType, int iteration, int epoch, IModel model) { boolean b = true; for(FailureTrigger ft : triggers) b &= ft.triggerFailure(callType, iteration, epoch, model); @@ -191,7 +191,7 @@ public class FailureTestingListener implements TrainingListener, Serializable { } @Override - public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { + public boolean triggerFailure(CallType callType, int iteration, int epoch, IModel model) { boolean b = false; for(FailureTrigger ft : triggers) b |= ft.triggerFailure(callType, iteration, epoch, model); @@ -213,7 +213,7 @@ public class FailureTestingListener implements TrainingListener, Serializable { } @Override - public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { + public boolean triggerFailure(CallType callType, int iteration, int epoch, IModel model) { return (this.callType == CallType.ANY || callType == this.callType) && rng.nextDouble() < probability; } @@ -237,7 +237,7 @@ public class FailureTestingListener implements TrainingListener, Serializable { } @Override - public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { + public boolean triggerFailure(CallType callType, int iteration, int epoch, IModel model) { return (System.currentTimeMillis() - initTime) > msSinceInit; } @@ -260,7 +260,7 @@ public class FailureTestingListener implements TrainingListener, Serializable { @Override - public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { + public boolean triggerFailure(CallType callType, int iteration, int epoch, IModel model) { return shouldFail; } @@ -284,7 +284,7 @@ public class FailureTestingListener implements TrainingListener, Serializable { @Override - public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { + public boolean triggerFailure(CallType callType, int iteration, int epoch, IModel model) { return shouldFail; } @@ -314,7 +314,7 @@ public class FailureTestingListener implements TrainingListener, Serializable { } @Override - public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) { + public boolean triggerFailure(CallType callType, int iteration, int epoch, IModel model) { return (isEpoch && epoch == count) || (!isEpoch && iteration == count); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java index 68402f40e..ff76fbfc0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/PerformanceListener.java @@ -22,14 +22,12 @@ package org.deeplearning4j.optimize.listeners; import com.google.common.base.Preconditions; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.ObjectInputStream; @@ -78,7 +76,7 @@ public class PerformanceListener extends BaseTrainingListener implements Seriali } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { // we update lastTime on every iteration // just to simplify things if (lastTime.get() == null) diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreIterationListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreIterationListener.java index 6568d2f67..2d8cc1829 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreIterationListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreIterationListener.java @@ -21,10 +21,8 @@ package org.deeplearning4j.optimize.listeners; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.api.BaseTrainingListener; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.Serializable; @@ -43,7 +41,7 @@ public class ScoreIterationListener extends BaseTrainingListener implements Seri public ScoreIterationListener() {} @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { if (printIterations <= 0) printIterations = 1; if (iteration % printIterations == 0) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreToChartListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreToChartListener.java index 2fc2999d6..1a8620a48 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreToChartListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/ScoreToChartListener.java @@ -26,7 +26,7 @@ import lombok.extern.slf4j.Slf4j; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.Response; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.api.BaseTrainingListener; @Slf4j @@ -40,7 +40,7 @@ public class ScoreToChartListener extends BaseTrainingListener { } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { double score = model.score(); String nurl = url+"s="+score+"&n="+seriesName; OkHttpClient client = new OkHttpClient(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/SleepyTrainingListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/SleepyTrainingListener.java index 4c262a64c..834778001 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/SleepyTrainingListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/SleepyTrainingListener.java @@ -22,7 +22,7 @@ package org.deeplearning4j.optimize.listeners; import lombok.*; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.api.BaseTrainingListener; import org.nd4j.common.util.ThreadUtils; import org.nd4j.linalg.api.ndarray.INDArray; @@ -160,7 +160,7 @@ public class SleepyTrainingListener extends BaseTrainingListener implements Seri } @Override - public void onEpochStart(Model model) { + public void onEpochStart(IModel model) { sleep(lastES.get(), timerES); if (lastES.get() == null) @@ -170,7 +170,7 @@ public class SleepyTrainingListener extends BaseTrainingListener implements Seri } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { sleep(lastEE.get(), timerEE); if (lastEE.get() == null) @@ -180,7 +180,7 @@ public class SleepyTrainingListener extends BaseTrainingListener implements Seri } @Override - public void onForwardPass(Model model, List activations) { + public void onForwardPass(IModel model, List activations) { sleep(lastFF.get(), timerFF); if (lastFF.get() == null) @@ -190,7 +190,7 @@ public class SleepyTrainingListener extends BaseTrainingListener implements Seri } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { sleep(lastFF.get(), timerFF); if (lastFF.get() == null) @@ -200,7 +200,7 @@ public class SleepyTrainingListener extends BaseTrainingListener implements Seri } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { sleep(lastIteration.get(), timerIteration); if (lastIteration.get() == null) @@ -210,7 +210,7 @@ public class SleepyTrainingListener extends BaseTrainingListener implements Seri } @Override - public void onBackwardPass(Model model) { + public void onBackwardPass(IModel model) { sleep(lastBP.get(), timerBP); if (lastBP.get() == null) @@ -220,7 +220,7 @@ public class SleepyTrainingListener extends BaseTrainingListener implements Seri } @Override - public void onGradientCalculation(Model model) { + public void onGradientCalculation(IModel model) { // } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/TimeIterationListener.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/TimeIterationListener.java index cc48c216b..8a947e4c0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/TimeIterationListener.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/TimeIterationListener.java @@ -21,10 +21,8 @@ package org.deeplearning4j.optimize.listeners; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.api.BaseTrainingListener; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import java.io.Serializable; import java.util.Date; @@ -46,7 +44,7 @@ public class TimeIterationListener extends BaseTrainingListener implements Seria } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { long currentIteration = iterationCounter.incrementAndGet(); long elapsed = System.currentTimeMillis() - start; long remaining = (iterationCount - currentIteration) * elapsed / currentIteration; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/callbacks/EvaluationCallback.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/callbacks/EvaluationCallback.java index 6cb756bb8..5f14ef3b0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/callbacks/EvaluationCallback.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/callbacks/EvaluationCallback.java @@ -20,11 +20,11 @@ package org.deeplearning4j.optimize.listeners.callbacks; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.listeners.EvaluativeListener; import org.nd4j.evaluation.IEvaluation; public interface EvaluationCallback { - void call(EvaluativeListener listener, Model model, long invocationsCount, IEvaluation[] evaluations); + void call(EvaluativeListener listener, IModel model, long invocationsCount, IEvaluation[] evaluations); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/callbacks/ModelSavingCallback.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/callbacks/ModelSavingCallback.java index df46f1fc6..cb0fe44a0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/callbacks/ModelSavingCallback.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/listeners/callbacks/ModelSavingCallback.java @@ -23,7 +23,7 @@ package org.deeplearning4j.optimize.listeners.callbacks; import lombok.NonNull; import org.apache.commons.io.FilenameUtils; import org.deeplearning4j.exception.DL4JInvalidConfigException; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.optimize.listeners.EvaluativeListener; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.evaluation.IEvaluation; @@ -66,7 +66,7 @@ public class ModelSavingCallback implements EvaluationCallback { } @Override - public void call(EvaluativeListener listener, Model model, long invocationsCount, IEvaluation[] evaluations) { + public void call(EvaluativeListener listener, IModel model, long invocationsCount, IEvaluation[] evaluations) { String temp = template.replaceAll("%d", "" + invocationsCount); @@ -81,7 +81,7 @@ public class ModelSavingCallback implements EvaluationCallback { * @param model * @param filename */ - protected void save(Model model, String filename) { + protected void save(IModel model, String filename) { try { ModelSerializer.writeModel(model, filename, true); } catch (IOException e) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BackTrackLineSearch.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BackTrackLineSearch.java index 18e64c081..1f391a3ef 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BackTrackLineSearch.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BackTrackLineSearch.java @@ -20,9 +20,9 @@ package org.deeplearning4j.optimize.solvers; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.math3.util.FastMath; import org.deeplearning4j.exception.InvalidStepException; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.stepfunctions.NegativeGradientStepFunction; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.LineOptimizer; @@ -33,7 +33,6 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarSetValue; import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps; -import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.slf4j.Logger; @@ -44,7 +43,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.abs; public class BackTrackLineSearch implements LineOptimizer { private static final Logger log = LoggerFactory.getLogger(BackTrackLineSearch.class); - private final Model layer; + private final IModel layer; private final StepFunction stepFunction; private final ConvexOptimizer optimizer; private int maxIterations; @@ -64,18 +63,18 @@ public class BackTrackLineSearch implements LineOptimizer { * @param stepFunction * @param optimizer */ - public BackTrackLineSearch(Model layer, StepFunction stepFunction, ConvexOptimizer optimizer) { + public BackTrackLineSearch(IModel layer, StepFunction stepFunction, ConvexOptimizer optimizer) { this.layer = layer; this.stepFunction = stepFunction; this.optimizer = optimizer; - this.maxIterations = layer.conf().getMaxNumLineSearchIterations(); + this.maxIterations = layer.getNetConfiguration().getMaxNumLineSearchIterations(); } /** * @param optimizable * @param optimizer */ - public BackTrackLineSearch(Model optimizable, ConvexOptimizer optimizer) { + public BackTrackLineSearch(IModel optimizable, ConvexOptimizer optimizer) { this(optimizable, new NegativeDefaultStepFunction(), optimizer); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BaseOptimizer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BaseOptimizer.java index 42ce490e5..b5e06a3c3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BaseOptimizer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/BaseOptimizer.java @@ -21,12 +21,11 @@ package org.deeplearning4j.optimize.solvers; import lombok.Getter; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.exception.InvalidStepException; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -64,7 +63,7 @@ public abstract class BaseOptimizer implements ConvexOptimizer { @Getter protected StepFunction stepFunction; protected Collection trainingListeners = new ArrayList<>(); - protected Model model; + protected IModel model; protected BackTrackLineSearch lineMaximizer; protected Updater updater; protected ComputationGraphUpdater computationGraphUpdater; @@ -90,7 +89,7 @@ public abstract class BaseOptimizer implements ConvexOptimizer { * @param model */ public BaseOptimizer(NeuralNetConfiguration conf, StepFunction stepFunction, - Collection trainingListeners, Model model) { + Collection trainingListeners, IModel model) { this.conf = conf; this.stepFunction = (stepFunction != null ? stepFunction : getDefaultStepFunctionForOptimizer(this.getClass())); this.trainingListeners = trainingListeners != null ? trainingListeners : new ArrayList(); @@ -289,7 +288,7 @@ public abstract class BaseOptimizer implements ConvexOptimizer { @Override - public void updateGradientAccordingToParams(Gradient gradient, Model model, int batchSize, LayerWorkspaceMgr workspaceMgr) { + public void updateGradientAccordingToParams(Gradient gradient, IModel model, int batchSize, LayerWorkspaceMgr workspaceMgr) { if (model instanceof ComputationGraph) { ComputationGraph graph = (ComputationGraph) model; if (computationGraphUpdater == null) { @@ -316,7 +315,7 @@ public abstract class BaseOptimizer implements ConvexOptimizer { */ @Override public void setupSearchState(Pair pair) { - INDArray gradient = pair.getFirst().gradient(conf.variables()); + INDArray gradient = pair.getFirst().gradient(conf.netWideVariables()); INDArray params = model.params().dup(); //Need dup here: params returns an array that isn't a copy (hence changes to this are problematic for line search methods) searchState.put(GRADIENT_KEY, gradient); searchState.put(SCORE_KEY, pair.getSecond()); @@ -332,39 +331,39 @@ public abstract class BaseOptimizer implements ConvexOptimizer { } } - public static int getIterationCount(Model model) { + public static int getIterationCount(IModel model) { if (model instanceof MultiLayerNetwork) { - return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getIterationCount(); + return ((MultiLayerNetwork) model).getNetConfiguration().getIterationCount(); } else if (model instanceof ComputationGraph) { return ((ComputationGraph) model).getComputationGraphConfiguration().getIterationCount(); } else { - return model.conf().getIterationCount(); + return model.getNetConfiguration().getIterationCount(); } } - public static void incrementIterationCount(Model model, int incrementBy) { + public static void incrementIterationCount(IModel model, int incrementBy) { if (model instanceof MultiLayerNetwork) { - MultiLayerConfiguration conf = ((MultiLayerNetwork) model).getLayerWiseConfigurations(); + NeuralNetConfiguration conf = ((MultiLayerNetwork) model).getNetConfiguration(); conf.setIterationCount(conf.getIterationCount() + incrementBy); } else if (model instanceof ComputationGraph) { ComputationGraphConfiguration conf = ((ComputationGraph) model).getComputationGraphConfiguration(); conf.setIterationCount(conf.getIterationCount() + incrementBy); } else { - model.conf().setIterationCount(model.conf().getIterationCount() + incrementBy); + model.getNetConfiguration().setIterationCount(model.getNetConfiguration().getIterationCount() + incrementBy); } } - public static int getEpochCount(Model model){ + public static int getEpochCount(IModel model){ if (model instanceof MultiLayerNetwork) { - return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount(); + return ((MultiLayerNetwork) model).getNetConfiguration().getEpochCount(); } else if (model instanceof ComputationGraph) { return ((ComputationGraph) model).getComputationGraphConfiguration().getEpochCount(); } else { - return model.conf().getEpochCount(); + return model.getNetConfiguration().getEpochCount(); } } - public static void applyConstraints(Model model){ + public static void applyConstraints(IModel model){ int iter = getIterationCount(model); int epoch = getEpochCount(model); model.applyConstraints(iter, epoch); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/ConjugateGradient.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/ConjugateGradient.java index b07ade04a..614075e20 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/ConjugateGradient.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/ConjugateGradient.java @@ -20,7 +20,7 @@ package org.deeplearning4j.optimize.solvers; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.optimize.api.StepFunction; import org.deeplearning4j.optimize.api.TrainingListener; @@ -38,7 +38,7 @@ public class ConjugateGradient extends BaseOptimizer { public ConjugateGradient(NeuralNetConfiguration conf, StepFunction stepFunction, - Collection trainingListeners, Model model) { + Collection trainingListeners, IModel model) { super(conf, stepFunction, trainingListeners, model); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java index 5760ee337..3a8fa9bdc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java @@ -20,7 +20,7 @@ package org.deeplearning4j.optimize.solvers; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.optimize.api.StepFunction; @@ -42,7 +42,7 @@ public class LBFGS extends BaseOptimizer { private final int m = 4; public LBFGS(NeuralNetConfiguration conf, StepFunction stepFunction, - Collection trainingListeners, Model model) { + Collection trainingListeners, IModel model) { super(conf, stepFunction, trainingListeners, model); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LineGradientDescent.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LineGradientDescent.java index 2afc53453..78ebf3231 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LineGradientDescent.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LineGradientDescent.java @@ -20,7 +20,7 @@ package org.deeplearning4j.optimize.solvers; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.optimize.api.StepFunction; import org.deeplearning4j.optimize.api.TrainingListener; @@ -33,7 +33,7 @@ public class LineGradientDescent extends BaseOptimizer { private static final long serialVersionUID = 6336124657542062284L; public LineGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction, - Collection trainingListeners, Model model) { + Collection trainingListeners, IModel model) { super(conf, stepFunction, trainingListeners, model); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java index fbee9c2a3..ee7070f01 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java @@ -21,7 +21,7 @@ package org.deeplearning4j.optimize.solvers; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -41,7 +41,7 @@ public class StochasticGradientDescent extends BaseOptimizer { public StochasticGradientDescent(NeuralNetConfiguration conf, StepFunction stepFunction, - Collection trainingListeners, Model model) { + Collection trainingListeners, IModel model) { super(conf, stepFunction, trainingListeners, model); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java index 490acc178..7684caa6a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java @@ -24,8 +24,8 @@ import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.exception.DL4JInvalidConfigException; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.optimize.api.StepFunction; import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor; import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm; @@ -171,7 +171,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist } - public static long getOptimalBufferSize(Model model, int numWorkers, int queueSize) { + public static long getOptimalBufferSize(IModel model, int numWorkers, int queueSize) { return getOptimalBufferSize(model.params().length(), numWorkers, queueSize); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java index 53bed93a2..cd3bd3f2c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/Convolution1DUtils.java @@ -27,7 +27,6 @@ import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.*; -import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; @@ -62,7 +61,7 @@ public class Convolution1DUtils { * @return true if the input layer has an rnn format * false otherwise */ - public static boolean hasRnnDataFormat(Layer layer) { + public static boolean hasRnnDataFormat(LayerConfiguration layer) { return layer instanceof Convolution1D || layer instanceof Convolution1DLayer || layer instanceof Subsampling1DLayer || @@ -78,7 +77,7 @@ public class Convolution1DUtils { * @param layer the layer to get the format for * @return the format for the layer */ - public static RNNFormat getRnnFormatFromLayer(Layer layer) { + public static RNNFormat getRnnFormatFromLayer(LayerConfiguration layer) { Preconditions.checkState(hasRnnDataFormat(layer),"ILayer of type " + layer.getClass().getName() + " and name " + layer.getLayerName() + " does not have an RNNFormat"); if(layer instanceof SimpleRnn) { SimpleRnn simpleRnn = (SimpleRnn) layer; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java index 616f1c620..e7adaa86a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ConvolutionUtils.java @@ -52,7 +52,7 @@ public class ConvolutionUtils { public static final String NCHW_NHWC_ERROR_MSG = "Note: Convolution layers can be configured for either NCHW (channels first)" + " or NHWC (channels last) format for input images and activations.\n" + "Layers can be configured using .dataFormat(CNN2DFormat.NCHW/NHWC) when constructing the layer, or for the entire net using" + - " .setInputType(InputType.convolutional(height, width, depth, CNN2DForman.NCHW/NHWC)).\n" + + " .inputType(InputType.convolutional(height, width, depth, CNN2DForman.NCHW/NHWC)).\n" + "ImageRecordReader and NativeImageLoader can also be configured to load image data in either NCHW or NHWC format which must match the network"; @@ -176,7 +176,7 @@ public class ConvolutionUtils { * @param layer the layer to check * @return true if the layer is one of the above types, false otherwise */ - public static boolean layerHasConvolutionLayout(Layer layer) { + public static boolean layerHasConvolutionLayout(LayerConfiguration layer) { return layer instanceof ConvolutionLayer || layer instanceof SubsamplingLayer || layer instanceof SpaceToBatchLayer || @@ -191,15 +191,15 @@ public class ConvolutionUtils { /** * Get the format for a given layer. - * {@link #layerHasConvolutionLayout(Layer)} - * should return true on the given {@link Layer} + * {@link #layerHasConvolutionLayout(LayerConfiguration)} + * should return true on the given {@link LayerConfiguration} * type or an {@link IllegalArgumentException} * will be thrown * @param layer the input layer * @return the {@link CNN2DFormat} for the given * layer */ - public static CNN2DFormat getFormatForLayer(Layer layer) { + public static CNN2DFormat getFormatForLayer(LayerConfiguration layer) { if(layer instanceof Convolution1DLayer) { Convolution1DLayer convolution1DLayer = (Convolution1DLayer) layer; return convolution1DLayer.getCnn2dDataFormat(); @@ -520,9 +520,9 @@ public class ConvolutionUtils { * @param conf the configuration to get height and width from * @return the configuration to get height and width from */ - public static int[] getHeightAndWidth(NeuralNetConfiguration conf) { + public static int[] getHeightAndWidth(LayerConfiguration conf) { return getHeightAndWidth( - ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf.getLayer()).getKernelSize()); + ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf).getKernelSize()); } @@ -531,8 +531,8 @@ public class ConvolutionUtils { * the number of kernels from * @return the number of kernels/filters to apply */ - public static long numFeatureMap(NeuralNetConfiguration conf) { - return ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf.getLayer()).getNOut(); + public static long numFeatureMap(LayerConfiguration conf) { + return ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer) conf).getNOut(); } /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java index 5227ad77f..56f7d3b7f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java @@ -41,12 +41,12 @@ import java.util.Set; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.FileUtils; import org.apache.commons.lang3.exception.ExceptionUtils; import org.bytedeco.javacpp.Pointer; import org.deeplearning4j.common.config.DL4JSystemProperties; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.BackpropType; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -136,7 +136,7 @@ public class CrashReportingUtil { * @param net Net to generate the crash dump for. May not be null * @param e Throwable/exception. Stack trace will be included in the network output */ - public static void writeMemoryCrashDump(@NonNull Model net, @NonNull Throwable e){ + public static void writeMemoryCrashDump(@NonNull IModel net, @NonNull Throwable e){ if(!crashDumpsEnabled){ return; } @@ -189,7 +189,7 @@ public class CrashReportingUtil { * @param net Net to generate the report for * @return Report as a String */ - public static String generateMemoryStatus(Model net, int minibatch, InputType... inputTypes){ + public static String generateMemoryStatus(IModel net, int minibatch, InputType... inputTypes){ MultiLayerNetwork mln = null; ComputationGraph cg = null; boolean isMLN; @@ -310,12 +310,12 @@ public class CrashReportingUtil { //Workspaces, backprop type, layer info, activation info, helper info if(isMLN) { - sb.append(f("Backprop Type", mln.getLayerWiseConfigurations().getBackpropType())); - if(mln.getLayerWiseConfigurations().getBackpropType() == BackpropType.TruncatedBPTT){ - sb.append(f("TBPTT Length", mln.getLayerWiseConfigurations().getTbpttFwdLength() + "/" + mln.getLayerWiseConfigurations().getTbpttBackLength())); + sb.append(f("Backprop Type", mln.getNetConfiguration().getBackpropType())); + if(mln.getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT){ + sb.append(f("TBPTT Length", mln.getNetConfiguration().getTbpttFwdLength() + "/" + mln.getNetConfiguration().getTbpttBackLength())); } - sb.append(f("Workspace Mode: Training", mln.getLayerWiseConfigurations().getTrainingWorkspaceMode())); - sb.append(f("Workspace Mode: Inference", mln.getLayerWiseConfigurations().getInferenceWorkspaceMode())); + sb.append(f("Workspace Mode: Training", mln.getNetConfiguration().getTrainingWorkspaceMode())); + sb.append(f("Workspace Mode: Inference", mln.getNetConfiguration().getInferenceWorkspaceMode())); appendLayerInformation(sb, mln.getLayers(), bytesPerElement); appendHelperInformation(sb, mln.getLayers()); appendActivationShapes(mln, (inputTypes == null || inputTypes.length == 0 ? null : inputTypes[0]), minibatch, sb, bytesPerElement); @@ -470,7 +470,7 @@ public class CrashReportingUtil { sb.append(String.format(format, "Idx", "Name", "ILayer Type", "ILayer # Parameters", "ILayer Parameter Memory")).append("\n"); for(Layer layer : layers){ long numParams = layer.numParams(); - sb.append(String.format(format, layer.getIndex(), layer.conf().getLayer().getLayerName(), + sb.append(String.format(format, layer.getIndex(), layer.getLayerConfiguration().getLayerName(), layer.getClass().getSimpleName(), numParams, fBytes(numParams * bytesPerElement))).append("\n"); } @@ -503,7 +503,7 @@ public class CrashReportingUtil { } int idx = l.getIndex(); - String layerName = l.conf().getLayer().getLayerName(); + String layerName = l.getLayerConfiguration().getLayerName(); if(layerName == null) layerName = String.valueOf(idx); @@ -549,7 +549,7 @@ public class CrashReportingUtil { sb.append(f("Current Minibatch Size", minibatch)); sb.append(f("Input Shape", Arrays.toString(inputShape))); - List inputTypes = net.getLayerWiseConfigurations().getLayerActivationTypes(inputType); + List inputTypes = net.getNetConfiguration().getLayerActivationTypes(inputType); String format = "%-3s %-20s %-20s %-42s %-20s %-12s %-12s"; sb.append(String.format(format, "Idx", "Name", "ILayer Type", "Activations Type", "Activations Shape", "# Elements", "Memory")).append("\n"); @@ -567,7 +567,7 @@ public class CrashReportingUtil { bytes = 0; } totalActivationBytes += bytes; - sb.append(String.format(format, i, layers[i].conf().getLayer().getLayerName(), layers[i].getClass().getSimpleName(), + sb.append(String.format(format, i, layers[i].getLayerConfiguration().getLayerName(), layers[i].getClass().getSimpleName(), inputTypes.get(i), Arrays.toString(shape), (numElements < 0 ? "" : String.valueOf(numElements)), fBytes(bytes))).append("\n"); last = bytes; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/DL4JModelValidator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/DL4JModelValidator.java index 6413a5eb4..739d3482c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/DL4JModelValidator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/DL4JModelValidator.java @@ -22,9 +22,9 @@ package org.deeplearning4j.util; import lombok.NonNull; import org.apache.commons.io.IOUtils; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.common.validation.Nd4jCommonValidator; @@ -47,7 +47,7 @@ public class DL4JModelValidator { /** * Validate whether the file represents a valid MultiLayerNetwork saved previously with {@link MultiLayerNetwork#save(File)} - * or {@link ModelSerializer#writeModel(Model, File, boolean)}, to be read with {@link MultiLayerNetwork#load(File, boolean)} + * or {@link ModelSerializer#writeModel(IModel, File, boolean)}, to be read with {@link MultiLayerNetwork#load(File, boolean)} * * @param f File that should represent an saved MultiLayerNetwork * @return Result of validation @@ -80,14 +80,14 @@ public class DL4JModelValidator { } try{ - MultiLayerConfiguration.fromJson(config); + NeuralNetConfiguration.fromJson(config); } catch (Throwable t){ return ValidationResult.builder() .formatType("MultiLayerNetwork") .formatClass(MultiLayerNetwork.class) .valid(false) .path(Nd4jCommonValidator.getPath(f)) - .issues(Collections.singletonList("Zip file JSON model configuration does not appear to represent a valid MultiLayerConfiguration")) + .issues(Collections.singletonList("Zip file JSON model configuration does not appear to represent a valid NeuralNetConfiguration")) .exception(t) .build(); } @@ -104,7 +104,7 @@ public class DL4JModelValidator { /** * Validate whether the file represents a valid ComputationGraph saved previously with {@link ComputationGraph#save(File)} - * or {@link ModelSerializer#writeModel(Model, File, boolean)}, to be read with {@link ComputationGraph#load(File, boolean)} + * or {@link ModelSerializer#writeModel(IModel, File, boolean)}, to be read with {@link ComputationGraph#load(File, boolean)} * * @param f File that should represent an saved MultiLayerNetwork * @return Result of validation diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java index e636334fd..e763d30bf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java @@ -20,6 +20,7 @@ package org.deeplearning4j.util; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.input.CloseShieldInputStream; import org.deeplearning4j.common.util.DL4JFileUtils; import com.google.common.io.Files; @@ -28,10 +29,9 @@ import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; import org.apache.commons.io.output.CloseShieldOutputStream; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.common.base.Preconditions; @@ -74,7 +74,7 @@ public class ModelSerializer { * @param saveUpdater whether to save the updater or not * @throws IOException */ - public static void writeModel(@NonNull Model model, @NonNull File file, boolean saveUpdater) throws IOException { + public static void writeModel(@NonNull IModel model, @NonNull File file, boolean saveUpdater) throws IOException { writeModel(model,file,saveUpdater,null); } @@ -88,7 +88,7 @@ public class ModelSerializer { * @param dataNormalization the normalizer to save (optional) * @throws IOException */ - public static void writeModel(@NonNull Model model, @NonNull File file, boolean saveUpdater,DataNormalization dataNormalization) throws IOException { + public static void writeModel(@NonNull IModel model, @NonNull File file, boolean saveUpdater,DataNormalization dataNormalization) throws IOException { try (BufferedOutputStream stream = new BufferedOutputStream(new FileOutputStream(file))) { writeModel(model, stream, saveUpdater,dataNormalization); } @@ -103,7 +103,7 @@ public class ModelSerializer { * or not * @throws IOException */ - public static void writeModel(@NonNull Model model, @NonNull String path, boolean saveUpdater) throws IOException { + public static void writeModel(@NonNull IModel model, @NonNull String path, boolean saveUpdater) throws IOException { try (BufferedOutputStream stream = new BufferedOutputStream(new FileOutputStream(path))) { writeModel(model, stream, saveUpdater); } @@ -116,7 +116,7 @@ public class ModelSerializer { * @param saveUpdater whether to save the updater for the model or not * @throws IOException */ - public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater) + public static void writeModel(@NonNull IModel model, @NonNull OutputStream stream, boolean saveUpdater) throws IOException { writeModel(model,stream,saveUpdater,null); } @@ -132,14 +132,14 @@ public class ModelSerializer { * @param dataNormalization the normalizer ot save (may be null) * @throws IOException */ - public static void writeModel(@NonNull Model model, @NonNull OutputStream stream, boolean saveUpdater,DataNormalization dataNormalization) + public static void writeModel(@NonNull IModel model, @NonNull OutputStream stream, boolean saveUpdater,DataNormalization dataNormalization) throws IOException { ZipOutputStream zipfile = new ZipOutputStream(new CloseShieldOutputStream(stream)); // Save configuration as JSON String json = ""; if (model instanceof MultiLayerNetwork) { - json = ((MultiLayerNetwork) model).getLayerWiseConfigurations().toJson(); + json = ((MultiLayerNetwork) model).getNetConfiguration().toJson(); } else if (model instanceof ComputationGraph) { json = ((ComputationGraph) model).getComputationGraphConfiguration().toJson(); } @@ -318,20 +318,20 @@ public class ModelSerializer { if (gotConfig && gotCoefficients) { - MultiLayerConfiguration confFromJson; + NeuralNetConfiguration confFromJson; try{ - confFromJson = MultiLayerConfiguration.fromJson(json); + confFromJson = NeuralNetConfiguration.fromJson(json); } catch (Exception e){ ComputationGraphConfiguration cg; try{ cg = ComputationGraphConfiguration.fromJson(json); } catch (Exception e2){ //Invalid, and not a compgraph - throw new RuntimeException("Error deserializing JSON MultiLayerConfiguration. Saved model JSON is" + - " not a valid MultiLayerConfiguration", e); + throw new RuntimeException("Error deserializing JSON NeuralNetConfiguration. Saved model JSON is" + + " not a valid NeuralNetConfiguration", e); } if(cg.getNetworkInputs() != null && cg.getVertices() != null) { - throw new RuntimeException("Error deserializing JSON MultiLayerConfiguration. Saved model appears to be " + + throw new RuntimeException("Error deserializing JSON NeuralNetConfiguration. Saved model appears to be " + "a ComputationGraph - use ModelSerializer.restoreComputationGraph instead"); } else { throw e; @@ -554,7 +554,7 @@ public class ModelSerializer { throw e; } try{ - MultiLayerConfiguration.fromJson(json); + NeuralNetConfiguration.fromJson(json); } catch (Exception e2){ //Invalid, and not a compgraph throw new RuntimeException("Error deserializing JSON ComputationGraphConfiguration. Saved model JSON is" + @@ -652,7 +652,7 @@ public class ModelSerializer { * @param model * @return */ - public static Task taskByModel(Model model) { + public static Task taskByModel(IModel model) { Task task = new Task(); try { task.setArchitectureType(Task.ArchitectureType.RECURRENT); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java index 4348be74a..900a516cd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java @@ -21,14 +21,13 @@ package org.deeplearning4j.util; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.Trainable; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -61,21 +60,21 @@ public class NetworkUtils { // by definition the identical for a MLN and "single stack" computation graph. This also has to hold // for the updater state... - ComputationGraphConfiguration.GraphBuilder b = new NeuralNetConfiguration.Builder() - .dataType(net.getLayerWiseConfigurations().getDataType()) + ComputationGraphConfiguration.GraphBuilder b = NeuralNetConfiguration.builder() + .dataType(net.getNetConfiguration().getDataType()) .graphBuilder(); - MultiLayerConfiguration origConf = net.getLayerWiseConfigurations().clone(); + NeuralNetConfiguration origConf = net.getNetConfiguration().clone(); int layerIdx = 0; String lastLayer = "in"; b.addInputs("in"); - for (NeuralNetConfiguration c : origConf.getConfs()) { + for (NeuralNetConfiguration c : origConf.getNetConfigurations()) { String currLayer = String.valueOf(layerIdx); InputPreProcessor preproc = origConf.getInputPreProcess(layerIdx); - b.addLayer(currLayer, c.getLayer(), preproc, lastLayer); + b.addLayer(currLayer, c.getFlattenedLayerConfigurations().get(layerIdx), preproc, lastLayer); lastLayer = currLayer; layerIdx++; @@ -123,7 +122,7 @@ public class NetworkUtils { private static void setLearningRate(MultiLayerNetwork net, int layerNumber, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) { - Layer l = net.getLayer(layerNumber).conf().getLayer(); + LayerConfiguration l = net.getLayer(layerNumber).getLayerConfiguration(); if (l instanceof BaseLayer) { BaseLayer bl = (BaseLayer) l; IUpdater u = bl.getIUpdater(); @@ -155,8 +154,8 @@ public class NetworkUtils { /** * Set the learning rate schedule for all layers in the network to the specified schedule. * This schedule will replace any/all existing schedules, and also any fixed learning rate values.
- * Note that the iteration/epoch counts will not be reset. Use {@link MultiLayerConfiguration#setIterationCount(int)} - * and {@link MultiLayerConfiguration#setEpochCount(int)} if this is required + * Note that the iteration/epoch counts will not be reset. Use {@link NeuralNetConfiguration#setIterationCount(int)} + * and {@link NeuralNetConfiguration#setEpochCount(int)} if this is required * * @param newLrSchedule New learning rate schedule for all layers */ @@ -184,8 +183,8 @@ public class NetworkUtils { * Note also that {@link #setLearningRate(MultiLayerNetwork, ISchedule)} should also be used in preference, when all layers need * to be set to a new LR schedule.
* This schedule will replace any/all existing schedules, and also any fixed learning rate values.
- * Note also that the iteration/epoch counts will not be reset. Use {@link MultiLayerConfiguration#setIterationCount(int)} - * and {@link MultiLayerConfiguration#setEpochCount(int)} if this is required + * Note also that the iteration/epoch counts will not be reset. Use {@link NeuralNetConfiguration#setIterationCount(int)} + * and {@link NeuralNetConfiguration#setEpochCount(int)} if this is required * * @param layerNumber Number of the layer to set the LR schedule for * @param lrSchedule New learning rate for a single layer @@ -203,7 +202,7 @@ public class NetworkUtils { * @return Learning rate for the specified layer, or null */ public static Double getLearningRate(MultiLayerNetwork net, int layerNumber) { - Layer l = net.getLayer(layerNumber).conf().getLayer(); + LayerConfiguration l = net.getLayer(layerNumber).getLayerConfiguration(); int iter = net.getIterationCount(); int epoch = net.getEpochCount(); if (l instanceof BaseLayer) { @@ -238,14 +237,14 @@ public class NetworkUtils { private static void setLearningRate(ComputationGraph net, double newLr, ISchedule lrSchedule) { org.deeplearning4j.nn.api.Layer[] layers = net.getLayers(); for (int i = 0; i < layers.length; i++) { - setLearningRate(net, layers[i].conf().getLayer().getLayerName(), newLr, lrSchedule, false); + setLearningRate(net, layers[i].getLayerConfiguration().getLayerName(), newLr, lrSchedule, false); } refreshUpdater(net); } private static void setLearningRate(ComputationGraph net, String layerName, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) { - Layer l = net.getLayer(layerName).conf().getLayer(); + LayerConfiguration l = net.getLayer(layerName).getLayerConfiguration(); if (l instanceof BaseLayer) { BaseLayer bl = (BaseLayer) l; IUpdater u = bl.getIUpdater(); @@ -325,7 +324,7 @@ public class NetworkUtils { * @return Learning rate for the specified layer, or null */ public static Double getLearningRate(ComputationGraph net, String layerName) { - Layer l = net.getLayer(layerName).conf().getLayer(); + LayerConfiguration l = net.getLayer(layerName).getLayerConfiguration(); int iter = net.getComputationGraphConfiguration().getIterationCount(); int epoch = net.getComputationGraphConfiguration().getEpochCount(); if (l instanceof BaseLayer) { @@ -353,7 +352,7 @@ public class NetworkUtils { * @see org.deeplearning4j.nn.graph.ComputationGraph#outputSingle(INDArray...) * @see org.deeplearning4j.nn.multilayer.MultiLayerNetwork#output(INDArray) */ - public static INDArray output(Model model, INDArray input) { + public static INDArray output(IModel model, INDArray input) { if (model instanceof MultiLayerNetwork) { final MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) model; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java index fb3d9ea64..76f06b556 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java @@ -58,7 +58,7 @@ public class OutputLayerUtil { OUTSIDE_ZERO_ONE_RANGE.add(ActivationThresholdedReLU.class); } - private static final String COMMON_MSG = "\nThis configuration validation check can be disabled for MultiLayerConfiguration" + + private static final String COMMON_MSG = "\nThis configuration validation check can be disabled for NeuralNetConfiguration" + " and ComputationGraphConfiguration using validateOutputLayerConfig(false), however this is not recommended."; @@ -70,7 +70,7 @@ public class OutputLayerUtil { * @param layerName Name of the layer * @param layer ILayer */ - public static void validateOutputLayer(String layerName, Layer layer){ + public static void validateOutputLayer(String layerName, LayerConfiguration layer){ IActivation activation; ILossFunction loss; long nOut; @@ -166,7 +166,7 @@ public class OutputLayerUtil { * @param outputLayer Output layer * @param classifierEval Class for the classifier evaluation */ - public static void validateOutputLayerForClassifierEvaluation(Layer outputLayer, Class classifierEval){ + public static void validateOutputLayerForClassifierEvaluation(LayerConfiguration outputLayer, Class classifierEval){ if(outputLayer instanceof Yolo2OutputLayer){ throw new IllegalStateException("Classifier evaluation using " + classifierEval.getSimpleName() + " class cannot be applied for object" + " detection evaluation using Yolo2OutputLayer: " + classifierEval.getSimpleName() + " class is for classifier evaluation only."); @@ -182,7 +182,7 @@ public class OutputLayerUtil { throw new IllegalStateException("Classifier evaluation using " + classifierEval.getSimpleName() + " class cannot be applied to output" + " layers with activation functions that are not probabilities (in range 0 to 1). Output layer type: " + outputLayer.getClass().getSimpleName() + " has activation function " + bl.getActivationFn().getClass().getSimpleName() + - ". This check can be disabled using MultiLayerNetwork.getLayerWiseConfigurations().setValidateOutputLayerConfig(false)" + + ". This check can be disabled using MultiLayerNetwork.getConfiguration().setValidateOutputLayerConfig(false)" + " or ComputationGraph.getConfiguration().setValidateOutputLayerConfig(false)"); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java index eb5814b49..4723211b9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/TimeSeriesUtils.java @@ -23,7 +23,7 @@ package org.deeplearning4j.util; import lombok.val; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; -import org.deeplearning4j.nn.conf.layers.Layer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; @@ -442,7 +442,7 @@ public class TimeSeriesUtils { * LastTimeStep, etc * @param layer ILayer to get the RNNFormat from */ - public static RNNFormat getFormatFromRnnLayer(Layer layer){ + public static RNNFormat getFormatFromRnnLayer(LayerConfiguration layer){ if(layer instanceof BaseRecurrentLayer){ return ((BaseRecurrentLayer) layer).getRnnDataFormat(); } else if(layer instanceof MaskZeroLayer){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/resources/simplelogger.properties b/cavis-dnn/cavis-dnn-nn/src/main/resources/simplelogger.properties new file mode 100644 index 000000000..93090cbc4 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/resources/simplelogger.properties @@ -0,0 +1,22 @@ +# +# +# ****************************************************************************** +# * +# * This program and the accompanying materials are made available under the +# * terms of the Apache License, Version 2.0 which is available at +# * https://www.apache.org/licenses/LICENSE-2.0. +# * +# * See the NOTICE file distributed with this work for additional +# * information regarding copyright ownership. +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# * License for the specific language governing permissions and limitations +# * under the License. +# * +# * SPDX-License-Identifier: Apache-2.0 +# ***************************************************************************** +# +# + +org.slf4j.simpleLogger.defaultLogLevel = trace \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java index 06c322a57..9ca79badc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java +++ b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java @@ -21,18 +21,17 @@ package net.brutex.ai.dnn.api; -import static net.brutex.ai.dnn.api.dnn.*; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Iterator; import java.util.concurrent.atomic.AtomicInteger; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator; -import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.Updater; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ActivationLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.Pair; @@ -53,8 +52,10 @@ class dnnTest { assertTrue(iterator.hasNext()); + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().build(); + /** - * MultiLayerConfiguration confxx = new NeuralNetConfiguration.Builder() + * NeuralNetConfiguration confxx = NeuralNetConfiguration.builder() * .seed(42) * .updater(UPDATER) * .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) @@ -62,7 +63,7 @@ class dnnTest { * .weightInit(WeightInit.XAVIER) * .activation(Activation.IDENTITY) * .list(genLayers()) - * .setInputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) + * .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) * // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS)) * .build(); */ @@ -76,20 +77,18 @@ class dnnTest { * new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), * new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) */ - dnn.conf() + NN.net() .seed(42) .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold( 100 ) - .weightInit( new WeightInitXavier() ) - .activation( new ActivationIdentity() ) + .weightInitFn( new WeightInitXavier() ) + .activationFn( new ActivationIdentity() ) .inputType( InputType.convolutional( 28, 28, 1)) - .layer( dnn.DenseLayer(10,30).build() ) + .layer( new DenseLayer.Builder().nIn(10).nOut(20).build() ) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build() ) - ; - } protected static Iterable> floatIterable(final int totalRows, final int numColumns) { diff --git a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/conf/layer/FFLayerTest.java b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/conf/layer/FFLayerTest.java index 2fa944000..8430ec35d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/conf/layer/FFLayerTest.java +++ b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/conf/layer/FFLayerTest.java @@ -21,23 +21,12 @@ package net.brutex.ai.dnn.conf.layer; -import net.brutex.ai.dnn.api.IModel; -import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; -import net.brutex.ai.dnn.api.ILayerConfiguration; import org.junit.jupiter.api.Test; class FFLayerTest { @Test void instantiate() { - ILayerConfiguration ff_conf = FeedForwardLayerConfiguration.builder().build(); - INeuralNetworkConfiguration net_conf = net.brutex.ai.dnn.conf.NeuralNetworkConfiguration.builder() - .layerConfiguration(ff_conf) - .build(); - IModel network = net.brutex.ai.dnn.impl.network.NeuralNetwork.builder().name("Test Network") - .configuration(net_conf) - .build(); - ff_conf.instantiate(network); } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainer.java index ce14ae0b6..7af10085b 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainer.java @@ -25,7 +25,7 @@ import lombok.Builder; import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.parallelism.ParallelWrapper; @@ -82,23 +82,23 @@ public class ParameterServerTrainer extends DefaultTrainer { } @Override - public Model getModel() { + public IModel getModel() { return super.getModel(); } @Override - public void updateModel(@NonNull Model model) { + public void updateModel(@NonNull IModel model) { super.updateModel(model); } public static class ParameterServerTrainerBuilder extends DefaultTrainerBuilder { @Override - public ParameterServerTrainerBuilder originalModel(Model originalModel) { + public ParameterServerTrainerBuilder originalModel(IModel originalModel) { return (ParameterServerTrainerBuilder) super.originalModel(originalModel); } @Override - public ParameterServerTrainerBuilder replicatedModel(Model replicatedModel) { + public ParameterServerTrainerBuilder replicatedModel(IModel replicatedModel) { return (ParameterServerTrainerBuilder) super.replicatedModel(replicatedModel); } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainerContext.java b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainerContext.java index 47d04d303..89f8d71a9 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainerContext.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainerContext.java @@ -21,7 +21,7 @@ package org.deeplearning4j.parallelism.parameterserver; import io.aeron.driver.MediaDriver; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.parallelism.ParallelWrapper; @@ -47,7 +47,7 @@ public class ParameterServerTrainerContext implements TrainerContext { * @param args the arguments to initialize with (maybe null) */ @Override - public void init(Model model, Object... args) { + public void init(IModel model, Object... args) { mediaDriverContext = new MediaDriver.Context(); mediaDriver = MediaDriver.launchEmbedded(mediaDriverContext); parameterServerNode = new ParameterServerNode(mediaDriver, statusServerPort, numWorkers); @@ -73,7 +73,7 @@ public class ParameterServerTrainerContext implements TrainerContext { * @return the created training instance */ @Override - public Trainer create(String uuid, int threadId, Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, + public Trainer create(String uuid, int threadId, IModel model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, WorkspaceMode mode, int averagingFrequency) { return ParameterServerTrainer.builder().originalModel(model).parameterServerClient(ParameterServerClient .builder().aeron(parameterServerNode.getAeron()) @@ -86,12 +86,12 @@ public class ParameterServerTrainerContext implements TrainerContext { } @Override - public void finalizeRound(Model originalModel, Model... models) { + public void finalizeRound(IModel originalModel, IModel... models) { // no-op } @Override - public void finalizeTraining(Model originalModel, Model... models) { + public void finalizeTraining(IModel originalModel, IModel... models) { // no-op } } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java index d92cdf753..cfeaf0821 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/test/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerParallelWrapperTest.java @@ -23,8 +23,8 @@ package org.deeplearning4j.parallelism.parameterserver; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -56,7 +56,7 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest { DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345); log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)).list() @@ -73,9 +73,9 @@ public class ParameterServerParallelWrapperTest extends BaseDL4JTest { .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)); + .inputType(InputType.convolutionalFlat(28, 28, 1)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java index e1f8b9273..a7b4a98bc 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java @@ -22,6 +22,7 @@ package org.deeplearning4j.parallelism; import com.google.common.util.concurrent.AtomicDouble; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration; import org.deeplearning4j.earlystopping.EarlyStoppingResult; import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener; @@ -29,7 +30,6 @@ import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition; import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition; import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.BaseTrainingListener; @@ -45,7 +45,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @Slf4j -public class EarlyStoppingParallelTrainer implements IEarlyStoppingTrainer { +public class EarlyStoppingParallelTrainer implements IEarlyStoppingTrainer { protected T model; @@ -314,7 +314,7 @@ public class EarlyStoppingParallelTrainer implements IEarlyStop * with each averaging step, and thus averaging is considered analogous to an iteration. * @param */ - private class AveragingTrainingListener extends BaseTrainingListener { + private class AveragingTrainingListener extends BaseTrainingListener { private final Logger log = LoggerFactory.getLogger(AveragingTrainingListener.class); private final IterationTerminationCondition terminationReason = null; private final EarlyStoppingParallelTrainer trainer; @@ -325,7 +325,7 @@ public class EarlyStoppingParallelTrainer implements IEarlyStop } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { //Check per-iteration termination conditions double latestScore = model.score(); trainer.setLatestScore(latestScore); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java index 9f32446ae..33009e994 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java @@ -23,10 +23,10 @@ package org.deeplearning4j.parallelism; import lombok.*; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.ModelAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.parallelism.inference.LoadBalanceMode; @@ -69,14 +69,14 @@ public class InplaceParallelInference extends ParallelInference { } @Override - public synchronized void updateModel(@NonNull Model model) { + public synchronized void updateModel(@NonNull IModel model) { for (val h:holders) h.updateModel(model); } @Override - protected synchronized Model[] getCurrentModelsFromWorkers() { - val models = new Model[holders.size()]; + protected synchronized IModel[] getCurrentModelsFromWorkers() { + val models = new IModel[holders.size()]; int cnt = 0; for (val h:holders) { models[cnt++] = h.sourceModel; @@ -101,7 +101,7 @@ public class InplaceParallelInference extends ParallelInference { */ public T output(@NonNull ModelAdapter adapter, INDArray[] input, INDArray[] inputMasks, INDArray[] labelsMasks) { val holder = selector.getModelForThisThread(); - Model model = null; + IModel model = null; boolean acquired = false; try { model = holder.acquireModel(); @@ -158,9 +158,9 @@ public class InplaceParallelInference extends ParallelInference { @AllArgsConstructor @lombok.Builder protected static class ModelHolder { - protected Model sourceModel; + protected IModel sourceModel; @lombok.Builder.Default protected int workers = 4; - @lombok.Builder.Default protected List replicas = new ArrayList<>(); + @lombok.Builder.Default protected List replicas = new ArrayList<>(); @lombok.Builder.Default protected boolean rootDevice = true; @lombok.Builder.Default protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.ROUND_ROBIN; protected int targetDeviceId; @@ -169,7 +169,7 @@ public class InplaceParallelInference extends ParallelInference { protected final ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock(); // this queue is used in FIFO mode - protected final BlockingQueue queue = new LinkedBlockingQueue<>(); + protected final BlockingQueue queue = new LinkedBlockingQueue<>(); @lombok.Builder.Default protected transient boolean isCG = false; @lombok.Builder.Default protected transient boolean isMLN = false; @@ -204,7 +204,7 @@ public class InplaceParallelInference extends ParallelInference { if (loadBalanceMode == LoadBalanceMode.FIFO) queue.add(model); } else if (sourceModel instanceof MultiLayerNetwork) { - val model = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(((MultiLayerNetwork) sourceModel).getLayerWiseConfigurations().toJson())); + val model = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(((MultiLayerNetwork) sourceModel).getConfiguration().toJson())); model.init(params, false); Nd4j.getExecutioner().commit(); @@ -217,7 +217,7 @@ public class InplaceParallelInference extends ParallelInference { } - protected Model acquireModel() throws InterruptedException { + protected IModel acquireModel() throws InterruptedException { try { modelLock.readLock().lock(); @@ -235,7 +235,7 @@ public class InplaceParallelInference extends ParallelInference { } } - protected void releaseModel(Model model) { + protected void releaseModel(IModel model) { try { modelLock.readLock().lock(); @@ -290,7 +290,7 @@ public class InplaceParallelInference extends ParallelInference { } } - protected void updateModel(@NonNull Model model) { + protected void updateModel(@NonNull IModel model) { try { modelLock.writeLock().lock(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java index 8547e7b9f..ea2e02ad7 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java @@ -23,10 +23,10 @@ package org.deeplearning4j.parallelism; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.ModelAdapter; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.parallelism.inference.InferenceMode; @@ -52,7 +52,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; @Slf4j public class ParallelInference { - protected Model model; + protected IModel model; protected long nanos; protected int workers; protected int batchLimit; @@ -86,7 +86,7 @@ public class ParallelInference { * * @param model */ - public void updateModel(@NonNull Model model) { + public void updateModel(@NonNull IModel model) { if (zoo != null) { for (val w: zoo) w.updateModel(model); @@ -102,11 +102,11 @@ public class ParallelInference { * * @return */ - protected Model[] getCurrentModelsFromWorkers() { + protected IModel[] getCurrentModelsFromWorkers() { if (zoo == null) - return new Model[0]; + return new IModel[0]; - val models = new Model[zoo.length]; + val models = new IModel[zoo.length]; int cnt = 0; for (val w:zoo) { models[cnt++] = w.replicatedModel; @@ -284,14 +284,14 @@ public class ParallelInference { public static class Builder { - private final Model model; + private final IModel model; private int workers = DEFAULT_NUM_WORKERS; private int batchLimit = DEFAULT_BATCH_LIMIT; private InferenceMode inferenceMode = DEFAULT_INFERENCE_MODE; private int queueLimit = DEFAULT_QUEUE_LIMIT; protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.FIFO; - public Builder(@NonNull Model model) { + public Builder(@NonNull IModel model) { this.model = model; } @@ -416,15 +416,15 @@ public class ParallelInference { private final BlockingQueue inputQueue; private final AtomicBoolean shouldWork = new AtomicBoolean(true); private final AtomicBoolean isStopped = new AtomicBoolean(false); - private Model protoModel; - private Model replicatedModel; + private IModel protoModel; + private IModel replicatedModel; private final AtomicLong counter = new AtomicLong(0); private final boolean rootDevice; private final int deviceId; private final ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock(); - private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) { + private InferenceWorker(int id, @NonNull IModel model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) { this.inputQueue = inputQueue; this.protoModel = model; this.rootDevice = rootDevice; @@ -439,7 +439,7 @@ public class ParallelInference { return counter.get(); } - protected void updateModel(@NonNull Model model) { + protected void updateModel(@NonNull IModel model) { try { modelLock.writeLock().lock(); this.protoModel = model; @@ -471,8 +471,8 @@ public class ParallelInference { } } else if (protoModel instanceof MultiLayerNetwork) { if (!rootDevice) { - this.replicatedModel = new MultiLayerNetwork(MultiLayerConfiguration.fromJson( - ((MultiLayerNetwork) protoModel).getLayerWiseConfigurations().toJson())); + this.replicatedModel = new MultiLayerNetwork(NeuralNetConfiguration.fromJson( + ((MultiLayerNetwork) protoModel).getConfiguration().toJson())); this.replicatedModel.init(); synchronized (locker) { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java index 8da3b5262..e2a621508 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java @@ -22,6 +22,7 @@ package org.deeplearning4j.parallelism; import lombok.*; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler; @@ -32,7 +33,6 @@ import org.deeplearning4j.datasets.iterator.DummyBlockDataSetIterator; import org.deeplearning4j.datasets.iterator.DummyBlockMultiDataSetIterator; import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback; import org.deeplearning4j.exception.DL4JInvalidConfigException; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -49,7 +49,6 @@ import org.deeplearning4j.parallelism.factory.DefaultTrainerContext; import org.deeplearning4j.parallelism.factory.SymmetricTrainerContext; import org.deeplearning4j.parallelism.factory.TrainerContext; import org.deeplearning4j.parallelism.trainer.Trainer; -import org.nd4j.common.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; @@ -93,7 +92,7 @@ public class ParallelWrapper implements AutoCloseable { protected AtomicBoolean exceptionEncountered; protected Throwable exception; protected final String uuid = java.util.UUID.randomUUID().toString(); - protected Model model; + protected IModel model; protected int workers = 2; protected int prefetchSize = 2; protected int averagingFrequency = 1; @@ -131,7 +130,7 @@ public class ParallelWrapper implements AutoCloseable { } }; - protected ParallelWrapper(Model model, int workers, int prefetchSize) { + protected ParallelWrapper(IModel model, int workers, int prefetchSize) { this.model = model; this.workers = workers; this.prefetchSize = prefetchSize; @@ -669,7 +668,7 @@ public class ParallelWrapper implements AutoCloseable { } } - public static class Builder { + public static class Builder { protected TrainingMode trainingMode = TrainingMode.AVERAGING; protected T model; protected int workers = Nd4j.getAffinityManager().getNumberOfDevices(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContext.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContext.java index 4aea543eb..dc9fa3982 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContext.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContext.java @@ -20,7 +20,7 @@ package org.deeplearning4j.parallelism.factory; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.parallelism.ParallelWrapper; @@ -35,7 +35,7 @@ public class DefaultTrainerContext implements TrainerContext { * @param args the arguments to initialize with (maybe null) */ @Override - public void init(Model model, Object... args) { + public void init(IModel model, Object... args) { } @@ -53,7 +53,7 @@ public class DefaultTrainerContext implements TrainerContext { * @return the created training instance */ @Override - public Trainer create(String uuid, int threadId, Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, + public Trainer create(String uuid, int threadId, IModel model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, WorkspaceMode mode, int averagingFrequency) { DefaultTrainer trainer = DefaultTrainer.builder().originalModel(model).replicatedModel(model).threadId(threadId) @@ -68,14 +68,14 @@ public class DefaultTrainerContext implements TrainerContext { } @Override - public void finalizeRound(Model originalModel, Model... models) { + public void finalizeRound(IModel originalModel, IModel... models) { // apply averaging // TODO: move averaging here } @Override - public void finalizeTraining(Model originalModel, Model... models) { + public void finalizeTraining(IModel originalModel, IModel... models) { finalizeRound(originalModel, models); } } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.java index 3febe09c0..663cb148c 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.java @@ -21,11 +21,10 @@ package org.deeplearning4j.parallelism.factory; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.parallelism.ParallelWrapper; -import org.deeplearning4j.parallelism.trainer.DefaultTrainer; import org.deeplearning4j.parallelism.trainer.SymmetricTrainer; import org.deeplearning4j.parallelism.trainer.Trainer; @@ -38,7 +37,7 @@ public class SymmetricTrainerContext implements TrainerContext { * @param args the arguments to initialize with (maybe null) */ @Override - public void init(Model model, Object... args) { + public void init(IModel model, Object... args) { } @@ -56,7 +55,7 @@ public class SymmetricTrainerContext implements TrainerContext { * @return the created training instance */ @Override - public Trainer create(String uuid, int threadId, Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, + public Trainer create(String uuid, int threadId, IModel model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, WorkspaceMode mode, int averagingFrequency) { SymmetricTrainer trainer = new SymmetricTrainer(model, uuid, threadId, mode, wrapper, useMDS); @@ -68,12 +67,12 @@ public class SymmetricTrainerContext implements TrainerContext { } @Override - public void finalizeRound(Model originalModel, Model... models) { + public void finalizeRound(IModel originalModel, IModel... models) { // no-op } @Override - public void finalizeTraining(Model originalModel, Model... models) { + public void finalizeTraining(IModel originalModel, IModel... models) { // we CAN avarage here, but for now we'll just push first model params to original model originalModel.setParams(models[0].params()); } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/TrainerContext.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/TrainerContext.java index cc1bd53f7..57cdd76fd 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/TrainerContext.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/TrainerContext.java @@ -20,7 +20,7 @@ package org.deeplearning4j.parallelism.factory; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.parallelism.ParallelWrapper; import org.deeplearning4j.parallelism.trainer.Trainer; @@ -33,7 +33,7 @@ public interface TrainerContext { * @param model * @param args the arguments to initialize with (maybe null) */ - void init(Model model, Object... args); + void init(IModel model, Object... args); /** * Create a {@link Trainer} @@ -47,7 +47,7 @@ public interface TrainerContext { * for coordination with the {@link ParallelWrapper} 's {@link org.deeplearning4j.optimize.api.TrainingListener} * @return the created training instance */ - Trainer create(String uuid, int threadId, Model model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, + Trainer create(String uuid, int threadId, IModel model, int rootDevice, boolean useMDS, ParallelWrapper wrapper, WorkspaceMode workspaceMode, int averagingFrequency); @@ -57,7 +57,7 @@ public interface TrainerContext { * @param originalModel * @param models */ - void finalizeRound(Model originalModel, Model... models); + void finalizeRound(IModel originalModel, IModel... models); /** * This method is called @@ -65,5 +65,5 @@ public interface TrainerContext { * @param originalModel * @param models */ - void finalizeTraining(Model originalModel, Model... models); + void finalizeTraining(IModel originalModel, IModel... models); } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/main/ParallelWrapperMain.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/main/ParallelWrapperMain.java index c0f6c9785..26e76ed61 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/main/ParallelWrapperMain.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/main/ParallelWrapperMain.java @@ -25,10 +25,10 @@ import com.beust.jcommander.Parameter; import com.beust.jcommander.ParameterException; import lombok.Data; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.config.DL4JClassLoading; import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.parallelism.ParallelWrapper; import org.deeplearning4j.core.util.ModelGuesser; @@ -101,7 +101,7 @@ public class ParallelWrapperMain { public void run() throws Exception { - Model model = ModelGuesser.loadModelGuess(modelPath); + IModel model = ModelGuesser.loadModelGuess(modelPath); // ParallelWrapper will take care of load balancing between GPUs. wrapper = new ParallelWrapper.Builder(model) // DataSets prefetching options. Set this value with respect to number of actual devices diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java index be706234f..dd7cda946 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java @@ -22,12 +22,12 @@ package org.deeplearning4j.parallelism.trainer; import lombok.*; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -56,7 +56,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; @AllArgsConstructor public class DefaultTrainer extends Thread implements Trainer { - protected Model replicatedModel; + protected IModel replicatedModel; // TODO: make queue size configurable @Builder.Default @@ -89,7 +89,7 @@ public class DefaultTrainer extends Thread implements Trainer { protected WorkspaceMode workspaceMode; protected int averagingFrequency; protected int threadId; - protected Model originalModel; + protected IModel originalModel; protected final ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock(); @@ -135,12 +135,12 @@ public class DefaultTrainer extends Thread implements Trainer { } @Override - public Model getModel() { + public IModel getModel() { return replicatedModel; } @Override - public void updateModel(@NonNull Model model) { + public void updateModel(@NonNull IModel model) { this.shouldUpdate.set(true); try { modelLock.writeLock().lock(); @@ -295,8 +295,8 @@ public class DefaultTrainer extends Thread implements Trainer { // however, we don't need clone or anything here if (originalModel instanceof MultiLayerNetwork) { if (!onRootModel) { - MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson( - ((MultiLayerNetwork) originalModel).getLayerWiseConfigurations().toJson()); + NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson( + ((MultiLayerNetwork) originalModel).getConfiguration().toJson()); conf.setTrainingWorkspaceMode(workspaceMode); this.replicatedModel = new MultiLayerNetwork(conf); @@ -323,7 +323,7 @@ public class DefaultTrainer extends Thread implements Trainer { if (!((MultiLayerNetwork) replicatedModel).isInitCalled()) this.replicatedModel.init(); - ((MultiLayerNetwork) replicatedModel).getLayerWiseConfigurations() + ((MultiLayerNetwork) replicatedModel).getConfiguration() .setTrainingWorkspaceMode(workspaceMode); } } else if (originalModel instanceof ComputationGraph) { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/SymmetricTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/SymmetricTrainer.java index 96e02cad8..a3a3c57db 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/SymmetricTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/SymmetricTrainer.java @@ -22,7 +22,7 @@ package org.deeplearning4j.parallelism.trainer; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -34,7 +34,7 @@ import org.deeplearning4j.parallelism.ParallelWrapper; public class SymmetricTrainer extends DefaultTrainer implements CommunicativeTrainer { protected GradientsAccumulator accumulator; - public SymmetricTrainer(@NonNull Model originalModel, String uuid, int threadIdx, @NonNull WorkspaceMode mode, + public SymmetricTrainer(@NonNull IModel originalModel, String uuid, int threadIdx, @NonNull WorkspaceMode mode, @NonNull ParallelWrapper wrapper, boolean useMDS) { super(); this.uuid = uuid + "_thread_" + threadIdx; diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/Trainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/Trainer.java index 51e9be570..bdc773b0f 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/Trainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/Trainer.java @@ -21,7 +21,7 @@ package org.deeplearning4j.parallelism.trainer; import lombok.NonNull; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -54,17 +54,17 @@ public interface Trainer extends Runnable { /** * THe current model for the trainer - * @return the current {@link Model} + * @return the current {@link IModel} * for the worker */ - Model getModel(); + IModel getModel(); /** - * Update the current {@link Model} + * Update the current {@link IModel} * for the worker * @param model the new model for this worker */ - void updateModel(@NonNull Model model); + void updateModel(@NonNull IModel model); boolean isRunning(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java index ecb28ef9b..e64e6d06f 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java @@ -40,7 +40,7 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { public void testUpdateModel() { int nIn = 5; - val conf = new NeuralNetConfiguration.Builder() + val conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("out0", new OutputLayer.Builder().nIn(nIn).nOut(4).activation(Activation.SOFTMAX).build(), "in") @@ -68,7 +68,7 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { assertEquals(net.params(), m.params()); } - val conf2 = new NeuralNetConfiguration.Builder() + val conf2 = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("out0", new OutputLayer.Builder().nIn(nIn).nOut(4).activation(Activation.SOFTMAX).build(), "in") @@ -101,7 +101,7 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { public void testOutput_RoundRobin_1() throws Exception { int nIn = 5; - val conf = new NeuralNetConfiguration.Builder() + val conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("out0", new OutputLayer.Builder().nIn(nIn).nOut(4).activation(Activation.SOFTMAX).build(), "in") @@ -134,7 +134,7 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { public void testOutput_FIFO_1() throws Exception { int nIn = 5; - val conf = new NeuralNetConfiguration.Builder() + val conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("out0", new OutputLayer.Builder().nIn(nIn).nOut(4).activation(Activation.SOFTMAX).build(), "in") diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java index 3919bfbc7..5f1ac9a7a 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java @@ -23,12 +23,11 @@ package org.deeplearning4j.parallelism; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JInvalidInputException; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -414,7 +413,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { int nIn = 10; int[] tsLengths = {3,5,7,10,50,100}; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .list() @@ -459,7 +458,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { int nIn = 10; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .list() @@ -527,7 +526,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { {1,nIn,40,45}, }; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .list() @@ -575,7 +574,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { int nIn = 3; int[] defaultShape = new int[]{1, nIn, 16, 16}; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .convolutionMode(ConvolutionMode.Same) @@ -625,7 +624,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { int nIn = 10; int wrongNIn = 5; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .list() @@ -689,7 +688,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { int nIn = 10; int tsLength = 16; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .seed(12345) .list() @@ -757,7 +756,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { public void testModelUpdate_1() throws Exception { int nIn = 5; - val conf = new NeuralNetConfiguration.Builder() + val conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("out0", new OutputLayer.Builder().nIn(nIn).nOut(4).activation(Activation.SOFTMAX).build(), "in") @@ -782,12 +781,12 @@ public class ParallelInferenceTest extends BaseDL4JTest { assertNotEquals(0, output.length); } - Model[] modelsBefore = inf.getCurrentModelsFromWorkers(); + IModel[] modelsBefore = inf.getCurrentModelsFromWorkers(); assertEquals(4, modelsBefore.length); boolean passed = false; int cnt0 = 0; - for (Model m : modelsBefore) { + for (IModel m : modelsBefore) { // model can be null for some of the workers yet, due to race condition if (m != null) { Thread.sleep(500); @@ -799,7 +798,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { assertTrue(passed); - val conf2 = new NeuralNetConfiguration.Builder() + val conf2 = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("out0", new OutputLayer.Builder().nIn(nIn).nOut(4).build(), "in") @@ -830,7 +829,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { int nIn = 5; - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .layer("out0", new OutputLayer.Builder().nIn(nIn).nOut(4).activation(Activation.SOFTMAX).build(), "in") diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java index 458b9dab1..b74262dd2 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java @@ -25,8 +25,8 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -37,8 +37,6 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -72,7 +70,7 @@ public class ParallelWrapperTest extends BaseDL4JTest { log.info("F: {}; L: {};", t0.getFeatures().shape(), t0.getLabels().shape()); log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .l2(0.0005) //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) .weightInit(WeightInit.XAVIER) @@ -90,9 +88,9 @@ public class ParallelWrapperTest extends BaseDL4JTest { .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, nChannels)); + .inputType(InputType.convolutionalFlat(28, 28, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java index eb3ccfef8..799f2dfd7 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java @@ -25,9 +25,8 @@ import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.datasets.iterator.ExistingDataSetIterator; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -56,11 +55,11 @@ public class TestListeners extends BaseDL4JTest { public void testListeners() { TestListener.clearCounts(); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0, + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10) .activation(Activation.TANH).build()); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); @@ -71,7 +70,7 @@ public class TestListeners extends BaseDL4JTest { public void testListenersGraph() { TestListener.clearCounts(); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder() .addInputs("in").addLayer("0", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10) .activation(Activation.TANH).build(), @@ -88,11 +87,11 @@ public class TestListeners extends BaseDL4JTest { public void testListenersViaModel() { TestListener.clearCounts(); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().list().layer(0, + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().list().layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10) .activation(Activation.TANH).build()); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); @@ -109,7 +108,7 @@ public class TestListeners extends BaseDL4JTest { public void testListenersViaModelGraph() { TestListener.clearCounts(); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder() .addInputs("in").addLayer("0", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(10).nOut(10) .activation(Activation.TANH).build(), @@ -128,7 +127,7 @@ public class TestListeners extends BaseDL4JTest { assertEquals(2, ss.listWorkerIDsForSession(ss.listSessionIDs().get(0)).size()); } - private static void testListenersForModel(Model model, List listeners) { + private static void testListenersForModel(IModel model, List listeners) { int nWorkers = 2; ParallelWrapper wrapper = new ParallelWrapper.Builder(model).workers(nWorkers).averagingFrequency(1) @@ -176,26 +175,26 @@ public class TestListeners extends BaseDL4JTest { } @Override - public void onEpochStart(Model model) {} + public void onEpochStart(IModel model) {} @Override - public void onEpochEnd(Model model) {} + public void onEpochEnd(IModel model) {} @Override - public void onForwardPass(Model model, List activations) { + public void onForwardPass(IModel model, List activations) { forwardPassCount.incrementAndGet(); } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { forwardPassCount.incrementAndGet(); } @Override - public void onGradientCalculation(Model model) {} + public void onGradientCalculation(IModel model) {} @Override - public void onBackwardPass(Model model) { + public void onBackwardPass(IModel model) { backwardPassCount.getAndIncrement(); } @@ -233,7 +232,7 @@ public class TestListeners extends BaseDL4JTest { } @Override - public void iterationDone(Model model, int iteration, int epoch) {} + public void iterationDone(IModel model, int iteration, int epoch) {} } } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java index 2eaf2e850..a3b97339b 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java @@ -32,7 +32,6 @@ import org.deeplearning4j.earlystopping.termination.MaxScoreIterationTermination import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition; import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -56,7 +55,7 @@ public class TestParallelEarlyStopping extends BaseDL4JTest { // be properly designed // @Test // public void testEarlyStoppingIris(){ - // MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + // NeuralNetConfiguration conf = NeuralNetConfiguration.builder() // .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) // .updater(Updater.SGD) // .weightInit(WeightInit.XAVIER) @@ -101,7 +100,7 @@ public class TestParallelEarlyStopping extends BaseDL4JTest { @Test public void testEarlyStoppingEveryNEpoch() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) @@ -133,7 +132,7 @@ public class TestParallelEarlyStopping extends BaseDL4JTest { //Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(1.0)) //Intentionally huge LR .weightInit(WeightInit.XAVIER).list() diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java index 7bea67ef6..9d8fe7c70 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java @@ -31,7 +31,6 @@ import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator; import org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition; import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -56,7 +55,7 @@ public class TestParallelEarlyStoppingUI extends BaseDL4JTest { public void testParallelStatsListenerCompatibility() throws Exception { UIServer uiServer = UIServer.getInstance(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java index 3a85b4b34..306338b11 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/DefaultTrainerContextTest.java @@ -21,8 +21,8 @@ package org.deeplearning4j.parallelism.factory; import lombok.val; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -33,7 +33,6 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.parallelism.ParallelWrapper; -import org.deeplearning4j.parallelism.trainer.SymmetricTrainer; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.Nesterovs; @@ -50,11 +49,11 @@ public class DefaultTrainerContextTest extends BaseDL4JTest { @Test public void testEqualUuid1() { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .l2(0.0005) //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) .weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)).list() + .updater(new Nesterovs(0.01, 0.9)) .layer(0, new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()) @@ -68,9 +67,9 @@ public class DefaultTrainerContextTest extends BaseDL4JTest { .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, nChannels)); + .inputType(InputType.convolutionalFlat(28, 28, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java index ec82896df..b61f820f7 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContextTest.java @@ -21,8 +21,8 @@ package org.deeplearning4j.parallelism.factory; import lombok.val; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -50,7 +50,7 @@ public class SymmetricTrainerContextTest extends BaseDL4JTest { @Test public void testEqualUuid1() { - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .l2(0.0005) //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) .weightInit(WeightInit.XAVIER) @@ -68,9 +68,9 @@ public class SymmetricTrainerContextTest extends BaseDL4JTest { .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, nChannels)); + .inputType(InputType.convolutionalFlat(28, 28, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java index 315788855..da27c4c63 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/main/ParallelWrapperMainTest.java @@ -22,8 +22,8 @@ package org.deeplearning4j.parallelism.main; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -66,7 +66,7 @@ public class ParallelWrapperMainTest extends BaseDL4JTest { DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345); log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) .updater(new Nesterovs(0.01, 0.9)).list() @@ -83,9 +83,9 @@ public class ParallelWrapperMainTest extends BaseDL4JTest { .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, nChannels)); + .inputType(InputType.convolutionalFlat(28, 28, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); File tempModel = new File(testDir, "tmpmodel.zip"); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java index 8ea9738db..2adec7f64 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/TrainingHook.java @@ -20,7 +20,7 @@ package org.deeplearning4j.spark.api; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -33,7 +33,7 @@ public interface TrainingHook extends Serializable { * that was used for the update * @param model themodel that was update */ - void preUpdate(DataSet minibatch, Model model); + void preUpdate(DataSet minibatch, IModel model); /** * A hook method for post update @@ -41,7 +41,7 @@ public interface TrainingHook extends Serializable { * that was usd for the update * @param model the model that was updated */ - void postUpdate(DataSet minibatch, Model model); + void postUpdate(DataSet minibatch, IModel model); /** * A hook method for pre update. @@ -49,7 +49,7 @@ public interface TrainingHook extends Serializable { * that was used for the update * @param model the model that was update */ - void preUpdate(MultiDataSet minibatch, Model model); + void preUpdate(MultiDataSet minibatch, IModel model); /** * A hook method for post update @@ -57,6 +57,6 @@ public interface TrainingHook extends Serializable { * that was usd for the update * @param model the model that was updated */ - void postUpdate(MultiDataSet minibatch, Model model); + void postUpdate(MultiDataSet minibatch, IModel model); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java index 9fa317026..f0d69c039 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/api/worker/NetBroadcastTuple.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.api.worker; import lombok.Data; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import java.io.Serializable; @@ -31,13 +31,13 @@ import java.util.concurrent.atomic.AtomicInteger; @Data public class NetBroadcastTuple implements Serializable { - private final MultiLayerConfiguration configuration; + private final NeuralNetConfiguration configuration; private final ComputationGraphConfiguration graphConfiguration; private final INDArray parameters; private final INDArray updaterState; private final AtomicInteger counter; - public NetBroadcastTuple(MultiLayerConfiguration configuration, INDArray parameters, INDArray updaterState) { + public NetBroadcastTuple(NeuralNetConfiguration configuration, INDArray parameters, INDArray updaterState) { this(configuration, null, parameters, updaterState); } @@ -47,12 +47,12 @@ public class NetBroadcastTuple implements Serializable { } - public NetBroadcastTuple(MultiLayerConfiguration configuration, ComputationGraphConfiguration graphConfiguration, + public NetBroadcastTuple(NeuralNetConfiguration configuration, ComputationGraphConfiguration graphConfiguration, INDArray parameters, INDArray updaterState) { this(configuration, graphConfiguration, parameters, updaterState, new AtomicInteger(0)); } - public NetBroadcastTuple(MultiLayerConfiguration configuration, ComputationGraphConfiguration graphConfiguration, + public NetBroadcastTuple(NeuralNetConfiguration configuration, ComputationGraphConfiguration graphConfiguration, INDArray parameters, INDArray updaterState, AtomicInteger counter) { this.configuration = configuration; this.graphConfiguration = graphConfiguration; diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java index 5ed1848b7..8d799c1b2 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/earlystopping/BaseSparkEarlyStoppingTrainer.java @@ -29,7 +29,7 @@ import org.deeplearning4j.earlystopping.scorecalc.ScoreCalculator; import org.deeplearning4j.earlystopping.termination.EpochTerminationCondition; import org.deeplearning4j.earlystopping.termination.IterationTerminationCondition; import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; import org.slf4j.Logger; @@ -39,7 +39,7 @@ import java.io.IOException; import java.util.LinkedHashMap; import java.util.Map; -public abstract class BaseSparkEarlyStoppingTrainer implements IEarlyStoppingTrainer { +public abstract class BaseSparkEarlyStoppingTrainer implements IEarlyStoppingTrainer { private static final Logger log = LoggerFactory.getLogger(BaseSparkEarlyStoppingTrainer.class); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java index ed302a351..e1bfc277f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeReconstructionProbWithKeyFunction.java @@ -31,7 +31,7 @@ public abstract class BaseVaeReconstructionProbWithKeyFunction extends BaseVa /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param useLogProbability If true: use log probability. False: use raw probability. * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java index 4140b8a53..cfcc93b78 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/common/score/BaseVaeScoreWithKeyFunction.java @@ -45,7 +45,7 @@ public abstract class BaseVaeScoreWithKeyFunction implements PairFlatMapFunct /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param batchSize Batch size to use when scoring */ public BaseVaeScoreWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java index a38322234..beb8d7972 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/evaluation/EvaluationRunner.java @@ -22,12 +22,12 @@ package org.deeplearning4j.spark.impl.evaluation; import lombok.*; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; import org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.common.base.Preconditions; @@ -124,14 +124,14 @@ public class EvaluationRunner { EvaluationFuture f = new EvaluationFuture(); f.setResult(evals); try { - Model m; + IModel m; if (isCG) { ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(json.getValue()); ComputationGraph cg = new ComputationGraph(conf); cg.init(deviceLocalParams.get(), false); m = cg; } else { - MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(json.getValue()); + NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson(json.getValue()); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(deviceLocalParams.get(), false); m = net; @@ -176,7 +176,7 @@ public class EvaluationRunner { return f; } - private static void doEval(Model m, IEvaluation[] e, Iterator ds, Iterator mds, int evalBatchSize){ + private static void doEval(IModel m, IEvaluation[] e, Iterator ds, Iterator mds, int evalBatchSize){ if(m instanceof MultiLayerNetwork){ MultiLayerNetwork mln = (MultiLayerNetwork)m; if(ds != null){ diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java index e460ddc2f..84c7cf753 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java @@ -632,7 +632,7 @@ public class SparkComputationGraph extends SparkListenable { * @return {@link RegressionEvaluation} instance with regression performance */ public T evaluateRegression(JavaRDD data, int minibatchSize) { - val nOut = ((FeedForwardLayer) network.getOutputLayer(0).conf().getLayer()).getNOut(); + val nOut = ((FeedForwardLayer) network.getOutputLayer(0).getLayerConfiguration()).getNOut(); return (T)doEvaluation(data, new org.deeplearning4j.eval.RegressionEvaluation(nOut), minibatchSize); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java index b7da3d143..f6794f27d 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionErrorWithKeyFunction.java @@ -33,7 +33,7 @@ public class CGVaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWith /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param batchSize Batch size to use when scoring */ public CGVaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java index 43defe37f..b5413e0dc 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/scoring/CGVaeReconstructionProbWithKeyFunction.java @@ -33,7 +33,7 @@ public class CGVaeReconstructionProbWithKeyFunction extends BaseVaeReconstruc /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param useLogProbability If true: use log probability. False: use raw probability. * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java index d8e1c1437..2e50414da 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java @@ -35,7 +35,7 @@ import org.datavec.spark.util.BroadcastHadoopConfigHolder; import org.deeplearning4j.core.loader.DataSetLoader; import org.deeplearning4j.core.loader.MultiDataSetLoader; import org.deeplearning4j.core.loader.impl.SerializedDataSetLoader; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.spark.api.TrainingMaster; @@ -80,7 +80,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { public static final int DEFAULT_ROC_THRESHOLD_STEPS = 32; public static final int DEFAULT_EVAL_WORKERS = 4; private final transient JavaSparkContext sc; - private final MultiLayerConfiguration conf; + private final NeuralNetConfiguration conf; private MultiLayerNetwork network; private double lastScore; private int defaultEvaluationWorkers = DEFAULT_EVAL_WORKERS; @@ -104,7 +104,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @param sparkContext the spark context to use * @param conf the configuration of the network */ - public SparkDl4jMultiLayer(SparkContext sparkContext, MultiLayerConfiguration conf, + public SparkDl4jMultiLayer(SparkContext sparkContext, NeuralNetConfiguration conf, TrainingMaster trainingMaster) { this(new JavaSparkContext(sparkContext), initNetwork(conf), trainingMaster); } @@ -115,14 +115,14 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @param sc the spark context to use * @param conf the configuration of the network */ - public SparkDl4jMultiLayer(JavaSparkContext sc, MultiLayerConfiguration conf, TrainingMaster trainingMaster) { + public SparkDl4jMultiLayer(JavaSparkContext sc, NeuralNetConfiguration conf, TrainingMaster trainingMaster) { this(sc.sc(), conf, trainingMaster); } public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork network, TrainingMaster trainingMaster) { sc = javaSparkContext; - this.conf = network.getLayerWiseConfigurations().clone(); + this.conf = network.getConfiguration().clone(); this.network = network; if (!network.isInitCalled()) network.init(); @@ -132,7 +132,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { SparkUtils.checkKryoConfiguration(javaSparkContext, log); } - private static MultiLayerNetwork initNetwork(MultiLayerConfiguration conf) { + private static MultiLayerNetwork initNetwork(NeuralNetConfiguration conf) { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); return net; @@ -315,8 +315,8 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @return the multi layer network that was fitDataSet */ public MultiLayerNetwork fitLabeledPoint(JavaRDD rdd) { - int nLayers = network.getLayerWiseConfigurations().getConfs().size(); - FeedForwardLayer ffl = (FeedForwardLayer) network.getLayerWiseConfigurations().getConf(nLayers - 1).getLayer(); + int nLayers = network.getConfiguration().getFlattenedLayerConfigurations().size(); + FeedForwardLayer ffl = (FeedForwardLayer) network.getConfiguration().getFlattenedLayerConfigurations().get(nLayers - 1); JavaRDD ds = MLLibUtil.fromLabeledPoint(sc, rdd, ffl.getNOut()); return fit(ds); } @@ -577,7 +577,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @return {@link RegressionEvaluation} instance with regression performance */ public T evaluateRegression(JavaRDD data, int minibatchSize) { - long nOut = ((FeedForwardLayer) network.getOutputLayer().conf().getLayer()).getNOut(); + long nOut = ((FeedForwardLayer) network.getOutputLayer().getLayerConfiguration()).getNOut(); return (T)doEvaluation(data, new org.deeplearning4j.eval.RegressionEvaluation(nOut), minibatchSize); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java index 510f2e4d4..c064c81d0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/FeedForwardWithKeyFunction.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSetUtil; @@ -49,7 +49,7 @@ public class FeedForwardWithKeyFunction /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param batchSize Batch size to use for forward pass (use > 1 for efficiency) */ public FeedForwardWithKeyFunction(Broadcast params, Broadcast jsonConfig, int batchSize) { @@ -65,7 +65,7 @@ public class FeedForwardWithKeyFunction return Collections.emptyIterator(); } - MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); + MultiLayerNetwork network = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(jsonConfig.getValue())); network.init(); INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java index 6c3878da5..b6a21d181 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/ScoreExamplesFunction.java @@ -21,9 +21,8 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; import org.apache.spark.api.java.function.DoubleFlatMapFunction; -import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.broadcast.Broadcast; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.DataSet; @@ -60,7 +59,7 @@ public class ScoreExamplesFunction implements DoubleFlatMapFunction implements PairFlatMapFunction implements PairFlatMapFunction, DataSetIterator iter = new IteratorDataSetIterator(dataSetIterator, minibatchSize); //Does batching where appropriate - MultiLayerNetwork network = new MultiLayerNetwork(MultiLayerConfiguration.fromJson(json)); + MultiLayerNetwork network = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(json)); network.init(); INDArray val = params.value().unsafeDuplication(); //.value() object will be shared by all executors on each machine -> OK, as params are not modified by score function if (val.length() != network.numParams(false)) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java index a0bcca02b..d9901cbe0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionErrorWithKeyFunction.java @@ -22,21 +22,18 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.spark.impl.common.score.BaseVaeScoreWithKeyFunction; import org.nd4j.linalg.api.ndarray.INDArray; -import scala.Tuple2; - -import java.util.Iterator; public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKeyFunction { /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param batchSize Batch size to use when scoring */ public VaeReconstructionErrorWithKeyFunction(Broadcast params, Broadcast jsonConfig, @@ -47,7 +44,7 @@ public class VaeReconstructionErrorWithKeyFunction extends BaseVaeScoreWithKe @Override public VariationalAutoencoder getVaeLayer() { MultiLayerNetwork network = - new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); + new MultiLayerNetwork(NeuralNetConfiguration.fromJson(jsonConfig.getValue())); network.init(); INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java index d65084dc5..b7cdbd403 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/scoring/VaeReconstructionProbWithKeyFunction.java @@ -22,7 +22,7 @@ package org.deeplearning4j.spark.impl.multilayer.scoring; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.spark.impl.common.score.BaseVaeReconstructionProbWithKeyFunction; @@ -34,7 +34,7 @@ public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructi /** * @param params MultiLayerNetwork parameters - * @param jsonConfig MultiLayerConfiguration, as json + * @param jsonConfig NeuralNetConfiguration, as json * @param useLogProbability If true: use log probability. False: use raw probability. * @param batchSize Batch size to use when scoring * @param numSamples Number of samples to use when calling {@link VariationalAutoencoder#reconstructionLogProbability(INDArray, int)} @@ -47,7 +47,7 @@ public class VaeReconstructionProbWithKeyFunction extends BaseVaeReconstructi @Override public VariationalAutoencoder getVaeLayer() { MultiLayerNetwork network = - new MultiLayerNetwork(MultiLayerConfiguration.fromJson(jsonConfig.getValue())); + new MultiLayerNetwork(NeuralNetConfiguration.fromJson(jsonConfig.getValue())); network.init(); INDArray val = params.value().unsafeDuplication(); if (val.length() != network.numParams(false)) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java index 4a0252b28..1dc1d4f1b 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java @@ -41,7 +41,7 @@ import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.StatsStorageRouterProvider; import org.deeplearning4j.core.storage.StorageMetaData; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.TrainingListener; @@ -275,7 +275,7 @@ public class ParameterAveragingTrainingMaster @Override public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) { - NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getLayerWiseConfigurations(), + NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getConfiguration(), network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray()); if (collectTrainingStats) @@ -727,7 +727,7 @@ public class ParameterAveragingTrainingMaster if (params != null) { //Params may be null for edge case (empty RDD) if (network != null) { - MultiLayerConfiguration conf = network.getNetwork().getLayerWiseConfigurations(); + NeuralNetConfiguration conf = network.getNetwork().getConfiguration(); int numUpdates = averagingFrequency; conf.setIterationCount(conf.getIterationCount() + numUpdates); } else { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java index 87374a584..4820e938f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java @@ -21,13 +21,13 @@ package org.deeplearning4j.spark.impl.paramavg; import lombok.val; +import net.brutex.ai.dnn.api.IModel; import org.apache.spark.broadcast.Broadcast; import org.deeplearning4j.core.storage.Persistable; import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.StatsStorageRouterProvider; import org.deeplearning4j.core.storage.StorageMetaData; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.util.ComputationGraphUtil; @@ -159,7 +159,7 @@ public class ParameterAveragingTrainingWorker extends BaseTrainingWorker list = new ArrayList<>(trainingListeners.size()); for (TrainingListener l : trainingListeners) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index 5a8ac5d7e..e8412455a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -20,13 +20,10 @@ package org.deeplearning4j.spark; -import org.apache.hadoop.conf.Configuration; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.datavec.spark.util.SerializableHadoopConfig; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.spark.impl.multilayer.SparkDl4jMultiLayer; import org.deeplearning4j.spark.impl.paramavg.ParameterAveragingTrainingMaster; @@ -124,8 +121,8 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable return 4; } - protected MultiLayerConfiguration getBasicConf() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + protected NeuralNetConfiguration getBasicConf() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .updater(new Nesterovs(0.1, 0.9)).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) .activation(Activation.TANH).build()) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java index f4e9f674e..09d147283 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSpark.java @@ -35,7 +35,6 @@ import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationC import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition; import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -68,7 +67,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) @@ -123,7 +122,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(10.0)) //Intentionally huge LR .weightInit(WeightInit.XAVIER).list() @@ -163,7 +162,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) @@ -209,7 +208,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) @@ -246,7 +245,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).list() .layer(0, new OutputLayer.Builder().nIn(4).nOut(3) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java index 39618055e..f0e1fefb1 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java @@ -71,7 +71,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { //Spark tests don't run on windows return; } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3) @@ -124,7 +124,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(2.0)) //Intentionally huge LR .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") @@ -165,7 +165,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(1e-6)).weightInit(WeightInit.XAVIER).graphBuilder() .addInputs("in") @@ -213,7 +213,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { return; } Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(0.0)).weightInit(WeightInit.XAVIER).graphBuilder() .addInputs("in") @@ -253,7 +253,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { //Spark tests don't run on windows return; } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd()).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestKryo.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestKryo.java index da5d7822a..7815303f0 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestKryo.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestKryo.java @@ -22,7 +22,6 @@ package org.deeplearning4j.spark; import org.apache.spark.serializer.SerializerInstance; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.graph.*; @@ -68,16 +67,16 @@ public class TestKryo extends BaseSparkKryoTest { Map m = new HashMap<>(); m.put(0, 0.5); m.put(10, 0.1); - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder() - .updater(new Nadam(new MapSchedule(ScheduleType.ITERATION,m))).list().layer(0, new OutputLayer.Builder().nIn(10).nOut(10).build()) + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder() + .updater(new Nadam(new MapSchedule(ScheduleType.ITERATION,m))).layer(0, new OutputLayer.Builder().nIn(10).nOut(10).build()) .build(); testSerialization(mlc, si); - ComputationGraphConfiguration cgc = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration cgc = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder() .dist(new UniformDistribution(-1, 1)) - .updater(new Adam(new MapSchedule(ScheduleType.ITERATION,m))) + .updater(new Adam(new MapSchedule(ScheduleType.ITERATION,m)))) .graphBuilder() .addInputs("in").addLayer("out", new OutputLayer.Builder().nIn(10).nOut(10).build(), "in") .setOutputs("out").build(); @@ -86,7 +85,7 @@ public class TestKryo extends BaseSparkKryoTest { //Check main layers: - Layer[] layers = new Layer[] {new OutputLayer.Builder().nIn(10).nOut(10).build(), + LayerConfiguration[] layers = new LayerConfiguration[] {new OutputLayer.Builder().nIn(10).nOut(10).build(), new RnnOutputLayer.Builder().nIn(10).nOut(10).build(), new LossLayer.Builder().build(), new CenterLossOutputLayer.Builder().nIn(10).nOut(10).build(), new DenseLayer.Builder().nIn(10).nOut(10).build(), @@ -97,7 +96,7 @@ public class TestKryo extends BaseSparkKryoTest { new LSTM.Builder().nIn(10).nOut(10).build(), new DropoutLayer.Builder(0.5).build(), new BatchNormalization.Builder().build(), new LocalResponseNormalization.Builder().build()}; - for (Layer l : layers) { + for (LayerConfiguration l : layers) { testSerialization(l, si); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java index 714c3ffb6..cc32d9723 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/datavec/TestPreProcessedData.java @@ -30,7 +30,6 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.spark.BaseSparkTest; @@ -84,7 +83,7 @@ public class TestPreProcessedData extends BaseSparkTest { iter.next().save(f2); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3) .activation(Activation.TANH).build()) @@ -134,7 +133,7 @@ public class TestPreProcessedData extends BaseSparkTest { iter.next().save(f2); } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3) @@ -188,7 +187,7 @@ public class TestPreProcessedData extends BaseSparkTest { mds.save(f2); } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(Updater.RMSPROP) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(3) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java index ec2195081..402ecb46a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/TestKryoWarning.java @@ -23,7 +23,6 @@ package org.deeplearning4j.spark.impl; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.spark.api.TrainingMaster; @@ -40,7 +39,7 @@ public class TestKryoWarning { try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().list() .layer(0, new OutputLayer.Builder().nIn(10).nOut(10).build()) .build(); @@ -57,7 +56,7 @@ public class TestKryoWarning { try { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("0", new OutputLayer.Builder().nIn(10).nOut(10).build(), "in").setOutputs("0") .build(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java index b3c96333d..d8b0ddb0a 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/TestCustomLayer.java @@ -22,7 +22,6 @@ package org.deeplearning4j.spark.impl.customlayer; import com.sun.jna.Platform; import org.apache.spark.api.java.JavaRDD; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -51,8 +50,8 @@ public class TestCustomLayer extends BaseSparkTest { } //Basic test - checks whether exceptions etc are thrown with custom layers + spark //Custom layers are tested more extensively in dl4j core - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().updater(new Sgd(0.1)).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new CustomLayer(3.14159)).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java index 189e1f529..a9a8e1293 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java @@ -57,9 +57,9 @@ public class CustomLayer extends FeedForwardLayer { ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); - Map paramTable = initializer().init(conf, layerParamsView, initializeParams); + Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setConf(conf); + ret.setLayerConfiguration(conf); return ret; } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java index 579effe1a..109add55d 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java @@ -77,7 +77,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { public static ComputationGraph getBasicNetIris2Class() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .graphBuilder().addInputs("in") .addLayer("l0", new DenseLayer.Builder().nIn(4).nOut(10).build(), "in") .addLayer("l1", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) @@ -104,7 +104,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { while (iter.hasNext()) list.add(iter.next()); - ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration config = NeuralNetConfiguration.builder() .updater(new Sgd(0.1)) .graphBuilder().addInputs("in") .addLayer("dense", new DenseLayer.Builder().nIn(4).nOut(2).build(), "in").addLayer("out", @@ -138,7 +138,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { @Test public void testDistributedScoring() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.1) .seed(123).updater(new Nesterovs(0.1, 0.9)).graphBuilder() .addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) @@ -217,7 +217,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { //@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue") public void testSeedRepeatability() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(Updater.RMSPROP) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(Updater.RMSPROP.getIUpdaterWithDefaultConfig()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(4) @@ -414,7 +414,7 @@ public class TestSparkComputationGraph extends BaseSparkTest { JavaRDD rdd = sc.parallelize(l); // simple model - val modelConf = new NeuralNetConfiguration.Builder() + val modelConf = NeuralNetConfiguration.builder() .updater(new Adam(0.01)) .weightInit(WeightInit.XAVIER_UNIFORM) .biasInit(0) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java index c899fae04..2e01cc17d 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java @@ -23,6 +23,8 @@ package org.deeplearning4j.spark.impl.misc; import org.apache.spark.api.java.JavaRDD; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.DenseLayer.Builder; +import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -45,6 +47,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; import static org.junit.jupiter.api.Assertions.*; @@ -53,7 +56,7 @@ public class TestFrozenLayers extends BaseSparkTest { @Test public void testSparkFrozenLayers() { - NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + NeuralNetConfiguration.NeuralNetConfigurationBuilder overallConf = NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.TANH); FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().updater(new Sgd(0.1)).build(); @@ -61,12 +64,12 @@ public class TestFrozenLayers extends BaseSparkTest { int nIn = 6; int nOut = 3; - MultiLayerNetwork origModel = new MultiLayerNetwork(overallConf.clone().list() - .layer(0, new DenseLayer.Builder().nIn(6).nOut(5).build()) - .layer(1, new DenseLayer.Builder().nIn(5).nOut(4).build()) - .layer(2, new DenseLayer.Builder().nIn(4).nOut(3).build()) - .layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder( - LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) + MultiLayerNetwork origModel = new MultiLayerNetwork((NeuralNetConfiguration) overallConf.clone() + .layer(0, new Builder().nIn(6).nOut(5).build()) + .layer(1, new Builder().nIn(5).nOut(4).build()) + .layer(2, new Builder().nIn(4).nOut(3).build()) + .layer(3, new OutputLayer.Builder( + LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3) .build()) .build()); origModel.init(); @@ -74,7 +77,7 @@ public class TestFrozenLayers extends BaseSparkTest { MultiLayerNetwork withFrozen = new TransferLearning.Builder(origModel).fineTuneConfiguration(finetune) .setFeatureExtractor(1).build(); - Map m = withFrozen.paramTable(); + Map m = withFrozen.getParamTable(); Map pCopy = new HashMap<>(); for (Map.Entry entry : m.entrySet()) { pCopy.put(entry.getKey(), entry.getValue().dup()); @@ -110,7 +113,7 @@ public class TestFrozenLayers extends BaseSparkTest { MultiLayerNetwork fitted = sNet.getNetwork(); - Map fittedParams = fitted.paramTable(); + Map fittedParams = fitted.getParamTable(); for (Map.Entry entry : fittedParams.entrySet()) { INDArray orig = pCopy.get(entry.getKey()); @@ -136,7 +139,7 @@ public class TestFrozenLayers extends BaseSparkTest { int nIn = 6; int nOut = 3; - ComputationGraph origModel = new ComputationGraph(new NeuralNetConfiguration.Builder().updater(new Sgd(0.1)) + ComputationGraph origModel = new ComputationGraph(NeuralNetConfiguration.builder().updater(new Sgd(0.1)) .activation(Activation.TANH).graphBuilder().addInputs("in") .addLayer("0", new DenseLayer.Builder().nIn(6).nOut(5).build(), "in") .addLayer("1", new DenseLayer.Builder().nIn(5).nOut(4).build(), "0") diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java index 8b5a8b46c..6b22acca7 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java @@ -23,7 +23,6 @@ package org.deeplearning4j.spark.impl.multilayer; import org.apache.spark.api.java.JavaPairRDD; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution; @@ -57,7 +56,7 @@ public class TestMiscFunctions extends BaseSparkTest { @Test public void testFeedForwardWithKey() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(3).nOut(3) .activation(Activation.SOFTMAX).build()) @@ -107,7 +106,7 @@ public class TestMiscFunctions extends BaseSparkTest { @Test public void testFeedForwardWithKeyInputMask() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .list() .layer( new LSTM.Builder().nIn(4).nOut(3).build()) .layer(new GlobalPoolingLayer(PoolingType.AVG)) @@ -162,7 +161,7 @@ public class TestMiscFunctions extends BaseSparkTest { @Test public void testFeedForwardWithKeyGraph() { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER) .graphBuilder().addInputs("in1", "in2") .addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "in1") .addLayer("1", new DenseLayer.Builder().nIn(4).nOut(3).build(), "in2").addLayer("2", @@ -220,7 +219,7 @@ public class TestMiscFunctions extends BaseSparkTest { int nIn = 10; - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder().list() .layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .reconstructionDistribution( new GaussianReconstructionDistribution(Activation.IDENTITY)) @@ -259,7 +258,7 @@ public class TestMiscFunctions extends BaseSparkTest { int nIn = 10; - MultiLayerConfiguration mlc = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration mlc = NeuralNetConfiguration.builder() .list().layer(0, new org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder.Builder() .reconstructionDistribution(new LossFunctionWrapper( diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java index d2c0d66bc..7de0dc285 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestSparkDl4jMultiLayer.java @@ -25,7 +25,6 @@ import lombok.extern.slf4j.Slf4j; import org.apache.spark.api.java.JavaRDD; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -41,7 +40,6 @@ import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.ArrayList; @@ -96,7 +94,7 @@ public class TestSparkDl4jMultiLayer extends BaseSparkTest { //---------------------------------- //Create network configuration and conduct network training - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .dataType(DataType.FLOAT) .seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java index 050e6279c..277c4a133 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java @@ -26,7 +26,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -51,7 +50,6 @@ import org.nd4j.linalg.learning.config.Sgd; import org.nd4j.linalg.lossfunctions.LossFunctions; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import static org.junit.jupiter.api.Assertions.*; @@ -63,9 +61,9 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { } - private static MultiLayerConfiguration getConf(int seed, IUpdater updater) { + private static NeuralNetConfiguration getConf(int seed, IUpdater updater) { Nd4j.getRandom().setSeed(seed); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()).layer(1, new OutputLayer.Builder() @@ -74,9 +72,9 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { return conf; } - private static MultiLayerConfiguration getConfCNN(int seed, IUpdater updater) { + private static NeuralNetConfiguration getConfCNN(int seed, IUpdater updater) { Nd4j.getRandom().setSeed(seed); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).list() .layer(0, new ConvolutionLayer.Builder().nOut(3).kernelSize(2, 2).stride(1, 1).padding(0, 0) @@ -85,13 +83,13 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { .activation(Activation.TANH).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MSE).nOut(10) .build()) - .setInputType(InputType.convolutional(10, 10, 3)).build(); + .inputType(InputType.convolutional(10, 10, 3)).build(); return conf; } private static ComputationGraphConfiguration getGraphConf(int seed, IUpdater updater) { Nd4j.getRandom().setSeed(seed); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder() .addInputs("in") @@ -105,7 +103,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { private static ComputationGraphConfiguration getGraphConfCNN(int seed, IUpdater updater) { Nd4j.getRandom().setSeed(seed); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).updater(updater).seed(seed).graphBuilder() .addInputs("in") diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index c2c24a617..8376638f3 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -37,7 +37,6 @@ import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; @@ -121,7 +120,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { .toJavaRDD().map(new TestFn()); DataSet d = new IrisDataSetIterator(150, 150).next(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER) .activation(Activation.RELU).build()) @@ -156,8 +155,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { .getAbsolutePath()) .toJavaRDD().map(new TestFn()); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(123) .updater(new Adam(1e-6)) .weightInit(WeightInit.XAVIER) .list() @@ -211,14 +210,14 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { MultiLayerNetwork netCopy = sparkNet.getNetwork().clone(); netCopy.fit(data); - IUpdater expectedUpdater = ((BaseLayer) netCopy.conf().getLayer()).getIUpdater(); - double expectedLR = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getLearningRate(); - double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.conf().getLayer()).getIUpdater()).getMomentum(); + IUpdater expectedUpdater = ((BaseLayer) netCopy.getLayerConfiguration()).getIUpdater(); + double expectedLR = ((Nesterovs)((BaseLayer) netCopy.getLayerConfiguration()).getIUpdater()).getLearningRate(); + double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.getLayerConfiguration()).getIUpdater()).getMomentum(); - IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater(); + IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater(); sparkNet.fit(sparkData); - double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getLearningRate(); - double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().conf().getLayer()).getIUpdater()).getMomentum(); + double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getLearningRate(); + double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getMomentum(); assertEquals(expectedUpdater, actualUpdater); assertEquals(expectedLR, actualLR, 0.01); @@ -269,7 +268,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) .activation(Activation.TANH).build()) @@ -294,7 +293,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { @Test public void testDistributedScoring() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.1) .seed(123).updater(new Nesterovs(0.1, 0.9)).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(nIn).nOut(3) .activation(Activation.TANH).build()) @@ -383,7 +382,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { list.add(iter.next()); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) .activation(Activation.TANH).build()) @@ -447,7 +446,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) .activation(Activation.TANH).build()) @@ -517,7 +516,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) .activation(Activation.TANH).build()) @@ -605,7 +604,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) @@ -678,7 +677,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .weightInit(WeightInit.XAVIER).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(4).nOut(4) @@ -763,7 +762,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { list.add(iter.next()); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) .activation(Activation.TANH).build()) @@ -785,13 +784,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { JavaRDD rdd = sc.parallelize(list); - assertEquals(0, sparkNet.getNetwork().getLayerWiseConfigurations().getIterationCount()); + assertEquals(0, sparkNet.getNetwork().getConfiguration().getIterationCount()); sparkNet.fit(rdd); assertEquals(minibatchesPerWorkerPerEpoch, - sparkNet.getNetwork().getLayerWiseConfigurations().getIterationCount()); + sparkNet.getNetwork().getConfiguration().getIterationCount()); sparkNet.fit(rdd); assertEquals(2 * minibatchesPerWorkerPerEpoch, - sparkNet.getNetwork().getLayerWiseConfigurations().getIterationCount()); + sparkNet.getNetwork().getConfiguration().getIterationCount()); sparkNet.getTrainingMaster().deleteTempFiles(sc); } @@ -813,7 +812,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { list.add(iter.next()); } - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().updater(new RmsProp()) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().updater(new RmsProp()) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .graphBuilder().addInputs("in") .addLayer("0", new org.deeplearning4j.nn.conf.layers.DenseLayer.Builder().nIn(28 * 28).nOut(50) @@ -854,7 +853,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int nIn = 8; Nd4j.getRandom().setSeed(12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp()) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new RmsProp()) .weightInit(WeightInit.XAVIER).list() .layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(10).encoderLayerSizes(12) .decoderLayerSizes(13).reconstructionDistribution( @@ -890,7 +889,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int nIn = 8; Nd4j.getRandom().setSeed(12345); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).updater(new RmsProp()) + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().seed(12345).updater(new RmsProp()) .weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in") .addLayer("0", new VariationalAutoencoder.Builder().nIn(8).nOut(10).encoderLayerSizes(12) .decoderLayerSizes(13).reconstructionDistribution( @@ -930,8 +929,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int nOut = 2; int layerSize = 10; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).build()) .layer(1, new OutputLayer.Builder().nIn(layerSize).nOut(nOut) .activation(Activation.SOFTMAX).lossFunction( @@ -985,8 +984,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { int nOut = 3; int layerSize = 10; - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().weightInit(WeightInit.XAVIER).list() .layer(0, new DenseLayer.Builder().nIn(nIn).nOut(layerSize).build()) .layer(1, new OutputLayer.Builder().nIn(layerSize).nOut(nOut) .activation(Activation.SOFTMAX).lossFunction( @@ -1039,12 +1038,12 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { //Spark tests don't run on windows return; } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .list() .layer(new OutputLayer.Builder().nIn(4).nOut(3).build()) .build(); - ComputationGraphConfiguration conf2 = new NeuralNetConfiguration.Builder() + ComputationGraphConfiguration conf2 = NeuralNetConfiguration.builder() .graphBuilder() .addInputs("in") .addLayer("out", new OutputLayer.Builder().nIn(4).nOut(3).build(), "in") @@ -1075,11 +1074,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { for(int i=0; i<3; i++ ){ - assertEquals(i, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount()); + assertEquals(i, sn1.getNetwork().getConfiguration().getEpochCount()); assertEquals(i, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount()); sn1.fit(rdd); sn2.fit(rdd); - assertEquals(i+1, sn1.getNetwork().getLayerWiseConfigurations().getEpochCount()); + assertEquals(i+1, sn1.getNetwork().getConfiguration().getEpochCount()); assertEquals(i+1, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount()); } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java index 5d33e82c6..5b735e5a2 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/stats/TestTrainingStatsCollection.java @@ -22,11 +22,9 @@ package org.deeplearning4j.spark.impl.stats; import com.sun.jna.Platform; import org.apache.commons.io.FilenameUtils; -import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -68,7 +66,7 @@ public class TestTrainingStatsCollection extends BaseSparkTest { try { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder().nIn(10).nOut(10).build()) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java index aadf69cdd..1104f8667 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/ui/TestListeners.java @@ -27,7 +27,6 @@ import org.deeplearning4j.core.storage.Persistable; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -60,7 +59,7 @@ public class TestListeners extends BaseSparkTest { JavaSparkContext sc = getContext(); int nExecutors = numExecutors(); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(123) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new DenseLayer.Builder().nIn(4).nOut(100).weightInit(WeightInit.XAVIER) .activation(Activation.RELU).build()) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java index 402560c73..060b88dc1 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/ParameterServerTrainingHook.java @@ -20,7 +20,7 @@ package org.deeplearning4j.spark.parameterserver; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.spark.api.TrainingHook; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.MultiDataSet; @@ -39,7 +39,7 @@ public class ParameterServerTrainingHook implements TrainingHook { * @param model themodel that was update */ @Override - public void preUpdate(DataSet minibatch, Model model) { + public void preUpdate(DataSet minibatch, IModel model) { //pull } @@ -51,7 +51,7 @@ public class ParameterServerTrainingHook implements TrainingHook { * @param model the model that was updated */ @Override - public void postUpdate(DataSet minibatch, Model model) { + public void postUpdate(DataSet minibatch, IModel model) { //push } @@ -63,7 +63,7 @@ public class ParameterServerTrainingHook implements TrainingHook { * @param model themodel that was update */ @Override - public void preUpdate(MultiDataSet minibatch, Model model) { + public void preUpdate(MultiDataSet minibatch, IModel model) { //pull } @@ -75,7 +75,7 @@ public class ParameterServerTrainingHook implements TrainingHook { * @param model the model that was updated */ @Override - public void postUpdate(MultiDataSet minibatch, Model model) { + public void postUpdate(MultiDataSet minibatch, IModel model) { //push } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java index 7e521f0c1..0265837bd 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java @@ -27,7 +27,7 @@ import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.common.config.DL4JEnvironmentVars; import org.deeplearning4j.exception.DL4JInvalidConfigException; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -89,7 +89,7 @@ public class SharedTrainingWrapper { protected ThreadLocal iteratorDataSetCount = new ThreadLocal<>(); //Using AtomicInteger because it's mutable, not because it's atomic protected ThreadLocal observer = new ThreadLocal<>(); protected EncodedGradientsAccumulator accumulator; - protected Model originalModel; + protected IModel originalModel; protected UpdatesConsumer consumer; @@ -200,7 +200,7 @@ public class SharedTrainingWrapper { SharedTrainingConfiguration trainingConfiguration = worker.getBroadcastConfiguration().getValue(); VoidConfiguration voidConfiguration = worker.getBroadcastConfiguration().getValue().getVoidConfiguration(); - Model model = null; + IModel model = null; /* Plan is simple here: if there's defined field in SharedTrainingConfiguration - use that. @@ -425,7 +425,7 @@ public class SharedTrainingWrapper { .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); ((ComputationGraph) originalModel).setGradientsAccumulator(accumulator); } else if (model instanceof MultiLayerNetwork) { - ((MultiLayerNetwork) originalModel).getLayerWiseConfigurations() + ((MultiLayerNetwork) originalModel).getConfiguration() .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); ((MultiLayerNetwork) originalModel).setGradientsAccumulator(accumulator); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java index 1a11d70a5..ef252470b 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java @@ -262,7 +262,7 @@ public class SharedTrainingMaster extends BaseTrainingMaster iterations = Collections.newSetFromMap(new ConcurrentHashMap<>()); private static final Set epochs = Collections.newSetFromMap(new ConcurrentHashMap<>()); @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { iterations.add(iteration); epochs.add(epoch); } diff --git a/cavis-dnn/cavis-dnn-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java b/cavis-dnn/cavis-dnn-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java index d8efd7dbb..920f58979 100644 --- a/cavis-dnn/cavis-dnn-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java +++ b/cavis-dnn/cavis-dnn-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java @@ -22,11 +22,11 @@ import lombok.AllArgsConstructor; import lombok.Data; import lombok.Setter; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.clustering.algorithm.Distance; import org.deeplearning4j.clustering.sptree.DataPoint; import org.deeplearning4j.clustering.sptree.SpTree; import org.deeplearning4j.clustering.vptree.VPTree; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -64,7 +64,7 @@ import static org.nd4j.linalg.ops.transforms.Transforms.sign; */ @Slf4j @Data -public class BarnesHutTsne implements Model { +public class BarnesHutTsne implements IModel { public final static String workspaceCache = "LOOP_CACHE"; @@ -897,12 +897,12 @@ public class BarnesHutTsne implements Model { } @Override - public NeuralNetConfiguration conf() { + public NeuralNetConfiguration getNetConfiguration() { return null; } @Override - public void setConf(NeuralNetConfiguration conf) { + public void setLayerConfiguration(NeuralNetConfiguration layerConfiguration) { } diff --git a/cavis-ui/cavis-ui-common/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java b/cavis-ui/cavis-ui-common/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java index 4770b2d76..44caf37c4 100644 --- a/cavis-ui/cavis-ui-common/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java +++ b/cavis-ui/cavis-ui-common/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java @@ -23,12 +23,12 @@ package org.deeplearning4j.ui.weights; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; +import net.brutex.ai.dnn.api.IModel; import org.datavec.image.loader.ImageLoader; import org.deeplearning4j.core.storage.Persistable; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.optimize.api.BaseTrainingListener; @@ -40,8 +40,6 @@ import org.deeplearning4j.ui.model.weights.ConvolutionListenerPersistable; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.common.io.ClassPathResource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import javax.imageio.ImageIO; import java.awt.*; @@ -60,7 +58,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { } private int freq = 10; - private static final Logger log = LoggerFactory.getLogger(ConvolutionalIterationListener.class); + private int minibatchNum = 0; private boolean openBrowser = true; private final String path; @@ -125,12 +123,12 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { * @param iteration the iteration number */ @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { int iteration = (model instanceof MultiLayerNetwork ? ((MultiLayerNetwork)model).getIterationCount() : ((ComputationGraph)model).getIterationCount()); if (iteration % freq == 0) { @@ -147,7 +145,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { throw new RuntimeException("layers.length != activations.size(). Got layers.length="+layers.length+", activations.size()="+activations.size()); for( int i=0; i activations) { + public void onForwardPass(IModel model, List activations) { int iteration = (model instanceof MultiLayerNetwork ? ((MultiLayerNetwork)model).getIterationCount() : ((ComputationGraph)model).getIterationCount()); if (iteration % freq == 0) { diff --git a/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/ManualTests.java b/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/ManualTests.java index 81cd4e5b1..7c5de3bbb 100644 --- a/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/ManualTests.java +++ b/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/ManualTests.java @@ -29,8 +29,8 @@ import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -125,10 +125,10 @@ public class ManualTests { outputNum, useSubset, true, 1.0, new Random(seed)); log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .activation(Activation.RELU).weightInit(WeightInit.XAVIER) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - .updater(new AdaGrad(0.01)).weightNoise(new DropConnect(0.5)).list() + .updater(new AdaGrad(0.01)).weightNoise(new DropConnect(0.5)) .layer(0, new ConvolutionLayer.Builder(4, 4).name("cnn1").nIn(nChannels).stride(1, 1).nOut(20) .build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[] {2, 2}) @@ -144,7 +144,7 @@ public class ManualTests { .layer(8, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(numRows, numColumns, nChannels)); + .inputType(InputType.convolutional(numRows, numColumns, nChannels)); MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); model.init(); @@ -246,10 +246,10 @@ public class ManualTests { DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, 12345); log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)).list() + .updater(new Nesterovs(0.01, 0.9)) .layer(0, new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()) @@ -263,9 +263,9 @@ public class ManualTests { .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(28, 28, nChannels)); + .inputType(InputType.convolutional(28, 28, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); /* @@ -320,10 +320,10 @@ public class ManualTests { DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); log.info("Build model...."); - MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed) + NeuralNetConfiguration.NeuralNetConfigurationBuilder builder = NeuralNetConfiguration.builder().seed(seed) .l2(0.0005) .weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)).list() + .updater(new Nesterovs(0.01, 0.9)) .layer(0, new FrozenLayer(new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build())) @@ -332,9 +332,9 @@ public class ManualTests { .layer(2, new FrozenLayer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build())) .layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, nChannels)); + .inputType(InputType.convolutionalFlat(28, 28, nChannels)); - MultiLayerConfiguration conf = builder.build(); + NeuralNetConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); diff --git a/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java b/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java index 442f3bf01..e545ff53b 100644 --- a/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java +++ b/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java @@ -21,7 +21,6 @@ package org.deeplearning4j.ui.weights; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -54,9 +53,9 @@ public class TestConvolutionalListener { DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) // Training iterations as above + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) // Training iterations as above .l2(0.0005).weightInit(WeightInit.XAVIER) - .updater(new Nesterovs(0.01, 0.9)).list() + .updater(new Nesterovs(0.01, 0.9)) .layer(0, new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(nChannels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()) @@ -70,7 +69,7 @@ public class TestConvolutionalListener { .layer(4, new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum).activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutionalFlat(28, 28, 1)) //See note below + .inputType(InputType.convolutionalFlat(28, 28, 1)) //See note below .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java index e660a8e04..b9a7e985d 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java @@ -21,6 +21,7 @@ package org.deeplearning4j.ui.model.stats; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.IOUtils; import org.bytedeco.javacpp.Pointer; import org.deeplearning4j.common.config.DL4JClassLoading; @@ -28,7 +29,6 @@ import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.StorageMetaData; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -85,7 +85,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { private Map meanMagGradients; private static class ModelInfo implements Serializable { - private final Model model; + private final IModel model; private long initTime; private long lastReportTime = -1; private int lastReportIteration = -1; @@ -97,12 +97,12 @@ public abstract class BaseStatsListener implements RoutingIterationListener { private int iterCount = 0; - private ModelInfo(Model model) { + private ModelInfo(IModel model) { this.model = model; } } - private ModelInfo getModelInfo(Model model) { + private ModelInfo getModelInfo(IModel model) { ModelInfo mi = null; for (ModelInfo m : modelInfos) { if (m.model == model) { @@ -218,7 +218,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { return sessionID; } - private String getSessionID(Model model) { + private String getSessionID(IModel model) { if (model instanceof MultiLayerNetwork || model instanceof ComputationGraph) return sessionID; if (model instanceof Layer) { @@ -231,17 +231,17 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } @Override - public void onEpochStart(Model model) { + public void onEpochStart(IModel model) { } @Override - public void onEpochEnd(Model model) { + public void onEpochEnd(IModel model) { } @Override - public void onForwardPass(Model model, List activations) { + public void onForwardPass(IModel model, List activations) { int iterCount = getModelInfo(model).iterCount; if (calcFromActivations() && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) { //Assumption: we have input, layer 0, layer 1, ... @@ -257,7 +257,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } @Override - public void onForwardPass(Model model, Map activations) { + public void onForwardPass(IModel model, Map activations) { int iterCount = getModelInfo(model).iterCount; if (calcFromActivations() && updateConfig.reportingFrequency() > 0 && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) { @@ -277,7 +277,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } @Override - public void onGradientCalculation(Model model) { + public void onGradientCalculation(IModel model) { int iterCount = getModelInfo(model).iterCount; if (calcFromGradients() && updateConfig.reportingFrequency() > 0 && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) { @@ -311,12 +311,12 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } @Override - public void onBackwardPass(Model model) { + public void onBackwardPass(IModel model) { //No op } @Override - public void iterationDone(Model model, int iteration, int epoch) { + public void iterationDone(IModel model, int iteration, int epoch) { ModelInfo modelInfo = getModelInfo(model); boolean backpropParamsOnly = backpropParamsOnly(model); @@ -426,10 +426,10 @@ public abstract class BaseStatsListener implements RoutingIterationListener { //Need to append "0_", "1_" etc to param names from layers... int layerIdx = 0; for (Layer l : ((MultiLayerNetwork) model).getLayers()) { - NeuralNetConfiguration conf = l.conf(); - List paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer()); + NeuralNetConfiguration conf = l.getNetConfiguration(); + List paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration()); for (String s : paramkeys) { - double lr = conf.getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); + double lr = conf.getFirstLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); if (Double.isNaN(lr)) { //Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate lr = 0.0; @@ -440,11 +440,11 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } } else if (model instanceof ComputationGraph) { for (Layer l : ((ComputationGraph) model).getLayers()) { - NeuralNetConfiguration conf = l.conf(); - String layerName = conf.getLayer().getLayerName(); - List paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer()); + NeuralNetConfiguration conf = l.getNetConfiguration(); + String layerName = conf.getFirstLayer().getLayerName(); + List paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration()); for (String s : paramkeys) { - double lr = conf.getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); + double lr = conf.getFirstLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); if (Double.isNaN(lr)) { //Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate lr = 0.0; @@ -454,9 +454,9 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } } else if (model instanceof Layer) { Layer l = (Layer) model; - List paramkeys = l.conf().getLayer().initializer().paramKeys(l.conf().getLayer()); + List paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration()); for (String s : paramkeys) { - double lr = l.conf().getLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); + double lr = l.getLayerConfiguration().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); lrs.put(s, lr); } } @@ -575,7 +575,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { return System.currentTimeMillis(); } - private void doInit(Model model) { + private void doInit(IModel model) { boolean backpropParamsOnly = backpropParamsOnly(model); long initTime = System.currentTimeMillis(); //TODO support NTP StatsInitializationReport initReport = getNewInitializationReport(); @@ -652,7 +652,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { long numParams; if (model instanceof MultiLayerNetwork) { MultiLayerNetwork net = ((MultiLayerNetwork) model); - jsonConf = net.getLayerWiseConfigurations().toJson(); + jsonConf = net.getConfiguration().toJson(); numLayers = net.getnLayers(); numParams = net.numParams(); } else if (model instanceof ComputationGraph) { @@ -662,7 +662,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { numParams = cg.numParams(); } else if (model instanceof Layer) { Layer l = (Layer) model; - jsonConf = l.conf().toJson(); + jsonConf = l.getNetConfiguration().toJson(); numLayers = 1; numParams = l.numParams(); } else { @@ -707,7 +707,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } } - private void updateExamplesMinibatchesCounts(Model model) { + private void updateExamplesMinibatchesCounts(IModel model) { ModelInfo modelInfo = getModelInfo(model); int examplesThisMinibatch = 0; if (model instanceof MultiLayerNetwork) { @@ -723,7 +723,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { modelInfo.totalMinibatches++; } - private boolean backpropParamsOnly(Model model) { + private boolean backpropParamsOnly(IModel model) { //For pretrain layers (VAE, AE) we *do* want pretrain params also; for MLN and CG we only want backprop params // as we only have backprop gradients return model instanceof MultiLayerNetwork || model instanceof ComputationGraph; diff --git a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java index 56952d870..9b1a4801e 100644 --- a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java +++ b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java @@ -24,7 +24,6 @@ import org.deeplearning4j.core.storage.Persistable; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -51,10 +50,10 @@ public class TestStatsListener extends BaseDL4JTest { DataSet ds = new IrisDataSetIterator(150, 150).next(); - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) - .list().layer(0, + .layer(0, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(4).nOut(3).build()) .build(); diff --git a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java index 3cf4ec7d9..d5b1a116b 100644 --- a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java +++ b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java @@ -21,7 +21,6 @@ package org.deeplearning4j.ui.stats; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -43,7 +42,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest { @Test public void test() throws IOException { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(0, new DenseLayer.Builder().nIn(10).nOut(10).build()) .layer(1, new OutputLayer.Builder().activation(Activation.SOFTMAX).nIn(10).nOut(10).build()).build(); diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 858648018..7e384dec5 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -38,7 +38,6 @@ import org.deeplearning4j.core.storage.StatsStorageEvent; import org.deeplearning4j.core.storage.StatsStorageListener; import org.deeplearning4j.common.config.DL4JSystemProperties; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; @@ -872,7 +871,7 @@ public class TrainModule implements UIModule { .end(json); } - private TrainModuleUtils.GraphInfo getGraphInfo(Triple conf) { if (conf == null) { return null; @@ -881,7 +880,7 @@ public class TrainModule implements UIModule { if (conf.getFirst() != null) { return TrainModuleUtils.buildGraphInfo(conf.getFirst()); } else if (conf.getSecond() != null) { - return TrainModuleUtils.buildGraphInfo(conf.getSecond()); + return TrainModuleUtils.buildGraphInfo(conf.getSecond().getDefaultConfiguration()); } else if (conf.getThird() != null) { return TrainModuleUtils.buildGraphInfo(conf.getThird()); } else { @@ -889,7 +888,7 @@ public class TrainModule implements UIModule { } } - private Triple getConfig(String sessionId) { + private Triple getConfig(String sessionId) { boolean noData = (sessionId == null || !knownSessionIDs.containsKey(sessionId)); StatsStorage ss = (noData ? null : knownSessionIDs.get(sessionId)); List allStatic = (noData ? Collections.EMPTY_LIST @@ -902,7 +901,7 @@ public class TrainModule implements UIModule { String config = p.getModelConfigJson(); if (modelClass.endsWith("MultiLayerNetwork")) { - MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(config); + NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson(config); return new Triple<>(conf, null, null); } else if (modelClass.endsWith("ComputationGraph")) { ComputationGraphConfiguration conf = ComputationGraphConfiguration.fromJson(config); @@ -940,7 +939,7 @@ public class TrainModule implements UIModule { Map result = new HashMap<>(); result.put("updateTimestamp", lastUpdateTime); - Triple conf = getConfig(sessionId); + Triple conf = getConfig(sessionId); if (conf == null) { rc.response() .putHeader("content-type", "application/json") @@ -1097,7 +1096,7 @@ public class TrainModule implements UIModule { .end(asJson(ret)); } - private static String getLayerType(Layer layer) { + private static String getLayerType(LayerConfiguration layer) { String layerType = "n/a"; if (layer != null) { try { @@ -1124,14 +1123,14 @@ public class TrainModule implements UIModule { //TODO error handling... String layerType = ""; - Layer layer = null; + LayerConfiguration layer = null; NeuralNetConfiguration nnc = null; if (modelClass.endsWith("MultiLayerNetwork")) { - MultiLayerConfiguration conf = MultiLayerConfiguration.fromJson(configJson); + NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson(configJson); int confIdx = layerIdx - 1; //-1 because of input if (confIdx >= 0) { - nnc = conf.getConf(confIdx); - layer = nnc.getLayer(); + nnc = conf.getNetConfigurations().get(confIdx); + layer = nnc.getFirstLayer(); } else { //Input layer layerType = "Input"; @@ -1144,8 +1143,8 @@ public class TrainModule implements UIModule { Map vertices = conf.getVertices(); if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) { LayerVertex lv = (LayerVertex) vertices.get(vertexName); - nnc = lv.getLayerConf(); - layer = nnc.getLayer(); + nnc = lv.getNetConfiguration(); + layer = nnc.getFirstLayer(); } else if (conf.getNetworkInputs().contains(vertexName)) { layerType = "Input"; } else { @@ -1178,7 +1177,7 @@ public class TrainModule implements UIModule { if (layer instanceof BaseLayer) { BaseLayer bl = (BaseLayer) layer; activationFn = bl.getActivationFn().toString(); - long nParams = layer.initializer().numParams(nnc); + long nParams = layer.initializer().numParams(nnc.getFirstLayer()); layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerNParams"), String.valueOf(nParams)}); if (nParams > 0) { diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java index 274e670f6..34b6563f1 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java @@ -24,7 +24,6 @@ import com.fasterxml.jackson.annotation.JsonIgnore; import lombok.AllArgsConstructor; import lombok.Data; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; @@ -50,7 +49,7 @@ public class TrainModuleUtils { private List originalVertexName; } - public static GraphInfo buildGraphInfo(MultiLayerConfiguration config) { + public static GraphInfo buildGraphInfo(NeuralNetConfiguration config) { List vertexNames = new ArrayList<>(); List originalVertexName = new ArrayList<>(); List layerTypes = new ArrayList<>(); @@ -63,17 +62,17 @@ public class TrainModuleUtils { layerInfo.add(Collections.emptyMap()); - List list = config.getConfs(); + List list = config.getNetConfigurations(); int layerIdx = 1; for (NeuralNetConfiguration c : list) { - Layer layer = c.getLayer(); + LayerConfiguration layer = c.getFirstLayer(); String layerName = layer.getLayerName(); if (layerName == null) layerName = "layer" + layerIdx; vertexNames.add(layerName); originalVertexName.add(String.valueOf(layerIdx - 1)); - String layerType = c.getLayer().getClass().getSimpleName().replaceAll("Layer$", ""); + String layerType = c.getFirstLayer().getClass().getSimpleName().replaceAll("Layer$", ""); layerTypes.add(layerType); layerInputs.add(Collections.singletonList(layerIdx - 1)); @@ -87,6 +86,7 @@ public class TrainModuleUtils { return new GraphInfo(vertexNames, layerTypes, layerInputs, layerInfo, originalVertexName); } + /** public static GraphInfo buildGraphInfo(ComputationGraphConfiguration config) { List layerNames = new ArrayList<>(); List layerTypes = new ArrayList<>(); @@ -129,7 +129,7 @@ public class TrainModuleUtils { if (gv instanceof LayerVertex) { NeuralNetConfiguration c = ((LayerVertex) gv).getLayerConf(); - Layer layer = c.getLayer(); + LayerConfiguration layer = c.getFirstLayer(); String layerType = layer.getClass().getSimpleName().replaceAll("Layer$", ""); layerTypes.add(layerType); @@ -148,7 +148,9 @@ public class TrainModuleUtils { return new GraphInfo(layerNames, layerTypes, layerInputs, layerInfo, originalVertexName); } + **/ + /** public static GraphInfo buildGraphInfo(NeuralNetConfiguration config) { List vertexNames = new ArrayList<>(); @@ -162,9 +164,9 @@ public class TrainModuleUtils { layerInputs.add(Collections.emptyList()); layerInfo.add(Collections.emptyMap()); - if (config.getLayer() instanceof VariationalAutoencoder) { + if (config.getFirstLayer() instanceof VariationalAutoencoder) { //Special case like this is a bit ugly - but it works - VariationalAutoencoder va = (VariationalAutoencoder) config.getLayer(); + VariationalAutoencoder va = (VariationalAutoencoder) config.getFirstLayer(); int[] encLayerSizes = va.getEncoderLayerSizes(); int[] decLayerSizes = va.getDecoderLayerSizes(); @@ -240,14 +242,14 @@ public class TrainModuleUtils { } else { //VAE or similar... - Layer layer = config.getLayer(); + LayerConfiguration layer = config.getFirstLayer(); String layerName = layer.getLayerName(); if (layerName == null) layerName = "layer0"; vertexNames.add(layerName); originalVertexName.add("0"); - String layerType = config.getLayer().getClass().getSimpleName().replaceAll("Layer$", ""); + String layerType = config.getFirstLayer().getClass().getSimpleName().replaceAll("Layer$", ""); layerTypes.add(layerType); layerInputs.add(Collections.singletonList(0)); @@ -256,20 +258,18 @@ public class TrainModuleUtils { Map map = getLayerInfo(config, layer); layerInfo.add(map); } - - return new GraphInfo(vertexNames, layerTypes, layerInputs, layerInfo, originalVertexName); } +**/ - - private static Map getLayerInfo(NeuralNetConfiguration c, Layer layer) { + private static Map getLayerInfo(NeuralNetConfiguration c, LayerConfiguration layer) { Map map = new LinkedHashMap<>(); if (layer instanceof FeedForwardLayer) { FeedForwardLayer layer1 = (FeedForwardLayer) layer; map.put("Input size", String.valueOf(layer1.getNIn())); map.put("Output size", String.valueOf(layer1.getNOut())); - map.put("Num Parameters", String.valueOf(layer1.initializer().numParams(c))); + map.put("Num Parameters", String.valueOf(layer1.initializer().numParams(layer))); map.put("Activation Function", layer1.getActivationFn().toString()); } diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java index 37a7aab14..8bae39055 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java @@ -27,7 +27,6 @@ import org.deeplearning4j.core.storage.impl.CollectionStatsStorageRouter; import org.deeplearning4j.core.storage.impl.RemoteUIStatsStorageRouter; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -133,8 +132,8 @@ public class TestRemoteReceiver extends BaseDL4JTest { public void testRemoteFull() throws Exception { //Use this in conjunction with startRemoteUI() - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(4).nOut(3).build()) diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java index 988b9b502..694a557bc 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java @@ -31,7 +31,6 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -60,7 +59,6 @@ import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicReference; import static org.junit.jupiter.api.Assertions.*; @@ -94,10 +92,10 @@ public class TestVertxUI extends BaseDL4JTest { UIServer uiServer = UIServer.getInstance(); uiServer.attach(ss); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(new Sgd(1e-5)) - .list().layer(0, + .layer(0, new VariationalAutoencoder.Builder().nIn(4).nOut(3).encoderLayerSizes(10, 11) .decoderLayerSizes(12, 13).weightInit(WeightInit.XAVIER) .pzxActivationFunction(Activation.IDENTITY) @@ -135,8 +133,8 @@ public class TestVertxUI extends BaseDL4JTest { UIServer uiServer = UIServer.getInstance(); uiServer.attach(ss); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(4).nOut(3).build()) @@ -163,7 +161,7 @@ public class TestVertxUI extends BaseDL4JTest { UIServer uiServer = UIServer.getInstance(); uiServer.attach(ss); - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(), "in") .addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) @@ -185,7 +183,7 @@ public class TestVertxUI extends BaseDL4JTest { @Test public void testAutoAttach() throws Exception { - ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in") + ComputationGraphConfiguration conf = NeuralNetConfiguration.builder().graphBuilder().addInputs("in") .addLayer("L0", new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(4).build(), "in") .addLayer("L1", new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java index d6f11df5e..bc1ae16a8 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java @@ -28,7 +28,6 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -92,12 +91,12 @@ public class TestVertxUIManual extends BaseDL4JTest { int numInputs = 4; int outputNum = 3; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .activation(Activation.TANH) .weightInit(WeightInit.XAVIER) .updater(new Sgd(0.03)) .l2(1e-4) - .list() + .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(3) .build()) .layer(1, new DenseLayer.Builder().nIn(3).nOut(3) @@ -192,8 +191,8 @@ public class TestVertxUIManual extends BaseDL4JTest { ss = new InMemoryStatsStorage(); String sessionId = Integer.toString(session); statsProvider.put(sessionId, ss); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build()) diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java index 5a774dceb..7da17dafd 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java @@ -27,7 +27,6 @@ import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -87,9 +86,9 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { Thread training = new Thread(() -> { int layerSize = sid + 4; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .updater(new Adam(1e-2)) - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build()) @@ -153,8 +152,8 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { InMemoryStatsStorage ss = new InMemoryStatsStorage(); String sessionId = Integer.toString(session); statsStorageForSession.put(sessionId, ss); - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() - .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build()) .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build()) diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/InstantiableModel.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/InstantiableModel.java index 6045f7ca1..fa72dea55 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/InstantiableModel.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/InstantiableModel.java @@ -20,20 +20,20 @@ package org.deeplearning4j.zoo; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; public interface InstantiableModel { void setInputShape(int[][] inputShape); - M init(); + M init(); /** * @deprecated No longer used, will be removed in a future release */ @Deprecated ModelMetaData metaData(); - Class modelType(); + Class modelType(); String pretrainedUrl(PretrainedType pretrainedType); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ZooModel.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ZooModel.java index 958edec33..da2bc3a78 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ZooModel.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/ZooModel.java @@ -21,10 +21,10 @@ package org.deeplearning4j.zoo; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.apache.commons.io.FileUtils; import org.deeplearning4j.common.resources.DL4JResources; import org.deeplearning4j.common.resources.ResourceType; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; @@ -48,7 +48,7 @@ public abstract class ZooModel implements InstantiableModel { * @return * @throws IOException */ - public Model initPretrained() throws IOException { + public IModel initPretrained() throws IOException { return initPretrained(PretrainedType.IMAGENET); } @@ -59,7 +59,7 @@ public abstract class ZooModel implements InstantiableModel { * @return * @throws IOException */ - public M initPretrained(PretrainedType pretrainedType) throws IOException { + public M initPretrained(PretrainedType pretrainedType) throws IOException { String remoteUrl = pretrainedUrl(pretrainedType); if (remoteUrl == null) throw new UnsupportedOperationException( diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/AlexNet.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/AlexNet.java index b65441942..9e55c5b26 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/AlexNet.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/AlexNet.java @@ -22,13 +22,14 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.zoo.ModelMetaData; import org.deeplearning4j.zoo.PretrainedType; import org.deeplearning4j.zoo.ZooModel; @@ -64,15 +65,16 @@ public class AlexNet extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return MultiLayerNetwork.class; } - public MultiLayerConfiguration conf() { + public NeuralNetConfiguration conf() { double nonZeroBias = 1; - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) - .weightInit(new NormalDistribution(0.0, 0.01)) + NeuralNetConfiguration conf = NeuralNetConfiguration.builder() + .seed(seed) + .weightInit( WeightInit.NORMAL) //new NormalDistribution(0.0, 0.01)) .activation(Activation.RELU) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) @@ -84,7 +86,7 @@ public class AlexNet extends ZooModel { .cacheMode(cacheMode) .l2(5 * 1e-4) .miniBatch(false) - .list() + .layer(0, new ConvolutionLayer.Builder(new int[]{11,11}, new int[]{4, 4}) .name("cnn1") .cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST) @@ -158,15 +160,16 @@ public class AlexNet extends ZooModel { .build()) - .setInputType(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0])) - .build(); + .inputType( InputType.convolutional(inputShape[2], inputShape[1], inputShape[0]) ) + .build() + ; return conf; } @Override public MultiLayerNetwork init() { - MultiLayerConfiguration conf = conf(); + NeuralNetConfiguration conf = conf(); MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); return network; diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Darknet19.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Darknet19.java index 739493bd8..f2b07ec58 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Darknet19.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Darknet19.java @@ -22,9 +22,8 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -80,22 +79,22 @@ public class Darknet19 extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } public ComputationGraphConfiguration conf() { - GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder() + GraphBuilder graphBuilder = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder() .seed(seed) .updater(updater) .weightInit(weightInit) - .l2(0.00001) + .l2(0.00001) .activation(Activation.IDENTITY) .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) - .cudnnAlgoMode(cudnnAlgoMode) - .graphBuilder() + .cudnnAlgoMode(cudnnAlgoMode)) + .graphBuilder() .addInputs("input") .setInputTypes(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0])); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/FaceNetNN4Small2.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/FaceNetNN4Small2.java index 487401625..07ce6b985 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/FaceNetNN4Small2.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/FaceNetNN4Small2.java @@ -22,8 +22,7 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.graph.L2NormalizeVertex; @@ -69,13 +68,13 @@ public class FaceNetNN4Small2 extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } public ComputationGraphConfiguration conf() { - ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed) + ComputationGraphConfiguration.GraphBuilder graph = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder().seed(seed) .activation(Activation.IDENTITY) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) @@ -86,7 +85,7 @@ public class FaceNetNN4Small2 extends ZooModel { .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) .cudnnAlgoMode(cudnnAlgoMode) - .convolutionMode(ConvolutionMode.Same) + .convolutionMode(ConvolutionMode.Same)) .graphBuilder(); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/InceptionResNetV1.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/InceptionResNetV1.java index 50f14da0b..2d5d69dda 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/InceptionResNetV1.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/InceptionResNetV1.java @@ -22,18 +22,15 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; -import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.TruncatedNormalDistribution; import org.deeplearning4j.nn.conf.graph.L2NormalizeVertex; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.zoo.ModelMetaData; import org.deeplearning4j.zoo.PretrainedType; import org.deeplearning4j.zoo.ZooModel; @@ -69,7 +66,7 @@ public class InceptionResNetV1 extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } @@ -102,7 +99,8 @@ public class InceptionResNetV1 extends ZooModel { public ComputationGraphConfiguration.GraphBuilder graphBuilder(String input) { - ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed) + ComputationGraphConfiguration.GraphBuilder graph = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder() + .seed(seed) .activation(Activation.RELU) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) @@ -112,7 +110,7 @@ public class InceptionResNetV1 extends ZooModel { .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) - .convolutionMode(ConvolutionMode.Truncate).graphBuilder(); + .convolutionMode(ConvolutionMode.Truncate)).graphBuilder(); graph diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/LeNet.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/LeNet.java index 64a6f8c92..6dc75af5f 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/LeNet.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/LeNet.java @@ -22,9 +22,8 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -74,12 +73,12 @@ public class LeNet extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return MultiLayerNetwork.class; } - public MultiLayerConfiguration conf() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(seed) + public NeuralNetConfiguration conf() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(seed) .activation(Activation.IDENTITY) .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) @@ -89,7 +88,7 @@ public class LeNet extends ZooModel { .inferenceWorkspaceMode(workspaceMode) .cudnnAlgoMode(cudnnAlgoMode) .convolutionMode(ConvolutionMode.Same) - .list() + // block 1 .layer(new ConvolutionLayer.Builder() .name("cnn1") @@ -128,14 +127,14 @@ public class LeNet extends ZooModel { .nOut(numClasses) .activation(Activation.SOFTMAX) // radial basis function required .build()) - .setInputType(InputType.convolutionalFlat(inputShape[2], inputShape[1], inputShape[0])) + .inputType(InputType.convolutionalFlat(inputShape[2], inputShape[1], inputShape[0])) .build(); return conf; } @Override - public Model init() { + public IModel init() { MultiLayerNetwork network = new MultiLayerNetwork(conf()); network.init(); return network; diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/NASNet.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/NASNet.java index 0e78f819e..35f617773 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/NASNet.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/NASNet.java @@ -22,8 +22,8 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -86,7 +86,7 @@ public class NASNet extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } @@ -110,7 +110,7 @@ public class NASNet extends ZooModel { } int filters = (int) Math.floor(penultimateFilters / 24); - ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed) + ComputationGraphConfiguration.GraphBuilder graph = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) .weightInit(weightInit) @@ -120,7 +120,7 @@ public class NASNet extends ZooModel { .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) .cudnnAlgoMode(cudnnAlgoMode) - .convolutionMode(ConvolutionMode.Truncate) + .convolutionMode(ConvolutionMode.Truncate)) .graphBuilder(); if(!skipReduction) { diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/ResNet50.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/ResNet50.java index 2453bb21c..f530e0781 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/ResNet50.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/ResNet50.java @@ -22,19 +22,16 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; -import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.TruncatedNormalDistribution; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.IWeightInit; -import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.zoo.ModelMetaData; import org.deeplearning4j.zoo.PretrainedType; @@ -77,7 +74,7 @@ public class ResNet50 extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } @@ -175,11 +172,11 @@ public class ResNet50 extends ZooModel { public ComputationGraphConfiguration.GraphBuilder graphBuilder() { - ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed) + ComputationGraphConfiguration.GraphBuilder graph = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder().seed(seed) .activation(Activation.IDENTITY) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) - .weightInit(weightInit) + .weightInitFn(weightInit) .l1(1e-7) .l2(5e-5) .miniBatch(true) @@ -187,7 +184,7 @@ public class ResNet50 extends ZooModel { .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) .cudnnAlgoMode(cudnnAlgoMode) - .convolutionMode(ConvolutionMode.Truncate) + .convolutionMode(ConvolutionMode.Truncate)) .graphBuilder(); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/SimpleCNN.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/SimpleCNN.java index 17f22d1f4..f5b1c41ee 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/SimpleCNN.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/SimpleCNN.java @@ -22,8 +22,7 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -63,13 +62,13 @@ public class SimpleCNN extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return MultiLayerNetwork.class; } - public MultiLayerConfiguration conf() { - MultiLayerConfiguration conf = - new NeuralNetConfiguration.Builder().seed(seed) + public NeuralNetConfiguration conf() { + NeuralNetConfiguration conf = + NeuralNetConfiguration.builder().seed(seed) .activation(Activation.IDENTITY) .weightInit(WeightInit.RELU) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) @@ -78,7 +77,7 @@ public class SimpleCNN extends ZooModel { .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) .convolutionMode(ConvolutionMode.Same) - .list() + // block 1 .layer(0, new ConvolutionLayer.Builder(new int[] {7, 7}).name("image_array") .nIn(inputShape[0]).nOut(16).build()) @@ -130,7 +129,7 @@ public class SimpleCNN extends ZooModel { .layer(31, new GlobalPoolingLayer.Builder(PoolingType.AVG).build()) .layer(32, new ActivationLayer.Builder().activation(Activation.SOFTMAX).build()) - .setInputType(InputType.convolutional(inputShape[2], inputShape[1], + .inputType(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0])) .build(); @@ -138,7 +137,7 @@ public class SimpleCNN extends ZooModel { } @Override - public Model init() { + public IModel init() { MultiLayerNetwork network = new MultiLayerNetwork(conf()); network.init(); return network; diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/SqueezeNet.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/SqueezeNet.java index 2f77a2d4c..e63e36cea 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/SqueezeNet.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/SqueezeNet.java @@ -22,12 +22,10 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; -import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -79,12 +77,12 @@ public class SqueezeNet extends ZooModel { public ComputationGraph initPretrained(PretrainedType pretrainedType) throws IOException { ComputationGraph cg = (ComputationGraph) super.initPretrained(pretrainedType); //Set collapse dimensions to true in global avg pooling - more useful for users [N,1000] rather than [N,1000,1,1] out. Also matches non-pretrain config - ((GlobalPoolingLayer)cg.getLayer("global_average_pooling2d_5").conf().getLayer()).setCollapseDimensions(true); + ((GlobalPoolingLayer)cg.getLayer("global_average_pooling2d_5").getLayerConfiguration()).setCollapseDimensions(true); return cg; } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } @@ -103,7 +101,7 @@ public class SqueezeNet extends ZooModel { public ComputationGraphConfiguration.GraphBuilder graphBuilder() { - ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed) + ComputationGraphConfiguration.GraphBuilder graph = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) .weightInit(weightInit) @@ -112,7 +110,7 @@ public class SqueezeNet extends ZooModel { .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) - .convolutionMode(ConvolutionMode.Truncate) + .convolutionMode(ConvolutionMode.Truncate)) .graphBuilder(); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java index 432c74231..962b8f677 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TextGenerationLSTM.java @@ -22,8 +22,7 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; @@ -66,12 +65,12 @@ public class TextGenerationLSTM extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return MultiLayerNetwork.class; } - public MultiLayerConfiguration conf() { - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + public NeuralNetConfiguration conf() { + NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .l2(0.001) .weightInit(WeightInit.XAVIER) @@ -80,21 +79,21 @@ public class TextGenerationLSTM extends ZooModel { .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) .cudnnAlgoMode(cudnnAlgoMode) - .list() + .layer(0, new GravesLSTM.Builder().nIn(inputShape[1]).nOut(256).activation(Activation.TANH) .build()) .layer(1, new GravesLSTM.Builder().nOut(256).activation(Activation.TANH).build()) .layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT) .activation(Activation.SOFTMAX) //MCXENT + softmax for classification .nOut(totalUniqueCharacters).build()) - .backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(50).tBPTTBackwardLength(50) + .backpropType(BackpropType.TruncatedBPTT).tbpttFwdLength(50).tbpttBackLength(50) .build(); return conf; } @Override - public Model init() { + public IModel init() { MultiLayerNetwork network = new MultiLayerNetwork(conf()); network.init(); return network; diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TinyYOLO.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TinyYOLO.java index abbdf06cf..e5281d33d 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TinyYOLO.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/TinyYOLO.java @@ -23,9 +23,8 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; -import lombok.NoArgsConstructor; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; @@ -80,14 +79,14 @@ public class TinyYOLO extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } public ComputationGraphConfiguration conf() { INDArray priors = Nd4j.create(priorBoxes); - GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder() + GraphBuilder graphBuilder = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder() .seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) @@ -98,7 +97,7 @@ public class TinyYOLO extends ZooModel { .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) - .cudnnAlgoMode(cudnnAlgoMode) + .cudnnAlgoMode(cudnnAlgoMode)) .graphBuilder() .addInputs("input") .setInputTypes(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0])); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/UNet.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/UNet.java index ca8136f62..f9400ba8e 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/UNet.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/UNet.java @@ -22,11 +22,10 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; -import org.deeplearning4j.nn.conf.distribution.TruncatedNormalDistribution; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -73,7 +72,7 @@ public class UNet extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } @@ -92,7 +91,7 @@ public class UNet extends ZooModel { public ComputationGraphConfiguration.GraphBuilder graphBuilder() { - ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed) + ComputationGraphConfiguration.GraphBuilder graph = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) .weightInit(weightInit) @@ -100,7 +99,7 @@ public class UNet extends ZooModel { .miniBatch(true) .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) - .inferenceWorkspaceMode(workspaceMode) + .inferenceWorkspaceMode(workspaceMode)) .graphBuilder(); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/VGG16.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/VGG16.java index c52d8988d..2f6aa1cac 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/VGG16.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/VGG16.java @@ -22,8 +22,8 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; @@ -83,19 +83,19 @@ public class VGG16 extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } public ComputationGraphConfiguration conf() { ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(seed) + ((NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) .activation(Activation.RELU) .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) - .inferenceWorkspaceMode(workspaceMode) + .inferenceWorkspaceMode(workspaceMode)) .graphBuilder() .addInputs("in") // block 1 diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/VGG19.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/VGG19.java index ee2bb0725..5e846efda 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/VGG19.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/VGG19.java @@ -22,9 +22,8 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -33,7 +32,6 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.zoo.ModelMetaData; import org.deeplearning4j.zoo.PretrainedType; import org.deeplearning4j.zoo.ZooModel; @@ -74,19 +72,19 @@ public class VGG19 extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } public ComputationGraphConfiguration conf() { ComputationGraphConfiguration conf = - new NeuralNetConfiguration.Builder().seed(seed) + ((NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) .activation(Activation.RELU) .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) - .inferenceWorkspaceMode(workspaceMode) + .inferenceWorkspaceMode(workspaceMode)) .graphBuilder() .addInputs("in") // block 1 diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Xception.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Xception.java index 4c851fa08..bbba3ff3c 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Xception.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Xception.java @@ -22,12 +22,10 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; -import lombok.NoArgsConstructor; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; -import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -39,7 +37,6 @@ import org.deeplearning4j.zoo.ZooModel; import org.deeplearning4j.zoo.ZooType; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.learning.config.AdaDelta; -import org.nd4j.linalg.learning.config.AdaGrad; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -75,7 +72,7 @@ public class Xception extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } @@ -94,7 +91,7 @@ public class Xception extends ZooModel { public ComputationGraphConfiguration.GraphBuilder graphBuilder() { - ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(seed) + ComputationGraphConfiguration.GraphBuilder graph =((NeuralNetConfiguration.NeuralNetConfigurationBuilder) NeuralNetConfiguration.builder().seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) .weightInit(weightInit) @@ -103,7 +100,7 @@ public class Xception extends ZooModel { .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) - .convolutionMode(ConvolutionMode.Truncate) + .convolutionMode(ConvolutionMode.Truncate)) .graphBuilder(); diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/YOLO2.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/YOLO2.java index 030a5c46b..3c28a36a0 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/YOLO2.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/YOLO2.java @@ -23,9 +23,8 @@ package org.deeplearning4j.zoo.model; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Getter; -import lombok.NoArgsConstructor; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.common.resources.DL4JResources; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; @@ -87,14 +86,14 @@ public class YOLO2 extends ZooModel { } @Override - public Class modelType() { + public Class modelType() { return ComputationGraph.class; } public ComputationGraphConfiguration conf() { INDArray priors = Nd4j.create(priorBoxes); - GraphBuilder graphBuilder = new NeuralNetConfiguration.Builder() + GraphBuilder graphBuilder = ((NeuralNetConfiguration.NeuralNetConfigurationBuilder)NeuralNetConfiguration.builder() .seed(seed) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) @@ -105,7 +104,7 @@ public class YOLO2 extends ZooModel { .cacheMode(cacheMode) .trainingWorkspaceMode(workspaceMode) .inferenceWorkspaceMode(workspaceMode) - .cudnnAlgoMode(cudnnAlgoMode) + .cudnnAlgoMode(cudnnAlgoMode)) .graphBuilder() .addInputs("input") .setInputTypes(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0])); diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java index f9e8b83a1..2bf9e7ed1 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java @@ -21,10 +21,10 @@ package org.deeplearning4j.zoo; import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.BenchmarkDataSetIterator; -import org.deeplearning4j.nn.api.Model; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LossLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -40,7 +40,6 @@ import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.LossFunctions; @@ -108,7 +107,7 @@ public class TestInstantiation extends BaseDL4JTest { new int[]{8, inputShape[0], inputShape[1], inputShape[2]}, numClasses, 1, gridWidth, gridHeight); - Model initializedModel = model.init(); + IModel initializedModel = model.init(); AsyncDataSetIterator async = new AsyncDataSetIterator(iter); if (initializedModel instanceof MultiLayerNetwork) { ((MultiLayerNetwork) initializedModel).fit(async); diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java index 44d9dff3c..240cabfcc 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java @@ -45,7 +45,7 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); + assertEquals(net.getConfiguration(), restored.getConfiguration()); assertEquals(net.params(), restored.params()); return restored; From a5dfdcb18fde8a3566bab81878a304ac168f6efb Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 7 Apr 2023 17:05:32 +0200 Subject: [PATCH 122/126] Playing with some new code 2 - clean build Signed-off-by: brian --- .../spark/impl/misc/TestFrozenLayers.java | 8 +- .../src/test/java/net/brutex/gan/GAN.java | 4 +- .../integration/IntegrationTestRunner.java | 16 +- .../deeplearning4j/integration/TestUtils.java | 4 +- .../LayerHelperValidationUtil.java | 8 +- .../java/org/deeplearning4j/TestUtils.java | 4 +- .../org/deeplearning4j/eval/EvalTest.java | 2 +- .../gradientcheck/GradientCheckTests.java | 4 +- .../nn/conf/NeuralNetConfigurationTest.java | 20 +- .../nn/conf/constraints/TestConstraints.java | 2 +- .../nn/conf/dropout/TestDropout.java | 6 +- .../nn/conf/graph/ElementWiseVertexTest.java | 6 +- .../nn/conf/graph/ShiftVertexTest.java | 2 +- .../nn/conf/layers/LayerBuilderTest.java | 10 +- .../conf/preprocessor/TestPreProcessors.java | 16 +- .../deeplearning4j/nn/dtypes/DTypeTests.java | 6 +- .../nn/graph/TestCompGraphUnsupervised.java | 8 +- .../nn/graph/TestComputationGraphNetwork.java | 26 +- .../nn/graph/TestSetGetParameters.java | 6 +- .../nn/layers/BaseLayerTest.java | 12 +- .../nn/layers/FrozenLayerTest.java | 2 +- .../nn/layers/OutputLayerTest.java | 4 +- .../nn/layers/RepeatVectorTest.java | 2 +- .../deeplearning4j/nn/layers/SeedTest.java | 4 +- .../layers/convolution/Convolution3DTest.java | 4 +- .../convolution/ConvolutionLayerTest.java | 8 +- .../layers/convolution/SpaceToDepthTest.java | 2 +- .../convolution/SubsamplingLayerTest.java | 2 +- .../convolution/TestConvolutionModes.java | 16 +- .../layers/convolution/Upsampling1DTest.java | 2 +- .../layers/convolution/Upsampling2DTest.java | 2 +- .../custom/testclasses/CustomLayer.java | 6 +- .../custom/testclasses/CustomLayerImpl.java | 3 +- .../custom/testclasses/CustomOutputLayer.java | 6 +- .../testclasses/CustomOutputLayerImpl.java | 3 +- .../layers/feedforward/dense/DenseTest.java | 4 +- .../normalization/BatchNormalizationTest.java | 4 +- .../nn/layers/ocnn/OCNNOutputLayerTest.java | 2 +- .../GravesBidirectionalLSTMTest.java | 32 +-- .../nn/layers/recurrent/GravesLSTMTest.java | 14 +- .../nn/layers/samediff/TestSameDiffConv.java | 2 +- .../nn/layers/samediff/TestSameDiffDense.java | 2 +- .../samediff/TestSameDiffDenseVertex.java | 4 +- .../nn/layers/variational/TestVAE.java | 12 +- .../nn/misc/WorkspaceTests.java | 8 +- .../nn/multilayer/MultiLayerTest.java | 20 +- .../nn/multilayer/MultiLayerTestRNN.java | 10 +- .../rl/TestMultiModelGradientApplication.java | 4 +- .../nn/transferlearning/TestFrozenLayers.java | 4 +- .../TestTransferLearningModelSerializer.java | 8 +- .../TransferLearningCompGraphTest.java | 12 +- .../TransferLearningMLNTest.java | 32 +-- .../nn/updater/TestGradientNormalization.java | 20 +- .../optimize/solver/TestOptimizers.java | 264 +++++++++++++++++- .../regressiontest/RegressionTest050.java | 6 +- .../regressiontest/RegressionTest060.java | 8 +- .../regressiontest/RegressionTest071.java | 8 +- .../regressiontest/RegressionTest080.java | 8 +- .../regressiontest/RegressionTest100a.java | 6 +- .../regressiontest/RegressionTest100b3.java | 8 +- .../regressiontest/RegressionTest100b4.java | 8 +- .../regressiontest/RegressionTest100b6.java | 8 +- .../customlayer100a/CustomLayer.java | 7 +- .../customlayer100a/CustomLayerImpl.java | 7 +- .../deeplearning4j/util/ModelGuesserTest.java | 4 +- .../util/ModelSerializerTest.java | 4 +- .../configurations/KerasModelImportTest.java | 3 +- .../models/word2vec/Word2VecTestsSmall.java | 2 +- .../java/net/brutex/ai/dnn/api/IModel.java | 10 + .../nn/conf/NeuralNetConfiguration.java | 11 + .../nn/conf/layers/LayerConfiguration.java | 3 + .../nn/multilayer/MultiLayerNetwork.java | 9 + .../EarlyStoppingParallelTrainer.java | 4 +- .../parallelism/InplaceParallelInference.java | 2 +- .../parallelism/ParallelInference.java | 2 +- .../parallelism/ParallelWrapper.java | 4 +- .../parallelism/trainer/DefaultTrainer.java | 6 +- .../impl/multilayer/SparkDl4jMultiLayer.java | 6 +- .../ParameterAveragingTrainingMaster.java | 4 +- .../ParameterAveragingTrainingWorker.java | 4 +- .../impl/customlayer/layer/CustomLayer.java | 6 +- .../customlayer/layer/CustomLayerImpl.java | 3 +- .../spark/impl/misc/TestFrozenLayers.java | 4 +- ...TestSparkMultiLayerParameterAveraging.java | 10 +- .../pw/SharedTrainingWrapper.java | 4 +- .../training/SharedTrainingMaster.java | 2 +- .../deeplearning4j/plot/BarnesHutTsne.java | 114 +++++++- .../ui/model/stats/BaseStatsListener.java | 23 +- .../ui/module/train/TrainModule.java | 8 +- .../ui/module/train/TrainModuleUtils.java | 10 +- .../main/resources/templates/SameDiffUI.html | 2 +- .../org/deeplearning4j/zoo/TestUtils.java | 2 +- 92 files changed, 716 insertions(+), 318 deletions(-) diff --git a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java index f0d15745d..87493404e 100644 --- a/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java +++ b/.old/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java @@ -74,7 +74,7 @@ public class TestFrozenLayers extends BaseSparkTest { MultiLayerNetwork withFrozen = new TransferLearning.Builder(origModel).fineTuneConfiguration(finetune) .setFeatureExtractor(1).build(); - Map m = withFrozen.paramTable(); + Map m = withFrozen.getParamTable(); Map pCopy = new HashMap<>(); for (Map.Entry entry : m.entrySet()) { pCopy.put(entry.getKey(), entry.getValue().dup()); @@ -110,7 +110,7 @@ public class TestFrozenLayers extends BaseSparkTest { MultiLayerNetwork fitted = sNet.getNetwork(); - Map fittedParams = fitted.paramTable(); + Map fittedParams = fitted.getParamTable(); for (Map.Entry entry : fittedParams.entrySet()) { INDArray orig = pCopy.get(entry.getKey()); @@ -151,7 +151,7 @@ public class TestFrozenLayers extends BaseSparkTest { ComputationGraph withFrozen = new TransferLearning.GraphBuilder(origModel).fineTuneConfiguration(finetune) .setFeatureExtractor("1").build(); - Map m = withFrozen.paramTable(); + Map m = withFrozen.getParamTable(); Map pCopy = new HashMap<>(); for (Map.Entry entry : m.entrySet()) { pCopy.put(entry.getKey(), entry.getValue().dup()); @@ -187,7 +187,7 @@ public class TestFrozenLayers extends BaseSparkTest { ComputationGraph fitted = sNet.getNetwork(); - Map fittedParams = fitted.paramTable(); + Map fittedParams = fitted.getParamTable(); for (Map.Entry entry : fittedParams.entrySet()) { INDArray orig = pCopy.get(entry.getKey()); diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java index 659c6ab32..b1e780d59 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java @@ -200,8 +200,8 @@ public class GAN { Layer[] disLayers = ganDiscriminator.getLayers(); Layer[] layers = ArrayUtils.addAll(genLayers, disLayers); - NeuralNetConfiguration genConf = generator.getConfiguration(); - NeuralNetConfiguration disConf = ganDiscriminator.getConfiguration(); + NeuralNetConfiguration genConf = generator.getNetConfiguration(); + NeuralNetConfiguration disConf = ganDiscriminator.getNetConfiguration(); LayerConfiguration[] confLayers = new LayerConfiguration[layers.length]; Map preProcessors = new HashMap<>(); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index 870f4022a..e68751c1b 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -190,7 +190,7 @@ public class IntegrationTestRunner { m = mln; MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true); - assertEquals(loaded.getConfiguration(), mln.getConfiguration(), "Configs not equal"); + assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal"); assertEquals( loaded.params(), mln.params(), "Params not equal"); assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal"); } else if(config instanceof ComputationGraphConfiguration ){ @@ -202,7 +202,7 @@ public class IntegrationTestRunner { ComputationGraph loaded = ComputationGraph.load(savedModel, true); assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" ); assertEquals( loaded.params(), cg.params(), "Params not equal"); - assertEquals(loaded.paramTable(), cg.paramTable(), "Param table not equal"); + assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal"); } else if(config instanceof SameDiff){ sd = (SameDiff)config; SameDiff loaded = SameDiff.load(savedModel, true); @@ -426,8 +426,8 @@ public class IntegrationTestRunner { boolean isTbptt; int tbpttLength; if(modelType == ModelType.MLN){ - isTbptt = mln.getConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; - tbpttLength = mln.getConfiguration().getTbpttFwdLength(); + isTbptt = mln.getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; + tbpttLength = mln.getNetConfiguration().getTbpttFwdLength(); } else if(modelType == ModelType.CG) { isTbptt = cg.getComputationGraphConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; tbpttLength = cg.getComputationGraphConfiguration().getTbpttFwdLength(); @@ -606,7 +606,7 @@ public class IntegrationTestRunner { if (modelType == ModelType.MLN) { ModelSerializer.writeModel(m, f, true); MultiLayerNetwork restored = MultiLayerNetwork.load(f, true); - assertEquals(mln.getConfiguration(), restored.getConfiguration()); + assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration()); assertEquals(mln.params(), restored.params()); } else if(modelType == ModelType.CG){ ModelSerializer.writeModel(m, f, true); @@ -742,7 +742,7 @@ public class IntegrationTestRunner { //Collect preprocessor coverage information: Collection preProcessors; if (isMLN) { - preProcessors = mln.getConfiguration().getInputPreProcessors().values(); + preProcessors = mln.getNetConfiguration().getInputPreProcessors().values(); } else { preProcessors = new ArrayList<>(); for (org.deeplearning4j.nn.conf.graph.GraphVertex gv : cg.getComputationGraphConfiguration().getVertices().values()) { @@ -834,7 +834,7 @@ public class IntegrationTestRunner { } else { paramPrefix = l.getLayerConfiguration().getLayerName() + "_"; } - Map paramTable = l.paramTable(); + Map paramTable = l.getParamTable(); for(Map.Entry e : paramTable.entrySet()){ out.put(paramPrefix + e.getKey(), e.getValue().dup()); } @@ -1088,7 +1088,7 @@ public class IntegrationTestRunner { if(pSoFar + n < i){ pSoFar += n; } else { - for(Map.Entry e : l.paramTable().entrySet()){ + for(Map.Entry e : l.getParamTable().entrySet()){ pSoFar += e.getValue().length(); if(pSoFar >= i){ pName = e.getKey(); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java index bbe38a662..5bdae5d39 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java @@ -48,7 +48,7 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(net.getConfiguration(), restored.getConfiguration()); + assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen @@ -56,7 +56,7 @@ public class TestUtils { } //Also check the NeuralNetConfiguration is serializable (required by Spark etc) - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); serializeDeserializeJava(conf); return restored; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java index 8da3ff4e5..db11f8cc7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java @@ -109,12 +109,12 @@ public class LayerHelperValidationUtil { } - MultiLayerNetwork net1NoHelper = new MultiLayerNetwork(netOrig.getConfiguration().clone()); + MultiLayerNetwork net1NoHelper = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net1NoHelper.init(); log.info("Removing all layer helpers from network copy 1"); removeHelpers(net1NoHelper.getLayers(), null); - MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); + MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net2With.init(); net2With.params().assign(netOrig.params()); log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses()); @@ -253,7 +253,7 @@ public class LayerHelperValidationUtil { Preconditions.checkNotNull(t.getData(), "DataSetIterator is not set (null)"); log.info("Testing run-to-run consistency of training with layer helper"); - net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); + net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net2With.init(); net2With.params().assign(netOrig.params()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); @@ -265,7 +265,7 @@ public class LayerHelperValidationUtil { for( int i=0; i<2; i++ ) { - net2With = new MultiLayerNetwork(netOrig.getConfiguration().clone()); + net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net2With.init(); net2With.params().assign(netOrig.params()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java index 6e4456ef2..374724ae5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -66,7 +66,7 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(net.getConfiguration(), restored.getConfiguration()); + assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); assertEquals(net.params(), restored.params()); } catch (IOException e){ //Should never happen @@ -74,7 +74,7 @@ public class TestUtils { } //Also check the NeuralNetConfiguration is serializable (required by Spark etc) - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); serializeDeserializeJava(conf); return restored; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 8f69cf1d9..8b5f5d46b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -622,7 +622,7 @@ public class EvalTest extends BaseDL4JTest { //Disable validation, and check same thing: - net.getConfiguration().setValidateOutputLayerConfig(false); + net.getNetConfiguration().setValidateOutputLayerConfig(false); net.evaluate(iter); net.evaluateROCMultiClass(iter, 0); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index 90f927d66..6cefb32aa 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -511,7 +511,7 @@ public class GradientCheckTests extends BaseDL4JTest { ComputationGraph netGraph = new ComputationGraph(conf); netGraph.init(); - log.info("params before learning: " + netGraph.getLayer(1).paramTable()); + log.info("params before learning: " + netGraph.getLayer(1).getParamTable()); //Run a number of iterations of learning manually make some pseudo data //the ides is simple: since we do a element wise multiplication layer (just a scaling), we want the cos sim @@ -538,7 +538,7 @@ public class GradientCheckTests extends BaseDL4JTest { assertTrue( scoreAfter < 0.8 * scoreBefore, msg); // expectation in case linear regression(with only element wise multiplication layer): large weight for the fourth weight - log.info("params after learning: " + netGraph.getLayer(1).paramTable()); + log.info("params after learning: " + netGraph.getLayer(1).getParamTable()); boolean gradOK = GradientCheckUtil.checkGradients(new GradientCheckUtil.GraphConfig().net(netGraph).inputs(new INDArray[]{features}) .labels(new INDArray[]{labels})); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java index 64a9fba11..6a7ec6408 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java @@ -100,14 +100,14 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { @Test public void testClone() { NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true); - BaseLayer bl = (BaseLayer) conf.getFirstLayer(); + BaseLayer bl = (BaseLayer) conf.getFlattenedLayerConfigurations().get(0); conf.setStepFunction(new DefaultStepFunction()); NeuralNetConfiguration conf2 = conf.clone(); assertEquals(conf, conf2); assertNotSame(conf, conf2); - assertNotSame(conf.getFirstLayer(), conf2.getFirstLayer()); + assertNotSame(conf.getFlattenedLayerConfigurations().get(0), conf2.getFlattenedLayerConfigurations().get(0)); assertNotSame(conf.getStepFunction(), conf2.getStepFunction()); } @@ -119,9 +119,9 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer).build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer model = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer model = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); INDArray modelWeights = model.getParam(DefaultParamInitializer.WEIGHT_KEY); @@ -130,9 +130,9 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { NeuralNetConfiguration conf2 = NeuralNetConfiguration.builder().seed(123) .optimizationAlgo(OptimizationAlgorithm.CONJUGATE_GRADIENT).layer(layer2).build(); - long numParams2 = conf2.getFirstLayer().initializer().numParams(conf); + long numParams2 = conf2.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params2 = Nd4j.create(1, numParams); - Layer model2 = conf2.getFirstLayer().instantiate(conf2, null, 0, params2, true, params.dataType()); + Layer model2 = conf2.getFlattenedLayerConfigurations().get(0).instantiate(conf2, null, 0, params2, true, params.dataType()); INDArray modelWeights2 = model2.getParam(DefaultParamInitializer.WEIGHT_KEY); assertEquals(modelWeights, modelWeights2); @@ -208,9 +208,9 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { private static Layer getLayer(int nIn, int nOut, IWeightInit weightInit, boolean preTrain) { NeuralNetConfiguration conf = getConfig(nIn, nOut, weightInit, preTrain); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); } @@ -235,7 +235,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - ConvexOptimizer opt = new StochasticGradientDescent(net.getConfiguration(), + ConvexOptimizer opt = new StochasticGradientDescent(net.getNetConfiguration(), new NegativeDefaultStepFunction(), null, net); assertEquals(lr, ((Sgd)net.getLayer(0).getLayerConfiguration().getUpdaterByParam("W")).getLearningRate(), 1e-4); assertEquals(biasLr, ((Sgd)net.getLayer(0).getLayerConfiguration().getUpdaterByParam("b")).getLearningRate(), 1e-4); @@ -295,7 +295,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - ConvexOptimizer opt = new StochasticGradientDescent(net.getConfiguration(), + ConvexOptimizer opt = new StochasticGradientDescent(net.getNetConfiguration(), new NegativeDefaultStepFunction(), null, net); assertEquals(l1, TestUtils.getL1(net.getLayer(0).getLayerConfiguration().getRegularizationByParam("W")), 1e-4); List r = net.getLayer(0).getLayerConfiguration().getRegularizationByParam("b"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java index afbb64726..d1aae72e9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/constraints/TestConstraints.java @@ -456,7 +456,7 @@ public class TestConstraints extends BaseDL4JTest { INDArray label = Nd4j.rand(1, 1); g.fit(new INDArray[]{in1, in2}, new INDArray[]{label}); - for(Map.Entry e : g.paramTable().entrySet()){ + for(Map.Entry e : g.getParamTable().entrySet()){ if(!e.getKey().contains("W")){ continue; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java index 26c266dc7..f574ae089 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/dropout/TestDropout.java @@ -82,9 +82,9 @@ public class TestDropout extends BaseDL4JTest { .setOutputs("2") .build(); - assertEquals(new Dropout(0.6), ((LayerVertex)conf2.getVertices().get("0")).getNetConfiguration().getFirstLayer().getIDropout()); - assertEquals(new Dropout(0.7), ((LayerVertex)conf2.getVertices().get("1")).getNetConfiguration().getFirstLayer().getIDropout()); - assertEquals(new AlphaDropout(0.5), ((LayerVertex)conf2.getVertices().get("2")).getNetConfiguration().getFirstLayer().getIDropout()); + assertEquals(new Dropout(0.6), ((LayerVertex)conf2.getVertices().get("0")).getLayerConfiguration().getIDropout()); + assertEquals(new Dropout(0.7), ((LayerVertex)conf2.getVertices().get("1")).getLayerConfiguration().getIDropout()); + assertEquals(new AlphaDropout(0.5), ((LayerVertex)conf2.getVertices().get("2")).getLayerConfiguration().getIDropout()); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java index 02babc8bc..c3ec4a87c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ElementWiseVertexTest.java @@ -232,7 +232,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { cg.computeGradientAndScore(); // Let's figure out what our params are now. - Map params = cg.paramTable(); + Map params = cg.getParamTable(); INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_b = nullsafe(params.get("dense1_b")); INDArray dense2_W = nullsafe(params.get("dense2_W")); @@ -408,7 +408,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { cg.computeGradientAndScore(); // Let's figure out what our params are now. - Map params = cg.paramTable(); + Map params = cg.getParamTable(); INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_b = nullsafe(params.get("dense1_b")); INDArray dense2_W = nullsafe(params.get("dense2_W")); @@ -578,7 +578,7 @@ public class ElementWiseVertexTest extends BaseDL4JTest { cg.computeGradientAndScore(); // Let's figure out what our params are now. - Map params = cg.paramTable(); + Map params = cg.getParamTable(); INDArray dense1_W = nullsafe(params.get("dense1_W")); INDArray dense1_b = nullsafe(params.get("dense1_b")); INDArray dense2_W = nullsafe(params.get("dense2_W")); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java index cf0e743e6..9cf99a89c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java @@ -159,7 +159,7 @@ public class ShiftVertexTest extends BaseDL4JTest { cg.setLabel(0, target); cg.computeGradientAndScore(); double score_dl4j = cg.score(); - Map weights = cg.paramTable(); + Map weights = cg.getParamTable(); Gradient g = cg.gradient(); Map gradients = g.gradientForVariable(); Map manual_gradients = new TreeMap(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java index e4e7ce73c..3ae5d8bd0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java @@ -212,21 +212,21 @@ public class LayerBuilderTest extends BaseDL4JTest { try (ByteArrayInputStream bis = new ByteArrayInputStream(data); ObjectInput in = new ObjectInputStream(bis)) { confActual = (NeuralNetConfiguration) in.readObject(); } - assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal Java serialization"); + assertEquals(confExpected.getFlattenedLayerConfigurations().get(0), confActual.getFlattenedLayerConfigurations().get(0), "unequal Java serialization"); // check JSON String json = confExpected.toJson(); confActual = NeuralNetConfiguration.fromJson(json); - assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal JSON serialization"); + assertEquals(confExpected.getFlattenedLayerConfigurations().get(0), confActual.getFlattenedLayerConfigurations().get(0), "unequal JSON serialization"); // check YAML String yaml = confExpected.toYaml(); confActual = NeuralNetConfiguration.fromYaml(yaml); - assertEquals(confExpected.getFirstLayer(), confActual.getFirstLayer(), "unequal YAML serialization"); + assertEquals(confExpected.getFlattenedLayerConfigurations().get(0), confActual.getFlattenedLayerConfigurations().get(0), "unequal YAML serialization"); // check the layer's use of callSuper on equals method - confActual.getFirstLayer().setIDropout(new Dropout(new java.util.Random().nextDouble())); - assertNotEquals( confExpected.getFirstLayer(), confActual.getFirstLayer(), "broken equals method (missing callSuper?)"); + confActual.getFlattenedLayerConfigurations().get(0).setIDropout(new Dropout(new java.util.Random().nextDouble())); + assertNotEquals( confExpected, confActual, "broken equals method (missing callSuper?)"); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java index 1f279a762..798762556 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java @@ -62,9 +62,9 @@ public class TestPreProcessors extends BaseDL4JTest { .nOut(layerSize).build()) .build(); - long numParams = nnc.getFirstLayer().initializer().numParams(nnc); + long numParams = nnc.getFlattenedLayerConfigurations().get(0).initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); - DenseLayer layer = (DenseLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); + DenseLayer layer = (DenseLayer) nnc.getFlattenedLayerConfigurations().get(0).instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray activations3dc = Nd4j.create(new int[] {miniBatchSize, layerSize, timeSeriesLength}, 'c'); @@ -147,9 +147,9 @@ public class TestPreProcessors extends BaseDL4JTest { .nOut(layerSize).build()) .build(); - val numParams = nnc.getFirstLayer().initializer().numParams(nnc); + val numParams = nnc.getFlattenedLayerConfigurations().get(0).initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); - DenseLayer layer = (DenseLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); + DenseLayer layer = (DenseLayer) nnc.getFlattenedLayerConfigurations().get(0).instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray rand = Nd4j.rand(miniBatchSize * timeSeriesLength, layerSize); @@ -232,10 +232,10 @@ public class TestPreProcessors extends BaseDL4JTest { .nOut(nChannels).build()) .build(); - val numParams = nnc.getFirstLayer().initializer().numParams(nnc); + val numParams = nnc.getFlattenedLayerConfigurations().get(0).initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); ConvolutionLayer layer = - (ConvolutionLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); + (ConvolutionLayer) nnc.getFlattenedLayerConfigurations().get(0).instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); INDArray activationsCnn = Nd4j.rand(miniBatchSize * timeSeriesLength, nChannels, @@ -314,10 +314,10 @@ public class TestPreProcessors extends BaseDL4JTest { .nOut(nChannels).build()) .build(); - val numParams = nnc.getFirstLayer().initializer().numParams(nnc); + val numParams = nnc.getFlattenedLayerConfigurations().get(0).initializer().numParams(nnc); INDArray params = Nd4j.create(1, numParams); ConvolutionLayer layer = - (ConvolutionLayer) nnc.getFirstLayer().instantiate(nnc, null, 0, params, true, params.dataType()); + (ConvolutionLayer) nnc.getFlattenedLayerConfigurations().get(0).instantiate(nnc, null, 0, params, true, params.dataType()); layer.setInputMiniBatchSize(miniBatchSize); val shape_rnn = new long[] {miniBatchSize, nChannels * inputHeight * inputWidth, diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index 9002ba2af..e37b7b7cb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -256,9 +256,9 @@ public class DTypeTests extends BaseDL4JTest { } public static void logUsedClasses(MultiLayerNetwork net) { - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) { - LayerConfiguration l = nnc.getFirstLayer(); + LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0); seenLayers.add(l.getClass()); if (l instanceof BaseWrapperLayer) { BaseWrapperLayer bwl = (BaseWrapperLayer) l; @@ -281,7 +281,7 @@ public class DTypeTests extends BaseDL4JTest { for (GraphVertex gv : conf.getVertices().values()) { seenVertices.add(gv.getClass()); if (gv instanceof LayerVertex) { - seenLayers.add(((LayerVertex) gv).getNetConfiguration().getFirstLayer().getClass()); + seenLayers.add(((LayerVertex) gv).getLayerConfiguration().getClass()); InputPreProcessor ipp = ((LayerVertex) gv).getPreProcessor(); if (ipp != null) { seenPreprocs.add(ipp.getClass()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java index f4da77575..b24dc76ed 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java @@ -96,11 +96,11 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest { Map paramsBefore = new HashMap<>(); //Pretrain first layer - for(Map.Entry e : cg.paramTable().entrySet()){ + for(Map.Entry e : cg.getParamTable().entrySet()){ paramsBefore.put(e.getKey(), e.getValue().dup()); } cg.pretrainLayer("vae1", ds); - for(Map.Entry e : cg.paramTable().entrySet()){ + for(Map.Entry e : cg.getParamTable().entrySet()){ if(e.getKey().startsWith("vae1")){ assertNotEquals(paramsBefore.get(e.getKey()), e.getValue()); } else { @@ -113,11 +113,11 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest { //Pretrain second layer - for(Map.Entry e : cg.paramTable().entrySet()){ + for(Map.Entry e : cg.getParamTable().entrySet()){ paramsBefore.put(e.getKey(), e.getValue().dup()); } cg.pretrainLayer("vae2", ds); - for(Map.Entry e : cg.paramTable().entrySet()){ + for(Map.Entry e : cg.getParamTable().entrySet()){ if(e.getKey().startsWith("vae2")){ assertNotEquals(paramsBefore.get(e.getKey()), e.getValue()); } else { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index adf347260..7feb29ddb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -406,9 +406,9 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .addLayer("rnn", new GravesLSTM.Builder().nOut(5).build(), "in") .addLayer("out", new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).build(), "rnn").setOutputs("out").build(); - assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("rnn")).getNetConfiguration().getFirstLayer()) + assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("rnn")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getNIn()); - assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("out")).getNetConfiguration().getFirstLayer()) + assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf1.getVertices().get("out")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getNIn()); LayerVertex lv1 = (LayerVertex) conf1.getVertices().get("rnn"); @@ -423,9 +423,9 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .addLayer("out", new RnnOutputLayer.Builder().nOut(5).activation(Activation.SOFTMAX).build(), "ff") .setOutputs("out").build(); - assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("ff")).getNetConfiguration().getFirstLayer()) + assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("ff")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getNIn()); - assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("out")).getNetConfiguration().getFirstLayer()) + assertEquals(5, ((FeedForwardLayer) ((LayerVertex) conf2.getVertices().get("out")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getNIn()); lv1 = (LayerVertex) conf2.getVertices().get("ff"); @@ -460,7 +460,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { LayerVertex lv4 = (LayerVertex) conf3.getVertices().get("out"); assertNull(lv4.getPreProcessor()); //Check nIns: - assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFirstLayer()).getNIn()); + assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFlattenedLayerConfigurations().get(0)).getNIn()); //CNN->Dense, RNN->Dense, Dense->RNN ComputationGraphConfiguration conf4 = @@ -495,9 +495,9 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { LayerVertex lv5 = (LayerVertex) conf4.getVertices().get("out"); assertTrue(lv5.getPreProcessor() instanceof FeedForwardToRnnPreProcessor); //Check nIns: - assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFirstLayer()).getNIn()); - assertEquals(5, ((FeedForwardLayer) lv4.getNetConfiguration().getFirstLayer()).getNIn()); - assertEquals(20, ((FeedForwardLayer) lv5.getNetConfiguration().getFirstLayer()).getNIn()); //10+10 out of the merge vertex -> 20 in to output layer vertex + assertEquals(7 * 7 * 3, ((FeedForwardLayer) lv3.getNetConfiguration().getFlattenedLayerConfigurations().get(0)).getNIn()); + assertEquals(5, ((FeedForwardLayer) lv4.getNetConfiguration().getFlattenedLayerConfigurations().get(0)).getNIn()); + assertEquals(20, ((FeedForwardLayer) lv5.getNetConfiguration().getFlattenedLayerConfigurations().get(0)).getNIn()); //10+10 out of the merge vertex -> 20 in to output layer vertex //Input to 2 CNN layers: @@ -903,7 +903,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .build(); LayerVertex lv = (LayerVertex) conf.getVertices().get("layer"); - FeedForwardLayer l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); + FeedForwardLayer l = ((FeedForwardLayer) (lv).getNetConfiguration().getFlattenedLayerConfigurations().get(0)); assertEquals(3, l.getNIn()); assertNull(lv.getPreProcessor()); @@ -920,7 +920,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { .build(); lv = (LayerVertex) conf.getVertices().get("layer"); - l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); + l = ((FeedForwardLayer) (lv).getNetConfiguration().getFlattenedLayerConfigurations().get(0)); assertEquals(3, l.getNIn()); assertNotNull(lv.getPreProcessor()); InputPreProcessor preProcessor = lv.getPreProcessor(); @@ -945,7 +945,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { //Check subsampling layer: lv = (LayerVertex) conf.getVertices().get("l0"); - SubsamplingLayer sl = ((SubsamplingLayer) (lv).getNetConfiguration().getFirstLayer()); + SubsamplingLayer sl = ((SubsamplingLayer) (lv).getNetConfiguration().getFlattenedLayerConfigurations().get(0)); assertNotNull(lv.getPreProcessor()); preProcessor = lv.getPreProcessor(); assertTrue(preProcessor instanceof FeedForwardToCnnPreProcessor); @@ -955,7 +955,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertEquals(3, preproc.getNumChannels()); //Check dense layer lv = (LayerVertex) conf.getVertices().get("layer"); - l = ((FeedForwardLayer) (lv).getNetConfiguration().getFirstLayer()); + l = ((FeedForwardLayer) (lv).getNetConfiguration().getFlattenedLayerConfigurations().get(0)); assertEquals(3, l.getNIn()); assertNull(lv.getPreProcessor()); @@ -1673,7 +1673,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph g = new ComputationGraph(conf2); g.init(); - g.setParamTable(cg.paramTable()); + g.setParamTable(cg.getParamTable()); int[] origOrder = g.topologicalSortOrder(); INDArray[] out4 = g.output(in); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java index ce8019133..2f752b316 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java @@ -72,9 +72,9 @@ public class TestSetGetParameters extends BaseDL4JTest { assertSame(params, net3.params()); //Same object due to clone - Map paramsMap = net.paramTable(); - Map paramsMap2 = net2.paramTable(); - Map paramsMap3 = net3.paramTable(); + Map paramsMap = net.getParamTable(); + Map paramsMap2 = net2.getParamTable(); + Map paramsMap3 = net3.getParamTable(); for (String s : paramsMap.keySet()) { assertEquals(paramsMap.get(s), paramsMap2.get(s)); assertEquals(paramsMap.get(s), paramsMap3.get(s)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java index 3162ed209..189467ab4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java @@ -57,10 +57,10 @@ public class BaseLayerTest extends BaseDL4JTest { @Test public void testSetExistingParamsConvolutionSingleLayer() { Layer layer = configureSingleLayer(); - assertNotEquals(paramTable, layer.paramTable()); + assertNotEquals(paramTable, layer.getParamTable()); layer.setParamTable(paramTable); - assertEquals(paramTable, layer.paramTable()); + assertEquals(paramTable, layer.getParamTable()); } @@ -69,9 +69,9 @@ public class BaseLayerTest extends BaseDL4JTest { MultiLayerNetwork net = configureMultiLayer(); for (Layer layer : net.getLayers()) { - assertNotEquals(paramTable, layer.paramTable()); + assertNotEquals(paramTable, layer.getParamTable()); layer.setParamTable(paramTable); - assertEquals(paramTable, layer.paramTable()); + assertEquals(paramTable, layer.getParamTable()); } } @@ -83,9 +83,9 @@ public class BaseLayerTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .layer(new ConvolutionLayer.Builder().nIn(nIn).nOut(nOut).build()).build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java index 1e83adaf2..2b8977ed0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java @@ -133,7 +133,7 @@ public class FrozenLayerTest extends BaseDL4JTest { MultiLayerNetwork clonedModel = modelNow.clone(); //Check json - assertEquals(modelNow.getConfiguration().toJson(), clonedModel.getConfiguration().toJson()); + assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson()); //Check params assertEquals(modelNow.params(), clonedModel.params()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java index 0bdf441ac..0d4f0d710 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java @@ -64,9 +64,9 @@ public class OutputLayerTest extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - OutputLayer l = (OutputLayer) conf.getFirstLayer().instantiate(conf, + OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); params = l.params(); l.setParamsTable(params); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java index 483e34572..a62ccdcf0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/RepeatVectorTest.java @@ -43,7 +43,7 @@ public class RepeatVectorTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder().seed(123) .dataType(DataType.DOUBLE) .layer(new RepeatVector.Builder(REPEAT).build()).build(); - return conf.getFirstLayer().instantiate(conf, null, 0, + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, false, DataType.DOUBLE); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java index db7d4525c..6306c333b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java @@ -52,9 +52,9 @@ public class SeedTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(layerType).seed(123).build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java index d4a685a3a..c8137f4a6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Convolution3DTest.java @@ -90,9 +90,9 @@ public class Convolution3DTest extends BaseDL4JTest { .dataFormat(Convolution3D.DataFormat.NCDHW).convolutionMode(mode).hasBias(false) .build()) .build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.ones(1, numParams); - return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); } public INDArray getData() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index 0c58b8703..f234d3b78 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -258,9 +258,9 @@ public class ConvolutionLayerTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(cnn).build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); assertEquals(1, layer.getParam("b").size(0)); } @@ -319,9 +319,9 @@ public class ConvolutionLayerTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(layer).build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - return conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); } public Layer getMNISTConfig() { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java index ed8e8c99d..1c47e1b2d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SpaceToDepthTest.java @@ -62,7 +62,7 @@ public class SpaceToDepthTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new SpaceToDepthLayer.Builder(blockSize, dataFormat).build()).build(); - return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java index 9fda734eb..4cc8341cc 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/SubsamplingLayerTest.java @@ -172,7 +172,7 @@ public class SubsamplingLayerTest extends BaseDL4JTest { .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new SubsamplingLayer.Builder(pooling, new int[] {2, 2}).build()).build(); - return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } public INDArray getData() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java index 61f937cec..8cdc85768 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/TestConvolutionModes.java @@ -287,28 +287,28 @@ public class TestConvolutionModes extends BaseDL4JTest { .activation(Activation.SOFTMAX).nOut(3).build(), "7") .setOutputs("8").build(); - assertEquals(cm, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("0")).getNetConfiguration().getFirstLayer()) + assertEquals(cm, ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("0")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getConvolutionMode()); assertEquals(ConvolutionMode.Strict, - ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("1")).getNetConfiguration().getFirstLayer()) + ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("1")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getConvolutionMode()); assertEquals(ConvolutionMode.Truncate, - ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("2")).getNetConfiguration().getFirstLayer()) + ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("2")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getConvolutionMode()); assertEquals(ConvolutionMode.Same, - ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("3")).getNetConfiguration().getFirstLayer()) + ((ConvolutionLayer) ((LayerVertex) conf.getVertices().get("3")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getConvolutionMode()); - assertEquals(cm, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("4")).getNetConfiguration().getFirstLayer()) + assertEquals(cm, ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("4")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getConvolutionMode()); assertEquals(ConvolutionMode.Strict, - ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("5")).getNetConfiguration().getFirstLayer()) + ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("5")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getConvolutionMode()); assertEquals(ConvolutionMode.Truncate, - ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("6")).getNetConfiguration().getFirstLayer()) + ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("6")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getConvolutionMode()); assertEquals(ConvolutionMode.Same, - ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("7")).getNetConfiguration().getFirstLayer()) + ((SubsamplingLayer) ((LayerVertex) conf.getVertices().get("7")).getNetConfiguration().getFlattenedLayerConfigurations().get(0)) .getConvolutionMode()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java index 5d74b94fa..064464d67 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling1DTest.java @@ -107,7 +107,7 @@ public class Upsampling1DTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new Upsampling1D.Builder(size).build()).build(); - return conf.getFirstLayer().instantiate(conf, null, 0, + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java index bfb872ba8..286259904 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/Upsampling2DTest.java @@ -111,7 +111,7 @@ public class Upsampling2DTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder() .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).seed(123) .layer(new Upsampling2D.Builder(size).build()).build(); - return conf.getFirstLayer().instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); + return conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, null, true, Nd4j.defaultFloatingPointType()); } public INDArray getData() throws Exception { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java index f3b201d63..ea59a8091 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; @@ -53,13 +54,14 @@ public class CustomLayer extends FeedForwardLayer { public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - CustomLayerImpl ret = new CustomLayerImpl(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setLayerConfiguration(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayerImpl.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayerImpl.java index e0f582a52..38a7d215b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayerImpl.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayerImpl.java @@ -21,11 +21,12 @@ package org.deeplearning4j.nn.layers.custom.testclasses; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.layers.BaseLayer; import org.nd4j.linalg.api.buffer.DataType; public class CustomLayerImpl extends BaseLayer { - public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { + public CustomLayerImpl(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java index b64a341d8..80c983589 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -51,13 +52,14 @@ public class CustomOutputLayer extends BaseOutputLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - CustomOutputLayerImpl ret = new CustomOutputLayerImpl(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setLayerConfiguration(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayerImpl.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayerImpl.java index 349adba9d..f48f35038 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayerImpl.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayerImpl.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.layers.custom.testclasses; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.layers.BaseOutputLayer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -28,7 +29,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; public class CustomOutputLayerImpl extends BaseOutputLayer { - public CustomOutputLayerImpl(NeuralNetConfiguration conf, DataType dataType) { + public CustomOutputLayerImpl(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java index 382476fc9..ba1129cef 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java @@ -53,9 +53,9 @@ public class DenseTest extends BaseDL4JTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(build).build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, Nd4j.defaultFloatingPointType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, Nd4j.defaultFloatingPointType()); assertEquals(1, layer.getParam("b").size(0)); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java index e6f85611a..eb76c88f2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/normalization/BatchNormalizationTest.java @@ -130,12 +130,12 @@ public class BatchNormalizationTest extends BaseDL4JTest { BatchNormalization bN = b.build(); NeuralNetConfiguration conf = NeuralNetConfiguration.builder().layer(bN).build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = null; if (numParams > 0) { params = Nd4j.create(1, numParams); } - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params == null ? Nd4j.defaultFloatingPointType() : params.dataType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params == null ? Nd4j.defaultFloatingPointType() : params.dataType()); if (numParams > 0) { layer.setBackpropGradientsViewArray(Nd4j.create(1, numParams)); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java index c989d0bf5..c0f6fa24c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java @@ -123,7 +123,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { DataSet filtered = next.filterBy(new int[]{0, 1}); for (int i = 0; i < 10; i++) { network.setEpochCount(i); - network.getConfiguration().setEpochCount(i); + network.getNetConfiguration().setEpochCount(i); network.fit(filtered); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index d51fc5280..7d8dd8977 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -68,10 +68,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .nOut(nHiddenUnits).dataFormat(rnnDataFormat).activation(Activation.TANH).build()) .build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); final GravesBidirectionalLSTM layer = - (GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + (GravesBidirectionalLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; @@ -135,11 +135,11 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); GravesBidirectionalLSTM lstm = - (GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); - lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFirstLayer().initializer().numParams(conf))); + (GravesBidirectionalLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); + lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf))); //Set input, do a forward pass: lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(lstm.input()); @@ -207,10 +207,10 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); final GravesBidirectionalLSTM lstm = - (GravesBidirectionalLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + (GravesBidirectionalLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); final INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); @@ -266,9 +266,9 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .build(); - long numParams = confBidirectional.getFirstLayer().initializer().numParams(confBidirectional); + long numParams = confBidirectional.getFlattenedLayerConfigurations().get(0).initializer().numParams(confBidirectional); INDArray params = Nd4j.create(1, numParams); - final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFirstLayer() + final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFlattenedLayerConfigurations().get(0) .instantiate(confBidirectional, null, 0, params, true, params.dataType()); @@ -311,19 +311,19 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { .weightInit(WeightInit.ZERO).activation(Activation.TANH).build()) .build(); - long numParams = confForwards.getFirstLayer().initializer().numParams(confForwards); + long numParams = confForwards.getFlattenedLayerConfigurations().get(0).initializer().numParams(confForwards); INDArray params = Nd4j.create(1, numParams); - long numParamsBD = confBidirectional.getFirstLayer().initializer().numParams(confBidirectional); + long numParamsBD = confBidirectional.getFlattenedLayerConfigurations().get(0).initializer().numParams(confBidirectional); INDArray paramsBD = Nd4j.create(1, numParamsBD); - final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFirstLayer() + final GravesBidirectionalLSTM bidirectionalLSTM = (GravesBidirectionalLSTM) confBidirectional.getFlattenedLayerConfigurations().get(0) .instantiate(confBidirectional, null, 0, paramsBD, true, params.dataType()); final GravesLSTM forwardsLSTM = - (GravesLSTM) confForwards.getFirstLayer().instantiate(confForwards, null, 0, params, true, params.dataType()); + (GravesLSTM) confForwards.getFlattenedLayerConfigurations().get(0).instantiate(confForwards, null, 0, params, true, params.dataType()); bidirectionalLSTM.setBackpropGradientsViewArray( - Nd4j.create(1, confBidirectional.getFirstLayer().initializer().numParams(confBidirectional))); + Nd4j.create(1, confBidirectional.getFlattenedLayerConfigurations().get(0).initializer().numParams(confBidirectional))); forwardsLSTM.setBackpropGradientsViewArray( - Nd4j.create(1, confForwards.getFirstLayer().initializer().numParams(confForwards))); + Nd4j.create(1, confForwards.getFlattenedLayerConfigurations().get(0).initializer().numParams(confForwards))); final INDArray sig = (rnnDataFormat == RNNFormat.NCW)?Nd4j.rand(miniBatchSize, nIn, timeSeriesLength): @@ -546,7 +546,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { net.init(); assertEquals(gateAfn, ((org.deeplearning4j.nn.conf.layers.GravesBidirectionalLSTM) net.getLayer(0).getNetConfiguration() - .getFirstLayer()).getGateActivationFn().toString()); + .getFlattenedLayerConfigurations().get(0)).getGateActivationFn().toString()); INDArray in = Nd4j.rand(3, 2, 5); INDArray labels = Nd4j.rand(3, 2, 5); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java index 2868c08d8..791ff8fa6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTMTest.java @@ -63,9 +63,9 @@ public class GravesLSTMTest extends BaseDL4JTest { .nOut(nHiddenUnits).activation(Activation.TANH).build()) .build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM layer = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + GravesLSTM layer = (GravesLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); //Data: has shape [miniBatchSize,nIn,timeSeriesLength]; //Output/activations has shape [miniBatchsize,nHiddenUnits,timeSeriesLength]; @@ -109,10 +109,10 @@ public class GravesLSTMTest extends BaseDL4JTest { .dist(new UniformDistribution(0, 1)).activation(Activation.TANH).build()) .build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM lstm = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); - lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFirstLayer().initializer().numParams(conf))); + GravesLSTM lstm = (GravesLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); + lstm.setBackpropGradientsViewArray(Nd4j.create(1, conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf))); //Set input, do a forward pass: lstm.activate(inputData, false, LayerWorkspaceMgr.noWorkspaces()); assertNotNull(lstm.input()); @@ -160,9 +160,9 @@ public class GravesLSTMTest extends BaseDL4JTest { .activation(Activation.TANH).build()) .build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - GravesLSTM lstm = (GravesLSTM) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + GravesLSTM lstm = (GravesLSTM) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); INDArray input = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); lstm.setInput(input, LayerWorkspaceMgr.noWorkspaces()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java index 690c07f37..f0b23e335 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffConv.java @@ -73,7 +73,7 @@ public class TestSameDiffConv extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Map pt1 = net.getLayer(0).paramTable(); + Map pt1 = net.getLayer(0).getParamTable(); assertNotNull(pt1); assertEquals(2, pt1.size()); assertNotNull(pt1.get(ConvolutionParamInitializer.WEIGHT_KEY)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java index 64d59c84b..60446d43f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java @@ -71,7 +71,7 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - Map pt1 = net.getLayer(0).paramTable(); + Map pt1 = net.getLayer(0).getParamTable(); assertNotNull(pt1); assertEquals(2, pt1.size()); assertNotNull(pt1.get(DefaultParamInitializer.WEIGHT_KEY)); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java index f70c4de92..5e67862ff 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java @@ -104,7 +104,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { //Check params: assertEquals(netStandard.params(), netSD.params()); - assertEquals(netStandard.paramTable(), netSD.paramTable()); + assertEquals(netStandard.getParamTable(), netSD.getParamTable()); INDArray in = Nd4j.rand(minibatch, nIn); INDArray l = TestUtils.randomOneHot(minibatch, nOut, 12345); @@ -159,7 +159,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { netSD.fit(ds); netStandard.fit(ds); - assertEquals(netStandard.paramTable(), netSD.paramTable()); + assertEquals(netStandard.getParamTable(), netSD.getParamTable()); assertEquals(netStandard.params(), netSD.params()); assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java index 3da4abed5..639520492 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java @@ -63,7 +63,7 @@ public class TestVAE extends BaseDL4JTest { .build()) .build(); - LayerConfiguration c = mlc.getFirstLayer(); + LayerConfiguration c = mlc.getFlattenedLayerConfigurations().get(0); org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder vae = (VariationalAutoencoder) c; @@ -78,7 +78,7 @@ public class TestVAE extends BaseDL4JTest { System.out.println("Exp num params: " + expNumParams); assertEquals(expNumParams, net.getLayer(0).params().length()); - Map paramTable = net.getLayer(0).paramTable(); + Map paramTable = net.getLayer(0).getParamTable(); int count = 0; for (INDArray arr : paramTable.values()) { count += arr.length(); @@ -135,7 +135,7 @@ public class TestVAE extends BaseDL4JTest { net.init(); net.initGradientsView(); //TODO this should happen automatically - Map paramTable = net.getLayer(0).paramTable(); + Map paramTable = net.getLayer(0).getParamTable(); Map gradTable = ((org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0)) .getGradientViews(); @@ -175,7 +175,7 @@ public class TestVAE extends BaseDL4JTest { org.deeplearning4j.nn.layers.variational.VariationalAutoencoder layer = (org.deeplearning4j.nn.layers.variational.VariationalAutoencoder) net.getLayer(0); - Map layerParams = layer.paramTable(); + Map layerParams = layer.getParamTable(); Map layerGradViews = layer.getGradientViews(); layer.setInput(Nd4j.rand(3, 10), LayerWorkspaceMgr.noWorkspaces()); @@ -239,7 +239,7 @@ public class TestVAE extends BaseDL4JTest { net.pretrainLayer(0, input); //Get a snapshot of the pretrain params after fitting: - Map layerParams = layer.paramTable(); + Map layerParams = layer.getParamTable(); Map pretrainParamsBefore = new HashMap<>(); for (String s : layerParams.keySet()) { if (layer.isPretrainParam(s)) { @@ -255,7 +255,7 @@ public class TestVAE extends BaseDL4JTest { net.fit(features, labels); } - Map layerParamsAfter = layer.paramTable(); + Map layerParamsAfter = layer.getParamTable(); for (String s : pretrainParamsBefore.keySet()) { INDArray before = pretrainParamsBefore.get(s); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java index 9649adffd..794b45411 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java @@ -104,13 +104,13 @@ public class WorkspaceTests extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf.clone()); net.init(); - net.getConfiguration().setInferenceWorkspaceMode(WorkspaceMode.ENABLED); - net.getConfiguration().setTrainingWorkspaceMode(WorkspaceMode.ENABLED); + net.getNetConfiguration().setInferenceWorkspaceMode(WorkspaceMode.ENABLED); + net.getNetConfiguration().setTrainingWorkspaceMode(WorkspaceMode.ENABLED); MultiLayerNetwork net2 = new MultiLayerNetwork(conf.clone()); net2.init(); - net2.getConfiguration().setInferenceWorkspaceMode(WorkspaceMode.NONE); - net2.getConfiguration().setTrainingWorkspaceMode(WorkspaceMode.NONE); + net2.getNetConfiguration().setInferenceWorkspaceMode(WorkspaceMode.NONE); + net2.getNetConfiguration().setTrainingWorkspaceMode(WorkspaceMode.NONE); INDArray in = Nd4j.rand(1, 2, 5, 5); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index cad0cfd50..4fb1c3fad 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -817,15 +817,15 @@ public class MultiLayerTest extends BaseDL4JTest { DataSetIterator iter = new IrisDataSetIterator(50, 150); - assertEquals(0, network.getConfiguration().getIterationCount()); + assertEquals(0, network.getNetConfiguration().getIterationCount()); network.fit(iter); - assertEquals(3, network.getConfiguration().getIterationCount()); + assertEquals(3, network.getNetConfiguration().getIterationCount()); iter.reset(); network.fit(iter); - assertEquals(6, network.getConfiguration().getIterationCount()); + assertEquals(6, network.getNetConfiguration().getIterationCount()); iter.reset(); network.fit(iter.next()); - assertEquals(7, network.getConfiguration().getIterationCount()); + assertEquals(7, network.getNetConfiguration().getIterationCount()); ByteArrayOutputStream baos = new ByteArrayOutputStream(); ModelSerializer.writeModel(network, baos, true); @@ -833,7 +833,7 @@ public class MultiLayerTest extends BaseDL4JTest { ByteArrayInputStream bais = new ByteArrayInputStream(asBytes); MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(7, net.getConfiguration().getIterationCount()); + assertEquals(7, net.getNetConfiguration().getIterationCount()); } @@ -1072,20 +1072,20 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0, net.getConfiguration().getEpochCount()); + assertEquals(0, net.getNetConfiguration().getEpochCount()); DataSetIterator iter = new IrisDataSetIterator(150, 150); for (int i = 0; i < 4; i++) { - assertEquals(i, net.getConfiguration().getEpochCount()); + assertEquals(i, net.getNetConfiguration().getEpochCount()); net.fit(iter); - assertEquals(i + 1, net.getConfiguration().getEpochCount()); + assertEquals(i + 1, net.getNetConfiguration().getEpochCount()); } - assertEquals(4, net.getConfiguration().getEpochCount()); + assertEquals(4, net.getNetConfiguration().getEpochCount()); MultiLayerNetwork restored = TestUtils.testModelSerialization(net); - assertEquals(4, restored.getConfiguration().getEpochCount()); + assertEquals(4, restored.getNetConfiguration().getEpochCount()); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java index 1a6175cde..99c1c6077 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java @@ -86,7 +86,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { Layer layer = network.getLayer(0); assertTrue(layer instanceof GravesLSTM); - Map paramTable = layer.paramTable(); + Map paramTable = layer.getParamTable(); assertEquals(3, paramTable.size()); //2 sets of weights, 1 set of biases INDArray recurrentWeights = paramTable.get(GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY); @@ -131,7 +131,7 @@ public class MultiLayerTestRNN extends BaseDL4JTest { Layer layer = network.getLayer(i); assertTrue(layer instanceof GravesLSTM); - Map paramTable = layer.paramTable(); + Map paramTable = layer.getParamTable(); assertEquals(3, paramTable.size()); //2 sets of weights, 1 set of biases int layerNIn = (i == 0 ? nIn : nHiddenUnits[i - 1]); @@ -458,9 +458,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest { mlnTBPTT.clearTbpttState = false; - assertEquals(BackpropType.TruncatedBPTT, mlnTBPTT.getConfiguration().getBackpropType()); - assertEquals(timeSeriesLength, mlnTBPTT.getConfiguration().getTbpttFwdLength()); - assertEquals(timeSeriesLength, mlnTBPTT.getConfiguration().getTbpttBackLength()); + assertEquals(BackpropType.TruncatedBPTT, mlnTBPTT.getNetConfiguration().getBackpropType()); + assertEquals(timeSeriesLength, mlnTBPTT.getNetConfiguration().getTbpttFwdLength()); + assertEquals(timeSeriesLength, mlnTBPTT.getNetConfiguration().getTbpttBackLength()); INDArray inputData = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray labels = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java index fe80d1e24..19360abb7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java @@ -124,8 +124,8 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { net2GradUpd.getUpdater().getStateViewArray()); //Remove the next 2 lines: fails - as net 1 is 1 iteration ahead - net1GradCalc.getConfiguration().setIterationCount(0); - net2GradUpd.getConfiguration().setIterationCount(0); + net1GradCalc.getNetConfiguration().setIterationCount(0); + net2GradUpd.getNetConfiguration().setIterationCount(0); for (int i = 0; i < 100; i++) { net1GradCalc.fit(f, l); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java index d35b46911..5c5fb204e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestFrozenLayers.java @@ -127,7 +127,7 @@ public class TestFrozenLayers extends BaseDL4JTest { } Map paramsBefore = new LinkedHashMap<>(); - for(Map.Entry entry : transfer.paramTable().entrySet()){ + for(Map.Entry entry : transfer.getParamTable().entrySet()){ paramsBefore.put(entry.getKey(), entry.getValue().dup()); } @@ -137,7 +137,7 @@ public class TestFrozenLayers extends BaseDL4JTest { transfer.fit(new INDArray[]{f},new INDArray[]{l}); } - for(Map.Entry entry : transfer.paramTable().entrySet()){ + for(Map.Entry entry : transfer.getParamTable().entrySet()){ String s = msg + " - " + entry.getKey(); if(entry.getKey().startsWith("5_")){ //Non-frozen layer diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java index b328c8dff..6d6ce41c0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TestTransferLearningModelSerializer.java @@ -70,9 +70,9 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest { assertTrue(withFrozen.getLayer(0) instanceof FrozenLayer); assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer); - assertTrue(withFrozen.getConfiguration().getConf(0) + assertTrue(withFrozen.getNetConfiguration().getConf(0) .getLayer() instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); - assertTrue(withFrozen.getConfiguration().getConf(1) + assertTrue(withFrozen.getNetConfiguration().getConf(1) .getLayer() instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); MultiLayerNetwork restored = TestUtils.testModelSerialization(withFrozen); @@ -120,8 +120,8 @@ public class TestTransferLearningModelSerializer extends BaseDL4JTest { assertTrue(withFrozen.getLayer(1) instanceof FrozenLayer); Map m = withFrozen.getComputationGraphConfiguration().getVertices(); - LayerConfiguration l0 = ((LayerVertex) m.get("0")).getNetConfiguration().getFirstLayer(); - LayerConfiguration l1 = ((LayerVertex) m.get("1")).getNetConfiguration().getFirstLayer(); + LayerConfiguration l0 = ((LayerVertex) m.get("0")).getLayerConfiguration(); + LayerConfiguration l1 = ((LayerVertex) m.get("1")).getLayerConfiguration(); assertTrue(l0 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); assertTrue(l1 instanceof org.deeplearning4j.nn.conf.layers.misc.FrozenLayer); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java index 0f75f1426..195ee2f6d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java @@ -605,13 +605,13 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { cg2.output(arr); - Map m = new HashMap<>(cg.paramTable()); + Map m = new HashMap<>(cg.getParamTable()); m.put("newOut_W", m.remove("out_W")); m.put("newOut_b", m.remove("out_b")); cg2.setParamTable(m); - Map p1 = cg.paramTable(); - Map p2 = cg2.paramTable(); + Map p1 = cg.getParamTable(); + Map p2 = cg2.getParamTable(); for(String s : p1.keySet()){ INDArray i1 = p1.get(s); INDArray i2 = p2.get(s.replaceAll("out", "newOut")); @@ -651,13 +651,13 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { cg2.output(arr); - Map m = new HashMap<>(cg.paramTable()); + Map m = new HashMap<>(cg.getParamTable()); m.put("newOut_W", m.remove("out_W")); m.put("newOut_b", m.remove("out_b")); cg2.setParamTable(m); - Map p1 = cg.paramTable(); - Map p2 = cg2.paramTable(); + Map p1 = cg.getParamTable(); + Map p2 = cg2.getParamTable(); for(String s : p1.keySet()){ INDArray i1 = p1.get(s); INDArray i2 = p2.get(s.replaceAll("out", "newOut")); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java index cda7da0b4..f33c48738 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java @@ -112,8 +112,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest { assertEquals(expectedModel.params(), modelNow.params()); //Check json - NeuralNetConfiguration expectedConf = expectedModel.getConfiguration(); - assertEquals(expectedConf.toJson(), modelNow.getConfiguration().toJson()); + NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration(); + assertEquals(expectedConf.toJson(), modelNow.getNetConfiguration().toJson()); //Check params after fit modelNow.fit(randomData); @@ -160,9 +160,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest { //Will fail - expected because of dist and weight init changes //assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); - BaseLayer bl0 = ((BaseLayer) modelNow.getConfiguration().getConf(0).getLayer()); - BaseLayer bl1 = ((BaseLayer) modelNow.getConfiguration().getConf(1).getLayer()); - BaseLayer bl3 = ((BaseLayer) modelNow.getConfiguration().getConf(3).getLayer()); + BaseLayer bl0 = ((BaseLayer) modelNow.getNetConfiguration().getConf(0).getLayer()); + BaseLayer bl1 = ((BaseLayer) modelNow.getNetConfiguration().getConf(1).getLayer()); + BaseLayer bl3 = ((BaseLayer) modelNow.getNetConfiguration().getConf(3).getLayer()); assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class); try { assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), @@ -357,18 +357,18 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .setInputPreProcessor(4, new FeedForwardToRnnPreProcessor()).build(); //modelNow should have the same architecture as modelExpectedArch - assertEquals(modelExpectedArch.getConfiguration().getConf(0).toJson(), - modelNow.getConfiguration().getConf(0).toJson()); + assertEquals(modelExpectedArch.getNetConfiguration().getConf(0).toJson(), + modelNow.getNetConfiguration().getConf(0).toJson()); //some learning related info the subsampling layer will not be overwritten //assertTrue(modelExpectedArch.getConfiguration().getConf(1).toJson().equals(modelNow.getConfiguration().getConf(1).toJson())); - assertEquals(modelExpectedArch.getConfiguration().getConf(2).toJson(), - modelNow.getConfiguration().getConf(2).toJson()); - assertEquals(modelExpectedArch.getConfiguration().getConf(3).toJson(), - modelNow.getConfiguration().getConf(3).toJson()); - assertEquals(modelExpectedArch.getConfiguration().getConf(4).toJson(), - modelNow.getConfiguration().getConf(4).toJson()); - assertEquals(modelExpectedArch.getConfiguration().getConf(5).toJson(), - modelNow.getConfiguration().getConf(5).toJson()); + assertEquals(modelExpectedArch.getNetConfiguration().getConf(2).toJson(), + modelNow.getNetConfiguration().getConf(2).toJson()); + assertEquals(modelExpectedArch.getNetConfiguration().getConf(3).toJson(), + modelNow.getNetConfiguration().getConf(3).toJson()); + assertEquals(modelExpectedArch.getNetConfiguration().getConf(4).toJson(), + modelNow.getNetConfiguration().getConf(4).toJson()); + assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(), + modelNow.getNetConfiguration().getConf(5).toJson()); assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); @@ -530,7 +530,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); assertEquals(0.2, TestUtils.getL2(l1), 1e-6); - assertEquals(BackpropType.TruncatedBPTT, net2.getConfiguration().getBackpropType()); + assertEquals(BackpropType.TruncatedBPTT, net2.getNetConfiguration().getBackpropType()); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java index 63c936b17..54d3a3174 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestGradientNormalization.java @@ -52,9 +52,9 @@ public class TestGradientNormalization extends BaseDL4JTest { .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).build()) .build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)), @@ -98,9 +98,9 @@ public class TestGradientNormalization extends BaseDL4JTest { .gradientNormalization(GradientNormalization.RenormalizeL2PerParamType).build()) .build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(params.shape())); Updater updater = UpdaterCreator.getUpdater(layer); INDArray weightGrad = Nd4j.rand(10, 20); @@ -131,9 +131,9 @@ public class TestGradientNormalization extends BaseDL4JTest { .gradientNormalizationThreshold(threshold).build()) .build(); - long numParams = conf.getFirstLayer().initializer().numParams(conf); + long numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(10).subi(5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = Shape.newShapeNoCopy(gradArray.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 200)), @@ -187,9 +187,9 @@ public class TestGradientNormalization extends BaseDL4JTest { .gradientNormalizationThreshold(threshold).build()) .build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); INDArray gradArray = Nd4j.rand(1, 220).muli(t == 0 ? 0.05 : 10).subi(t == 0 ? 0 : 5); layer.setBackpropGradientsViewArray(gradArray); INDArray weightGrad = @@ -242,9 +242,9 @@ public class TestGradientNormalization extends BaseDL4JTest { .gradientNormalizationThreshold(threshold).build()) .build(); - val numParams = conf.getFirstLayer().initializer().numParams(conf); + val numParams = conf.getFlattenedLayerConfigurations().get(0).initializer().numParams(conf); INDArray params = Nd4j.create(1, numParams); - Layer layer = conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); + Layer layer = conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(Nd4j.create(params.shape())); Updater updater = UpdaterCreator.getUpdater(layer); INDArray weightGrad = Nd4j.rand(10, 20).muli(0.05); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index 7753fae33..69afb6330 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -20,6 +20,7 @@ package org.deeplearning4j.optimize.solver; +import lombok.NonNull; import lombok.val; import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.BaseDL4JTest; @@ -44,6 +45,7 @@ import org.deeplearning4j.optimize.solvers.LineGradientDescent; import org.deeplearning4j.optimize.solvers.StochasticGradientDescent; import org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction; import org.junit.jupiter.api.Test; +import org.nd4j.evaluation.IEvaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -52,7 +54,9 @@ import org.nd4j.linalg.api.ops.impl.transforms.strict.Sin; import org.nd4j.linalg.api.rng.DefaultRandom; import org.nd4j.linalg.api.rng.Random; import org.nd4j.linalg.dataset.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; import org.nd4j.linalg.exception.ND4JArraySizeException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.conditions.Condition; @@ -317,6 +321,90 @@ public class TestOptimizers extends BaseDL4JTest { } + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return null; + } + + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + @Override + public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) { + + } + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + @Override + public void fit(MultiDataSet dataSet) { + + } + + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + @Override + public void fit(DataSetIterator iterator) { + + } + + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + @Override + public void fit(MultiDataSetIterator iterator) { + + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(DataSetIterator iterator, + T... evaluations) { + return null; + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, + T... evaluations) { + return null; + } + + /** + * @param netConfiguration + */ + @Override + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { + + } + @Override public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { // Gradients: d(x^2)/dx = 2x @@ -464,6 +552,90 @@ public class TestOptimizers extends BaseDL4JTest { } + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return null; + } + + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + @Override + public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) { + + } + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + @Override + public void fit(MultiDataSet dataSet) { + + } + + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + @Override + public void fit(DataSetIterator iterator) { + + } + + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + @Override + public void fit(MultiDataSetIterator iterator) { + + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(DataSetIterator iterator, + T... evaluations) { + return null; + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, + T... evaluations) { + return null; + } + + /** + * @param netConfiguration + */ + @Override + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { + + } + @Override public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { //Gradient decomposes due to sum, so: @@ -649,6 +821,90 @@ public class TestOptimizers extends BaseDL4JTest { return dist.sample(new int[] {1, nDimensions}); } + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return null; + } + + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + @Override + public void fit(org.nd4j.linalg.dataset.api.DataSet dataSet) { + + } + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + @Override + public void fit(MultiDataSet dataSet) { + + } + + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + @Override + public void fit(DataSetIterator iterator) { + + } + + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + @Override + public void fit(MultiDataSetIterator iterator) { + + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(DataSetIterator iterator, + T... evaluations) { + return null; + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, + T... evaluations) { + return null; + } + + /** + * @param netConfiguration + */ + @Override + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { + + } + @Override public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { val nDims = parameters.length(); @@ -912,7 +1168,7 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public void setLayerConfiguration(NeuralNetConfiguration layerConfiguration) { + public void setLayerConfiguration(LayerConfiguration layerConfiguration) { throw new UnsupportedOperationException(); } @@ -934,13 +1190,13 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public Map paramTable() { + public Map getParamTable() { return Collections.singletonMap("W", getParam("W")); } @Override - public Map paramTable(boolean backpropParamsOnly) { - return paramTable(); + public Map getParamTable(boolean backpropParamsOnly) { + return getParamTable(); } @Override diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index 50c177332..df6f1e0cb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -65,7 +65,7 @@ public class RegressionTest050 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); @@ -99,7 +99,7 @@ public class RegressionTest050 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); @@ -138,7 +138,7 @@ public class RegressionTest050 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(3, conf.getNetConfigurations().size()); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index 9b0870b0f..d6c88b4d3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -67,7 +67,7 @@ public class RegressionTest060 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); @@ -101,7 +101,7 @@ public class RegressionTest060 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); @@ -144,7 +144,7 @@ public class RegressionTest060 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(3, conf.getNetConfigurations().size()); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); @@ -190,7 +190,7 @@ public class RegressionTest060 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(3, conf.getNetConfigurations().size()); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index e21f75680..bf14dba46 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -68,7 +68,7 @@ public class RegressionTest071 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); @@ -102,7 +102,7 @@ public class RegressionTest071 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); @@ -145,7 +145,7 @@ public class RegressionTest071 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(3, conf.getNetConfigurations().size()); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); @@ -191,7 +191,7 @@ public class RegressionTest071 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(3, conf.getNetConfigurations().size()); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index 06af06ff4..4cc26f05a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -67,7 +67,7 @@ public class RegressionTest080 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); @@ -106,7 +106,7 @@ public class RegressionTest080 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(2, conf.getNetConfigurations().size()); DenseLayer l0 = (DenseLayer) conf.getConf(0).getLayer(); @@ -155,7 +155,7 @@ public class RegressionTest080 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(3, conf.getNetConfigurations().size()); ConvolutionLayer l0 = (ConvolutionLayer) conf.getConf(0).getLayer(); @@ -206,7 +206,7 @@ public class RegressionTest080 extends BaseDL4JTest { MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(f, true); - NeuralNetConfiguration conf = net.getConfiguration(); + NeuralNetConfiguration conf = net.getNetConfiguration(); assertEquals(3, conf.getNetConfigurations().size()); GravesLSTM l0 = (GravesLSTM) conf.getConf(0).getLayer(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index a847a85ef..6b6558c48 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -107,9 +107,9 @@ public class RegressionTest100a extends BaseDL4JTest { assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new RmsProp(0.1), l0.getIUpdater()); - assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); - assertEquals(50, net.getConfiguration().getTbpttBackLength()); - assertEquals(50, net.getConfiguration().getTbpttFwdLength()); + assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType()); + assertEquals(50, net.getNetConfiguration().getTbpttBackLength()); + assertEquals(50, net.getNetConfiguration().getTbpttFwdLength()); INDArray outExp; File f2 = Resources.asFile("regression_testing/100a/GravesLSTMCharModelingExample_Output_100a.bin"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index 23ae5d5bd..6d73c1074 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -108,7 +108,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { List activations = net.feedForward(in); - assertEquals(dt, net.getConfiguration().getDataType()); + assertEquals(dt, net.getNetConfiguration().getDataType()); assertEquals(dt, net.params().dataType()); assertEquals( outExp, outAct, dtype); } @@ -142,9 +142,9 @@ public class RegressionTest100b3 extends BaseDL4JTest { assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); - assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); - assertEquals(50, net.getConfiguration().getTbpttBackLength()); - assertEquals(50, net.getConfiguration().getTbpttFwdLength()); + assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType()); + assertEquals(50, net.getNetConfiguration().getTbpttBackLength()); + assertEquals(50, net.getNetConfiguration().getTbpttFwdLength()); INDArray outExp; File f2 = Resources.asFile("regression_testing/100b3/GravesLSTMCharModelingExample_Output_100b3.bin"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index fbbe55592..bd2f231d2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -125,7 +125,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { INDArray outAct = net.output(in); assertEquals(dtype, outAct.dataType()); - assertEquals(dtype, net.getConfiguration().getDataType()); + assertEquals(dtype, net.getNetConfiguration().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct); @@ -160,9 +160,9 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new Adam(0.005), l2.getIUpdater()); - assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); - assertEquals(50, net.getConfiguration().getTbpttBackLength()); - assertEquals(50, net.getConfiguration().getTbpttFwdLength()); + assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType()); + assertEquals(50, net.getNetConfiguration().getTbpttBackLength()); + assertEquals(50, net.getNetConfiguration().getTbpttFwdLength()); INDArray outExp; File f2 = Resources.asFile("regression_testing/100b4/GravesLSTMCharModelingExample_Output_100b4.bin"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index 979518196..bf13cff1b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -107,7 +107,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { INDArray outAct = net.output(in); assertEquals(dtype, outAct.dataType()); - assertEquals(dtype, net.getConfiguration().getDataType()); + assertEquals(dtype, net.getNetConfiguration().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct); @@ -142,9 +142,9 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new Adam(0.005), l2.getIUpdater()); - assertEquals(BackpropType.TruncatedBPTT, net.getConfiguration().getBackpropType()); - assertEquals(50, net.getConfiguration().getTbpttBackLength()); - assertEquals(50, net.getConfiguration().getTbpttFwdLength()); + assertEquals(BackpropType.TruncatedBPTT, net.getNetConfiguration().getBackpropType()); + assertEquals(50, net.getNetConfiguration().getTbpttBackLength()); + assertEquals(50, net.getNetConfiguration().getTbpttFwdLength()); INDArray outExp; File f2 = Resources.asFile("regression_testing/100b6/GravesLSTMCharModelingExample_Output_100b6.bin"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java index b8b3cdad6..72b55f9e6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.params.DefaultParamInitializer; @@ -68,11 +69,13 @@ public class CustomLayer extends FeedForwardLayer { @Override public Layer instantiate(NeuralNetConfiguration conf, Collection iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { + + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(0); //The instantiate method is how we go from the configuration class (i.e., this class) to the implementation class // (i.e., a CustomLayerImpl instance) //For the most part, it's the same for each type of layer - CustomLayerImpl myCustomLayer = new CustomLayerImpl(conf, networkDataType); + CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType); myCustomLayer.setListeners(iterationListeners); //Set the iteration listeners, if any myCustomLayer.setIndex(layerIndex); //Integer index of the layer @@ -87,7 +90,7 @@ public class CustomLayer extends FeedForwardLayer { // are in turn a view of the 'layerParamsView' array. Map paramTable = initializer().init(this, layerParamsView, initializeParams); myCustomLayer.setParamTable(paramTable); - myCustomLayer.setLayerConfiguration(conf); + myCustomLayer.setLayerConfiguration(lconf); return myCustomLayer; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java index 42b91d908..d233a5da3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java @@ -21,6 +21,7 @@ package org.deeplearning4j.regressiontest.customlayer100a; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.BaseLayer; @@ -35,7 +36,7 @@ import org.nd4j.common.primitives.Pair; public class CustomLayerImpl extends BaseLayer { //Generic parameter here: the configuration class type - public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { + public CustomLayerImpl(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } @@ -56,7 +57,7 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); IActivation activation1 = layerConf().getActivationFn(); - IActivation activation2 = ((CustomLayer) layerConfiguration.getFirstLayer()).getSecondActivationFunction(); + IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction(); //IActivation function instances modify the activation functions in-place activation1.getActivation(firstHalf, training); @@ -105,7 +106,7 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); IActivation activation1 = layerConf().getActivationFn(); - IActivation activation2 = ((CustomLayer) layerConfiguration.getFirstLayer()).getSecondActivationFunction(); + IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction(); //IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz activation1.backprop(firstHalf, epsilonFirstHalf); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java index 02a1fdaf5..4415c5455 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java @@ -155,7 +155,7 @@ public class ModelGuesserTest extends BaseDL4JTest { ModelSerializer.writeModel(net, tempFile, true); MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); - assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); + assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); @@ -172,7 +172,7 @@ public class ModelGuesserTest extends BaseDL4JTest { try (InputStream inputStream = new FileInputStream(tempFile)) { MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream); Assertions.assertNotNull(network); - assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); + assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java index 9f52ae300..5124e15ac 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java @@ -80,7 +80,7 @@ public class ModelSerializerTest extends BaseDL4JTest { MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); - assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); + assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -124,7 +124,7 @@ public class ModelSerializerTest extends BaseDL4JTest { MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis); - assertEquals(network.getConfiguration().toJson(), net.getConfiguration().toJson()); + assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); assertEquals(net.params(), network.params()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java index 20721371b..02c478093 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasModelImportTest.java @@ -24,7 +24,6 @@ import java.util.List; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; @@ -57,7 +56,7 @@ public class KerasModelImportTest extends BaseDL4JTest { @Test public void testNCHWNWHCChangeImport() { MultiLayerNetwork model = loadModel("modelimport/keras/weights/conv2dnchw/simpleconv2d.hdf5"); - List layerConfigs = model.getConfiguration().getFlattenedLayerConfigurations(); + List layerConfigs = model.getNetConfiguration().getFlattenedLayerConfigurations(); ConvolutionLayer convolutionLayer = (ConvolutionLayer) layerConfigs.get(0); assertEquals(CNN2DFormat.NCHW,convolutionLayer.getCnn2dDataFormat()); SubsamplingLayer subsamplingLayer = (SubsamplingLayer) layerConfigs.get(1); diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java index 1d7144b20..49237b20c 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java @@ -208,7 +208,7 @@ public class Word2VecTestsSmall extends BaseDL4JTest { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(net.getConfiguration(), restored.getConfiguration()); + assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); assertTrue(net.params().equalsWithEps(restored.params(), 2e-3)); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java index 2c31319fc..9f81fd3d8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java @@ -46,6 +46,15 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; public interface IModel { + /** + * The param table + * + * @return + */ + + Map getParamTable(); + Map getParamTable(boolean backpropOnly); + /** * This method returns updater state (if applicable), null otherwise * @@ -273,6 +282,7 @@ public interface IModel { * @param listeners new listeners */ void setListeners(TrainingListener... listeners); + void setListeners(Collection listeners); /** * Add TrainingListeners to the model diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index 5c221222c..cb87e885c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -1126,6 +1126,17 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { return getFlattenedLayerConfigurations().get(index); } + /** + * Deprecated, do not use. Workaround for old tests + * and getFlattenedLayerConfigurations().get(0); + * @return + */ + @Deprecated + public LayerConfiguration getFirstLayer() { + log.warn("This getFirstLayer method is an ugly workaround and will be removed."); + return getFlattenedLayerConfigurations().get(0); + } + public static abstract class NeuralNetConfigurationBuilder> extends NeuralNetBaseBuilderConfigurationBuilder { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java index a41870c3d..bb98be57d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java @@ -258,6 +258,9 @@ public abstract class LayerConfiguration implements TrainingConfig, Serializable "Not supported: all layers with parameters should override this method"); } + @Getter + private IUpdater iUpdater; + @Override public void setDataType(DataType dataType) { //No-op for most layers diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 7da7f837c..575ee27e9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -2443,6 +2443,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } + /** + * @param listeners + */ + @Override + public void setListeners(Collection listeners) { + setListeners(listeners.toArray(new TrainingListener[]{})); + } + /** * @deprecated Use {@link #getListeners()} */ @@ -4525,4 +4533,5 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial public String toString() { return getNetConfiguration().toString(); } + } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java index a7b4a98bc..73261f155 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java @@ -94,13 +94,13 @@ public class EarlyStoppingParallelTrainer implements IEarlySto Collection listeners = ((MultiLayerNetwork) model).getListeners(); Collection newListeners = new LinkedList<>(listeners); newListeners.add(trainerListener); - model.setListeners(newListeners); + model.setListeners(newListeners.toArray(new TrainingListener[]{})); } else if (model instanceof ComputationGraph) { Collection listeners = ((ComputationGraph) model).getListeners(); Collection newListeners = new LinkedList<>(listeners); newListeners.add(trainerListener); - model.setListeners(newListeners); + model.setListeners(newListeners.toArray(new TrainingListener[]{})); } this.wrapper = new ParallelWrapper.Builder<>(model).workers(workers).prefetchBuffer(prefetchBuffer) diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java index 33009e994..0c1515109 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java @@ -204,7 +204,7 @@ public class InplaceParallelInference extends ParallelInference { if (loadBalanceMode == LoadBalanceMode.FIFO) queue.add(model); } else if (sourceModel instanceof MultiLayerNetwork) { - val model = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(((MultiLayerNetwork) sourceModel).getConfiguration().toJson())); + val model = new MultiLayerNetwork(NeuralNetConfiguration.fromJson(((MultiLayerNetwork) sourceModel).getNetConfiguration().toJson())); model.init(params, false); Nd4j.getExecutioner().commit(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java index ea2e02ad7..9d2c76a23 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java @@ -472,7 +472,7 @@ public class ParallelInference { } else if (protoModel instanceof MultiLayerNetwork) { if (!rootDevice) { this.replicatedModel = new MultiLayerNetwork(NeuralNetConfiguration.fromJson( - ((MultiLayerNetwork) protoModel).getConfiguration().toJson())); + ((MultiLayerNetwork) protoModel).getNetConfiguration().toJson())); this.replicatedModel.init(); synchronized (locker) { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java index e2a621508..5a880872d 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java @@ -957,10 +957,10 @@ public class ParallelWrapper implements AutoCloseable { List modelListeners = null; if (model instanceof MultiLayerNetwork) { modelListeners = new ArrayList<>(((MultiLayerNetwork) model).getListeners()); - model.setListeners(Collections.emptyList()); + model.setListeners(new TrainingListener[]{}); } else if (model instanceof ComputationGraph) { modelListeners = new ArrayList<>(((ComputationGraph) model).getListeners()); - model.setListeners(Collections.emptyList()); + model.setListeners(new TrainingListener[]{}); } if (modelListeners != null && !modelListeners.isEmpty()) { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java index dd7cda946..2a1cf4d4e 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java @@ -278,7 +278,7 @@ public class DefaultTrainer extends Thread implements Trainer { } configureListeners(uuid, oldListeners, replicatedListeners); - this.replicatedModel.setListeners(replicatedListeners); + this.replicatedModel.setListeners(replicatedListeners.toArray(new TrainingListener[]{})); } @Override @@ -296,7 +296,7 @@ public class DefaultTrainer extends Thread implements Trainer { if (originalModel instanceof MultiLayerNetwork) { if (!onRootModel) { NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson( - ((MultiLayerNetwork) originalModel).getConfiguration().toJson()); + ((MultiLayerNetwork) originalModel).getNetConfiguration().toJson()); conf.setTrainingWorkspaceMode(workspaceMode); this.replicatedModel = new MultiLayerNetwork(conf); @@ -323,7 +323,7 @@ public class DefaultTrainer extends Thread implements Trainer { if (!((MultiLayerNetwork) replicatedModel).isInitCalled()) this.replicatedModel.init(); - ((MultiLayerNetwork) replicatedModel).getConfiguration() + ((MultiLayerNetwork) replicatedModel).getNetConfiguration() .setTrainingWorkspaceMode(workspaceMode); } } else if (originalModel instanceof ComputationGraph) { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java index 2e50414da..2a0c7b655 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java @@ -122,7 +122,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { public SparkDl4jMultiLayer(JavaSparkContext javaSparkContext, MultiLayerNetwork network, TrainingMaster trainingMaster) { sc = javaSparkContext; - this.conf = network.getConfiguration().clone(); + this.conf = network.getNetConfiguration().clone(); this.network = network; if (!network.isInitCalled()) network.init(); @@ -315,8 +315,8 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @return the multi layer network that was fitDataSet */ public MultiLayerNetwork fitLabeledPoint(JavaRDD rdd) { - int nLayers = network.getConfiguration().getFlattenedLayerConfigurations().size(); - FeedForwardLayer ffl = (FeedForwardLayer) network.getConfiguration().getFlattenedLayerConfigurations().get(nLayers - 1); + int nLayers = network.getNetConfiguration().getFlattenedLayerConfigurations().size(); + FeedForwardLayer ffl = (FeedForwardLayer) network.getNetConfiguration().getFlattenedLayerConfigurations().get(nLayers - 1); JavaRDD ds = MLLibUtil.fromLabeledPoint(sc, rdd, ffl.getNOut()); return fit(ds); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java index 1dc1d4f1b..d3fb3355f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java @@ -275,7 +275,7 @@ public class ParameterAveragingTrainingMaster @Override public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) { - NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getConfiguration(), + NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getNetConfiguration(), network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray()); if (collectTrainingStats) @@ -727,7 +727,7 @@ public class ParameterAveragingTrainingMaster if (params != null) { //Params may be null for edge case (empty RDD) if (network != null) { - NeuralNetConfiguration conf = network.getNetwork().getConfiguration(); + NeuralNetConfiguration conf = network.getNetwork().getNetConfiguration(); int numUpdates = averagingFrequency; conf.setIterationCount(conf.getIterationCount() + numUpdates); } else { diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java index 4820e938f..2322ba5c2 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java @@ -172,9 +172,9 @@ public class ParameterAveragingTrainingWorker extends BaseTrainingWorker trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { - CustomLayerImpl ret = new CustomLayerImpl(conf, networkDataType); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType); ret.setListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setLayerConfiguration(conf); + ret.setLayerConfiguration(lconf); return ret; } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java index 55b32d1dc..610f4079c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayerImpl.java @@ -21,11 +21,12 @@ package org.deeplearning4j.spark.impl.customlayer.layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.layers.BaseLayer; import org.nd4j.linalg.api.buffer.DataType; public class CustomLayerImpl extends BaseLayer { - public CustomLayerImpl(NeuralNetConfiguration conf, DataType dataType) { + public CustomLayerImpl(LayerConfiguration conf, DataType dataType) { super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java index 2e01cc17d..688135888 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/misc/TestFrozenLayers.java @@ -154,7 +154,7 @@ public class TestFrozenLayers extends BaseSparkTest { ComputationGraph withFrozen = new TransferLearning.GraphBuilder(origModel).fineTuneConfiguration(finetune) .setFeatureExtractor("1").build(); - Map m = withFrozen.paramTable(); + Map m = withFrozen.getParamTable(); Map pCopy = new HashMap<>(); for (Map.Entry entry : m.entrySet()) { pCopy.put(entry.getKey(), entry.getValue().dup()); @@ -190,7 +190,7 @@ public class TestFrozenLayers extends BaseSparkTest { ComputationGraph fitted = sNet.getNetwork(); - Map fittedParams = fitted.paramTable(); + Map fittedParams = fitted.getParamTable(); for (Map.Entry entry : fittedParams.entrySet()) { INDArray orig = pCopy.get(entry.getKey()); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index 8376638f3..42fc1112c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -784,13 +784,13 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { JavaRDD rdd = sc.parallelize(list); - assertEquals(0, sparkNet.getNetwork().getConfiguration().getIterationCount()); + assertEquals(0, sparkNet.getNetwork().getNetConfiguration().getIterationCount()); sparkNet.fit(rdd); assertEquals(minibatchesPerWorkerPerEpoch, - sparkNet.getNetwork().getConfiguration().getIterationCount()); + sparkNet.getNetwork().getNetConfiguration().getIterationCount()); sparkNet.fit(rdd); assertEquals(2 * minibatchesPerWorkerPerEpoch, - sparkNet.getNetwork().getConfiguration().getIterationCount()); + sparkNet.getNetwork().getNetConfiguration().getIterationCount()); sparkNet.getTrainingMaster().deleteTempFiles(sc); } @@ -1074,11 +1074,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { for(int i=0; i<3; i++ ){ - assertEquals(i, sn1.getNetwork().getConfiguration().getEpochCount()); + assertEquals(i, sn1.getNetwork().getNetConfiguration().getEpochCount()); assertEquals(i, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount()); sn1.fit(rdd); sn2.fit(rdd); - assertEquals(i+1, sn1.getNetwork().getConfiguration().getEpochCount()); + assertEquals(i+1, sn1.getNetwork().getNetConfiguration().getEpochCount()); assertEquals(i+1, sn2.getNetwork().getComputationGraphConfiguration().getEpochCount()); } } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java index 0265837bd..5bb21442c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java @@ -239,7 +239,7 @@ public class SharedTrainingWrapper { List listeners = worker.getListeners(); if(listeners != null){ - model.setListeners(listeners); + model.setListeners(listeners.toArray(new TrainingListener[]{})); StatsStorageRouter r = worker.getRouter(); if(r != null){ for(TrainingListener l : listeners){ @@ -425,7 +425,7 @@ public class SharedTrainingWrapper { .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); ((ComputationGraph) originalModel).setGradientsAccumulator(accumulator); } else if (model instanceof MultiLayerNetwork) { - ((MultiLayerNetwork) originalModel).getConfiguration() + ((MultiLayerNetwork) originalModel).getNetConfiguration() .setTrainingWorkspaceMode(trainingConfiguration.getWorkspaceMode()); ((MultiLayerNetwork) originalModel).setGradientsAccumulator(accumulator); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java index ef252470b..bb291c0b8 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java @@ -262,7 +262,7 @@ public class SharedTrainingMaster extends BaseTrainingMaster T[] doEvaluation(DataSetIterator iterator, T... evaluations) { + return null; + } + + /** + * This method executes evaluation of the model against given iterator and evaluation + * implementations + * + * @param iterator + * @param evaluations + */ + @Override + public T[] doEvaluation(MultiDataSetIterator iterator, + T... evaluations) { + return null; + } + @Override public INDArray getParam(String param) { return null; } @Override - public void addListeners(TrainingListener... listener) { - // no-op + public void addListeners(TrainingListener... listener) {//no op } - @Override - public Map paramTable() { + public Map getParamTable() { return null; } - @Override - public Map paramTable(boolean backprapParamsOnly) { + public Map getParamTable(boolean backprapParamsOnly) { return null; } - @Override + public void setParamTable(Map paramTable) { } @@ -490,7 +569,7 @@ public class BarnesHutTsne implements IModel { * * @param listeners */ - @Override + public void setListeners(Collection listeners) { } @@ -901,8 +980,15 @@ public class BarnesHutTsne implements IModel { return null; } + /** + * @param netConfiguration + */ @Override - public void setLayerConfiguration(NeuralNetConfiguration layerConfiguration) { + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { + + } + + public void setLayerConfiguration(LayerConfiguration layerConfiguration) { } @@ -1060,4 +1146,14 @@ public class BarnesHutTsne implements IModel { public void close(){ //No-op } + + /** + * Get the TrainingListeners + * + * @return training listener + */ + @Override + public Collection getListeners() { + return null; + } } diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java index b9a7e985d..3797b6550 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java @@ -30,6 +30,7 @@ import org.deeplearning4j.core.storage.StorageMetaData; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -426,10 +427,10 @@ public abstract class BaseStatsListener implements RoutingIterationListener { //Need to append "0_", "1_" etc to param names from layers... int layerIdx = 0; for (Layer l : ((MultiLayerNetwork) model).getLayers()) { - NeuralNetConfiguration conf = l.getNetConfiguration(); + LayerConfiguration conf = l.getLayerConfiguration(); List paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration()); for (String s : paramkeys) { - double lr = conf.getFirstLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); + double lr = conf.getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); if (Double.isNaN(lr)) { //Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate lr = 0.0; @@ -440,11 +441,11 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } } else if (model instanceof ComputationGraph) { for (Layer l : ((ComputationGraph) model).getLayers()) { - NeuralNetConfiguration conf = l.getNetConfiguration(); - String layerName = conf.getFirstLayer().getLayerName(); + LayerConfiguration conf = l.getLayerConfiguration(); + String layerName = conf.getLayerName(); List paramkeys = l.getLayerConfiguration().initializer().paramKeys(l.getLayerConfiguration()); for (String s : paramkeys) { - double lr = conf.getFirstLayer().getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); + double lr = conf.getUpdaterByParam(s).getLearningRate(l.getIterationCount(), l.getEpochCount()); if (Double.isNaN(lr)) { //Edge case: No-Op updater, AdaDelta etc - don't have a LR hence return NaN for IUpdater.getLearningRate lr = 0.0; @@ -467,7 +468,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { //--- Histograms --- if (updateConfig.collectHistograms(StatsType.Parameters)) { - Map paramHistograms = getHistograms(model.paramTable(backpropParamsOnly), + Map paramHistograms = getHistograms(model.getParamTable(backpropParamsOnly), updateConfig.numHistogramBins(StatsType.Parameters)); report.reportHistograms(StatsType.Parameters, paramHistograms); } @@ -490,7 +491,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { //--- Summary Stats: Mean, Variance, Mean Magnitudes --- if (updateConfig.collectMean(StatsType.Parameters)) { - Map meanParams = calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Mean); + Map meanParams = calculateSummaryStats(model.getParamTable(backpropParamsOnly), StatType.Mean); report.reportMean(StatsType.Parameters, meanParams); } @@ -511,7 +512,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { if (updateConfig.collectStdev(StatsType.Parameters)) { Map stdevParams = - calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.Stdev); + calculateSummaryStats(model.getParamTable(backpropParamsOnly), StatType.Stdev); report.reportStdev(StatsType.Parameters, stdevParams); } @@ -532,7 +533,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { if (updateConfig.collectMeanMagnitudes(StatsType.Parameters)) { Map meanMagParams = - calculateSummaryStats(model.paramTable(backpropParamsOnly), StatType.MeanMagnitude); + calculateSummaryStats(model.getParamTable(backpropParamsOnly), StatType.MeanMagnitude); report.reportMeanMagnitudes(StatsType.Parameters, meanMagParams); } @@ -652,7 +653,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { long numParams; if (model instanceof MultiLayerNetwork) { MultiLayerNetwork net = ((MultiLayerNetwork) model); - jsonConf = net.getConfiguration().toJson(); + jsonConf = net.getNetConfiguration().toJson(); numLayers = net.getnLayers(); numParams = net.numParams(); } else if (model instanceof ComputationGraph) { @@ -670,7 +671,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { + (model == null ? null : model.getClass())); } - Map paramMap = model.paramTable(backpropParamsOnly); + Map paramMap = model.getParamTable(backpropParamsOnly); String[] paramNames = new String[paramMap.size()]; int i = 0; for (String s : paramMap.keySet()) { //Assuming sensible iteration order - LinkedHashMaps are used in MLN/CG for example diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 7e384dec5..2ca083a4d 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -1129,8 +1129,8 @@ public class TrainModule implements UIModule { NeuralNetConfiguration conf = NeuralNetConfiguration.fromJson(configJson); int confIdx = layerIdx - 1; //-1 because of input if (confIdx >= 0) { - nnc = conf.getNetConfigurations().get(confIdx); - layer = nnc.getFirstLayer(); + layer = conf.getFlattenedLayerConfigurations().get(confIdx); + nnc = layer.getNetConfiguration(); } else { //Input layer layerType = "Input"; @@ -1144,7 +1144,7 @@ public class TrainModule implements UIModule { if (vertices.containsKey(vertexName) && vertices.get(vertexName) instanceof LayerVertex) { LayerVertex lv = (LayerVertex) vertices.get(vertexName); nnc = lv.getNetConfiguration(); - layer = nnc.getFirstLayer(); + layer = lv.getLayerConfiguration(); } else if (conf.getNetworkInputs().contains(vertexName)) { layerType = "Input"; } else { @@ -1177,7 +1177,7 @@ public class TrainModule implements UIModule { if (layer instanceof BaseLayer) { BaseLayer bl = (BaseLayer) layer; activationFn = bl.getActivationFn().toString(); - long nParams = layer.initializer().numParams(nnc.getFirstLayer()); + long nParams = layer.initializer().numParams(bl.getLayer()); layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerNParams"), String.valueOf(nParams)}); if (nParams > 0) { diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java index 34b6563f1..aebfaffa7 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModuleUtils.java @@ -62,24 +62,24 @@ public class TrainModuleUtils { layerInfo.add(Collections.emptyMap()); - List list = config.getNetConfigurations(); + List list = config.getFlattenedLayerConfigurations(); int layerIdx = 1; - for (NeuralNetConfiguration c : list) { - LayerConfiguration layer = c.getFirstLayer(); + for (LayerConfiguration c : list) { + LayerConfiguration layer = c; String layerName = layer.getLayerName(); if (layerName == null) layerName = "layer" + layerIdx; vertexNames.add(layerName); originalVertexName.add(String.valueOf(layerIdx - 1)); - String layerType = c.getFirstLayer().getClass().getSimpleName().replaceAll("Layer$", ""); + String layerType = c.getClass().getSimpleName().replaceAll("Layer$", ""); layerTypes.add(layerType); layerInputs.add(Collections.singletonList(layerIdx - 1)); layerIdx++; //Extract layer info - Map map = getLayerInfo(c, layer); + Map map = getLayerInfo(c.getNetConfiguration(), layer); layerInfo.add(map); } diff --git a/cavis-ui/cavis-ui-vertx/src/main/resources/templates/SameDiffUI.html b/cavis-ui/cavis-ui-vertx/src/main/resources/templates/SameDiffUI.html index 951aabeb5..2ecadd3ee 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/resources/templates/SameDiffUI.html +++ b/cavis-ui/cavis-ui-vertx/src/main/resources/templates/SameDiffUI.html @@ -143,7 +143,7 @@ + Spread

diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java index 240cabfcc..7a26046cd 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java @@ -45,7 +45,7 @@ public class TestUtils { ByteArrayInputStream bais = new ByteArrayInputStream(bytes); MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - assertEquals(net.getConfiguration(), restored.getConfiguration()); + assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); assertEquals(net.params(), restored.params()); return restored; From 0f21ed9ec584ccdc3a57e17f28aa6a3b5bbf3713 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Apr 2023 07:31:32 +0200 Subject: [PATCH 123/126] Playing with some new code 2 - clean build/test Signed-off-by: brian --- .../src/test/java/net/brutex/gan/App.java | 22 +- .../src/test/java/net/brutex/gan/GAN.java | 14 +- .../net/brutex/gan/MnistDCGANExample.java | 4 +- .../java/net/brutex/spark/TestServer2.java | 2 +- .../IntegrationTestBaselineGenerator.java | 8 +- .../integration/IntegrationTestRunner.java | 18 +- .../deeplearning4j/integration/TestUtils.java | 4 +- .../activations/impl/ActivationIdentity.java | 2 +- .../linalg/workspace/BaseWorkspaceMgr.java | 2 +- cavis-dnn/cavis-dnn-core/build.gradle | 1 + .../LayerHelperValidationUtil.java | 14 +- .../java/org/deeplearning4j/TestUtils.java | 18 +- .../iterator/DataSetIteratorTest.java | 5 +- .../earlystopping/TestEarlyStopping.java | 24 +- .../TestEarlyStoppingCompGraph.java | 12 +- .../org/deeplearning4j/eval/EvalTest.java | 8 +- .../gradientcheck/BNGradientCheckTest.java | 13 +- .../gradientcheck/CNNGradientCheckTest.java | 9 +- .../gradientcheck/GradientCheckTests.java | 24 +- .../LossFunctionGradientCheck.java | 2 +- .../MultiLayerNeuralNetConfigurationTest.java | 23 +- .../nn/conf/NeuralNetConfigurationTest.java | 4 +- .../nn/conf/graph/ShiftVertexTest.java | 2 +- .../nn/conf/layers/LayerConfigTest.java | 101 +- .../layers/LayerConfigValidationTest.java | 12 +- .../nn/conf/weightnoise/TestWeightNoise.java | 14 +- .../deeplearning4j/nn/dtypes/DTypeTests.java | 40 +- .../nn/graph/ComputationGraphTestRNN.java | 4 +- .../nn/graph/TestCompGraphCNN.java | 4 +- .../nn/graph/TestCompGraphUnsupervised.java | 6 +- .../nn/graph/TestComputationGraphNetwork.java | 24 +- .../nn/graph/TestSetGetParameters.java | 10 +- .../nn/graph/TestVariableLengthTSCG.java | 14 +- ...t.java => BaseLayerConfigurationTest.java} | 2 +- .../nn/layers/CacheModeTest.java | 12 +- .../nn/layers/CenterLossOutputLayerTest.java | 4 +- .../nn/layers/DropoutLayerTest.java | 4 +- .../nn/layers/FrozenLayerTest.java | 46 +- .../layers/FrozenLayerWithBackpropTest.java | 84 +- .../nn/layers/OutputLayerTest.java | 22 +- .../deeplearning4j/nn/layers/SeedTest.java | 8 +- .../convolution/ConvDataFormatTests.java | 12 +- .../convolution/ConvolutionLayerTest.java | 6 +- .../nn/layers/custom/TestCustomLayers.java | 4 +- .../custom/testclasses/CustomLayer.java | 2 +- .../custom/testclasses/CustomOutputLayer.java | 2 +- .../layers/feedforward/dense/DenseTest.java | 4 +- .../embedding/EmbeddingLayerTest.java | 30 +- .../objdetect/TestYolo2OutputLayer.java | 2 +- .../nn/layers/ocnn/OCNNOutputLayerTest.java | 10 +- .../layers/recurrent/BidirectionalTest.java | 20 +- .../GravesBidirectionalLSTMTest.java | 2 +- .../layers/recurrent/RnnDataFormatTests.java | 12 +- .../nn/layers/recurrent/TestRnnLayers.java | 7 +- .../layers/recurrent/TestTimeDistributed.java | 2 +- .../nn/layers/samediff/TestSameDiffDense.java | 18 +- .../samediff/TestSameDiffDenseVertex.java | 6 +- .../layers/samediff/TestSameDiffLambda.java | 8 +- .../layers/samediff/TestSameDiffOutput.java | 6 +- .../nn/layers/variational/TestVAE.java | 2 +- .../nn/misc/CloseNetworkTests.java | 4 +- .../deeplearning4j/nn/misc/LargeNetTest.java | 4 +- .../deeplearning4j/nn/misc/TestLrChanges.java | 28 +- .../nn/misc/TestNetConversion.java | 4 +- .../nn/misc/WorkspaceTests.java | 2 +- .../nn/multilayer/BackPropMLPTest.java | 2 +- .../nn/multilayer/MultiLayerTest.java | 46 +- .../nn/multilayer/MultiLayerTestRNN.java | 4 +- .../nn/multilayer/TestMasking.java | 8 +- .../nn/multilayer/TestSetGetParameters.java | 26 +- .../nn/multilayer/TestVariableLengthTS.java | 14 +- .../rl/TestMultiModelGradientApplication.java | 20 +- .../TransferLearningCompGraphTest.java | 65 +- .../TransferLearningComplex.java | 23 +- .../TransferLearningHelperTest.java | 30 +- .../TransferLearningMLNTest.java | 124 +- .../nn/updater/TestUpdaters.java | 56 +- .../nn/updater/custom/TestCustomUpdater.java | 18 +- .../solver/BackTrackLineSearchTest.java | 34 +- .../optimize/solver/TestOptimizers.java | 96 +- .../listener/TestCheckpointListener.java | 14 +- .../listener/TestFailureListener.java | 6 +- .../optimizer/listener/TestListeners.java | 20 +- .../parallelism/RandomTests.java | 6 +- .../listener/TestSystemInfoPrintListener.java | 2 +- .../regressiontest/RegressionTest050.java | 6 +- .../regressiontest/RegressionTest060.java | 6 +- .../regressiontest/RegressionTest071.java | 6 +- .../regressiontest/RegressionTest080.java | 6 +- .../regressiontest/RegressionTest100b3.java | 4 +- .../regressiontest/RegressionTest100b4.java | 4 +- .../regressiontest/RegressionTest100b6.java | 4 +- .../customlayer100a/CustomLayer.java | 2 +- .../customlayer100a/CustomLayerImpl.java | 11 +- .../CompareTrainingImplementations.java | 2 +- .../util/CrashReportingUtilTest.java | 4 +- .../deeplearning4j/util/ModelGuesserTest.java | 4 +- .../util/ModelSerializerTest.java | 14 +- .../modelimport/keras/layers/TFOpLayer.java | 10 +- .../keras/layers/recurrent/KerasLSTM.java | 6 +- .../layers/recurrent/KerasSimpleRnn.java | 6 +- .../nn/modelimport/keras/KerasTestUtils.java | 6 +- .../keras/e2e/KerasModelEndToEndTest.java | 4 +- .../models/word2vec/Word2VecTestsSmall.java | 2 +- .../ai/dnn/api/ILayerConfiguration.java | 38 +- .../java/net/brutex/ai/dnn/api/IModel.java | 41 +- .../dnn/api/INeuralNetworkConfiguration.java | 6 + .../main/java/net/brutex/ai/dnn/api/NN.java | 1 + .../ai/dnn/conf/layer/Layer_Descriptions.md | 31 + .../dnn/networks/ArtificialNeuralNetwork.java | 58 + .../trainer/BaseEarlyStoppingTrainer.java | 10 +- .../gradientcheck/GradientCheckUtil.java | 41 +- .../{Trainable.java => ITrainableLayer.java} | 25 +- ...va => ITraininableLayerConfiguration.java} | 5 +- .../java/org/deeplearning4j/nn/api/Layer.java | 11 +- .../org/deeplearning4j/nn/api/Updater.java | 4 +- .../conf/ComputationGraphConfiguration.java | 12 +- .../NeuralNetBaseBuilderConfiguration.java | 8 +- .../nn/conf/NeuralNetConfiguration.java | 64 +- .../nn/conf/graph/LayerVertex.java | 4 +- .../nn/conf/layers/ActivationLayer.java | 8 +- .../nn/conf/layers/AutoEncoder.java | 2 +- ...Layer.java => BaseLayerConfiguration.java} | 33 +- .../nn/conf/layers/BatchNormalization.java | 2 +- .../nn/conf/layers/CenterLossOutputLayer.java | 2 +- .../nn/conf/layers/Cnn3DLossLayer.java | 2 +- .../nn/conf/layers/CnnLossLayer.java | 2 +- .../nn/conf/layers/Convolution1DLayer.java | 2 +- .../nn/conf/layers/Convolution3D.java | 2 +- .../nn/conf/layers/ConvolutionLayer.java | 2 +- .../nn/conf/layers/Deconvolution2D.java | 2 +- .../nn/conf/layers/Deconvolution3D.java | 2 +- .../nn/conf/layers/DenseLayer.java | 3 +- .../conf/layers/DepthwiseConvolution2D.java | 2 +- .../nn/conf/layers/DropoutLayer.java | 2 +- .../nn/conf/layers/EmbeddingLayer.java | 2 +- .../conf/layers/EmbeddingSequenceLayer.java | 2 +- .../nn/conf/layers/FeedForwardLayer.java | 5 +- .../nn/conf/layers/GlobalPoolingLayer.java | 2 +- .../conf/layers/GravesBidirectionalLSTM.java | 2 +- .../nn/conf/layers/GravesLSTM.java | 2 +- .../deeplearning4j/nn/conf/layers/LSTM.java | 2 +- .../nn/conf/layers/LayerConfiguration.java | 34 +- .../nn/conf/layers/LayerValidation.java | 10 +- .../layers/LocalResponseNormalization.java | 13 +- .../nn/conf/layers/LossLayer.java | 2 +- .../nn/conf/layers/NoParamLayer.java | 21 +- .../nn/conf/layers/OutputLayer.java | 2 +- .../nn/conf/layers/PReLULayer.java | 4 +- .../nn/conf/layers/RnnLossLayer.java | 2 +- .../nn/conf/layers/RnnOutputLayer.java | 2 +- .../conf/layers/SeparableConvolution2D.java | 2 +- .../nn/conf/layers/SpaceToBatchLayer.java | 2 +- .../nn/conf/layers/SpaceToDepthLayer.java | 2 +- .../nn/conf/layers/Subsampling1DLayer.java | 2 +- .../nn/conf/layers/Subsampling3DLayer.java | 2 +- .../nn/conf/layers/SubsamplingLayer.java | 2 +- .../nn/conf/layers/Upsampling1D.java | 2 +- .../nn/conf/layers/Upsampling2D.java | 2 +- .../nn/conf/layers/Upsampling3D.java | 2 +- .../nn/conf/layers/ZeroPadding1DLayer.java | 2 +- .../nn/conf/layers/ZeroPadding3DLayer.java | 2 +- .../nn/conf/layers/ZeroPaddingLayer.java | 2 +- .../conf/layers/convolutional/Cropping1D.java | 2 +- .../conf/layers/convolutional/Cropping2D.java | 2 +- .../conf/layers/convolutional/Cropping3D.java | 2 +- .../misc/ElementWiseMultiplicationLayer.java | 2 +- .../nn/conf/layers/misc/FrozenLayer.java | 10 - .../layers/misc/FrozenLayerWithBackprop.java | 15 +- .../nn/conf/layers/misc/RepeatVector.java | 2 +- .../layers/objdetect/Yolo2OutputLayer.java | 14 +- .../conf/layers/recurrent/Bidirectional.java | 16 +- .../conf/layers/recurrent/LastTimeStep.java | 4 +- .../nn/conf/layers/recurrent/SimpleRnn.java | 2 +- .../layers/recurrent/TimeDistributed.java | 4 +- .../conf/layers/samediff/SameDiffVertex.java | 4 +- .../nn/conf/layers/util/MaskZeroLayer.java | 6 +- .../variational/VariationalAutoencoder.java | 2 +- .../conf/layers/wrapper/BaseWrapperLayer.java | 123 - .../BaseWrapperLayerConfiguration.java | 196 ++ .../nn/conf/misc/DummyConfig.java | 4 +- .../nn/conf/ocnn/OCNNOutputLayer.java | 8 +- .../conf/serde/BaseNetConfigDeserializer.java | 43 +- ...utationGraphConfigurationDeserializer.java | 31 +- .../NeuralNetConfigurationDeserializer.java | 27 +- .../nn/graph/ComputationGraph.java | 81 +- .../nn/graph/vertex/BaseGraphVertex.java | 8 +- .../nn/graph/vertex/BaseWrapperVertex.java | 10 +- .../nn/graph/vertex/GraphVertex.java | 4 +- .../nn/graph/vertex/impl/FrozenVertex.java | 9 +- .../nn/graph/vertex/impl/LayerVertex.java | 10 +- .../nn/layers/AbstractLayer.java | 591 ++-- .../nn/layers/ActivationLayer.java | 27 +- .../deeplearning4j/nn/layers/BaseLayer.java | 487 ++- .../nn/layers/BaseOutputLayer.java | 17 +- .../nn/layers/BasePretrainNetwork.java | 9 +- .../nn/layers/DropoutLayer.java | 7 +- .../deeplearning4j/nn/layers/FrozenLayer.java | 12 +- .../nn/layers/FrozenLayerWithBackprop.java | 12 +- .../deeplearning4j/nn/layers/LossLayer.java | 20 +- .../nn/layers/RepeatVector.java | 13 +- .../nn/layers/convolution/Cnn3DLossLayer.java | 47 +- .../nn/layers/convolution/CnnLossLayer.java | 27 +- .../convolution/Convolution1DLayer.java | 27 +- .../convolution/Convolution3DLayer.java | 11 +- .../layers/convolution/ConvolutionLayer.java | 75 +- .../layers/convolution/Cropping1DLayer.java | 4 + .../layers/convolution/Cropping2DLayer.java | 3 +- .../convolution/Deconvolution2DLayer.java | 35 +- .../convolution/Deconvolution3DLayer.java | 49 +- .../DepthwiseConvolution2DLayer.java | 35 +- .../SeparableConvolution2DLayer.java | 39 +- .../nn/layers/convolution/SpaceToBatch.java | 21 +- .../nn/layers/convolution/SpaceToDepth.java | 15 +- .../layers/convolution/ZeroPaddingLayer.java | 11 +- .../subsampling/Subsampling1DLayer.java | 7 +- .../subsampling/Subsampling3DLayer.java | 33 +- .../subsampling/SubsamplingLayer.java | 75 +- .../convolution/upsampling/Upsampling1D.java | 5 +- .../convolution/upsampling/Upsampling2D.java | 13 +- .../convolution/upsampling/Upsampling3D.java | 13 +- .../nn/layers/feedforward/PReLU.java | 3 +- .../feedforward/autoencoder/AutoEncoder.java | 11 +- .../layers/feedforward/dense/DenseLayer.java | 9 +- .../ElementWiseMultiplicationLayer.java | 5 +- .../feedforward/embedding/EmbeddingLayer.java | 9 +- .../embedding/EmbeddingSequenceLayer.java | 31 +- .../normalization/BatchNormalization.java | 79 +- .../LocalResponseNormalization.java | 37 +- .../nn/layers/objdetect/Yolo2OutputLayer.java | 43 +- .../nn/layers/ocnn/OCNNOutputLayer.java | 23 +- .../nn/layers/pooling/GlobalPoolingLayer.java | 9 +- .../layers/recurrent/BaseRecurrentLayer.java | 5 +- .../layers/recurrent/BidirectionalLayer.java | 58 +- .../recurrent/GravesBidirectionalLSTM.java | 23 +- .../nn/layers/recurrent/GravesLSTM.java | 9 +- .../nn/layers/recurrent/LSTM.java | 13 +- .../nn/layers/recurrent/LSTMHelpers.java | 7 +- .../nn/layers/recurrent/RnnLossLayer.java | 27 +- .../nn/layers/recurrent/RnnOutputLayer.java | 21 +- .../nn/layers/recurrent/SimpleRnn.java | 11 +- .../layers/samediff/SameDiffGraphVertex.java | 6 +- .../nn/layers/samediff/SameDiffLayer.java | 19 +- .../layers/samediff/SameDiffOutputLayer.java | 21 +- .../training/CenterLossOutputLayer.java | 35 +- .../variational/VariationalAutoencoder.java | 191 +- .../nn/layers/wrapper/BaseWrapperLayer.java | 119 +- .../nn/multilayer/MultiLayerNetwork.java | 2639 +++++++++-------- .../nn/params/PReLUParamInitializer.java | 6 +- .../params/WrapperLayerParamInitializer.java | 8 +- .../FineTuneConfiguration.java | 8 +- .../nn/transferlearning/TransferLearning.java | 16 +- .../TransferLearningHelper.java | 8 +- .../nn/updater/BaseMultiLayerUpdater.java | 38 +- .../nn/updater/LayerUpdater.java | 8 +- .../nn/updater/MultiLayerUpdater.java | 8 +- .../nn/updater/UpdaterBlock.java | 21 +- .../nn/updater/UpdaterCreator.java | 9 +- .../nn/updater/UpdaterUtils.java | 10 +- .../graph/ComputationGraphUpdater.java | 16 +- .../CollectScoresIterationListener.java | 2 +- .../listeners/CollectScoresListener.java | 2 +- .../listeners/PerformanceListener.java | 2 +- .../listeners/ScoreIterationListener.java | 2 +- .../listeners/ScoreToChartListener.java | 2 +- .../optimize/solvers/BackTrackLineSearch.java | 4 +- .../optimize/solvers/BaseOptimizer.java | 2 +- .../optimize/solvers/LBFGS.java | 2 +- .../solvers/StochasticGradientDescent.java | 6 +- .../EncodedGradientsAccumulator.java | 2 +- .../util/CrashReportingUtil.java | 8 +- .../deeplearning4j/util/ModelSerializer.java | 4 +- .../org/deeplearning4j/util/NetworkUtils.java | 24 +- .../deeplearning4j/util/OutputLayerUtil.java | 7 +- .../main/resources/simplelogger.properties | 5 +- .../java/net/brutex/ai/dnn/api/dnnTest.java | 86 +- .../ParameterServerTrainer.java | 4 +- .../EarlyStoppingParallelTrainer.java | 10 +- .../parallelism/InplaceParallelInference.java | 2 +- .../parallelism/ParallelInference.java | 4 +- .../parallelism/ParallelWrapper.java | 12 +- .../factory/SymmetricTrainerContext.java | 2 +- .../parallelism/trainer/DefaultTrainer.java | 14 +- .../InplaceParallelInferenceTest.java | 6 +- .../parallelism/ParallelInferenceTest.java | 4 +- .../parallelism/ParallelWrapperTest.java | 3 +- .../parallelism/TestListeners.java | 4 +- .../TestParallelEarlyStopping.java | 4 +- .../TestParallelEarlyStoppingUI.java | 2 +- .../impl/graph/SparkComputationGraph.java | 16 +- .../impl/multilayer/SparkDl4jMultiLayer.java | 10 +- .../ParameterAveragingTrainingMaster.java | 4 +- .../ParameterAveragingTrainingWorker.java | 8 +- .../spark/TestEarlyStoppingSpark.java | 10 +- .../TestEarlyStoppingSparkCompGraph.java | 10 +- .../impl/customlayer/layer/CustomLayer.java | 2 +- .../impl/graph/TestSparkComputationGraph.java | 7 +- .../impl/multilayer/TestMiscFunctions.java | 4 +- ...arameterAveragingSparkVsSingleMachine.java | 56 +- ...TestSparkMultiLayerParameterAveraging.java | 40 +- .../pw/SharedTrainingWrapper.java | 10 +- .../training/SharedTrainingMaster.java | 6 +- .../train/GradientSharingTrainingTest.java | 12 +- .../deeplearning4j/plot/BarnesHutTsne.java | 113 +- .../org/deeplearning4j/ui/ManualTests.java | 7 +- .../ui/weights/TestConvolutionalListener.java | 4 +- .../ui/model/stats/BaseStatsListener.java | 3 +- .../ui/stats/TestStatsListener.java | 4 +- .../ui/stats/TestTransferStatsCollection.java | 2 +- .../ui/module/train/TrainModule.java | 4 +- .../deeplearning4j/ui/TestRemoteReceiver.java | 2 +- .../org/deeplearning4j/ui/TestVertxUI.java | 10 +- .../deeplearning4j/ui/TestVertxUIManual.java | 4 +- .../ui/TestVertxUIMultiSession.java | 4 +- .../org/deeplearning4j/zoo/TestImageNet.java | 6 +- .../deeplearning4j/zoo/TestInstantiation.java | 4 +- .../org/deeplearning4j/zoo/TestUtils.java | 4 +- 317 files changed, 4528 insertions(+), 4191 deletions(-) rename cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/{BaseLayerTest.java => BaseLayerConfigurationTest.java} (98%) create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/Layer_Descriptions.md rename cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/{Trainable.java => ITrainableLayer.java} (90%) rename cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/{TrainingConfig.java => ITraininableLayerConfiguration.java} (96%) rename cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/{BaseLayer.java => BaseLayerConfiguration.java} (96%) delete mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java create mode 100644 cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayerConfiguration.java diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index c03d9f5c2..aba07ef0d 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -47,6 +47,7 @@ import org.datavec.image.transform.ShowImageTransform; import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -54,9 +55,11 @@ import org.deeplearning4j.nn.conf.layers.DropoutLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop; +import org.deeplearning4j.nn.conf.weightnoise.WeightNoise; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitXavier; +import org.deeplearning4j.optimize.listeners.PerformanceListener; import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; @@ -181,6 +184,7 @@ public class App { .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold( 100 ) //.weightInitFn( new WeightInitXavier() ) //this is internal + .weightNoise(new WeightNoise(new NormalDistribution(0.5, 0.5))) .weightInit( WeightInit.XAVIER) //.activationFn( new ActivationIdentity()) //this is internal .activation( Activation.IDENTITY ) @@ -232,10 +236,10 @@ public class App { copyParams(gen, dis, gan); - //gen.setListeners(new PerformanceListener(10, true)); - //dis.setListeners(new PerformanceListener(10, true)); - //gan.setListeners(new PerformanceListener(10, true)); - gan.setListeners(new ScoreToChartListener("gan")); + gen.addTrainingListeners(new PerformanceListener(10, true)); + dis.addTrainingListeners(new PerformanceListener(10, true)); + gan.addTrainingListeners(new PerformanceListener(10, true)); + gan.addTrainingListeners(new ScoreToChartListener("gan")); //dis.setListeners(new ScoreToChartListener("dis")); gan.fit(Nd4j.rand(batchSize, CHANNELS, X_DIM, Y_DIM), Nd4j.zeros(batchSize, 1)); @@ -322,23 +326,25 @@ public class App { int genLayerCount = gen.getLayers().length; for (int i = 0; i < gan.getLayers().length; i++) { if (i < genLayerCount) { - gen.getLayer(i).setParams(gan.getLayer(i).params()); + if(gan.getLayer(i).getParams() != null) + gen.getLayer(i).setParams(gan.getLayer(i).getParams()); } else { - dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params()); + if(gan.getLayer(i).getParams() != null) + dis.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams()); } } } private static void updateGen(MultiLayerNetwork gen, MultiLayerNetwork gan) { for (int i = 0; i < gen.getLayers().length; i++) { - gen.getLayer(i).setParams(gan.getLayer(i).params()); + gen.getLayer(i).setParams(gan.getLayer(i).getParams()); } } private static void updateGan(MultiLayerNetwork gen, MultiLayerNetwork dis, MultiLayerNetwork gan) { int genLayerCount = gen.getLayers().length; for (int i = genLayerCount; i < gan.getLayers().length; i++) { - gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).params()); + gan.getLayer(i).setParams(dis.getLayer(i - genLayerCount).getParams()); } } diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java index b1e780d59..41eb277d7 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/GAN.java @@ -115,15 +115,15 @@ public class GAN { public void setGeneratorListeners(BaseTrainingListener[] listeners) { - generator.setListeners(listeners); + generator.addTrainingListeners(listeners); } public void setDiscriminatorListeners(BaseTrainingListener[] listeners) { - discriminator.setListeners(listeners); + discriminator.addTrainingListeners(listeners); } public void setGanListeners(BaseTrainingListener[] listeners) { - gan.setListeners(listeners); + gan.addTrainingListeners(listeners); } public void fit(DataSetIterator realData, int numEpochs) { @@ -239,9 +239,9 @@ public class GAN { int genLayerCount = generator.getLayers().length; for (int i = 0; i < gan.getLayers().length; i++) { if (i < genLayerCount) { - generator.getLayer(i).setParams(gan.getLayer(i).params()); + generator.getLayer(i).setParams(gan.getLayer(i).getParams()); } else { - discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).params()); + discriminator.getLayer(i - genLayerCount).setParams(gan.getLayer(i).getParams()); } } } @@ -252,7 +252,7 @@ public class GAN { */ private void updateGeneratorFromGan() { for (int i = 0; i < generator.getLayers().length; i++) { - generator.getLayer(i).setParams(gan.getLayer(i).params()); + generator.getLayer(i).setParams(gan.getLayer(i).getParams()); } } @@ -263,7 +263,7 @@ public class GAN { private void updateGanWithDiscriminator() { int genLayerCount = generator.getLayers().length; for (int i = genLayerCount; i < gan.getLayers().length; i++) { - gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).params()); + gan.getLayer(i).setParams(discriminator.getLayer(i - genLayerCount).getParams()); } } diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java index 07e6a148a..4dd171fea 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/MnistDCGANExample.java @@ -155,8 +155,8 @@ public class MnistDCGANExample { .updater(new RmsProp.Builder().learningRate(0.0008).rmsDecay(1e-8).build()) .build(); - gan.getGenerator().setListeners(new PerformanceListener(1, true)); - gan.getDiscriminator().setListeners(new PerformanceListener(1, true)); + gan.getGenerator().addTrainingListeners(new PerformanceListener(1, true)); + gan.getDiscriminator().addTrainingListeners(new PerformanceListener(1, true)); Nd4j.getMemoryManager().setAutoGcWindow(15 * 1000); diff --git a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java index c2d6f739c..db8a74ae7 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java +++ b/brutex-extended-tests/src/test/java/net/brutex/spark/TestServer2.java @@ -205,7 +205,7 @@ public class TestServer2 { //PostgresStatsStorage psqlStore = new PostgresStatsStorage(); int listenerFrequency = 2; //net.setListeners(new StatsListener(psqlStore, listenerFrequency), new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200)); - net.setListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200)); + net.addTrainingListeners(new StatsListener(statsStorage, listenerFrequency), new ScoreIterationListener(200)); //Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java index 0842ebfd4..8775bfc2e 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestBaselineGenerator.java @@ -290,7 +290,7 @@ public class IntegrationTestBaselineGenerator { for (int i : layersToTrain) { mln.pretrainLayer(i, dsi); } - paramsPostTraining = mln.params(); + paramsPostTraining = mln.getModelParams(); } else if (modelType == ModelType.CG) { String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); Preconditions.checkState(layersToTrain != null, "ILayer names must not be null"); @@ -298,7 +298,7 @@ public class IntegrationTestBaselineGenerator { for (String i : layersToTrain) { cg.pretrainLayer(i, iter); } - paramsPostTraining = cg.params(); + paramsPostTraining = cg.getModelParams(); } else { throw new UnsupportedOperationException("SameDiff not supported for unsupervised training tests"); } @@ -314,7 +314,7 @@ public class IntegrationTestBaselineGenerator { CollectScoresListener l = new CollectScoresListener(1); if (modelType != ModelType.SAMEDIFF) - m.setListeners(l); + m.addTrainingListeners(l); History h = null; if (modelType == ModelType.MLN) { @@ -349,7 +349,7 @@ public class IntegrationTestBaselineGenerator { } } else { File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME); - IntegrationTestRunner.write(m.params(), p); + IntegrationTestRunner.write(m.getModelParams(), p); } } } diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java index e68751c1b..786c6d6b9 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/IntegrationTestRunner.java @@ -191,7 +191,7 @@ public class IntegrationTestRunner { MultiLayerNetwork loaded = MultiLayerNetwork.load(savedModel, true); assertEquals(loaded.getNetConfiguration(), mln.getNetConfiguration(), "Configs not equal"); - assertEquals( loaded.params(), mln.params(), "Params not equal"); + assertEquals( loaded.getModelParams(), mln.getModelParams(), "Params not equal"); assertEquals( loaded.getParamTable(), mln.getParamTable(), "Param table not equal"); } else if(config instanceof ComputationGraphConfiguration ){ ComputationGraphConfiguration cgc = (ComputationGraphConfiguration) config; @@ -201,7 +201,7 @@ public class IntegrationTestRunner { ComputationGraph loaded = ComputationGraph.load(savedModel, true); assertEquals(loaded.getComputationGraphConfiguration(), cg.getComputationGraphConfiguration(), "Configs not equal" ); - assertEquals( loaded.params(), cg.params(), "Params not equal"); + assertEquals( loaded.getModelParams(), cg.getModelParams(), "Params not equal"); assertEquals(loaded.getParamTable(), cg.getParamTable(), "Param table not equal"); } else if(config instanceof SameDiff){ sd = (SameDiff)config; @@ -389,7 +389,7 @@ public class IntegrationTestRunner { for( int i : layersToTrain){ mln.pretrainLayer(i, dsi); } - paramsPostTraining = mln.params(); + paramsPostTraining = mln.getModelParams(); layers = mln.getLayers(); } else if(modelType == ModelType.CG) { String[] layersToTrain = tc.getUnsupervisedTrainLayersCG(); @@ -398,7 +398,7 @@ public class IntegrationTestRunner { for( String i : layersToTrain){ cg.pretrainLayer(i, iter); } - paramsPostTraining = cg.params(); + paramsPostTraining = cg.getModelParams(); layers = cg.getLayers(); } else { throw new UnsupportedOperationException("Unsupported layerwise pretraining not supported for SameDiff models"); @@ -439,7 +439,7 @@ public class IntegrationTestRunner { CountingMultiDataSetIterator countingIter = new CountingMultiDataSetIterator(trainData, isTbptt, tbpttLength); CollectScoresListener l = new CollectScoresListener(1); if(modelType != ModelType.SAMEDIFF) { - m.setListeners(l); + m.addTrainingListeners(l); } int iterBefore; @@ -519,10 +519,10 @@ public class IntegrationTestRunner { if(modelType != ModelType.SAMEDIFF) { File p = new File(testBaseDir, IntegrationTestRunner.PARAMS_POST_TRAIN_FILENAME); INDArray paramsExp = read(p); - INDArray z = exceedsRelError(m.params(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining()); + INDArray z = exceedsRelError(m.getModelParams(), paramsExp, tc.getMaxRelativeErrorParamsPostTraining(), tc.getMinAbsErrorParamsPostTraining()); int count = z.sumNumber().intValue(); if (count > 0) { - logFailedParams(20, "Parameter", layers, z, paramsExp, m.params()); + logFailedParams(20, "Parameter", layers, z, paramsExp, m.getModelParams()); } assertEquals( 0, count, "Number of params exceeded max relative error"); } else { @@ -607,12 +607,12 @@ public class IntegrationTestRunner { ModelSerializer.writeModel(m, f, true); MultiLayerNetwork restored = MultiLayerNetwork.load(f, true); assertEquals(mln.getNetConfiguration(), restored.getNetConfiguration()); - assertEquals(mln.params(), restored.params()); + assertEquals(mln.getModelParams(), restored.getModelParams()); } else if(modelType == ModelType.CG){ ModelSerializer.writeModel(m, f, true); ComputationGraph restored = ComputationGraph.load(f, true); assertEquals(cg.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); - assertEquals(cg.params(), restored.params()); + assertEquals(cg.getModelParams(), restored.getModelParams()); } else { sd.save(f, true); SameDiff restored = SameDiff.load(f, true); diff --git a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java index 5bdae5d39..60e314d71 100644 --- a/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java +++ b/brutex-extended-tests/src/test/java/org/deeplearning4j/integration/TestUtils.java @@ -49,7 +49,7 @@ public class TestUtils { restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); - assertEquals(net.params(), restored.params()); + assertEquals(net.getModelParams(), restored.getModelParams()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); @@ -74,7 +74,7 @@ public class TestUtils { restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); - assertEquals(net.params(), restored.params()); + assertEquals(net.getModelParams(), restored.getModelParams()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationIdentity.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationIdentity.java index 46124c636..0a2c48fee 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationIdentity.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationIdentity.java @@ -26,7 +26,7 @@ import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; -/** +/** The ActivationIdentity activation function, just returns the input as is. * f(x) = x */ @EqualsAndHashCode(callSuper = false) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java index 9baf97578..a0f45a6d1 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/BaseWorkspaceMgr.java @@ -195,7 +195,7 @@ public abstract class BaseWorkspaceMgr> implements WorkspaceMg } @Override - public INDArray validateArrayLocation(@NonNull T arrayType, @NonNull INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) { + public INDArray validateArrayLocation(T arrayType, INDArray array, boolean migrateIfInvalid, boolean exceptionIfDetached) { validateConfig(arrayType); if(scopeOutOfWs.contains(arrayType)){ diff --git a/cavis-dnn/cavis-dnn-core/build.gradle b/cavis-dnn/cavis-dnn-core/build.gradle index 18c322532..e40b8482f 100644 --- a/cavis-dnn/cavis-dnn-core/build.gradle +++ b/cavis-dnn/cavis-dnn-core/build.gradle @@ -19,6 +19,7 @@ dependencies { testImplementation projects.cavisNative.cavisNativeCommon testImplementation projects.cavisNd4j.cavisNd4jCommonTests testImplementation projects.cavisDnn.cavisDnnCommonTests + testImplementation projects.cavisDnn.cavisDnnNn implementation "org.apache.commons:commons-lang3" diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java index db11f8cc7..308b7c7ad 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/LayerHelperValidationUtil.java @@ -116,7 +116,7 @@ public class LayerHelperValidationUtil { MultiLayerNetwork net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net2With.init(); - net2With.params().assign(netOrig.params()); + net2With.getModelParams().assign(netOrig.getModelParams()); log.info("Removing all except for specified helpers from network copy 2: " + t.getAllowHelpersForClasses()); removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses()); @@ -124,7 +124,7 @@ public class LayerHelperValidationUtil { Preconditions.checkNotNull(t.getFeatures(), "Features are not set (null)"); for (boolean train : new boolean[]{false, true}) { - assertEquals(net1NoHelper.params(), net2With.params()); + assertEquals(net1NoHelper.getModelParams(), net2With.getModelParams()); String s = "Feed forward test - " + t.getTestName() + " - " + (train ? "Train: " : "Test: "); List ff1; try { @@ -180,7 +180,7 @@ public class LayerHelperValidationUtil { double maxRE = relError.maxNumber().doubleValue(); log.info(s + "Output, max relative error: " + maxRE); - assertEquals(net1NoHelper.params(), net2With.params()); //Check that forward pass does not modify params + assertEquals(net1NoHelper.getModelParams(), net2With.getModelParams()); //Check that forward pass does not modify params assertTrue(maxRE < t.getMaxRelError(), s + "Max RE: " + maxRE); } } @@ -255,24 +255,24 @@ public class LayerHelperValidationUtil { net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net2With.init(); - net2With.params().assign(netOrig.params()); + net2With.getModelParams().assign(netOrig.getModelParams()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses()); CollectScoresListener listener = new CollectScoresListener(1); - net2With.setListeners(listener); + net2With.addTrainingListeners(listener); net2With.fit(t.getData()); for( int i=0; i<2; i++ ) { net2With = new MultiLayerNetwork(netOrig.getNetConfiguration().clone()); net2With.init(); - net2With.params().assign(netOrig.params()); + net2With.getModelParams().assign(netOrig.getModelParams()); log.info("Removing all except for specified layer helpers from network copy 2: " + t.getAllowHelpersForClasses()); removeHelpers(net2With.getLayers(), t.getAllowHelpersForClasses()); CollectScoresListener listener2 = new CollectScoresListener(1); - net2With.setListeners(listener2); + net2With.addTrainingListeners(listener2); net2With.fit(t.getData()); DoubleArrayList listOrig = listener.getListScore(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java index 374724ae5..495b21e18 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -25,7 +25,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer; @@ -67,7 +67,7 @@ public class TestUtils { restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); - assertEquals(net.params(), restored.params()); + assertEquals(net.getModelParams(), restored.getModelParams()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); @@ -91,7 +91,7 @@ public class TestUtils { restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); - assertEquals(net.params(), restored.params()); + assertEquals(net.getModelParams(), restored.getModelParams()); } catch (IOException e){ //Should never happen throw new RuntimeException(e); @@ -205,8 +205,8 @@ public class TestUtils { return null; } - public static L2Regularization getL2Reg(BaseLayer baseLayer){ - return getL2Reg(baseLayer.getRegularization()); + public static L2Regularization getL2Reg(BaseLayerConfiguration baseLayerConfiguration){ + return getL2Reg(baseLayerConfiguration.getRegularization()); } public static L2Regularization getL2Reg(List l){ @@ -218,7 +218,7 @@ public class TestUtils { return null; } - public static WeightDecay getWeightDecayReg(BaseLayer bl){ + public static WeightDecay getWeightDecayReg(BaseLayerConfiguration bl){ return getWeightDecayReg(bl.getRegularization()); } @@ -231,7 +231,7 @@ public class TestUtils { return null; } - public static double getL1(BaseLayer layer) { + public static double getL1(BaseLayerConfiguration layer) { List l = layer.getRegularization(); return getL1(l); } @@ -246,7 +246,7 @@ public class TestUtils { return l1Reg.getL1().valueAt(0,0); } - public static double getL2(BaseLayer layer) { + public static double getL2(BaseLayerConfiguration layer) { List l = layer.getRegularization(); return getL2(l); } @@ -269,7 +269,7 @@ public class TestUtils { return getL2(layer.getRegularization()); } - public static double getWeightDecay(BaseLayer layer) { + public static double getWeightDecay(BaseLayerConfiguration layer) { return getWeightDecayReg(layer.getRegularization()).getCoeff().valueAt(0,0); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java index f391f35f9..be740689b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/datasets/iterator/DataSetIteratorTest.java @@ -32,7 +32,6 @@ import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -183,7 +182,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { MultiLayerNetwork model = new MultiLayerNetwork(builder.build()); model.init(); - model.setListeners(new ScoreIterationListener(listenerFreq)); + model.addTrainingListeners(new ScoreIterationListener(listenerFreq)); model.fit(lfw.next()); @@ -247,7 +246,7 @@ public class DataSetIteratorTest extends BaseDL4JTest { //model.setListeners(Arrays.asList((TrainingListener) new ScoreIterationListener(listenerFreq))); CollectScoresIterationListener listener = new CollectScoresIterationListener(listenerFreq); - model.setListeners(listener); + model.addTrainingListeners(listener); model.fit(cifar); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java index 12e17fa3a..0923ba407 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStopping.java @@ -226,7 +226,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -255,7 +255,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter); @@ -304,7 +304,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -343,7 +343,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); @@ -386,7 +386,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); @@ -430,7 +430,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); int nSamples = 100; //Generate the training data INDArray x = Nd4j.linspace(-10, 10, nSamples).reshape(nSamples, 1); @@ -473,7 +473,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); MultipleEpochsIterator mIter = new MultipleEpochsIterator(10, irisIter); @@ -496,9 +496,9 @@ public class TestEarlyStopping extends BaseDL4JTest { assertEquals(net.getnLayers(), mln.getnLayers()); assertEquals(net.getNetConfiguration().getOptimizationAlgo(), mln.getNetConfiguration().getOptimizationAlgo()); - BaseLayer bl = (BaseLayer) net.getLayerConfiguration(); - assertEquals(bl.getActivationFn().toString(), ((BaseLayer) mln.getLayerConfiguration()).getActivationFn().toString()); - assertEquals(bl.getIUpdater(), ((BaseLayer) mln.getLayerConfiguration()).getIUpdater()); + BaseLayerConfiguration bl = (BaseLayerConfiguration) net.getLayerConfiguration(); + assertEquals(bl.getActivationFn().toString(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getActivationFn().toString()); + assertEquals(bl.getIUpdater(), ((BaseLayerConfiguration) mln.getLayerConfiguration()).getIUpdater()); } @Test @@ -511,7 +511,7 @@ public class TestEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -792,7 +792,7 @@ public class TestEarlyStopping extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); TestListener tl = new TestListener(); - net.setListeners(tl); + net.addTrainingListeners(tl); DataSetIterator irisIter = new IrisDataSetIterator(50, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java index fb55e2957..22b739f89 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/earlystopping/TestEarlyStoppingCompGraph.java @@ -84,7 +84,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -128,7 +128,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -165,7 +165,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); @@ -207,7 +207,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); @@ -241,7 +241,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(150, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -538,7 +538,7 @@ public class TestEarlyStoppingCompGraph extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); TestEarlyStopping.TestListener tl = new TestEarlyStopping.TestListener(); - net.setListeners(tl); + net.addTrainingListeners(tl); DataSetIterator irisIter = new IrisDataSetIterator(50, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java index 8b5f5d46b..04d6f440f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvalTest.java @@ -84,7 +84,7 @@ public class EvalTest extends BaseDL4JTest { // Instantiate model MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - model.addListeners(new ScoreIterationListener(1)); + model.addTrainingListeners(new ScoreIterationListener(1)); // Train-test split DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -324,7 +324,7 @@ public class EvalTest extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - net2.setParams(net1.params()); + net2.setParams(net1.getModelParams()); for(boolean useMask : new boolean[]{false, true}) { @@ -405,7 +405,7 @@ public class EvalTest extends BaseDL4JTest { ComputationGraph net2 = new ComputationGraph(conf2); net2.init(); - net2.setParams(net1.params()); + net2.setParams(net1.getModelParams()); for (boolean useMask : new boolean[]{false, true}) { @@ -492,7 +492,7 @@ public class EvalTest extends BaseDL4JTest { DataSetIterator iter = new IrisDataSetIterator(30, 150); DataSetIterator iterTest = new IrisDataSetIterator(30, 150); - net.setListeners(new EvaluativeListener(iterTest, 3)); + net.addTrainingListeners(new EvaluativeListener(iterTest, 3)); for( int i=0; i<3; i++ ){ net.fit(iter); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java index 5e6ed72bd..0380ed2a0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/BNGradientCheckTest.java @@ -26,7 +26,6 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.distribution.UniformDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -219,11 +218,11 @@ public class BNGradientCheckTest extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int k = 0; k < 20; k++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" @@ -323,11 +322,11 @@ public class BNGradientCheckTest extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int k = 0; k < 10; k++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" @@ -554,11 +553,11 @@ public class BNGradientCheckTest extends BaseDL4JTest { net.setInput(0, ds.getFeatures()); net.setLabels(ds.getLabels()); net.computeGradientAndScore(); - double scoreBefore = net.score(); + double scoreBefore = net.getScore(); for (int k = 0; k < 20; k++) net.fit(ds); net.computeGradientAndScore(); - double scoreAfter = net.score(); + double scoreAfter = net.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index bee788e55..d11bd33c6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -27,7 +27,6 @@ import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; @@ -120,11 +119,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = name + " - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation @@ -212,11 +211,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = testName + "- score did not (sufficiently) decrease during learning - activationFn=" diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java index 6cefb32aa..39dc54659 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/GradientCheckTests.java @@ -105,11 +105,11 @@ public class GradientCheckTests extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = "testMinibatchApplication() - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation @@ -184,11 +184,11 @@ public class GradientCheckTests extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation @@ -278,11 +278,11 @@ public class GradientCheckTests extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation @@ -452,11 +452,11 @@ public class GradientCheckTests extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", outputActivation=" + outputActivation @@ -523,13 +523,13 @@ public class GradientCheckTests extends BaseDL4JTest { netGraph.setInputs(features); netGraph.setLabels(labels); netGraph.computeGradientAndScore(); - double scoreBefore = netGraph.score(); + double scoreBefore = netGraph.getScore(); String msg; for (int epoch = 0; epoch < 5; epoch++) netGraph.fit(new INDArray[]{features}, new INDArray[]{labels}); netGraph.computeGradientAndScore(); - double scoreAfter = netGraph.score(); + double scoreAfter = netGraph.getScore(); //Can't test in 'characteristic mode of operation' if not learning msg = "elementWiseMultiplicationLayerTest() - score did not (sufficiently) decrease during learning - activationFn=" + "Id" + ", lossFn=" + "Cos-sim" + ", outputActivation=" + "Id" @@ -757,11 +757,11 @@ public class GradientCheckTests extends BaseDL4JTest { mln.setInput(ds.getFeatures()); mln.setLabels(ds.getLabels()); mln.computeGradientAndScore(); - double scoreBefore = mln.score(); + double scoreBefore = mln.getScore(); for (int j = 0; j < 10; j++) mln.fit(ds); mln.computeGradientAndScore(); - double scoreAfter = mln.score(); + double scoreAfter = mln.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = "testGradMLP2LayerIrisSimple() - score did not (sufficiently) decrease during learning - activationFn=" + afn + ", lossFn=" + lf + ", layerNorm=" + layerNorm + ", outputActivation=" + outputActivation diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index 0cf7ebd1b..6197f73d3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -666,7 +666,7 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { net.init(); //Check params to avoid test flakiness on small or large params - INDArray params = net.params(); + INDArray params = net.getModelParams(); for( int x=0; x 1.5){ double d = Nd4j.getRandom().nextDouble(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java index 700b70a6b..eead1511f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/MultiLayerNeuralNetConfigurationTest.java @@ -37,10 +37,9 @@ import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; @@ -254,8 +253,8 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); model2.init(); - float[] p1 = model1.params().data().asFloat(); - float[] p2 = model2.params().data().asFloat(); + float[] p1 = model1.getModelParams().data().asFloat(); + float[] p2 = model2.getModelParams().data().asFloat(); System.out.println(Arrays.toString(p1)); System.out.println(Arrays.toString(p2)); @@ -266,20 +265,20 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { public void testTrainingListener() { MultiLayerNetwork model1 = new MultiLayerNetwork(getConf()); model1.init(); - model1.addListeners(new ScoreIterationListener(1)); + model1.addTrainingListeners(new ScoreIterationListener(1)); MultiLayerNetwork model2 = new MultiLayerNetwork(getConf()); - model2.addListeners(new ScoreIterationListener(1)); + model2.addTrainingListeners(new ScoreIterationListener(1)); model2.init(); Layer[] l1 = model1.getLayers(); for (int i = 0; i < l1.length; i++) { - assertTrue(l1[i].getListeners() != null && l1[i].getListeners().size() == 1); + assertTrue(l1[i].getTrainingListeners() != null && l1[i].getTrainingListeners().size() == 1); } Layer[] l2 = model2.getLayers(); for (int i = 0; i < l2.length; i++) { - assertTrue(l2[i].getListeners() != null && l2[i].getListeners().size() == 1); + assertTrue(l2[i].getTrainingListeners() != null && l2[i].getTrainingListeners().size() == 1); } } @@ -384,10 +383,10 @@ public class MultiLayerNeuralNetConfigurationTest extends BaseDL4JTest { .weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()) .inputType(InputType.convolutional(28, 28, 1)).build(); - org.deeplearning4j.nn.conf.layers.BaseLayer l0 = (BaseLayer) conf.getConf(0).getLayer(); - org.deeplearning4j.nn.conf.layers.BaseLayer l1 = (BaseLayer) conf.getConf(1).getLayer(); - org.deeplearning4j.nn.conf.layers.BaseLayer l2 = (BaseLayer) conf.getConf(2).getLayer(); - org.deeplearning4j.nn.conf.layers.BaseLayer l3 = (BaseLayer) conf.getConf(3).getLayer(); + BaseLayerConfiguration l0 = (BaseLayerConfiguration) conf.getConf(0).getLayer(); + BaseLayerConfiguration l1 = (BaseLayerConfiguration) conf.getConf(1).getLayer(); + BaseLayerConfiguration l2 = (BaseLayerConfiguration) conf.getConf(2).getLayer(); + BaseLayerConfiguration l3 = (BaseLayerConfiguration) conf.getConf(3).getLayer(); assertEquals(0.5, ((Adam) l0.getUpdaterByParam("b")).getLearningRate(), 1e-6); assertEquals(1e-2, ((Adam) l0.getUpdaterByParam("W")).getLearningRate(), 1e-6); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java index 6a7ec6408..0ef220c25 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/NeuralNetConfigurationTest.java @@ -25,7 +25,7 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -100,7 +100,7 @@ public class NeuralNetConfigurationTest extends BaseDL4JTest { @Test public void testClone() { NeuralNetConfiguration conf = getConfig(1, 1, new WeightInitUniform(), true); - BaseLayer bl = (BaseLayer) conf.getFlattenedLayerConfigurations().get(0); + BaseLayerConfiguration bl = (BaseLayerConfiguration) conf.getFlattenedLayerConfigurations().get(0); conf.setStepFunction(new DefaultStepFunction()); NeuralNetConfiguration conf2 = conf.clone(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java index 9cf99a89c..be78b1ecf 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/graph/ShiftVertexTest.java @@ -158,7 +158,7 @@ public class ShiftVertexTest extends BaseDL4JTest { cg.setInput(0, input); cg.setLabel(0, target); cg.computeGradientAndScore(); - double score_dl4j = cg.score(); + double score_dl4j = cg.getScore(); Map weights = cg.getParamTable(); Gradient g = cg.gradient(); Map gradients = g.gradientForVariable(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java index db3731f6d..28d17c150 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java @@ -72,8 +72,8 @@ public class LayerConfigTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString()); - assertEquals("relu", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString()); + assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getActivationFn().toString()); + assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getActivationFn().toString()); //With conf = NeuralNetConfiguration.builder().activation(Activation.RELU) @@ -83,8 +83,8 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertEquals("relu", ((BaseLayer) conf.getConf(0).getLayer()).getActivationFn().toString()); - assertEquals("tanh", ((BaseLayer) conf.getConf(1).getLayer()).getActivationFn().toString()); + assertEquals("relu", ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getActivationFn().toString()); + assertEquals("tanh", ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getActivationFn().toString()); } @@ -99,11 +99,11 @@ public class LayerConfigTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn()); - assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn()); + assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn()); + assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn()); - assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0); - assertEquals(1, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0); + assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0); + assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0); //With: final Distribution overriddenDistribution = new UniformDistribution(0, 1); @@ -117,11 +117,11 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayer) conf.getConf(0).getLayer()).getWeightInitFn()); - assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayer) conf.getConf(1).getLayer()).getWeightInitFn()); + assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn()); + assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn()); - assertEquals(1, ((BaseLayer) conf.getConf(0).getLayer()).getBiasInit(), 0.0); - assertEquals(0, ((BaseLayer) conf.getConf(1).getLayer()).getBiasInit(), 0.0); + assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0); + assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0); } /* @@ -137,8 +137,8 @@ public class LayerConfigTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.3, ((BaseLayer) conf.getConf(0).getLayer()).getLearningRate(), 0.0); - assertEquals(0.3, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0); + assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getLearningRate(), 0.0); + assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getLearningRate(), 0.0); //With: conf = NeuralNetConfiguration.builder().learningRate(0.3) @@ -148,8 +148,8 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.3, ((BaseLayer) conf.getConf(0).getLayer()).getLearningRate(), 0.0); - assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getLearningRate(), 0.0); + assertEquals(0.3, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getLearningRate(), 0.0); + assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getLearningRate(), 0.0); //L1 and L2 without layerwise override: conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2) @@ -158,10 +158,10 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.1, ((BaseLayer) conf.getConf(0).getLayer()).getL1(), 0.0); - assertEquals(0.1, ((BaseLayer) conf.getConf(1).getLayer()).getL1(), 0.0); - assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0); - assertEquals(0.2, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0); + assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL1(), 0.0); + assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL1(), 0.0); + assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL2(), 0.0); + assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL2(), 0.0); //L1 and L2 with layerwise override: conf = NeuralNetConfiguration.builder().l1(0.1).l2(0.2) @@ -170,10 +170,10 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.9, ((BaseLayer) conf.getConf(0).getLayer()).getL1(), 0.0); - assertEquals(0.1, ((BaseLayer) conf.getConf(1).getLayer()).getL1(), 0.0); - assertEquals(0.2, ((BaseLayer) conf.getConf(0).getLayer()).getL2(), 0.0); - assertEquals(0.8, ((BaseLayer) conf.getConf(1).getLayer()).getL2(), 0.0); + assertEquals(0.9, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL1(), 0.0); + assertEquals(0.1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL1(), 0.0); + assertEquals(0.2, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getL2(), 0.0); + assertEquals(0.8, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getL2(), 0.0); }*/ @@ -213,8 +213,8 @@ public class LayerConfigTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); - assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); + assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); + assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); Map testMomentumAfter2 = new HashMap<>(); testMomentumAfter2.put(0, 0.2); @@ -227,8 +227,8 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.1, ((Nesterovs)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); - assertEquals(0.2, ((Nesterovs)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); + assertEquals(0.1, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); + assertEquals(0.2, ((Nesterovs)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getMomentumISchedule().valueAt(0,0), 0.0); } @Test @@ -239,10 +239,10 @@ public class LayerConfigTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta); - assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); - assertEquals(0.5, ((AdaDelta)((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0); - assertEquals(0.01, ((AdaDelta)((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); + assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof AdaDelta); + assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); + assertEquals(0.5, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRho(), 0.0); + assertEquals(0.01, ((AdaDelta)((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); conf = NeuralNetConfiguration.builder().updater(new RmsProp(1.0, 2.0, RmsProp.DEFAULT_RMSPROP_EPSILON)) .layer(0, new DenseLayer.Builder().nIn(2).nOut(2).updater(new RmsProp(1.0, 1.0, RmsProp.DEFAULT_RMSPROP_EPSILON)).build()) @@ -252,10 +252,10 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertTrue(((BaseLayer) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp); - assertTrue(((BaseLayer) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); - assertEquals(1.0, ((RmsProp) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0); - assertEquals(0.5, ((AdaDelta) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); + assertTrue(((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater() instanceof RmsProp); + assertTrue(((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater() instanceof AdaDelta); + assertEquals(1.0, ((RmsProp) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getRmsDecay(), 0.0); + assertEquals(0.5, ((AdaDelta) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getRho(), 0.0); } @@ -270,10 +270,10 @@ public class LayerConfigTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0); - assertEquals(0.6, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0); - assertEquals(0.5, ((Adam) ((BaseLayer) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0); - assertEquals(0.7, ((Adam) ((BaseLayer) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0); + assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta1(), 0.0); + assertEquals(0.6, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta1(), 0.0); + assertEquals(0.5, ((Adam) ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getIUpdater()).getBeta2(), 0.0); + assertEquals(0.7, ((Adam) ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getIUpdater()).getBeta2(), 0.0); } @Test @@ -287,13 +287,11 @@ public class LayerConfigTest extends BaseDL4JTest { .layer(1, new DenseLayer.Builder().nIn(2).nOut(2).build()).build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - - assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - conf.getConf(0).getLayer().getGradientNormalization()); - assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - conf.getConf(1).getLayer().getGradientNormalization()); - assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0); - assertEquals(10, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0); + BaseLayerConfiguration bconf = (BaseLayerConfiguration) conf.getConf(0).getLayer(); + assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization()); + assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization()); + assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0); + assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0); //With: conf = NeuralNetConfiguration.builder() @@ -308,11 +306,10 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, - conf.getConf(0).getLayer().getGradientNormalization()); - assertEquals(GradientNormalization.None, conf.getConf(1).getLayer().getGradientNormalization()); - assertEquals(10, conf.getConf(0).getLayer().getGradientNormalizationThreshold(), 0.0); - assertEquals(2.5, conf.getConf(1).getLayer().getGradientNormalizationThreshold(), 0.0); + assertEquals(GradientNormalization.ClipElementWiseAbsoluteValue, bconf.getGradientNormalization()); + assertEquals(GradientNormalization.None, bconf.getGradientNormalization()); + assertEquals(10, bconf.getGradientNormalizationThreshold(), 0.0); + assertEquals(2.5, bconf.getGradientNormalizationThreshold(), 0.0); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java index 65532a0bc..dae839a06 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java @@ -162,12 +162,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - BaseLayer layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); + BaseLayerConfiguration layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); assertEquals(expectedMomentum, ((Nesterovs) layerConf.getIUpdater()).getMomentum(), 1e-3); assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); - BaseLayer layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); + BaseLayerConfiguration layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); assertEquals(0.4, ((Nesterovs) layerConf1.getIUpdater()).getMomentum(), 1e-3); // Adam Updater @@ -178,11 +178,11 @@ public class LayerConfigValidationTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); + layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); assertEquals(0.3, TestUtils.getL1(layerConf), 1e-3); assertEquals(0.5, TestUtils.getL2(layerConf), 1e-3); - layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); + layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3); assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3); assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn()); @@ -196,12 +196,12 @@ public class LayerConfigValidationTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - layerConf = (BaseLayer) net.getLayer(0).getLayerConfiguration(); + layerConf = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); assertEquals(expectedRmsDecay, ((RmsProp) layerConf.getIUpdater()).getRmsDecay(), 1e-3); assertNull(TestUtils.getL1Reg(layerConf.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf.getRegularization())); - layerConf1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); + layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); assertEquals(0.4, ((RmsProp) layerConf1.getIUpdater()).getRmsDecay(), 1e-3); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java index 4d4b36013..8977d1b3f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/weightnoise/TestWeightNoise.java @@ -29,7 +29,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -75,9 +75,9 @@ public class TestWeightNoise extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(wn, ((BaseLayer) net.getLayer(0).getLayerConfiguration()).getWeightNoise()); - assertEquals(new DropConnect(0.25), ((BaseLayer) net.getLayer(1).getLayerConfiguration()).getWeightNoise()); - assertEquals(wn, ((BaseLayer) net.getLayer(2).getLayerConfiguration()).getWeightNoise()); + assertEquals(wn, ((BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration()).getWeightNoise()); + assertEquals(new DropConnect(0.25), ((BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration()).getWeightNoise()); + assertEquals(wn, ((BaseLayerConfiguration) net.getLayer(2).getLayerConfiguration()).getWeightNoise()); TestUtils.testModelSerialization(net); @@ -95,9 +95,9 @@ public class TestWeightNoise extends BaseDL4JTest { ComputationGraph graph = new ComputationGraph(conf2); graph.init(); - assertEquals(wn, ((BaseLayer) graph.getLayer(0).getLayerConfiguration()).getWeightNoise()); - assertEquals(new DropConnect(0.25), ((BaseLayer) graph.getLayer(1).getLayerConfiguration()).getWeightNoise()); - assertEquals(wn, ((BaseLayer) graph.getLayer(2).getLayerConfiguration()).getWeightNoise()); + assertEquals(wn, ((BaseLayerConfiguration) graph.getLayer(0).getLayerConfiguration()).getWeightNoise()); + assertEquals(new DropConnect(0.25), ((BaseLayerConfiguration) graph.getLayer(1).getLayerConfiguration()).getWeightNoise()); + assertEquals(wn, ((BaseLayerConfiguration) graph.getLayer(2).getLayerConfiguration()).getWeightNoise()); TestUtils.testModelSerialization(graph); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java index e37b7b7cb..2f2a316dd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/dtypes/DTypeTests.java @@ -124,7 +124,7 @@ import org.deeplearning4j.nn.conf.layers.recurrent.TimeDistributed; import org.deeplearning4j.nn.conf.layers.util.MaskLayer; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToRnnPreProcessor; @@ -260,8 +260,8 @@ public class DTypeTests extends BaseDL4JTest { for (NeuralNetConfiguration nnc : conf.getNetConfigurations()) { LayerConfiguration l = nnc.getFlattenedLayerConfigurations().get(0); seenLayers.add(l.getClass()); - if (l instanceof BaseWrapperLayer) { - BaseWrapperLayer bwl = (BaseWrapperLayer) l; + if (l instanceof BaseWrapperLayerConfiguration) { + BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration) l; seenLayers.add(bwl.getUnderlying().getClass()); } else if (l instanceof Bidirectional) { seenLayers.add(((Bidirectional) l).getFwd().getClass()); @@ -321,17 +321,17 @@ public class DTypeTests extends BaseDL4JTest { net.setInput(inD); net.setLabels(lD); net.computeGradientAndScore(); - double scoreDouble = net.score(); + double scoreDouble = net.getScore(); INDArray grads = net.getFlattenedGradients(); INDArray u = net.getUpdater().getStateViewArray(); - assertEquals(DataType.DOUBLE, net.params().dataType()); + assertEquals(DataType.DOUBLE, net.getModelParams().dataType()); assertEquals(DataType.DOUBLE, grads.dataType()); assertEquals(DataType.DOUBLE, u.dataType()); MultiLayerNetwork netFloat = net.convertDataType(DataType.FLOAT); netFloat.initGradientsView(); - assertEquals(DataType.FLOAT, netFloat.params().dataType()); + assertEquals(DataType.FLOAT, netFloat.getModelParams().dataType()); assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType()); assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType()); INDArray inF = inD.castTo(DataType.FLOAT); @@ -340,7 +340,7 @@ public class DTypeTests extends BaseDL4JTest { netFloat.setInput(inF); netFloat.setLabels(lF); netFloat.computeGradientAndScore(); - double scoreFloat = netFloat.score(); + double scoreFloat = netFloat.getScore(); INDArray gradsFloat = netFloat.getFlattenedGradients(); INDArray uFloat = netFloat.getUpdater().getStateViewArray(); @@ -352,7 +352,7 @@ public class DTypeTests extends BaseDL4JTest { MultiLayerNetwork netFP16 = net.convertDataType(DataType.HALF); netFP16.initGradientsView(); - assertEquals(DataType.HALF, netFP16.params().dataType()); + assertEquals(DataType.HALF, netFP16.getModelParams().dataType()); assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType()); assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType()); @@ -362,7 +362,7 @@ public class DTypeTests extends BaseDL4JTest { netFP16.setInput(inH); netFP16.setLabels(lH); netFP16.computeGradientAndScore(); - double scoreHalf = netFP16.score(); + double scoreHalf = netFP16.getScore(); INDArray gradsHalf = netFP16.getFlattenedGradients(); INDArray uHalf = netFP16.getUpdater().getStateViewArray(); @@ -406,17 +406,17 @@ public class DTypeTests extends BaseDL4JTest { net.setInput(0, inD); net.setLabels(lD); net.computeGradientAndScore(); - double scoreDouble = net.score(); + double scoreDouble = net.getScore(); INDArray grads = net.getFlattenedGradients(); INDArray u = net.getUpdater().getStateViewArray(); - assertEquals(DataType.DOUBLE, net.params().dataType()); + assertEquals(DataType.DOUBLE, net.getModelParams().dataType()); assertEquals(DataType.DOUBLE, grads.dataType()); assertEquals(DataType.DOUBLE, u.dataType()); ComputationGraph netFloat = net.convertDataType(DataType.FLOAT); netFloat.initGradientsView(); - assertEquals(DataType.FLOAT, netFloat.params().dataType()); + assertEquals(DataType.FLOAT, netFloat.getModelParams().dataType()); assertEquals(DataType.FLOAT, netFloat.getFlattenedGradients().dataType()); assertEquals(DataType.FLOAT, netFloat.getUpdater(true).getStateViewArray().dataType()); INDArray inF = inD.castTo(DataType.FLOAT); @@ -425,7 +425,7 @@ public class DTypeTests extends BaseDL4JTest { netFloat.setInput(0, inF); netFloat.setLabels(lF); netFloat.computeGradientAndScore(); - double scoreFloat = netFloat.score(); + double scoreFloat = netFloat.getScore(); INDArray gradsFloat = netFloat.getFlattenedGradients(); INDArray uFloat = netFloat.getUpdater().getStateViewArray(); @@ -437,7 +437,7 @@ public class DTypeTests extends BaseDL4JTest { ComputationGraph netFP16 = net.convertDataType(DataType.HALF); netFP16.initGradientsView(); - assertEquals(DataType.HALF, netFP16.params().dataType()); + assertEquals(DataType.HALF, netFP16.getModelParams().dataType()); assertEquals(DataType.HALF, netFP16.getFlattenedGradients().dataType()); assertEquals(DataType.HALF, netFP16.getUpdater(true).getStateViewArray().dataType()); @@ -447,7 +447,7 @@ public class DTypeTests extends BaseDL4JTest { netFP16.setInput(0, inH); netFP16.setLabels(lH); netFP16.computeGradientAndScore(); - double scoreHalf = netFP16.score(); + double scoreHalf = netFP16.getScore(); INDArray gradsHalf = netFP16.getFlattenedGradients(); INDArray uHalf = netFP16.getUpdater().getStateViewArray(); @@ -536,7 +536,7 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getModelParams().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); @@ -641,7 +641,7 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getModelParams().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); @@ -754,7 +754,7 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getModelParams().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); @@ -827,7 +827,7 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getModelParams().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); @@ -916,7 +916,7 @@ public class DTypeTests extends BaseDL4JTest { net.init(); net.initGradientsView(); - assertEquals(networkDtype, net.params().dataType(), msg); + assertEquals(networkDtype, net.getModelParams().dataType(), msg); assertEquals(networkDtype, net.getFlattenedGradients().dataType(), msg); assertEquals(networkDtype, net.getUpdater(true).getStateViewArray().dataType(), msg); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java index de8c16075..4197263b6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/ComputationGraphTestRNN.java @@ -520,9 +520,9 @@ public class ComputationGraphTestRNN extends BaseDL4JTest { INDArray inputLong = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); INDArray labelsLong = Nd4j.rand(miniBatchSize, nOut, timeSeriesLength); - INDArray initialParams = graph.params().dup(); + INDArray initialParams = graph.getModelParams().dup(); graph.fit(new INDArray[] {inputLong}, new INDArray[] {labelsLong}); - INDArray afterParams = graph.params(); + INDArray afterParams = graph.getModelParams(); assertNotEquals(initialParams, afterParams); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java index d83f4ac17..4129592b6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphCNN.java @@ -117,7 +117,7 @@ public class TestCompGraphCNN extends BaseDL4JTest { boolean orderOK = Arrays.equals(expOrder1, order) || Arrays.equals(expOrder2, order); assertTrue(orderOK); - INDArray params = graph.params(); + INDArray params = graph.getModelParams(); assertNotNull(params); // confirm param shape is what is expected @@ -129,7 +129,7 @@ public class TestCompGraphCNN extends BaseDL4JTest { // params are set graph.setParams(arr); - params = graph.params(); + params = graph.getModelParams(); assertEquals(arr, params); //Number of inputs and outputs: diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java index b24dc76ed..2cf9e0db4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestCompGraphUnsupervised.java @@ -108,7 +108,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest { } } - int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0); + int count = Nd4j.getExecutioner().exec(new MatchCondition(cg.getModelParams(), Conditions.isNan())).getInt(0); assertEquals(0, count); @@ -125,7 +125,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest { } } - count = Nd4j.getExecutioner().exec(new MatchCondition(cg.params(), Conditions.isNan())).getInt(0); + count = Nd4j.getExecutioner().exec(new MatchCondition(cg.getModelParams(), Conditions.isNan())).getInt(0); assertEquals(0, count); } } @@ -176,7 +176,7 @@ public class TestCompGraphUnsupervised extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); cg.pretrainLayer("0", ds); - assertEquals(net.params(), cg.params()); + assertEquals(net.getModelParams(), cg.getModelParams()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 7feb29ddb..46180da6d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -159,7 +159,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { DataSet ds = iris.next(); graph.setInput(0, ds.getFeatures()); - net.setParams(graph.params()); + net.setParams(graph.getModelParams()); Map activations = graph.feedForward(false); List feedForward = net.feedForward(ds.getFeatures()); @@ -184,7 +184,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { int[] expOrder = new int[]{0, 1, 2}; assertArrayEquals(expOrder, order); //Only one valid order: 0 (input) -> 1 (firstlayer) -> 2 (outputlayer) - INDArray params = graph.params(); + INDArray params = graph.getModelParams(); assertNotNull(params); int nParams = getNumParams(); @@ -194,7 +194,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { assertEquals(nParams, arr.length()); graph.setParams(arr); - params = graph.params(); + params = graph.getModelParams(); assertEquals(arr, params); //Number of inputs and outputs: @@ -315,8 +315,8 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { graph.fit(iris); //Check that parameters are equal for both models after fitting: - INDArray paramsMLN = net.params(); - INDArray paramsGraph = graph.params(); + INDArray paramsMLN = net.getModelParams(); + INDArray paramsGraph = graph.getModelParams(); assertNotEquals(params, paramsGraph); assertEquals(paramsMLN, paramsGraph); @@ -636,7 +636,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); net.init(); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator iter = new IrisDataSetIterator(10, 150); net.pretrain(iter); @@ -675,7 +675,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph netNoReg = new ComputationGraph(confNoReg); netNoReg.init(); - netNoReg.setParams(net.params().dup()); + netNoReg.setParams(net.getModelParams().dup()); //Score single example, and compare to scoreExamples: INDArray input = Nd4j.rand(3, nIn); @@ -878,13 +878,13 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { net.setParam("first_b", Nd4j.ones(1, 5)); net.setParam("output_W", Nd4j.ones(5, 3)); net.setParam("output_b", Nd4j.ones(1, 3)); - INDArray actualParams = net.params(); + INDArray actualParams = net.getModelParams(); // Confirm params assertEquals(Nd4j.ones(1, 43), actualParams); net.update(expectedGradient); - actualParams = net.params(); + actualParams = net.getModelParams(); assertEquals(Nd4j.ones(1, 43).addi(1), actualParams); } @@ -1638,7 +1638,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { conf3.setTopologicalOrderStr(null); ComputationGraph cg3 = new ComputationGraph(conf3); cg3.init(); - cg3.setParams(cg2.params()); + cg3.setParams(cg2.getModelParams()); int[] order3 = cg3.topologicalSortOrder(); List strOrder3 = cg.getComputationGraphConfiguration().getTopologicalOrderStr(); @@ -1712,7 +1712,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { exp.add(ComputationGraph.class); MultiLayerTest.CheckModelsListener listener = new MultiLayerTest.CheckModelsListener(); - net.setListeners(listener); + net.addTrainingListeners(listener); INDArray f = Nd4j.create(1,10); INDArray l = Nd4j.create(1,10); @@ -1874,7 +1874,7 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { ComputationGraph cg = new ComputationGraph(conf); cg.init(); - cg.params().assign(Nd4j.linspace(1, 220, 220).reshape(1, -11)); + cg.getModelParams().assign(Nd4j.linspace(1, 220, 220).reshape(1, -11)); INDArray p0w = cg.getParam("layer_zero_W"); assertEquals(Nd4j.linspace(1, 100, 100).reshape('f', 10, 10), p0w); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java index 2f752b316..685920d10 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestSetGetParameters.java @@ -56,7 +56,7 @@ public class TestSetGetParameters extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); net.init(); - INDArray params = net.params(); + INDArray params = net.getModelParams(); ComputationGraph net2 = new ComputationGraph(conf); @@ -65,11 +65,11 @@ public class TestSetGetParameters extends BaseDL4JTest { ComputationGraph net3 = new ComputationGraph(conf); net3.init(params, false); - assertEquals(params, net2.params()); - assertEquals(params, net3.params()); + assertEquals(params, net2.getModelParams()); + assertEquals(params, net3.getModelParams()); - assertNotSame(params, net2.params()); //Different objects due to clone - assertSame(params, net3.params()); //Same object due to clone + assertNotSame(params, net2.getModelParams()); //Different objects due to clone + assertSame(params, net3.getModelParams()); //Same object due to clone Map paramsMap = net.getParamTable(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java index 237e7550e..7023e0039 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/graph/TestVariableLengthTSCG.java @@ -103,14 +103,14 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { net.setInput(0, in1); net.setLabel(0, labels1); net.computeGradientAndScore(); - double score1 = net.score(); + double score1 = net.getScore(); Gradient g1 = net.gradient(); net.setInput(0, in2); net.setLabel(0, labels2); net.setLayerMaskArrays(null, new INDArray[] {labelMask}); net.computeGradientAndScore(); - double score2 = net.score(); + double score2 = net.getScore(); Gradient g2 = net.gradient(); //Scores and gradients should be identical for two cases (given mask array) @@ -134,7 +134,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { } net.setLabel(0, labels2); net.computeGradientAndScore(); - double score2a = net.score(); + double score2a = net.getScore(); Gradient g2a = net.gradient(); assertEquals(score2, score2a, 1e-6); for (String s : g2map.keySet()) { @@ -200,7 +200,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { net.setInput(0, in1); net.setLabel(0, labels1); net.computeGradientAndScore(); - double score1 = net.score(); + double score1 = net.getScore(); Gradient g1 = net.gradient(); Map map = g1.gradientForVariable(); for (String s : map.keySet()) { @@ -211,7 +211,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { net.setLabel(0, labels2); net.setLayerMaskArrays(new INDArray[] {inputMask}, null); net.computeGradientAndScore(); - double score2 = net.score(); + double score2 = net.getScore(); Gradient g2 = net.gradient(); Map activations2 = net.feedForward(); @@ -236,7 +236,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { net.setInput(0, in2); net.setLayerMaskArrays(new INDArray[]{inputMask}, null); net.computeGradientAndScore(); - double score2a = net.score(); + double score2a = net.getScore(); Gradient g2a = net.gradient(); assertEquals(score2, score2a, 1e-12); for (String s : g2.gradientForVariable().keySet()) { @@ -330,7 +330,7 @@ public class TestVariableLengthTSCG extends BaseDL4JTest { net.setLabel(0, labels); net.computeGradientAndScore(); - double score = net.score(); + double score = net.getScore(); assertEquals(expScore, score, 0.1, msg); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerConfigurationTest.java similarity index 98% rename from cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java rename to cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerConfigurationTest.java index 189467ab4..c481d20df 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/BaseLayerConfigurationTest.java @@ -40,7 +40,7 @@ import java.util.Map; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotEquals; -public class BaseLayerTest extends BaseDL4JTest { +public class BaseLayerConfigurationTest extends BaseDL4JTest { protected INDArray weight = Nd4j.create(new double[] {0.10, -0.20, -0.15, 0.05}, new int[] {2, 2}); protected INDArray bias = Nd4j.create(new double[] {0.5, 0.5}, new int[] {1, 2}); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java index 002495133..7898d35ad 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CacheModeTest.java @@ -56,10 +56,10 @@ public class CacheModeTest extends BaseDL4JTest { INDArray out2 = net2.output(in); assertEquals(out1, out2); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); net1.fit(in, labels); net2.fit(in, labels); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); } private static NeuralNetConfiguration getConf(CacheMode cacheMode){ @@ -99,10 +99,10 @@ public class CacheModeTest extends BaseDL4JTest { INDArray out2 = net2.output(in); assertEquals(out1, out2); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); net1.fit(in, labels); net2.fit(in, labels); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); } } @@ -145,10 +145,10 @@ public class CacheModeTest extends BaseDL4JTest { INDArray out2 = net2.outputSingle(in); assertEquals(out1, out2); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); net1.fit(new DataSet(in, labels)); net2.fit(new DataSet(in, labels)); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); } private static ComputationGraphConfiguration getConfCG(CacheMode cacheMode){ diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java index 9f5597199..84f94928b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/CenterLossOutputLayerTest.java @@ -121,7 +121,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest { graph.setInput(0, input); graph.setLabel(0, labels); graph.computeGradientAndScore(); - results[i] = graph.score(); + results[i] = graph.getScore(); } assertNotEquals(results[0], results[1]); @@ -137,7 +137,7 @@ public class CenterLossOutputLayerTest extends BaseDL4JTest { ComputationGraph net = getCNNMnistConfig(); net.init(); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); for (int i = 0; i < 50; i++) { net.fit(mnistTrain.next()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java index 716bbb8a9..80cf35543 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/DropoutLayerTest.java @@ -265,7 +265,7 @@ public class DropoutLayerTest extends BaseDL4JTest { MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate); netSeparate.init(); - assertEquals(netIntegrated.params(), netSeparate.params()); + assertEquals(netIntegrated.getModelParams(), netSeparate.getModelParams()); Nd4j.getRandom().setSeed(12345); netIntegrated.fit(next); @@ -273,7 +273,7 @@ public class DropoutLayerTest extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); netSeparate.fit(next); - assertEquals(netIntegrated.params(), netSeparate.params()); + assertEquals(netIntegrated.getModelParams(), netSeparate.getModelParams()); // check parameters assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W")); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java index 2b8977ed0..20880d71a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerTest.java @@ -80,7 +80,7 @@ public class FrozenLayerTest extends BaseDL4JTest { .setFeatureExtractor(1).build(); INDArray paramsLastTwoLayers = - Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); + Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams()); MultiLayerNetwork notFrozen = new MultiLayerNetwork( (NeuralNetConfiguration) overallConf.clone() .layer(0, new Builder().nIn(2).nOut(3).build()) @@ -102,9 +102,9 @@ public class FrozenLayerTest extends BaseDL4JTest { modelNow.fit(randomData); } - INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), - notFrozen.params()); - INDArray act = modelNow.params(); + INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), modelToFineTune.getLayer(1).getParams(), + notFrozen.getModelParams()); + INDArray act = modelNow.getModelParams(); assertEquals(expected, act); } @@ -136,7 +136,7 @@ public class FrozenLayerTest extends BaseDL4JTest { assertEquals(modelNow.getNetConfiguration().toJson(), clonedModel.getNetConfiguration().toJson()); //Check params - assertEquals(modelNow.params(), clonedModel.params()); + assertEquals(modelNow.getModelParams(), clonedModel.getModelParams()); MultiLayerNetwork notFrozen = new MultiLayerNetwork( (NeuralNetConfiguration) overallConf.layer(0, new Builder().nIn(2).nOut(3).build()) @@ -145,7 +145,7 @@ public class FrozenLayerTest extends BaseDL4JTest { .activation(Activation.SOFTMAX).nIn(3).nOut(3) .build()) .build(), - Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params())); + Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams())); int i = 0; while (i < 5) { @@ -155,10 +155,10 @@ public class FrozenLayerTest extends BaseDL4JTest { i++; } - INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), - modelToFineTune.getLayer(1).params(), notFrozen.params()); - assertEquals(expectedParams, modelNow.params()); - assertEquals(expectedParams, clonedModel.params()); + INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), + modelToFineTune.getLayer(1).getParams(), notFrozen.getModelParams()); + assertEquals(expectedParams, modelNow.getModelParams()); + assertEquals(expectedParams, clonedModel.getModelParams()); } @@ -199,8 +199,8 @@ public class FrozenLayerTest extends BaseDL4JTest { .setOutputs("layer1").build()); notFrozen.init(); - notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), - modelToFineTune.getLayer("layer3").params())); + notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").getParams(), + modelToFineTune.getLayer("layer3").getParams())); int i = 0; while (i < 5) { @@ -209,8 +209,8 @@ public class FrozenLayerTest extends BaseDL4JTest { i++; } - assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), - modelToFineTune.getLayer("layer1").params(), notFrozen.params()), modelNow.params()); + assertEquals(Nd4j.hstack(modelToFineTune.getLayer("layer0").getParams(), + modelToFineTune.getLayer("layer1").getParams(), notFrozen.getModelParams()), modelNow.getModelParams()); } @Test @@ -244,7 +244,7 @@ public class FrozenLayerTest extends BaseDL4JTest { assertEquals(clonedModel.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson()); //Check params - assertEquals(modelNow.params(), clonedModel.params()); + assertEquals(modelNow.getModelParams(), clonedModel.getModelParams()); ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") .addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In") @@ -256,8 +256,8 @@ public class FrozenLayerTest extends BaseDL4JTest { "layer0") .setOutputs("layer1").build()); notFrozen.init(); - notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), - modelToFineTune.getLayer("layer3").params())); + notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").getParams(), + modelToFineTune.getLayer("layer3").getParams())); int i = 0; @@ -268,10 +268,10 @@ public class FrozenLayerTest extends BaseDL4JTest { i++; } - INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), - modelToFineTune.getLayer("layer1").params(), notFrozen.params()); - assertEquals(expectedParams, modelNow.params()); - assertEquals(expectedParams, clonedModel.params()); + INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").getParams(), + modelToFineTune.getLayer("layer1").getParams(), notFrozen.getModelParams()); + assertEquals(expectedParams, modelNow.getModelParams()); + assertEquals(expectedParams, clonedModel.getModelParams()); } @@ -305,7 +305,7 @@ public class FrozenLayerTest extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); String json = conf2.toJson(); @@ -362,7 +362,7 @@ public class FrozenLayerTest extends BaseDL4JTest { ComputationGraph net2 = new ComputationGraph(conf2); net2.init(); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); String json = conf2.toJson(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java index 89c359ae7..d47973a89 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackpropTest.java @@ -75,7 +75,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); String json = conf2.toJson(); @@ -130,7 +130,7 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { ComputationGraph net2 = new ComputationGraph(conf2); net2.init(); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); String json = conf2.toJson(); @@ -170,19 +170,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { MultiLayerNetwork network = new MultiLayerNetwork(conf1); network.init(); - INDArray unfrozenLayerParams = network.getLayer(0).params().dup(); - INDArray frozenLayerParams1 = network.getLayer(1).params().dup(); - INDArray frozenLayerParams2 = network.getLayer(2).params().dup(); - INDArray frozenOutputLayerParams = network.getLayer(3).params().dup(); + INDArray unfrozenLayerParams = network.getLayer(0).getParams().dup(); + INDArray frozenLayerParams1 = network.getLayer(1).getParams().dup(); + INDArray frozenLayerParams2 = network.getLayer(2).getParams().dup(); + INDArray frozenOutputLayerParams = network.getLayer(3).getParams().dup(); for (int i = 0; i < 100; i++) { network.fit(randomData); } - assertNotEquals(unfrozenLayerParams, network.getLayer(0).params()); - assertEquals(frozenLayerParams1, network.getLayer(1).params()); - assertEquals(frozenLayerParams2, network.getLayer(2).params()); - assertEquals(frozenOutputLayerParams, network.getLayer(3).params()); + assertNotEquals(unfrozenLayerParams, network.getLayer(0).getParams()); + assertEquals(frozenLayerParams1, network.getLayer(1).getParams()); + assertEquals(frozenLayerParams2, network.getLayer(2).getParams()); + assertEquals(frozenOutputLayerParams, network.getLayer(3).getParams()); } @@ -228,19 +228,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { ComputationGraph computationGraph = new ComputationGraph(computationGraphConf); computationGraph.init(); - INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); - INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); - INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); - INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).params().dup(); + INDArray unfrozenLayerParams = computationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup(); + INDArray frozenLayerParams1 = computationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup(); + INDArray frozenLayerParams2 = computationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup(); + INDArray frozenOutputLayerParams = computationGraph.getLayer(frozenBranchOutput).getParams().dup(); for (int i = 0; i < 100; i++) { computationGraph.fit(randomData); } - assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).params()); - assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).params()); - assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).params()); - assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).params()); + assertNotEquals(unfrozenLayerParams, computationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams()); + assertEquals(frozenLayerParams1, computationGraph.getLayer(frozenBranchFrozenLayer1).getParams()); + assertEquals(frozenLayerParams2, computationGraph.getLayer(frozenBranchFrozenLayer2).getParams()); + assertEquals(frozenOutputLayerParams, computationGraph.getLayer(frozenBranchOutput).getParams()); } @@ -275,17 +275,17 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { .build(); MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen); frozenNetwork.init(); - INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).params().dup(); - INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).params().dup(); - INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).params().dup(); - INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).params().dup(); + INDArray unfrozenLayerParams = frozenNetwork.getLayer(0).getParams().dup(); + INDArray frozenLayerParams1 = frozenNetwork.getLayer(1).getParams().dup(); + INDArray frozenLayerParams2 = frozenNetwork.getLayer(2).getParams().dup(); + INDArray frozenOutputLayerParams = frozenNetwork.getLayer(3).getParams().dup(); MultiLayerNetwork sgdNetwork = new MultiLayerNetwork(confSgd); sgdNetwork.init(); - INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).params().dup(); - INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).params().dup(); - INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).params().dup(); - INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).params().dup(); + INDArray unfrozenSgdLayerParams = sgdNetwork.getLayer(0).getParams().dup(); + INDArray frozenSgdLayerParams1 = sgdNetwork.getLayer(1).getParams().dup(); + INDArray frozenSgdLayerParams2 = sgdNetwork.getLayer(2).getParams().dup(); + INDArray frozenSgdOutputLayerParams = sgdNetwork.getLayer(3).getParams().dup(); for (int i = 0; i < 100; i++) { frozenNetwork.fit(randomData); @@ -294,10 +294,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { sgdNetwork.fit(randomData); } - assertEquals(frozenNetwork.getLayer(0).params(), sgdNetwork.getLayer(0).params()); - assertEquals(frozenNetwork.getLayer(1).params(), sgdNetwork.getLayer(1).params()); - assertEquals(frozenNetwork.getLayer(2).params(), sgdNetwork.getLayer(2).params()); - assertEquals(frozenNetwork.getLayer(3).params(), sgdNetwork.getLayer(3).params()); + assertEquals(frozenNetwork.getLayer(0).getParams(), sgdNetwork.getLayer(0).getParams()); + assertEquals(frozenNetwork.getLayer(1).getParams(), sgdNetwork.getLayer(1).getParams()); + assertEquals(frozenNetwork.getLayer(2).getParams(), sgdNetwork.getLayer(2).getParams()); + assertEquals(frozenNetwork.getLayer(3).getParams(), sgdNetwork.getLayer(3).getParams()); } @@ -360,17 +360,17 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { ComputationGraph frozenComputationGraph = new ComputationGraph(computationGraphConf); frozenComputationGraph.init(); - INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); - INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); - INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); - INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).params().dup(); + INDArray unfrozenLayerParams = frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup(); + INDArray frozenLayerParams1 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup(); + INDArray frozenLayerParams2 = frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup(); + INDArray frozenOutputLayerParams = frozenComputationGraph.getLayer(frozenBranchOutput).getParams().dup(); ComputationGraph sgdComputationGraph = new ComputationGraph(computationSgdGraphConf); sgdComputationGraph.init(); - INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params().dup(); - INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params().dup(); - INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params().dup(); - INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).params().dup(); + INDArray unfrozenSgdLayerParams = sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams().dup(); + INDArray frozenSgdLayerParams1 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams().dup(); + INDArray frozenSgdLayerParams2 = sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams().dup(); + INDArray frozenSgdOutputLayerParams = sgdComputationGraph.getLayer(frozenBranchOutput).getParams().dup(); for (int i = 0; i < 100; i++) { frozenComputationGraph.fit(randomData); @@ -379,10 +379,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest { sgdComputationGraph.fit(randomData); } - assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).params()); - assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).params()); - assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).params(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).params()); - assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).params(), sgdComputationGraph.getLayer(frozenBranchOutput).params()); + assertEquals(frozenComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams(), sgdComputationGraph.getLayer(frozenBranchUnfrozenLayer0).getParams()); + assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer1).getParams()); + assertEquals(frozenComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams(), sgdComputationGraph.getLayer(frozenBranchFrozenLayer2).getParams()); + assertEquals(frozenComputationGraph.getLayer(frozenBranchOutput).getParams(), sgdComputationGraph.getLayer(frozenBranchOutput).getParams()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java index 0d4f0d710..6e2132d92 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/OutputLayerTest.java @@ -68,9 +68,9 @@ public class OutputLayerTest extends BaseDL4JTest { INDArray params = Nd4j.create(1, numParams); OutputLayer l = (OutputLayer) conf.getFlattenedLayerConfigurations().get(0).instantiate(conf, Collections.singletonList(new ScoreIterationListener(1)), 0, params, true, params.dataType()); - params = l.params(); + params = l.getModelParams(); l.setParamsTable(params); - assertEquals(params, l.params()); + assertEquals(params, l.getModelParams()); } @Test @@ -217,8 +217,8 @@ public class OutputLayerTest extends BaseDL4JTest { //However: OutputLayer version has miniBatch*timeSeriesLength "examples" (after reshaping) //RnnOutputLayer has miniBatch examples //Hence: expect difference in scores by factor of timeSeriesLength - double score = mln.score() * timeSeriesLength; - double scoreRNN = mlnRnn.score(); + double score = mln.getScore() * timeSeriesLength; + double scoreRNN = mlnRnn.getScore(); assertFalse(Double.isNaN(score)); assertFalse(Double.isNaN(scoreRNN)); @@ -234,7 +234,7 @@ public class OutputLayerTest extends BaseDL4JTest { RnnOutputLayer rnnol = (RnnOutputLayer) mlnRnn.getOutputLayer(); //assertArrayEquals(rnnol.getInput().shape(),new int[]{miniBatchSize,layerSize,timeSeriesLength}); - //Input may be set by BaseLayer methods. Thus input may end up as reshaped 2d version instead of original 3d version. + //Input may be set by BaseLayerConfiguration methods. Thus input may end up as reshaped 2d version instead of original 3d version. //Not ideal, but everything else works. assertArrayEquals(rnnol.getLabels().shape(), new long[] {miniBatchSize, nOut, timeSeriesLength}); @@ -303,7 +303,7 @@ public class OutputLayerTest extends BaseDL4JTest { MultiLayerNetwork mln2 = new MultiLayerNetwork(conf2); mln2.init(); - mln2.setParams(mln.params()); + mln2.setParams(mln.getModelParams()); INDArray in = Nd4j.rand(miniBatchSize, nIn, timeSeriesLength); @@ -330,7 +330,7 @@ public class OutputLayerTest extends BaseDL4JTest { mln2.computeGradientAndScore(); assertEquals(mln.gradient().gradient(), mln2.gradient().gradient()); - assertEquals(mln.score(), mln2.score(), 1e-6); + assertEquals(mln.getScore(), mln2.getScore(), 1e-6); TestUtils.testModelSerialization(mln); } @@ -386,7 +386,7 @@ public class OutputLayerTest extends BaseDL4JTest { mln2.init(); - mln2.setParams(mln.params()); + mln2.setParams(mln.getModelParams()); INDArray in = Nd4j.rand(3, 3, 5, 5); @@ -407,7 +407,7 @@ public class OutputLayerTest extends BaseDL4JTest { mln.computeGradientAndScore(); mln2.computeGradientAndScore(); - assertEquals(mln.score(), mln2.score(), 1e-6); + assertEquals(mln.getScore(), mln2.getScore(), 1e-6); assertEquals(mln.gradient().gradient(), mln2.gradient().gradient()); //Also check computeScoreForExamples @@ -479,7 +479,7 @@ public class OutputLayerTest extends BaseDL4JTest { graph2.init(); - graph2.setParams(graph.params()); + graph2.setParams(graph.getModelParams()); INDArray in = Nd4j.rand(3, 3, 5, 5); @@ -500,7 +500,7 @@ public class OutputLayerTest extends BaseDL4JTest { graph.computeGradientAndScore(); graph2.computeGradientAndScore(); - assertEquals(graph.score(), graph2.score(), 1e-6); + assertEquals(graph.getScore(), graph2.getScore(), 1e-6); assertEquals(graph.gradient().gradient(), graph2.gradient().gradient()); //Also check computeScoreForExamples diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java index 6306c333b..f1e64f204 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/SeedTest.java @@ -59,13 +59,13 @@ public class SeedTest extends BaseDL4JTest { layer.fit(data.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - double score = layer.score(); - INDArray parameters = layer.params(); + double score = layer.getScore(); + INDArray parameters = layer.getParams(); layer.setParams(parameters); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - double score2 = layer.score(); - assertEquals(parameters, layer.params()); + double score2 = layer.getScore(); + assertEquals(parameters, layer.getParams()); assertEquals(score, score2, 1e-4); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java index 44ee236c8..6f24ff226 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvDataFormatTests.java @@ -845,9 +845,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { public static void testHelper(TestCase tc) { - tc.net2.params().assign(tc.net1.params()); - tc.net3.params().assign(tc.net1.params()); - tc.net4.params().assign(tc.net1.params()); + tc.net2.getModelParams().assign(tc.net1.getModelParams()); + tc.net3.getModelParams().assign(tc.net1.getModelParams()); + tc.net4.getModelParams().assign(tc.net1.getModelParams()); //Test forward pass: INDArray inNCHW = tc.inNCHW; @@ -909,9 +909,9 @@ public class ConvDataFormatTests extends BaseDL4JTest { tc.net3.fit(inNHWC, tc.labelsNHWC); tc.net4.fit(inNHWC, tc.labelsNHWC); - assertEquals(tc.net1.params(), tc.net2.params(), tc.msg); - assertEquals(tc.net1.params(), tc.net3.params(), tc.msg); - assertEquals(tc.net1.params(), tc.net4.params(), tc.msg); + assertEquals(tc.net1.getModelParams(), tc.net2.getModelParams(), tc.msg); + assertEquals(tc.net1.getModelParams(), tc.net3.getModelParams(), tc.msg); + assertEquals(tc.net1.getModelParams(), tc.net4.getModelParams(), tc.msg); //Test serialization MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java index f234d3b78..4b5458b15 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayerTest.java @@ -30,7 +30,6 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ConvolutionMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; @@ -38,7 +37,6 @@ import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.nn.weights.WeightInitNormal; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; @@ -450,10 +448,10 @@ public class ConvolutionLayerTest extends BaseDL4JTest { MultiLayerNetwork net = getCNNMLNConfig(true, false); - INDArray paramsOrig = net.params().dup(); + INDArray paramsOrig = net.getModelParams().dup(); net.setParams(paramsOrig); - INDArray params2 = net.params(); + INDArray params2 = net.getModelParams(); assertEquals(paramsOrig, params2); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java index 4ef8fab18..75e48861d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/TestCustomLayers.java @@ -154,7 +154,7 @@ public class TestCustomLayers extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertEquals(net2.params(), net.params()); + assertEquals(net2.getModelParams(), net.getModelParams()); INDArray testFeatures = Nd4j.rand(1, 10); INDArray testLabels = Nd4j.zeros(1, 10); @@ -207,7 +207,7 @@ public class TestCustomLayers extends BaseDL4JTest { ComputationGraph net2 = new ComputationGraph(conf2); net2.init(); - assertEquals(net2.params(), net.params()); + assertEquals(net2.getModelParams(), net.getModelParams()); INDArray testFeatures = Nd4j.rand(1, 10); INDArray testLabels = Nd4j.zeros(1, 10); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java index ea59a8091..4fafcde0b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomLayer.java @@ -56,7 +56,7 @@ public class CustomLayer extends FeedForwardLayer { boolean initializeParams, DataType networkDataType) { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java index 80c983589..350f24f4b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/custom/testclasses/CustomOutputLayer.java @@ -54,7 +54,7 @@ public class CustomOutputLayer extends BaseOutputLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); CustomOutputLayerImpl ret = new CustomOutputLayerImpl(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java index ba1129cef..01cc7f2dd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/dense/DenseTest.java @@ -72,7 +72,7 @@ public class DenseTest extends BaseDL4JTest { DataSet test = iter.next(); - assertEquals(model.params(), model2.params()); + assertEquals(model.getModelParams(), model2.getModelParams()); Evaluation eval = new Evaluation(); INDArray output = model.output(test.getFeatures()); @@ -99,7 +99,7 @@ public class DenseTest extends BaseDL4JTest { DataSet test = iter.next(); - assertEquals(model.params(), model2.params()); + assertEquals(model.getModelParams(), model2.getModelParams()); Evaluation eval = new Evaluation(); INDArray output = model.output(test.getFeatures()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java index 60c4e3b0d..742f38a2d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/feedforward/embedding/EmbeddingLayerTest.java @@ -169,7 +169,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.init(); net2.init(); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); @@ -216,7 +216,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.init(); net2.init(); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); @@ -262,7 +262,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.init(); net2.init(); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); @@ -287,7 +287,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.computeGradientAndScore(); net2.computeGradientAndScore(); - assertEquals(net2.score(), net.score(), 1e-6); + assertEquals(net2.getScore(), net.getScore(), 1e-6); Map gradient = net.gradient().gradientForVariable(); Map gradient2 = net2.gradient().gradientForVariable(); @@ -323,7 +323,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.init(); net2.init(); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); int batchSize = 3; INDArray inEmbedding = Nd4j.create(batchSize, 1); @@ -349,7 +349,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net2.computeGradientAndScore(); // System.out.println(net.score() + "\t" + net2.score()); - assertEquals(net2.score(), net.score(), 1e-6); + assertEquals(net2.getScore(), net.getScore(), 1e-6); Map gradient = net.gradient().gradientForVariable(); Map gradient2 = net2.gradient().gradientForVariable(); @@ -395,7 +395,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.init(); net2.init(); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); INDArray inEmbedding = Nd4j.create(batchSize, 1, timeSeriesLength); INDArray inOneHot = Nd4j.create(batchSize, nClassesIn, timeSeriesLength); @@ -422,7 +422,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net2.computeGradientAndScore(); // System.out.println(net.score() + "\t" + net2.score()); - assertEquals(net2.score(), net.score(), 1e-5); + assertEquals(net2.getScore(), net.getScore(), 1e-5); Map gradient = net.gradient().gradientForVariable(); Map gradient2 = net2.gradient().gradientForVariable(); @@ -484,7 +484,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); INDArray inEmbedding = Nd4j.zeros(nExamples, 1, timeSeriesLength); INDArray inDense = Nd4j.zeros(nExamples, numInputClasses, timeSeriesLength); @@ -523,7 +523,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net2.computeGradientAndScore(); // System.out.println(net.score() + "\t" + net2.score()); - assertEquals(net2.score(), net.score(), 1e-5); + assertEquals(net2.getScore(), net.getScore(), 1e-5); Map gradients = net.gradient().gradientForVariable(); Map gradients2 = net2.gradient().gradientForVariable(); @@ -640,7 +640,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); INDArray inEmbedding = Nd4j.zeros(inLabelDtype, inputRank == 2 ? new long[]{nExamples, timeSeriesLength} : new long[]{nExamples, 1, timeSeriesLength}); INDArray inDense = Nd4j.zeros(inLabelDtype, nExamples, numInputClasses, timeSeriesLength); @@ -678,7 +678,7 @@ public class EmbeddingLayerTest extends BaseDL4JTest { net.computeGradientAndScore(); net2.computeGradientAndScore(); - assertEquals(net2.score(), net.score(), 1e-5); + assertEquals(net2.getScore(), net.getScore(), 1e-5); Map gradients = net.gradient().gradientForVariable(); Map gradients2 = net2.gradient().gradientForVariable(); @@ -777,9 +777,9 @@ public class EmbeddingLayerTest extends BaseDL4JTest { MultiLayerNetwork net3 = new MultiLayerNetwork(conf3); net3.init(); - INDArray p1 = net.params(); - INDArray p2 = net2.params(); - INDArray p3 = net3.params(); + INDArray p1 = net.getModelParams(); + INDArray p2 = net2.getModelParams(); + INDArray p3 = net3.getModelParams(); boolean eq = p1.equalsWithEps(p2, 1e-4); String str = (seq ? "EmbeddingSequenceLayer" : "EmbeddingLayer") + " - " + wi; assertTrue(eq, str + " p1/p2 params not equal"); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java index 558041072..5ef6fb110 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/objdetect/TestYolo2OutputLayer.java @@ -514,7 +514,7 @@ public class TestYolo2OutputLayer extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(new ScoreIterationListener(100)); + net.addTrainingListeners(new ScoreIterationListener(100)); int nEpochs = 1000; DataSet ds = iter.next(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java index c0f6fa24c..a52716589 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/ocnn/OCNNOutputLayerTest.java @@ -79,13 +79,13 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { if (doLearningFirst) { //Run a number of iterations of learning network.setInput(arr); - network.setListeners(new ScoreIterationListener(1)); + network.addTrainingListeners(new ScoreIterationListener(1)); network.computeGradientAndScore(); - double scoreBefore = network.score(); + double scoreBefore = network.getScore(); for (int j = 0; j < 10; j++) network.fit(ds); network.computeGradientAndScore(); - double scoreAfter = network.score(); + double scoreAfter = network.getScore(); //Can't test in 'characteristic mode of operation' if not learning String msg = "testLayer() - score did not (sufficiently) decrease during learning - activationFn=" + "relu" + ", lossFn=" + "ocnn" + ", " + "sigmoid" @@ -147,7 +147,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { tmpFile.deleteOnExit(); MultiLayerNetwork multiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(tmpFile); - assertEquals(network.params(),multiLayerNetwork.params()); + assertEquals(network.getModelParams(),multiLayerNetwork.getModelParams()); assertEquals(network.numParams(),multiLayerNetwork.numParams()); } @@ -187,7 +187,7 @@ public class OCNNOutputLayerTest extends BaseDL4JTest { .build(); MultiLayerNetwork network = new MultiLayerNetwork(configuration); network.init(); - network.setListeners(new ScoreIterationListener(1)); + network.addTrainingListeners(new ScoreIterationListener(1)); return network; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java index 8e329077c..101a55edb 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/BidirectionalTest.java @@ -124,7 +124,7 @@ public class BidirectionalTest extends BaseDL4JTest { assertEquals(n1, n2); } - net2.setParams(net1.params()); //Assuming exact same layout here... + net2.setParams(net1.getModelParams()); //Assuming exact same layout here... INDArray in; if (rnnDataFormat == NCW){ @@ -154,7 +154,7 @@ public class BidirectionalTest extends BaseDL4JTest { net2.computeGradientAndScore(); //Ensure scores are equal: - assertEquals(net1.score(), net2.score(), 1e-6); + assertEquals(net1.getScore(), net2.getScore(), 1e-6); //Ensure gradients are equal: Gradient g1 = net1.gradient(); @@ -174,8 +174,8 @@ public class BidirectionalTest extends BaseDL4JTest { net1.fit(in, labels); net2.fit(in, labels); - INDArray p1 = net1.params(); - INDArray p2 = net2.params(); + INDArray p1 = net1.getModelParams(); + INDArray p2 = net2.getModelParams(); assertEquals(p1, p2); } } @@ -232,7 +232,7 @@ public class BidirectionalTest extends BaseDL4JTest { assertEquals(n1, n2); } - net2.setParams(net1.params()); //Assuming exact same layout here... + net2.setParams(net1.getModelParams()); //Assuming exact same layout here... INDArray in = Nd4j.rand(3, 10, 5); @@ -253,7 +253,7 @@ public class BidirectionalTest extends BaseDL4JTest { net2.computeGradientAndScore(); //Ensure scores are equal: - assertEquals(net1.score(), net2.score(), 1e-6); + assertEquals(net1.getScore(), net2.getScore(), 1e-6); //Ensure gradients are equal: Gradient g1 = net1.gradient(); @@ -273,8 +273,8 @@ public class BidirectionalTest extends BaseDL4JTest { net1.fit(new DataSet(in, labels)); net2.fit(new DataSet(in, labels)); - INDArray p1 = net1.params(); - INDArray p2 = net2.params(); + INDArray p1 = net1.getModelParams(); + INDArray p2 = net2.getModelParams(); assertEquals(p1, p2); } } @@ -340,7 +340,7 @@ public class BidirectionalTest extends BaseDL4JTest { net1.computeGradientAndScore(); net2.computeGradientAndScore(); - assertEquals(net1.score(), net2.score(), 1e-6); + assertEquals(net1.getScore(), net2.getScore(), 1e-6); assertEquals(net1.gradient().gradient(), net2.gradient().gradient()); } } @@ -403,7 +403,7 @@ public class BidirectionalTest extends BaseDL4JTest { net1.computeGradientAndScore(); net2.computeGradientAndScore(); - assertEquals(net1.score(), net2.score(), 1e-6); + assertEquals(net1.getScore(), net2.getScore(), 1e-6); assertEquals(net1.gradient().gradient(), net2.gradient().gradient()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java index 7d8dd8977..be04304b6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTMTest.java @@ -277,7 +277,7 @@ public class GravesBidirectionalLSTMTest extends BaseDL4JTest { final INDArray act1 = bidirectionalLSTM.activate(sig, false, LayerWorkspaceMgr.noWorkspaces()); - params = bidirectionalLSTM.params(); + params = bidirectionalLSTM.getModelParams(); bidirectionalLSTM.setParamsTable(params); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java index c6b315cb5..93a60f38c 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/RnnDataFormatTests.java @@ -285,9 +285,9 @@ public class RnnDataFormatTests extends BaseDL4JTest { public static void testHelper(TestCase tc) { - tc.net2.params().assign(tc.net1.params()); - tc.net3.params().assign(tc.net1.params()); - tc.net4.params().assign(tc.net1.params()); + tc.net2.getModelParams().assign(tc.net1.getModelParams()); + tc.net3.getModelParams().assign(tc.net1.getModelParams()); + tc.net4.getModelParams().assign(tc.net1.getModelParams()); INDArray inNCW = tc.inNCW; INDArray inNWC = tc.inNCW.permute(0, 2, 1).dup(); @@ -352,9 +352,9 @@ public class RnnDataFormatTests extends BaseDL4JTest { tc.net3.fit(inNWC, tc.labelsNWC); tc.net4.fit(inNWC, tc.labelsNWC); - assertEquals(tc.net1.params(), tc.net2.params(), tc.msg); - assertEquals(tc.net1.params(), tc.net3.params(), tc.msg); - assertEquals(tc.net1.params(), tc.net4.params(), tc.msg); + assertEquals(tc.net1.getModelParams(), tc.net2.getModelParams(), tc.msg); + assertEquals(tc.net1.getModelParams(), tc.net3.getModelParams(), tc.msg); + assertEquals(tc.net1.getModelParams(), tc.net4.getModelParams(), tc.msg); //Test serialization MultiLayerNetwork net1a = TestUtils.testModelSerialization(tc.net1); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java index e2b6bc359..d6e0369d4 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestRnnLayers.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.recurrent; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.dropout.TestDropout; import org.deeplearning4j.nn.conf.layers.GravesLSTM; @@ -173,8 +172,8 @@ public class TestRnnLayers extends BaseDL4JTest { MultiLayerNetwork netD2 = new MultiLayerNetwork(confD2); netD2.init(); - assertEquals(net.params(), netD.params(), s); - assertEquals(net.params(), netD2.params(), s); + assertEquals(net.getModelParams(), netD.getModelParams(), s); + assertEquals(net.getModelParams(), netD2.getModelParams(), s); INDArray f = Nd4j.rand(DataType.FLOAT, 3, 10, 10); @@ -193,7 +192,7 @@ public class TestRnnLayers extends BaseDL4JTest { INDArray l = TestUtils.randomOneHotTimeSeries(3, 10, 10, 12345); net.fit(f.dup(), l); netD.fit(f.dup(), l); - assertNotEquals(net.params(), netD.params(), s); + assertNotEquals(net.getModelParams(), netD.getModelParams(), s); netD2.fit(f.dup(), l); netD2.fit(f.dup(), l); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java index 5a31cf4df..ec8008379 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/TestTimeDistributed.java @@ -115,7 +115,7 @@ public class TestTimeDistributed extends BaseDL4JTest { net1.fit(ds); net2.fit(ds); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); MultiLayerNetwork net3 = TestUtils.testModelSerialization(net2); out2 = net2.output(in); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java index 60446d43f..93d9421c3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDense.java @@ -124,10 +124,10 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - net.params().assign(net2.params()); + net.getModelParams().assign(net2.getModelParams()); //Check params: - assertEquals(net2.params(), net.params()); + assertEquals(net2.getModelParams(), net.getModelParams()); Map params1 = net.getParamTable(); Map params2 = net2.getParamTable(); assertEquals(params2, params1); @@ -209,10 +209,10 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertEquals(net2.params(), net.params()); + assertEquals(net2.getModelParams(), net.getModelParams()); //Check params: - assertEquals(net2.params(), net.params()); + assertEquals(net2.getModelParams(), net.getModelParams()); Map params1 = net.getParamTable(); Map params2 = net2.getParamTable(); assertEquals(params2, params1); @@ -287,10 +287,10 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2); netStandard.init(); - netSD.params().assign(netStandard.params()); + netSD.getModelParams().assign(netStandard.getModelParams()); //Check params: - assertEquals(netStandard.params(), netSD.params()); + assertEquals(netStandard.getModelParams(), netSD.getModelParams()); assertEquals(netStandard.getParamTable(), netSD.getParamTable()); INDArray in = Nd4j.rand(minibatch, nIn); @@ -379,10 +379,10 @@ public class TestSameDiffDense extends BaseDL4JTest { MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2); netStandard.init(); - netSD.params().assign(netStandard.params()); + netSD.getModelParams().assign(netStandard.getModelParams()); //Check params: - assertEquals(netStandard.params(), netSD.params()); + assertEquals(netStandard.getModelParams(), netSD.getModelParams()); assertEquals(netStandard.getParamTable(), netSD.getParamTable()); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -398,7 +398,7 @@ public class TestSameDiffDense extends BaseDL4JTest { netStandard.fit(ds); String s = String.valueOf(i); assertEquals( netStandard.getFlattenedGradients(), netSD.getFlattenedGradients(), s); - assertEquals( netStandard.params(), netSD.params(), s); + assertEquals( netStandard.getModelParams(), netSD.getModelParams(), s); assertEquals( netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray(), s); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java index 5e67862ff..5fd371d13 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffDenseVertex.java @@ -100,10 +100,10 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { ComputationGraph netStandard = new ComputationGraph(conf2); netStandard.init(); - netSD.params().assign(netStandard.params()); + netSD.getModelParams().assign(netStandard.getModelParams()); //Check params: - assertEquals(netStandard.params(), netSD.params()); + assertEquals(netStandard.getModelParams(), netSD.getModelParams()); assertEquals(netStandard.getParamTable(), netSD.getParamTable()); INDArray in = Nd4j.rand(minibatch, nIn); @@ -160,7 +160,7 @@ public class TestSameDiffDenseVertex extends BaseDL4JTest { netStandard.fit(ds); assertEquals(netStandard.getParamTable(), netSD.getParamTable()); - assertEquals(netStandard.params(), netSD.params()); + assertEquals(netStandard.getModelParams(), netSD.getModelParams()); assertEquals(netStandard.getFlattenedGradients(), netSD.getFlattenedGradients()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java index 8da331f8e..1514a6709 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffLambda.java @@ -98,7 +98,7 @@ public class TestSameDiffLambda extends BaseDL4JTest { ComputationGraph std = new ComputationGraph(confStd); std.init(); - lambda.setParams(std.params()); + lambda.setParams(std.getModelParams()); INDArray in = Nd4j.rand(3, 5); INDArray labels = TestUtils.randomOneHot(3, 5); @@ -119,7 +119,7 @@ public class TestSameDiffLambda extends BaseDL4JTest { std.fit(ds); String s = String.valueOf(i); - assertEquals(std.params(), lambda.params(), s); + assertEquals(std.getModelParams(), lambda.getModelParams(), s); assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s); } @@ -182,7 +182,7 @@ public class TestSameDiffLambda extends BaseDL4JTest { ComputationGraph std = new ComputationGraph(confStd); std.init(); - lambda.setParams(std.params()); + lambda.setParams(std.getModelParams()); INDArray in1 = Nd4j.rand(3, 5); INDArray in2 = Nd4j.rand(3, 5); @@ -204,7 +204,7 @@ public class TestSameDiffLambda extends BaseDL4JTest { std.fit(mds); String s = String.valueOf(i); - assertEquals(std.params(), lambda.params(), s); + assertEquals(std.getModelParams(), lambda.getModelParams(), s); assertEquals(std.getFlattenedGradients(), lambda.getFlattenedGradients(), s); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java index 8ff1d6bc9..0a3d2f915 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/TestSameDiffOutput.java @@ -85,7 +85,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { netSD.fit(ds); netStd.fit(ds); - assertEquals(netStd.params(), netSD.params()); + assertEquals(netStd.getModelParams(), netSD.getModelParams()); assertEquals(netStd.getFlattenedGradients(), netSD.getFlattenedGradients()); } @@ -131,7 +131,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { MultiLayerNetwork netStd = new MultiLayerNetwork(confStd); netStd.init(); - netSD.params().assign(netStd.params()); + netSD.getModelParams().assign(netStd.getModelParams()); assertEquals(netStd.getParamTable(), netSD.getParamTable()); @@ -165,7 +165,7 @@ public class TestSameDiffOutput extends BaseDL4JTest { netSD.fit(ds); netStd.fit(ds); String s = String.valueOf(i); - assertEquals( netStd.params(), netSD.params(), s); + assertEquals( netStd.getModelParams(), netSD.getModelParams(), s); assertEquals( netStd.getFlattenedGradients(), netSD.getFlattenedGradients(),s ); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java index 639520492..b7c89e007 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/variational/TestVAE.java @@ -77,7 +77,7 @@ public class TestVAE extends BaseDL4JTest { net.init(); System.out.println("Exp num params: " + expNumParams); - assertEquals(expNumParams, net.getLayer(0).params().length()); + assertEquals(expNumParams, net.getLayer(0).getParams().length()); Map paramTable = net.getLayer(0).getParamTable(); int count = 0; for (INDArray arr : paramTable.values()) { diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java index 0b8b1877d..5ed2a9c2b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java @@ -79,7 +79,7 @@ public class CloseNetworkTests extends BaseDL4JTest { net.close(); - assertTrue(net.params().wasClosed()); + assertTrue(net.getModelParams().wasClosed()); if(train) { assertTrue(net.getGradientsViewArray().wasClosed()); Updater u = net.getUpdater(false); @@ -127,7 +127,7 @@ public class CloseNetworkTests extends BaseDL4JTest { net.close(); - assertTrue(net.params().wasClosed()); + assertTrue(net.getModelParams().wasClosed()); if(train) { assertTrue(net.getGradientsViewArray().wasClosed()); Updater u = net.getUpdater(false); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java index 09dfb45ea..052e1fa07 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/LargeNetTest.java @@ -57,7 +57,7 @@ public class LargeNetTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray params = net.params(); + INDArray params = net.getModelParams(); long paramsLength = params.length(); long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10; assertEquals(expParamsLength, paramsLength); @@ -91,7 +91,7 @@ public class LargeNetTest extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); net.init(); - INDArray params = net.params(); + INDArray params = net.getModelParams(); long paramsLength = params.length(); long expParamsLength = 10_000_000L * 300 + 300 * 10 + 10; assertEquals(expParamsLength, paramsLength); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java index f6ddd312c..69099f0a0 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestLrChanges.java @@ -76,7 +76,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.init(); net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); conf2.setIterationCount(conf.getIterationCount()); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); assertEquals(0.1, net.getLearningRate(0).doubleValue(), 0.0); net.setLearningRate(0, 0.5); //Set LR for layer 0 to 0.5 @@ -96,7 +96,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.fit(in, l); } - assertEquals(net.params(), net2.params()); + assertEquals(net.getModelParams(), net2.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray()); INDArray in1 = Nd4j.rand(10, 10); @@ -110,7 +110,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.setLabels(l1); net2.computeGradientAndScore(); - assertEquals(net.score(), net2.score(), 1e-8); + assertEquals(net.getScore(), net2.getScore(), 1e-8); //Now: Set *all* LRs to say 0.3... @@ -126,7 +126,7 @@ public class TestLrChanges extends BaseDL4JTest { net3.init(); net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); conf3.setIterationCount(conf.getIterationCount()); - net3.setParams(net.params().dup()); + net3.setParams(net.getModelParams().dup()); net.setLearningRate(0.3); @@ -139,7 +139,7 @@ public class TestLrChanges extends BaseDL4JTest { net3.fit(in, l); } - assertEquals(net.params(), net3.params()); + assertEquals(net.getModelParams(), net3.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray()); } @@ -206,7 +206,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.init(); net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); conf2.setIterationCount(conf.getIterationCount()); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5 @@ -224,7 +224,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.fit(in, l); } - assertEquals(net.params(), net2.params()); + assertEquals(net.getModelParams(), net2.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray()); } @@ -270,7 +270,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.init(); net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); conf2.setIterationCount(conf.getIterationCount()); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); assertEquals(0.1, net.getLearningRate("0").doubleValue(), 0.0); net.setLearningRate("0", 0.5); //Set LR for layer 0 to 0.5 @@ -290,7 +290,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.fit(new DataSet(in, l)); } - assertEquals(net.params(), net2.params()); + assertEquals(net.getModelParams(), net2.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray()); INDArray in1 = Nd4j.rand(10, 10); @@ -304,7 +304,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.setLabels(l1); net2.computeGradientAndScore(); - assertEquals(net.score(), net2.score(), 1e-8); + assertEquals(net.getScore(), net2.getScore(), 1e-8); //Now: Set *all* LRs to say 0.3... @@ -320,7 +320,7 @@ public class TestLrChanges extends BaseDL4JTest { net3.init(); net3.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); conf3.setIterationCount(conf.getIterationCount()); - net3.setParams(net.params().dup()); + net3.setParams(net.getModelParams().dup()); net.setLearningRate(0.3); @@ -333,7 +333,7 @@ public class TestLrChanges extends BaseDL4JTest { net3.fit(new DataSet(in, l)); } - assertEquals(net.params(), net3.params()); + assertEquals(net.getModelParams(), net3.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), net3.getUpdater().getStateViewArray()); } @@ -375,7 +375,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.init(); net2.getUpdater().getStateViewArray().assign(net.getUpdater().getStateViewArray()); conf2.setIterationCount(conf.getIterationCount()); - net2.setParams(net.params().dup()); + net2.setParams(net.getModelParams().dup()); net.setLearningRate(new ExponentialSchedule(ScheduleType.ITERATION, 0.5, 0.8 )); //Set LR for layer 0 to 0.5 @@ -393,7 +393,7 @@ public class TestLrChanges extends BaseDL4JTest { net2.fit(new DataSet(in, l)); } - assertEquals(net.params(), net2.params()); + assertEquals(net.getModelParams(), net2.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), net2.getUpdater().getStateViewArray()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java index fdfb16fcd..01278db4e 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/TestNetConversion.java @@ -77,14 +77,14 @@ public class TestNetConversion extends BaseDL4JTest { n.computeGradientAndScore(); cg.computeGradientAndScore(); - assertEquals(n.score(), cg.score(), 1e-6); + assertEquals(n.getScore(), cg.getScore(), 1e-6); assertEquals(n.gradient().gradient(), cg.gradient().gradient()); n.fit(in, labels); cg.fit(new INDArray[]{in}, new INDArray[]{labels}); - assertEquals(n.params(), cg.params()); + assertEquals(n.getModelParams(), cg.getModelParams()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java index 794b45411..904dd845b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/misc/WorkspaceTests.java @@ -476,7 +476,7 @@ public class WorkspaceTests extends BaseDL4JTest { final ComputationGraph computationGraph = new ComputationGraph(config); computationGraph.init(); - computationGraph.setListeners(new ScoreIterationListener(3)); + computationGraph.addTrainingListeners(new ScoreIterationListener(3)); WSTestDataSetIterator iterator = new WSTestDataSetIterator(); computationGraph.fit(iterator); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java index 27efa9149..c818f2281 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/BackPropMLPTest.java @@ -54,7 +54,7 @@ public class BackPropMLPTest extends BaseDL4JTest { public void testMLPTrivial() { //Simplest possible case: 1 hidden layer, 1 hidden neuron, batch size of 1. MultiLayerNetwork network = new MultiLayerNetwork(getIrisMLPSimpleConfig(new int[] {1}, Activation.SIGMOID)); - network.setListeners(new ScoreIterationListener(1)); + network.addTrainingListeners(new ScoreIterationListener(1)); network.init(); DataSetIterator iter = new IrisDataSetIterator(1, 10); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java index 4fb1c3fad..ac1626eda 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTest.java @@ -64,7 +64,7 @@ import org.deeplearning4j.nn.conf.distribution.NormalDistribution; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.AutoEncoder; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -184,13 +184,13 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork network3 = new MultiLayerNetwork(conf); network3.init(); - INDArray params = network3.params(); + INDArray params = network3.getModelParams(); INDArray weights = network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY).dup(); INDArray bias = network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY).dup(); network3.setParameters(params); assertEquals(weights, network3.getLayer(0).getParam(DefaultParamInitializer.WEIGHT_KEY)); assertEquals(bias, network3.getLayer(0).getParam(DefaultParamInitializer.BIAS_KEY)); - INDArray params4 = network3.params(); + INDArray params4 = network3.getModelParams(); assertEquals(params, params4); } @@ -211,7 +211,7 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - network.setListeners(new ScoreIterationListener(1)); + network.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -242,7 +242,7 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); - network.setListeners(new ScoreIterationListener(1)); + network.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -330,7 +330,7 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); - model.addListeners(new ScoreIterationListener(listenerFreq)); + model.addTrainingListeners(new ScoreIterationListener(listenerFreq)); log.info("Train model...."); int cnt = 0; @@ -503,7 +503,7 @@ public class MultiLayerTest extends BaseDL4JTest { assertEquals(layerNameList.get(0), net.getLayer(0).getLayerConfiguration().getLayerName()); assertEquals(layerNameList, net.getLayerNames()); - BaseLayer b = (BaseLayer) net.getLayer(layerNameList.get(2)).getLayerConfiguration(); + BaseLayerConfiguration b = (BaseLayerConfiguration) net.getLayer(layerNameList.get(2)).getLayerConfiguration(); assertEquals("softmax", b.getActivationFn().toString()); } @@ -535,7 +535,7 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork netNoReg = new MultiLayerNetwork(confNoReg); netNoReg.init(); - netNoReg.setParameters(net.params().dup()); + netNoReg.setParameters(net.getModelParams().dup()); //Score single example, and compare to scoreExamples: INDArray input = Nd4j.rand(3, nIn); @@ -703,7 +703,7 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); net.fit(iter.next()); - // TODO validate actual layer gradientView - issue getting var out of BaseLayer w/o adding MLN getter that gets confused with local gradient vars + // TODO validate actual layer gradientView - issue getting var out of BaseLayerConfiguration w/o adding MLN getter that gets confused with local gradient vars Gradient actualGradient = net.gradient; assertNotEquals(expectedGradient.getGradientFor("0_W"), actualGradient.getGradientFor("0_W")); @@ -716,13 +716,13 @@ public class MultiLayerTest extends BaseDL4JTest { net.setParam("0_b", Nd4j.ones(1, 5)); net.setParam("1_W", Nd4j.ones(5, 3)); net.setParam("1_b", Nd4j.ones(1, 3)); - INDArray actualParams = net.params(); + INDArray actualParams = net.getModelParams(); // Confirm params assertEquals(expectedGradient.gradient(), actualParams); net.update(expectedGradient); - actualParams = net.params(); + actualParams = net.getModelParams(); assertEquals(Nd4j.ones(1, 43).addi(1), actualParams); } @@ -762,7 +762,7 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork aePre = getAeModel(true, nIn, nOut); int actualNP = (int) aePre.numParams(); assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); - INDArray params = aePre.params(); + INDArray params = aePre.getModelParams(); assertEquals(params.length(), actualNP); // check num params Map paramTable = aePre.getParamTable(); assertTrue(paramTable.containsKey("0_vb")); // check vb exists for pretrain layer @@ -774,7 +774,7 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork aeNoPre = getAeModel(false, nIn, nOut); actualNP = (int) aeNoPre.numParams(); assertEquals(2 * (nIn * nOut + nOut) + nIn, actualNP); - params = aeNoPre.params(); + params = aeNoPre.getModelParams(); assertEquals(params.length(), actualNP); paramTable = aePre.getParamTable(); assertTrue(paramTable.containsKey("0_vb")); @@ -865,14 +865,14 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - BaseLayer bl0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration(); + BaseLayerConfiguration bl0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration(); assertEquals(0.1, TestUtils.getL1(bl0.getRegularizationBias()), 1e-6); assertEquals(0.2, TestUtils.getL2(bl0.getRegularizationBias()), 1e-6); INDArray features = Nd4j.rand(10, 10); INDArray labels = Nd4j.rand(10, 10); - net2.setParams(net1.params().dup()); + net2.setParams(net1.getModelParams().dup()); net1.setInput(features); net1.setLabels(labels); @@ -888,15 +888,15 @@ public class MultiLayerTest extends BaseDL4JTest { r = net2.calcRegularizationScore(true); assertEquals(0.0, r, 0.0); - double s1 = net1.score(); - double s2 = net2.score(); + double s1 = net1.getScore(); + double s2 = net2.getScore(); assertEquals(s1, s2, 1e-6); //Biases initialized to 0 -> should initially have same score for (int i = 0; i < 10; i++) { net1.fit(features, labels); } - net2.setParams(net1.params().dup()); + net2.setParams(net1.getModelParams().dup()); net1.computeGradientAndScore(); net2.computeGradientAndScore(); @@ -906,8 +906,8 @@ public class MultiLayerTest extends BaseDL4JTest { r = net2.calcRegularizationScore(true); assertTrue(r > 0.0); - s1 = net1.score(); - s2 = net2.score(); + s1 = net1.getScore(); + s2 = net2.getScore(); assertNotEquals(s1, s2, 1e-6); //Scores should differ due to bias l1/l2 @@ -1022,11 +1022,11 @@ public class MultiLayerTest extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); net2.init(); - assertNotEquals(net1.params(), net2.params()); + assertNotEquals(net1.getModelParams(), net2.getModelParams()); assertNotEquals(net1.getParamTable(), net2.getParamTable()); net1.setParamTable(net2.getParamTable()); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); assertEquals(net1.getParamTable(), net2.getParamTable()); } @@ -1412,7 +1412,7 @@ public class MultiLayerTest extends BaseDL4JTest { exp.add(MultiLayerNetwork.class); CheckModelsListener listener = new CheckModelsListener(); - net.setListeners(listener); + net.addTrainingListeners(listener); INDArray f = Nd4j.create(1, 10); INDArray l = Nd4j.create(1, 10); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java index 99c1c6077..29d7e7a6a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/MultiLayerTestRNN.java @@ -753,9 +753,9 @@ public class MultiLayerTestRNN extends BaseDL4JTest { DataSet ds = new DataSet(features, labels, maskArrayInput, maskArrayOutput); - INDArray initialParams = mln.params().dup(); + INDArray initialParams = mln.getModelParams().dup(); mln.fit(ds); - INDArray afterParams = mln.params(); + INDArray afterParams = mln.getModelParams(); assertNotEquals(initialParams, afterParams); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java index 1cca6ede8..d98cd58b2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestMasking.java @@ -172,7 +172,7 @@ public class TestMasking extends BaseDL4JTest { net.setLabels(labels); net.computeGradientAndScore(); - double score1 = net.score(); + double score1 = net.getScore(); INDArray grad1 = net.gradient().gradient(); //Now: change the label values for the masked steps. The @@ -187,7 +187,7 @@ public class TestMasking extends BaseDL4JTest { assertNotEquals(labels, newLabels); - double score2 = net.score(); + double score2 = net.getScore(); INDArray grad2 = net.gradient().gradient(); assertEquals(score1, score2, 1e-6); @@ -214,7 +214,7 @@ public class TestMasking extends BaseDL4JTest { graph.setLabels(labels); graph.computeGradientAndScore(); - double gScore1 = graph.score(); + double gScore1 = graph.getScore(); INDArray gGrad1 = graph.gradient().gradient(); graph.setLayerMaskArrays(null, new INDArray[] {labelMask}); @@ -222,7 +222,7 @@ public class TestMasking extends BaseDL4JTest { graph.setLabels(newLabels); graph.computeGradientAndScore(); - double gScore2 = graph.score(); + double gScore2 = graph.getScore(); INDArray gGrad2 = graph.gradient().gradient(); assertEquals(gScore1, gScore2, 1e-6); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java index 7b75bc97b..9c3c1407b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestSetGetParameters.java @@ -53,12 +53,12 @@ public class TestSetGetParameters extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray initParams = net.params().dup(); + INDArray initParams = net.getModelParams().dup(); Map initParams2 = net.getParamTable(); - net.setParams(net.params()); + net.setParams(net.getModelParams()); - INDArray initParamsAfter = net.params(); + INDArray initParamsAfter = net.getModelParams(); Map initParams2After = net.getParamTable(); for (String s : initParams2.keySet()) { @@ -71,7 +71,7 @@ public class TestSetGetParameters extends BaseDL4JTest { INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape()); net.setParams(randomParams.dup()); - assertEquals(net.params(), randomParams); + assertEquals(net.getModelParams(), randomParams); } @Test @@ -90,12 +90,12 @@ public class TestSetGetParameters extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray initParams = net.params().dup(); + INDArray initParams = net.getModelParams().dup(); Map initParams2 = net.getParamTable(); - net.setParams(net.params()); + net.setParams(net.getModelParams()); - INDArray initParamsAfter = net.params(); + INDArray initParamsAfter = net.getModelParams(); Map initParams2After = net.getParamTable(); for (String s : initParams2.keySet()) { @@ -108,7 +108,7 @@ public class TestSetGetParameters extends BaseDL4JTest { INDArray randomParams = Nd4j.rand(initParams.dataType(), initParams.shape()); net.setParams(randomParams.dup()); - assertEquals(net.params(), randomParams); + assertEquals(net.getModelParams(), randomParams); } @Test @@ -128,7 +128,7 @@ public class TestSetGetParameters extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - INDArray params = net.params(); + INDArray params = net.getModelParams(); MultiLayerNetwork net2 = new MultiLayerNetwork(conf); @@ -137,11 +137,11 @@ public class TestSetGetParameters extends BaseDL4JTest { MultiLayerNetwork net3 = new MultiLayerNetwork(conf); net3.init(params, false); - assertEquals(params, net2.params()); - assertEquals(params, net3.params()); + assertEquals(params, net2.getModelParams()); + assertEquals(params, net3.getModelParams()); - assertNotSame(params, net2.params()); //Different objects due to clone - assertSame(params, net3.params()); //Same object due to clone + assertNotSame(params, net2.getModelParams()); //Different objects due to clone + assertSame(params, net3.getModelParams()); //Same object due to clone Map paramsMap = net.getParamTable(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java index 7dc7480c6..6f3747e84 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/multilayer/TestVariableLengthTS.java @@ -103,14 +103,14 @@ public class TestVariableLengthTS extends BaseDL4JTest { net.setInput(in1); net.setLabels(labels1); net.computeGradientAndScore(); - double score1 = net.score(); + double score1 = net.getScore(); Gradient g1 = net.gradient(); net.setInput(in2); net.setLabels(labels2); net.setLayerMaskArrays(null, labelMask); net.computeGradientAndScore(); - double score2 = net.score(); + double score2 = net.getScore(); Gradient g2 = net.gradient(); //Scores and gradients should be identical for two cases (given mask array) @@ -134,7 +134,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { } net.setLabels(labels2); net.computeGradientAndScore(); - double score2a = net.score(); + double score2a = net.getScore(); Gradient g2a = net.gradient(); assertEquals(score2, score2a, 1e-6); for (String s : g2map.keySet()) { @@ -196,7 +196,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { net.setInput(in1); net.setLabels(labels1); net.computeGradientAndScore(); - double score1 = net.score(); + double score1 = net.getScore(); Gradient g1 = net.gradient(); Map map1 = g1.gradientForVariable(); for (String s : map1.keySet()) { @@ -207,7 +207,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { net.setLabels(labels2); net.setLayerMaskArrays(inputMask, null); net.computeGradientAndScore(); - double score2 = net.score(); + double score2 = net.getScore(); Gradient g2 = net.gradient(); net.setInput(in2); @@ -240,7 +240,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { net.setInput(in2); net.setLayerMaskArrays(inputMask, null); net.computeGradientAndScore(); - double score2a = net.score(); + double score2a = net.getScore(); Gradient g2a = net.gradient(); assertEquals(score2, score2a, 1e-12); for (String s : g2.gradientForVariable().keySet()) { @@ -327,7 +327,7 @@ public class TestVariableLengthTS extends BaseDL4JTest { mln.setLabels(labels); mln.computeGradientAndScore(); - double score = mln.score(); + double score = mln.getScore(); assertEquals(expScore, score, 0.1, msg); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java index 19360abb7..0539c6262 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/rl/TestMultiModelGradientApplication.java @@ -77,7 +77,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { MultiLayerNetwork net2GradUpd = new MultiLayerNetwork(conf.clone()); net2GradUpd.init(); - assertEquals(net1GradCalc.params(), net2GradUpd.params()); + assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams()); INDArray f = Nd4j.rand(minibatch, nIn); INDArray l = Nd4j.create(minibatch, nOut); @@ -109,17 +109,17 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { //Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op // on the original network - net2GradUpd.params().subi(g.gradient()); + net2GradUpd.getModelParams().subi(g.gradient()); net1GradCalc.fit(f, l); - assertEquals(net1GradCalc.params(), net2GradUpd.params()); + assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams()); //============================= if (!(u instanceof Sgd)) { net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray()); } - assertEquals(net1GradCalc.params(), net2GradUpd.params()); + assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams()); assertEquals(net1GradCalc.getUpdater().getStateViewArray(), net2GradUpd.getUpdater().getStateViewArray()); @@ -130,7 +130,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { for (int i = 0; i < 100; i++) { net1GradCalc.fit(f, l); net2GradUpd.fit(f, l); - assertEquals(net1GradCalc.params(), net2GradUpd.params()); + assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams()); } } } @@ -169,7 +169,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { ComputationGraph net2GradUpd = new ComputationGraph(conf.clone()); net2GradUpd.init(); - assertEquals(net1GradCalc.params(), net2GradUpd.params()); + assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams()); INDArray f = Nd4j.rand(minibatch, nIn); INDArray l = Nd4j.create(minibatch, nOut); @@ -201,16 +201,16 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { //Also: if we apply the gradient using a subi op, we should get the same final params as if we did a fit op // on the original network - net2GradUpd.params().subi(g.gradient()); + net2GradUpd.getModelParams().subi(g.gradient()); net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l}); - assertEquals(net1GradCalc.params(), net2GradUpd.params()); + assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams()); //============================= if (!(u instanceof Sgd)) { net2GradUpd.getUpdater().getStateViewArray().assign(net1GradCalc.getUpdater().getStateViewArray()); } - assertEquals(net1GradCalc.params(), net2GradUpd.params()); + assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams()); assertEquals(net1GradCalc.getUpdater().getStateViewArray(), net2GradUpd.getUpdater().getStateViewArray()); @@ -222,7 +222,7 @@ public class TestMultiModelGradientApplication extends BaseDL4JTest { for (int i = 0; i < 100; i++) { net1GradCalc.fit(new INDArray[] {f}, new INDArray[] {l}); net2GradUpd.fit(new INDArray[] {f}, new INDArray[] {l}); - assertEquals(net1GradCalc.params(), net2GradUpd.params()); + assertEquals(net1GradCalc.getModelParams(), net2GradUpd.getModelParams()); } } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java index 195ee2f6d..a8bfd8d97 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java @@ -25,7 +25,6 @@ import org.deeplearning4j.TestUtils; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.constraint.UnitNormConstraint; import org.deeplearning4j.nn.conf.distribution.ConstantDistribution; import org.deeplearning4j.nn.conf.distribution.NormalDistribution; @@ -94,7 +93,7 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { ComputationGraph modelToFineTune = new ComputationGraph(expectedConf); modelToFineTune.init(); - modelToFineTune.setParams(expectedModel.params()); + modelToFineTune.setParams(expectedModel.getModelParams()); //model after applying changes with transfer learning ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune) @@ -108,8 +107,8 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { //Check params after fit modelNow.fit(randomData); expectedModel.fit(randomData); - assertEquals(modelNow.score(), expectedModel.score(), 1e-8); - assertEquals(modelNow.params(), expectedModel.params()); + assertEquals(modelNow.getScore(), expectedModel.getScore(), 1e-8); + assertEquals(modelNow.getModelParams(), expectedModel.getModelParams()); } @Test @@ -139,9 +138,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { //.setOutputs("layer3") .build(); - BaseLayer bl0 = ((BaseLayer) modelNow.getLayer("layer0").getLayerConfiguration()); - BaseLayer bl1 = ((BaseLayer) modelNow.getLayer("layer1").getLayerConfiguration()); - BaseLayer bl3 = ((BaseLayer) modelNow.getLayer("layer3").getLayerConfiguration()); + BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration()); + BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration()); + BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration()); assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1))); assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); @@ -161,22 +160,22 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { modelExpectedArch.init(); //modelNow should have the same architecture as modelExpectedArch - assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), - modelNow.getLayer("layer0").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), - modelNow.getLayer("layer1").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), - modelNow.getLayer("layer2").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), - modelNow.getLayer("layer3").params().shape()); + assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer0").getParams().shape(), + modelNow.getLayer("layer0").getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer1").getParams().shape(), + modelNow.getLayer("layer1").getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer2").getParams().shape(), + modelNow.getLayer("layer2").getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer3").getParams().shape(), + modelNow.getLayer("layer3").getParams().shape()); - modelNow.setParams(modelExpectedArch.params()); + modelNow.setParams(modelExpectedArch.getModelParams()); //fit should give the same results modelExpectedArch.fit(randomData); modelNow.fit(randomData); - assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8); - assertEquals(modelExpectedArch.params(), modelNow.params()); + assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 1e-8); + assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams()); } @Test @@ -227,22 +226,22 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { modelExpectedArch.init(); //modelNow should have the same architecture as modelExpectedArch - assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), - modelNow.getLayer("layer0").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), - modelNow.getLayer("layer1").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), - modelNow.getLayer("layer2").params().shape()); - assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), - modelNow.getLayer("layer3").params().shape()); + assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer0").getParams().shape(), + modelNow.getLayer("layer0").getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer1").getParams().shape(), + modelNow.getLayer("layer1").getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer2").getParams().shape(), + modelNow.getLayer("layer2").getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer("layer3").getParams().shape(), + modelNow.getLayer("layer3").getParams().shape()); - modelNow.setParams(modelExpectedArch.params()); + modelNow.setParams(modelExpectedArch.getModelParams()); //fit should give the same results modelExpectedArch.fit(randomData); modelNow.fit(randomData); - assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8); - assertEquals(modelExpectedArch.params(), modelNow.params()); + assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 1e-8); + assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams()); } @Test @@ -385,14 +384,14 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { assertEquals(modelExpectedArch.getComputationGraphConfiguration().toJson(), modelNow.getComputationGraphConfiguration().toJson()); - modelNow.setParams(modelExpectedArch.params()); + modelNow.setParams(modelExpectedArch.getModelParams()); int i = 0; while (i < 5) { modelExpectedArch.fit(randomData); modelNow.fit(randomData); i++; } - assertEquals(modelExpectedArch.params(), modelNow.params()); + assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java index ba201c62a..cee6e2f90 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningComplex.java @@ -26,10 +26,9 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -99,7 +98,7 @@ public class TransferLearningComplex extends BaseDL4JTest { } //Also check config: - BaseLayer bl = ((BaseLayer) l.getLayerConfiguration()); + BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration()); assertEquals(new Adam(2e-2), bl.getIUpdater()); assertEquals(Activation.LEAKYRELU.getActivationFunction(), bl.getActivationFn()); } @@ -154,8 +153,8 @@ public class TransferLearningComplex extends BaseDL4JTest { .setOutputs("outRight").build(); ComputationGraph modelOther = new ComputationGraph(otherConf); modelOther.init(); - modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").params()); - modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").params()); + modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").getParams()); + modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").getParams()); modelToTune.getVertex("denseCentre0").setLayerAsFrozen(); ComputationGraph modelNow = @@ -179,11 +178,11 @@ public class TransferLearningComplex extends BaseDL4JTest { assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0")); - assertEquals(modelOther.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params()); - assertEquals(modelOther.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params()); + assertEquals(modelOther.getLayer("denseRight0").getParams(), modelNow.getLayer("denseRight0").getParams()); + assertEquals(modelOther.getLayer("denseRight0").getParams(), modelToTune.getLayer("denseRight0").getParams()); - assertEquals(modelOther.getLayer("outRight").params(), modelNow.getLayer("outRight").params()); - assertEquals(modelOther.getLayer("outRight").params(), modelToTune.getLayer("outRight").params()); + assertEquals(modelOther.getLayer("outRight").getParams(), modelNow.getLayer("outRight").getParams()); + assertEquals(modelOther.getLayer("outRight").getParams(), modelToTune.getLayer("outRight").getParams()); n++; } @@ -237,11 +236,11 @@ public class TransferLearningComplex extends BaseDL4JTest { assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0")); - assertEquals(modelToTune.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params()); + assertEquals(modelToTune.getLayer("denseRight0").getParams(), modelNow.getLayer("denseRight0").getParams()); - assertEquals(modelToTune.getLayer("outRight").params(), modelNow.getLayer("outRight").params()); + assertEquals(modelToTune.getLayer("outRight").getParams(), modelNow.getLayer("outRight").getParams()); - assertEquals(modelToTune.getLayer("outCentre").params(), modelNow.getLayer("outCentre").params()); + assertEquals(modelToTune.getLayer("outCentre").getParams(), modelNow.getLayer("outCentre").getParams()); n++; } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java index f606e6402..48963619b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelperTest.java @@ -178,25 +178,25 @@ public class TransferLearningHelperTest extends BaseDL4JTest { TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2"); MultiDataSet featurizedDataSet = helper.featurize(origData); - assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params()); + assertEquals(modelIdentical.getLayer("denseRight0").getParams(), modelToTune.getLayer("denseRight0").getParams()); modelIdentical.fit(origData); helper.fitFeaturized(featurizedDataSet); - assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params()); - assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params()); - assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params()); - assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params()); - assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params()); + assertEquals(modelIdentical.getLayer("denseCentre0").getParams(), modelToTune.getLayer("denseCentre0").getParams()); + assertEquals(modelIdentical.getLayer("denseCentre1").getParams(), modelToTune.getLayer("denseCentre1").getParams()); + assertEquals(modelIdentical.getLayer("denseCentre2").getParams(), modelToTune.getLayer("denseCentre2").getParams()); + assertEquals(modelIdentical.getLayer("denseCentre3").getParams(), modelToTune.getLayer("denseCentre3").getParams()); + assertEquals(modelIdentical.getLayer("outCentre").getParams(), modelToTune.getLayer("outCentre").getParams()); assertEquals(modelIdentical.getLayer("denseRight").getNetConfiguration().toJson(), modelToTune.getLayer("denseRight").getNetConfiguration().toJson()); - assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params()); + assertEquals(modelIdentical.getLayer("denseRight").getParams(), modelToTune.getLayer("denseRight").getParams()); assertEquals(modelIdentical.getLayer("denseRight0").getNetConfiguration().toJson(), modelToTune.getLayer("denseRight0").getNetConfiguration().toJson()); //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params()); - assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params()); - assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params()); - assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params()); - assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params()); + assertEquals(modelIdentical.getLayer("denseRight1").getParams(), modelToTune.getLayer("denseRight1").getParams()); + assertEquals(modelIdentical.getLayer("outRight").getParams(), modelToTune.getLayer("outRight").getParams()); + assertEquals(modelIdentical.getLayer("denseLeft0").getParams(), modelToTune.getLayer("denseLeft0").getParams()); + assertEquals(modelIdentical.getLayer("outLeft").getParams(), modelToTune.getLayer("outLeft").getParams()); // log.info(modelIdentical.summary()); // log.info(helper.unfrozenGraph().summary()); @@ -230,7 +230,7 @@ public class TransferLearningHelperTest extends BaseDL4JTest { TransferLearningHelper helper = new TransferLearningHelper(modelToFineTune, 1); INDArray paramsLastTwoLayers = - Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params()); + Nd4j.hstack(modelToFineTune.getLayer(2).getParams(), modelToFineTune.getLayer(3).getParams()); MultiLayerNetwork notFrozen = new MultiLayerNetwork( (NeuralNetConfiguration) overallConf.clone().list() .layer(0, new Builder().nIn(2).nOut(3).build()) @@ -248,9 +248,9 @@ public class TransferLearningHelperTest extends BaseDL4JTest { modelNow.fit(randomData); } - INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), - notFrozen.params()); - INDArray act = modelNow.params(); + INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), modelToFineTune.getLayer(1).getParams(), + notFrozen.getModelParams()); + INDArray act = modelNow.getModelParams(); assertEquals(expected, act); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java index f33c48738..88e8d5d01 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java @@ -91,7 +91,7 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .build(); for (org.deeplearning4j.nn.api.Layer l : modelNow.getLayers()) { - BaseLayer bl = ((BaseLayer) l.getLayerConfiguration()); + BaseLayerConfiguration bl = ((BaseLayerConfiguration) l.getLayerConfiguration()); assertEquals(new RmsProp(0.5), bl.getIUpdater()); } @@ -107,9 +107,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .build()) .build()); expectedModel.init(); - expectedModel.setParams(modelToFineTune.params().dup()); + expectedModel.setParams(modelToFineTune.getModelParams().dup()); - assertEquals(expectedModel.params(), modelNow.params()); + assertEquals(expectedModel.getModelParams(), modelNow.getModelParams()); //Check json NeuralNetConfiguration expectedConf = expectedModel.getNetConfiguration(); @@ -119,9 +119,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest { modelNow.fit(randomData); expectedModel.fit(randomData); - assertEquals(modelNow.score(), expectedModel.score(), 1e-6); - INDArray pExp = expectedModel.params(); - INDArray pNow = modelNow.params(); + assertEquals(modelNow.getScore(), expectedModel.getScore(), 1e-6); + INDArray pExp = expectedModel.getModelParams(); + INDArray pNow = modelNow.getModelParams(); assertEquals(pExp, pNow); } @@ -160,9 +160,9 @@ public class TransferLearningMLNTest extends BaseDL4JTest { //Will fail - expected because of dist and weight init changes //assertEquals(modelExpectedArch.getConfiguration().toJson(), modelNow.getConfiguration().toJson()); - BaseLayer bl0 = ((BaseLayer) modelNow.getNetConfiguration().getConf(0).getLayer()); - BaseLayer bl1 = ((BaseLayer) modelNow.getNetConfiguration().getConf(1).getLayer()); - BaseLayer bl3 = ((BaseLayer) modelNow.getNetConfiguration().getConf(3).getLayer()); + BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer()); + BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer()); + BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer()); assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class); try { assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), @@ -173,18 +173,18 @@ public class TransferLearningMLNTest extends BaseDL4JTest { assertEquals(bl3.getWeightInitFn(), new WeightInitXavier()); //modelNow should have the same architecture as modelExpectedArch - assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); + assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(1).getParams().shape(), modelNow.getLayer(1).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape()); - modelNow.setParams(modelExpectedArch.params()); + modelNow.setParams(modelExpectedArch.getModelParams()); //fit should give the same results modelExpectedArch.fit(randomData); modelNow.fit(randomData); - assertEquals(modelExpectedArch.score(), modelNow.score(), 0.000001); - assertEquals(modelExpectedArch.params(), modelNow.params()); + assertEquals(modelExpectedArch.getScore(), modelNow.getScore(), 0.000001); + assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams()); } @@ -227,20 +227,20 @@ public class TransferLearningMLNTest extends BaseDL4JTest { modelExpectedArch.init(); //modelNow should have the same architecture as modelExpectedArch - assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); + assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(1).getParams().shape(), modelNow.getLayer(1).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape()); - modelNow.setParams(modelExpectedArch.params()); + modelNow.setParams(modelExpectedArch.getModelParams()); //fit should give the same results modelExpectedArch.fit(randomData); modelNow.fit(randomData); - double scoreExpected = modelExpectedArch.score(); - double scoreActual = modelNow.score(); + double scoreExpected = modelExpectedArch.getScore(); + double scoreActual = modelNow.getScore(); assertEquals(scoreExpected, scoreActual, 1e-4); - assertEquals(modelExpectedArch.params(), modelNow.params()); + assertEquals(modelExpectedArch.getModelParams(), modelNow.getModelParams()); } @Test @@ -370,14 +370,14 @@ public class TransferLearningMLNTest extends BaseDL4JTest { assertEquals(modelExpectedArch.getNetConfiguration().getConf(5).toJson(), modelNow.getNetConfiguration().getConf(5).toJson()); - assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); + assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape()); //subsampling has no params //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(2).params().shape(), modelNow.getLayer(2).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(3).params().shape(), modelNow.getLayer(3).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(4).params().shape(), modelNow.getLayer(4).params().shape()); - assertArrayEquals(modelExpectedArch.getLayer(5).params().shape(), modelNow.getLayer(5).params().shape()); + assertArrayEquals(modelExpectedArch.getLayer(2).getParams().shape(), modelNow.getLayer(2).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(3).getParams().shape(), modelNow.getLayer(3).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(4).getParams().shape(), modelNow.getLayer(4).getParams().shape()); + assertArrayEquals(modelExpectedArch.getLayer(5).getParams().shape(), modelNow.getLayer(5).getParams().shape()); } @@ -449,23 +449,23 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .inputType(InputType.convolutionalFlat(12, 12, 20)).build()); notFrozen.init(); - assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); + assertArrayEquals(modelToFineTune.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape()); //subsampling has no params //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); - assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape()); - modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params()); + assertArrayEquals(notFrozen.getLayer(0).getParams().shape(), modelNow.getLayer(2).getParams().shape()); + modelNow.getLayer(2).setParams(notFrozen.getLayer(0).getParams()); //subsampling has no params //assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape()); - assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape()); - modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params()); - assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape()); - modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params()); - assertArrayEquals(notFrozen.getLayer(4).params().shape(), modelNow.getLayer(6).params().shape()); - modelNow.getLayer(6).setParams(notFrozen.getLayer(4).params()); - assertArrayEquals(notFrozen.getLayer(5).params().shape(), modelNow.getLayer(7).params().shape()); - modelNow.getLayer(7).setParams(notFrozen.getLayer(5).params()); - assertArrayEquals(notFrozen.getLayer(6).params().shape(), modelNow.getLayer(8).params().shape()); - modelNow.getLayer(8).setParams(notFrozen.getLayer(6).params()); + assertArrayEquals(notFrozen.getLayer(2).getParams().shape(), modelNow.getLayer(4).getParams().shape()); + modelNow.getLayer(4).setParams(notFrozen.getLayer(2).getParams()); + assertArrayEquals(notFrozen.getLayer(3).getParams().shape(), modelNow.getLayer(5).getParams().shape()); + modelNow.getLayer(5).setParams(notFrozen.getLayer(3).getParams()); + assertArrayEquals(notFrozen.getLayer(4).getParams().shape(), modelNow.getLayer(6).getParams().shape()); + modelNow.getLayer(6).setParams(notFrozen.getLayer(4).getParams()); + assertArrayEquals(notFrozen.getLayer(5).getParams().shape(), modelNow.getLayer(7).getParams().shape()); + modelNow.getLayer(7).setParams(notFrozen.getLayer(5).getParams()); + assertArrayEquals(notFrozen.getLayer(6).getParams().shape(), modelNow.getLayer(8).getParams().shape()); + modelNow.getLayer(8).setParams(notFrozen.getLayer(6).getParams()); int i = 0; while (i < 3) { @@ -474,8 +474,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest { i++; } - INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params()); - assertEquals(expectedParams, modelNow.params()); + INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), notFrozen.getModelParams()); + assertEquals(expectedParams, modelNow.getModelParams()); } @@ -503,13 +503,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest { //Check original net isn't modified: - BaseLayer l0 = (BaseLayer) net.getLayer(0).getLayerConfiguration(); + BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); assertEquals(new Adam(1e-4), l0.getIUpdater()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(0.1, TestUtils.getL1(l0), 1e-6); - BaseLayer l1 = (BaseLayer) net.getLayer(1).getLayerConfiguration(); + BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); assertEquals(new Adam(1e-4), l1.getIUpdater()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); @@ -518,13 +518,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest { assertEquals(BackpropType.Standard, conf.getBackpropType()); //Check new net has only the appropriate things modified (i.e., LR) - l0 = (BaseLayer) net2.getLayer(0).getLayerConfiguration(); + l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration(); assertEquals(new Adam(2e-2), l0.getIUpdater()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); assertEquals(0.1, TestUtils.getL1(l0), 1e-6); - l1 = (BaseLayer) net2.getLayer(1).getLayerConfiguration(); + l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration(); assertEquals(new Adam(2e-2), l1.getIUpdater()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); @@ -586,17 +586,17 @@ public class TransferLearningMLNTest extends BaseDL4JTest { .build()); notFrozen.init(); - assertArrayEquals(modelToFineTune.getLayer(0).params().shape(), modelNow.getLayer(0).params().shape()); + assertArrayEquals(modelToFineTune.getLayer(0).getParams().shape(), modelNow.getLayer(0).getParams().shape()); //subsampling has no params //assertArrayEquals(modelExpectedArch.getLayer(1).params().shape(), modelNow.getLayer(1).params().shape()); - assertArrayEquals(notFrozen.getLayer(0).params().shape(), modelNow.getLayer(2).params().shape()); - modelNow.getLayer(2).setParams(notFrozen.getLayer(0).params()); - assertArrayEquals(notFrozen.getLayer(1).params().shape(), modelNow.getLayer(3).params().shape()); - modelNow.getLayer(3).setParams(notFrozen.getLayer(1).params()); - assertArrayEquals(notFrozen.getLayer(2).params().shape(), modelNow.getLayer(4).params().shape()); - modelNow.getLayer(4).setParams(notFrozen.getLayer(2).params()); - assertArrayEquals(notFrozen.getLayer(3).params().shape(), modelNow.getLayer(5).params().shape()); - modelNow.getLayer(5).setParams(notFrozen.getLayer(3).params()); + assertArrayEquals(notFrozen.getLayer(0).getParams().shape(), modelNow.getLayer(2).getParams().shape()); + modelNow.getLayer(2).setParams(notFrozen.getLayer(0).getParams()); + assertArrayEquals(notFrozen.getLayer(1).getParams().shape(), modelNow.getLayer(3).getParams().shape()); + modelNow.getLayer(3).setParams(notFrozen.getLayer(1).getParams()); + assertArrayEquals(notFrozen.getLayer(2).getParams().shape(), modelNow.getLayer(4).getParams().shape()); + modelNow.getLayer(4).setParams(notFrozen.getLayer(2).getParams()); + assertArrayEquals(notFrozen.getLayer(3).getParams().shape(), modelNow.getLayer(5).getParams().shape()); + modelNow.getLayer(5).setParams(notFrozen.getLayer(3).getParams()); int i = 0; while (i < 3) { @@ -605,8 +605,8 @@ public class TransferLearningMLNTest extends BaseDL4JTest { i++; } - INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).params(), notFrozen.params()); - assertEquals(expectedParams, modelNow.params()); + INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer(0).getParams(), notFrozen.getModelParams()); + assertEquals(expectedParams, modelNow.getModelParams()); } @Test diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java index cf73bb012..f92e34bf2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java @@ -99,7 +99,7 @@ public class TestUpdaters extends BaseDL4JTest { BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); - int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); + int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); INDArray updaterState = Nd4j.create(1, updaterStateSize); updater.setStateViewArray(layer, updaterState, true); @@ -144,7 +144,7 @@ public class TestUpdaters extends BaseDL4JTest { msdx.put(key, msdxTmp); count++; } - assertEquals(rho, ((AdaDelta)layer.layerConf().getIUpdater()).getRho(), 1e-4); + assertEquals(rho, ((AdaDelta)layer.getTypedLayerConfiguration().getIUpdater()).getRho(), 1e-4); } assertEquals(4, count); @@ -165,7 +165,7 @@ public class TestUpdaters extends BaseDL4JTest { BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); - int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); + int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); INDArray updaterState = Nd4j.create(1, updaterStateSize); updater.setStateViewArray(layer, updaterState, true); @@ -185,7 +185,7 @@ public class TestUpdaters extends BaseDL4JTest { assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); count++; } - assertEquals(lr, ((AdaGrad)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); + assertEquals(lr, ((AdaGrad)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4); assertEquals(2, count); } @@ -209,7 +209,7 @@ public class TestUpdaters extends BaseDL4JTest { BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); - int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); + int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); INDArray updaterState = Nd4j.create(1, updaterStateSize); updater.setStateViewArray(layer, updaterState, true); @@ -245,8 +245,8 @@ public class TestUpdaters extends BaseDL4JTest { count++; } - assertEquals(beta1, ((Adam)layer.layerConf().getIUpdater()).getBeta1(), 1e-4); - assertEquals(beta2, ((Adam)layer.layerConf().getIUpdater()).getBeta2(), 1e-4); + assertEquals(beta1, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4); + assertEquals(beta2, ((Adam)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4); assertEquals(2, count); } @@ -273,7 +273,7 @@ public class TestUpdaters extends BaseDL4JTest { layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); - int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); + int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); INDArray updaterState = Nd4j.create(1, updaterStateSize); updater.setStateViewArray(layer, updaterState, true); @@ -362,7 +362,7 @@ public class TestUpdaters extends BaseDL4JTest { BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); - int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); + int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); INDArray updaterState = Nd4j.create(1, updaterStateSize); updater.setStateViewArray(layer, updaterState, true); @@ -398,8 +398,8 @@ public class TestUpdaters extends BaseDL4JTest { count++; } - assertEquals(beta1, ((AdaMax)layer.layerConf().getIUpdater()).getBeta1(), 1e-4); - assertEquals(beta2, ((AdaMax)layer.layerConf().getIUpdater()).getBeta2(), 1e-4); + assertEquals(beta1, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta1(), 1e-4); + assertEquals(beta2, ((AdaMax)layer.getTypedLayerConfiguration().getIUpdater()).getBeta2(), 1e-4); assertEquals(2, count); } @@ -418,7 +418,7 @@ public class TestUpdaters extends BaseDL4JTest { BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); - int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); + int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); INDArray updaterState = Nd4j.create(1, updaterStateSize); updater.setStateViewArray(layer, updaterState, true); @@ -443,7 +443,7 @@ public class TestUpdaters extends BaseDL4JTest { count++; } - assertEquals(mu, ((Nesterovs)layer.layerConf().getIUpdater()).getMomentum(), 1e-4); + assertEquals(mu, ((Nesterovs)layer.getTypedLayerConfiguration().getIUpdater()).getMomentum(), 1e-4); assertEquals(2, count); } @@ -465,7 +465,7 @@ public class TestUpdaters extends BaseDL4JTest { BaseLayer layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); Updater updater = UpdaterCreator.getUpdater(layer); - int updaterStateSize = (int) layer.layerConf().getIUpdater().stateSize(numParams); + int updaterStateSize = (int) layer.getTypedLayerConfiguration().getIUpdater().stateSize(numParams); INDArray updaterState = Nd4j.create(1, updaterStateSize); updater.setStateViewArray(layer, updaterState, true); @@ -495,7 +495,7 @@ public class TestUpdaters extends BaseDL4JTest { assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); lastG.put(key, lastGTmp); } - assertEquals(rmsDecay, ((RmsProp)layer.layerConf().getIUpdater()).getRmsDecay(), 1e-4); + assertEquals(rmsDecay, ((RmsProp)layer.getTypedLayerConfiguration().getIUpdater()).getRmsDecay(), 1e-4); } @Test @@ -527,7 +527,7 @@ public class TestUpdaters extends BaseDL4JTest { gradExpected = val.mul(lr); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); } - assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); + assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4); } @@ -769,7 +769,7 @@ public class TestUpdaters extends BaseDL4JTest { gradExpected = val.mul(lr); assertEquals(gradExpected, gradient.getGradientFor(entry.getKey())); } - assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); + assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4); //Test with pretrain == false @@ -797,7 +797,7 @@ public class TestUpdaters extends BaseDL4JTest { layer = (BaseLayer) conf.getFirstLayer().instantiate(conf, null, 0, params, true, params.dataType()); layer.setBackpropGradientsViewArray(gradients); updater = UpdaterCreator.getUpdater(layer); - assertEquals(lr, ((Sgd)layer.layerConf().getIUpdater()).getLearningRate(), 1e-4); + assertEquals(lr, ((Sgd)layer.getTypedLayerConfiguration().getIUpdater()).getLearningRate(), 1e-4); } @Test @@ -858,11 +858,11 @@ public class TestUpdaters extends BaseDL4JTest { //Check first updater block: UpdaterBlock ub0 = blocks.get(0); assertEquals(3, ub0.getLayersAndVariablesInBlock().size()); - assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); + assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(0).getParamName()); - assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName()); + assertEquals("l0", ub0.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub0.getLayersAndVariablesInBlock().get(1).getParamName()); - assertEquals("l1", ub0.getLayersAndVariablesInBlock().get(2).getLayer().getConfig().getLayerName()); + assertEquals("l1", ub0.getLayersAndVariablesInBlock().get(2).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub0.getLayersAndVariablesInBlock().get(2).getParamName()); int nParams0 = 10 * 10 + 10 + 10 * 10; @@ -875,7 +875,7 @@ public class TestUpdaters extends BaseDL4JTest { //Check second updater block: UpdaterBlock ub1 = blocks.get(1); assertEquals(1, ub1.getLayersAndVariablesInBlock().size()); - assertEquals("l1", ub1.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); + assertEquals("l1", ub1.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub1.getLayersAndVariablesInBlock().get(0).getParamName()); int nParams1 = 10; @@ -888,9 +888,9 @@ public class TestUpdaters extends BaseDL4JTest { //Check third updater block: UpdaterBlock ub2 = blocks.get(2); assertEquals(2, ub2.getLayersAndVariablesInBlock().size()); - assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); + assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub2.getLayersAndVariablesInBlock().get(0).getParamName()); - assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName()); + assertEquals("l2", ub2.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub2.getLayersAndVariablesInBlock().get(1).getParamName()); int nParams2 = 10 * 10 + 10; @@ -903,9 +903,9 @@ public class TestUpdaters extends BaseDL4JTest { //Check fourth updater block: UpdaterBlock ub3 = blocks.get(3); assertEquals(2, ub3.getLayersAndVariablesInBlock().size()); - assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); + assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub3.getLayersAndVariablesInBlock().get(0).getParamName()); - assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName()); + assertEquals("l3", ub3.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub3.getLayersAndVariablesInBlock().get(1).getParamName()); int nParams3 = 10 * 10 + 10; @@ -918,9 +918,9 @@ public class TestUpdaters extends BaseDL4JTest { //Check fifth updater black UpdaterBlock ub4 = blocks.get(4); assertEquals(2, ub4.getLayersAndVariablesInBlock().size()); - assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(0).getLayer().getConfig().getLayerName()); + assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(0).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.WEIGHT_KEY, ub4.getLayersAndVariablesInBlock().get(0).getParamName()); - assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(1).getLayer().getConfig().getLayerName()); + assertEquals("l4", ub4.getLayersAndVariablesInBlock().get(1).getLayer().getTrainingConfig().getLayerName()); assertEquals(DefaultParamInitializer.BIAS_KEY, ub4.getLayersAndVariablesInBlock().get(1).getParamName()); int nParams4 = 10 * 10 + 10; diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java index e5caf981f..e52b126f2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/custom/TestCustomUpdater.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.updater.custom; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; @@ -61,18 +61,18 @@ public class TestCustomUpdater extends BaseDL4JTest { .build(); //First: Check updater config - assertTrue(((BaseLayer) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater); - assertTrue(((BaseLayer) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater); - assertTrue(((BaseLayer) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd); - assertTrue(((BaseLayer) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd); + assertTrue(((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater() instanceof CustomIUpdater); + assertTrue(((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater() instanceof CustomIUpdater); + assertTrue(((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater() instanceof Sgd); + assertTrue(((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater() instanceof Sgd); - CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayer) conf1.getConf(0).getLayer()).getIUpdater(); - CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayer) conf1.getConf(1).getLayer()).getIUpdater(); + CustomIUpdater u0_0 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(0).getLayer()).getIUpdater(); + CustomIUpdater u0_1 = (CustomIUpdater) ((BaseLayerConfiguration) conf1.getConf(1).getLayer()).getIUpdater(); assertEquals(lr, u0_0.getLearningRate(), 1e-6); assertEquals(lr, u0_1.getLearningRate(), 1e-6); - Sgd u1_0 = (Sgd) ((BaseLayer) conf2.getConf(0).getLayer()).getIUpdater(); - Sgd u1_1 = (Sgd) ((BaseLayer) conf2.getConf(1).getLayer()).getIUpdater(); + Sgd u1_0 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(0).getLayer()).getIUpdater(); + Sgd u1_1 = (Sgd) ((BaseLayerConfiguration) conf2.getConf(1).getLayer()).getIUpdater(); assertEquals(lr, u1_0.getLearningRate(), 1e-6); assertEquals(lr, u1_1.getLearningRate(), 1e-6); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java index 692f0f44f..91101fccc 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/BackTrackLineSearchTest.java @@ -81,7 +81,7 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, layer.getOptimizer()); - double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); + double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); assertEquals(1.0, step, 1e-3); } @@ -97,11 +97,11 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setLabels(irisData.getLabels()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - score1 = layer.score(); + score1 = layer.getScore(); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, new NegativeDefaultStepFunction(), layer.getOptimizer()); - double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); + double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); assertEquals(1.0, step, 1e-3); } @@ -118,18 +118,18 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setLabels(irisData.getLabels()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - score1 = layer.score(); + score1 = layer.getScore(); INDArray origGradient = layer.gradient().gradient().dup(); NegativeDefaultStepFunction sf = new NegativeDefaultStepFunction(); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer()); - double step = lineSearch.optimize(layer.params(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); - INDArray currParams = layer.params(); + double step = lineSearch.optimize(layer.getModelParams(), layer.gradient().gradient(), layer.gradient().gradient(), LayerWorkspaceMgr.noWorkspacesImmutable()); + INDArray currParams = layer.getModelParams(); sf.step(currParams, origGradient, step); layer.setParamsTable(currParams); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - score2 = layer.score(); + score2 = layer.getScore(); assertTrue(score1 > score2, "score1=" + score1 + ", score2=" + score2); @@ -146,19 +146,19 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { layer.setInput(irisData.getFeatures(), LayerWorkspaceMgr.noWorkspaces()); layer.setLabels(irisData.getLabels()); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - score1 = layer.score(); + score1 = layer.getScore(); INDArray origGradient = layer.gradient().gradient().dup(); DefaultStepFunction sf = new DefaultStepFunction(); BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer()); - double step = lineSearch.optimize(layer.params().dup(), layer.gradient().gradient().dup(), + double step = lineSearch.optimize(layer.getModelParams().dup(), layer.gradient().gradient().dup(), layer.gradient().gradient().dup(), LayerWorkspaceMgr.noWorkspacesImmutable()); - INDArray currParams = layer.params(); + INDArray currParams = layer.getModelParams(); sf.step(currParams, origGradient, step); layer.setParamsTable(currParams); layer.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - score2 = layer.score(); + score2 = layer.getScore(); assertTrue(score1 < score2, "score1 = " + score1 + ", score2 = " + score2); } @@ -190,12 +190,12 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, optimizer)); network.init(); TrainingListener listener = new ScoreIterationListener(10); - network.setListeners(Collections.singletonList(listener)); + network.addTrainingListeners(Collections.singletonList(listener)); double oldScore = network.score(data); for( int i=0; i<100; i++ ) { network.fit(data.getFeatures(), data.getLabels()); } - double score = network.score(); + double score = network.getScore(); assertTrue(score < oldScore); } @@ -208,13 +208,13 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); network.init(); TrainingListener listener = new ScoreIterationListener(10); - network.setListeners(Collections.singletonList(listener)); + network.addTrainingListeners(Collections.singletonList(listener)); double firstScore = network.score(data); for( int i=0; i<5; i++ ) { network.fit(data.getFeatures(), data.getLabels()); } - double score = network.score(); + double score = network.getScore(); assertTrue(score < firstScore); } @@ -227,13 +227,13 @@ public class BackTrackLineSearchTest extends BaseDL4JTest { MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, optimizer)); network.init(); TrainingListener listener = new ScoreIterationListener(10); - network.setListeners(Collections.singletonList(listener)); + network.addTrainingListeners(Collections.singletonList(listener)); double oldScore = network.score(data); for( int i=0; i<5; i++ ) { network.fit(data.getFeatures(), data.getLabels()); } - double score = network.score(); + double score = network.getScore(); assertTrue(score < oldScore); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index 69afb6330..7883c899f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -28,6 +28,7 @@ import org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator; import org.deeplearning4j.nn.api.*; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -211,38 +212,38 @@ public class TestOptimizers extends BaseDL4JTest { System.out.println("---------\n Alg= " + oa + ", nIter= " + numLineSearchIter + ", nDimensions= " + nDimensions); - NeuralNetConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter) + LayerConfiguration conf = NeuralNetConfiguration.builder().maxNumLineSearchIterations(numLineSearchIter) .updater(new Sgd(1e-2)) - .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); - conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here + .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build().getFlattenedLayerConfigurations().get(0); + conf.addVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here Random rng = new DefaultRandom(12345L); org.nd4j.linalg.api.rng.distribution.Distribution dist = new org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution(rng, -10, 10); IModel m = new SphereFunctionModel(nDimensions, dist, conf); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - double scoreBefore = m.score(); + double scoreBefore = m.getScore(); assertTrue(!Double.isNaN(scoreBefore) && !Double.isInfinite(scoreBefore)); if (PRINT_OPT_RESULTS) { System.out.println("Before:"); System.out.println(scoreBefore); - System.out.println(m.params()); + System.out.println(m.getModelParams()); } - ConvexOptimizer opt = getOptimizer(oa, conf, m); + ConvexOptimizer opt = getOptimizer(oa, conf.getNetConfiguration(), m); opt.setupSearchState(m.gradientAndScore()); for( int i=0; i<100; i++ ) { opt.optimize(LayerWorkspaceMgr.noWorkspaces()); } m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - double scoreAfter = m.score(); + double scoreAfter = m.getScore(); assertTrue(!Double.isNaN(scoreAfter) && !Double.isInfinite(scoreAfter)); if (PRINT_OPT_RESULTS) { System.out.println("After:"); System.out.println(scoreAfter); - System.out.println(m.params()); + System.out.println(m.getModelParams()); } //Expected behaviour after optimization: @@ -279,17 +280,17 @@ public class TestOptimizers extends BaseDL4JTest { .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here - IModel m = new SphereFunctionModel(100, dist, conf); + IModel m = new SphereFunctionModel(100, dist, conf.getFlattenedLayerConfigurations().get(0)); if (i == 0) { m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - scores[0] = m.score(); //Before optimization + scores[0] = m.getScore(); //Before optimization } else { ConvexOptimizer opt = getOptimizer(oa, conf, m); for( int j=0; j<100; j++ ) { opt.optimize(LayerWorkspaceMgr.noWorkspaces()); } m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - scores[i] = m.score(); + scores[i] = m.getScore(); assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i])); } } @@ -316,7 +317,7 @@ public class TestOptimizers extends BaseDL4JTest { private static final long serialVersionUID = -6963606137417355405L; private SphereFunctionModel(int nParams, org.nd4j.linalg.api.rng.distribution.Distribution distribution, - NeuralNetConfiguration conf) { + LayerConfiguration conf) { super(distribution.sample(new int[] {1, nParams}), conf); } @@ -437,7 +438,7 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public void setListeners(TrainingListener... listeners) { + public void addTrainingListeners(TrainingListener... listeners) { } @@ -499,17 +500,17 @@ public class TestOptimizers extends BaseDL4JTest { .layer(new DenseLayer.Builder().nIn(1).nOut(1).build()).build(); conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here - IModel m = new RastriginFunctionModel(10, conf); + IModel m = new RastriginFunctionModel(10, conf.getFlattenedLayerConfigurations().get(0)); int nParams = (int)m.numParams(); if (i == 0) { m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - scores[0] = m.score(); //Before optimization + scores[0] = m.getScore(); //Before optimization } else { ConvexOptimizer opt = getOptimizer(oa, conf, m); opt.getUpdater().setStateViewArray((Layer) m, Nd4j.create(new int[] {1, nParams}, 'c'), true); opt.optimize(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - scores[i] = m.score(); + scores[i] = m.getScore(); assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i])); } } @@ -540,7 +541,7 @@ public class TestOptimizers extends BaseDL4JTest { private static class RastriginFunctionModel extends SimpleOptimizableModel { private static final long serialVersionUID = -1772954508787487941L; - private RastriginFunctionModel(int nDimensions, NeuralNetConfiguration conf) { + private RastriginFunctionModel(int nDimensions, LayerConfiguration conf) { super(initParams(nDimensions), conf); } @@ -710,7 +711,7 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public void setListeners(TrainingListener... listeners) { + public void addTrainingListeners(TrainingListener... listeners) { } @@ -768,15 +769,15 @@ public class TestOptimizers extends BaseDL4JTest { .build(); conf.addNetWideVariable("W"); //Normally done by ParamInitializers, but obviously that isn't done here - IModel m = new RosenbrockFunctionModel(100, conf); + IModel m = new RosenbrockFunctionModel(100, conf.getFlattenedLayerConfigurations().get(0)); if (i == 0) { m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - scores[0] = m.score(); //Before optimization + scores[0] = m.getScore(); //Before optimization } else { ConvexOptimizer opt = getOptimizer(oa, conf, m); opt.optimize(LayerWorkspaceMgr.noWorkspaces()); m.computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - scores[i] = m.score(); + scores[i] = m.getScore(); assertTrue(!Double.isNaN(scores[i]) && !Double.isInfinite(scores[i]), "NaN or infinite score: " + scores[i]); } } @@ -810,7 +811,7 @@ public class TestOptimizers extends BaseDL4JTest { private static class RosenbrockFunctionModel extends SimpleOptimizableModel { private static final long serialVersionUID = -5129494342531033706L; - private RosenbrockFunctionModel(int nDimensions, NeuralNetConfiguration conf) { + private RosenbrockFunctionModel(int nDimensions, LayerConfiguration conf) { super(initParams(nDimensions), conf); } @@ -995,7 +996,7 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public void setListeners(TrainingListener... listeners) { + public void addTrainingListeners(TrainingListener... listeners) { } @@ -1029,13 +1030,31 @@ public class TestOptimizers extends BaseDL4JTest { private static final long serialVersionUID = 4409380971404019303L; protected INDArray parameters; protected INDArray gradientView; - protected final NeuralNetConfiguration conf; + protected final LayerConfiguration conf; protected Gradient gradient; protected double score; + /** + * @return 1d parameter vector + */ + @Override + public INDArray getParams() { + throw new RuntimeException("Not implemented"); + } + + /** + * Get a reference to the network this layer is part of. + * + * @return + */ + @Override + public IModel getNet() { + throw new RuntimeException("Not implemented"); + } + /**@param parameterInit Initial parameters. Also determines dimensionality of problem. Should be row vector. */ - private SimpleOptimizableModel(INDArray parameterInit, NeuralNetConfiguration conf) { + private SimpleOptimizableModel(INDArray parameterInit, LayerConfiguration conf) { this.parameters = parameterInit.dup(); this.gradientView = Nd4j.create(parameterInit.shape()); this.conf = conf; @@ -1048,17 +1067,12 @@ public class TestOptimizers extends BaseDL4JTest { */ @Override public LayerConfiguration getLayerConfiguration() { - return this.conf.getFirstLayer(); + return this.conf; } @Override - public void addListeners(TrainingListener... listener) { - // no-op - } - - @Override - public TrainingConfig getConfig() { - return conf.getFirstLayer(); + public ITraininableLayerConfiguration getTrainingConfig() { + return (BaseLayerConfiguration) conf; } /** @@ -1092,7 +1106,7 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public void setListeners(TrainingListener... listeners) { + public void addTrainingListeners(TrainingListener... listeners) { } @@ -1112,7 +1126,7 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public double score() { + public double getScore() { return score; } @@ -1132,7 +1146,7 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public INDArray params() { + public INDArray getModelParams() { return parameters; } @@ -1154,7 +1168,7 @@ public class TestOptimizers extends BaseDL4JTest { @Override public Pair gradientAndScore() { computeGradientAndScore(LayerWorkspaceMgr.noWorkspaces()); - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } @Override @@ -1164,7 +1178,7 @@ public class TestOptimizers extends BaseDL4JTest { @Override public NeuralNetConfiguration getNetConfiguration() { - return conf; + return conf.getNetConfiguration(); } @Override @@ -1225,12 +1239,12 @@ public class TestOptimizers extends BaseDL4JTest { } @Override - public Collection getListeners() { + public Collection getTrainingListeners() { return null; } @Override - public void setListeners(Collection listeners) { + public void addTrainingListeners(Collection listeners) { throw new UnsupportedOperationException(); } @@ -1310,4 +1324,6 @@ public class TestOptimizers extends BaseDL4JTest { public void close(){ } } + + } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java index 6f422fda1..2e34fcd46 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestCheckpointListener.java @@ -76,7 +76,7 @@ public class TestCheckpointListener extends BaseDL4JTest { .keepAll() .saveEveryNEpochs(2) .build(); - net.setListeners(l); + net.addTrainingListeners(l); for(int i=0; i<10; i++ ){ net.fit(iter); @@ -125,7 +125,7 @@ public class TestCheckpointListener extends BaseDL4JTest { .keepLast(3) .saveEveryNIterations(5) .build(); - net.setListeners(l); + net.addTrainingListeners(l); for(int i=0; i<20; i++ ){ //40 iterations total net.fit(iter); @@ -167,7 +167,7 @@ public class TestCheckpointListener extends BaseDL4JTest { MultiLayerNetwork netStatic2 = CheckpointListener.loadLastCheckpointMLN(f); assertEquals(35, netStatic2.getIterationCount()); - assertEquals(netStatic.params(), netStatic2.params()); + assertEquals(netStatic.getModelParams(), netStatic2.getModelParams()); } @Test @@ -182,7 +182,7 @@ public class TestCheckpointListener extends BaseDL4JTest { .keepLast(3) .saveEvery(4900, TimeUnit.MILLISECONDS) .build(); - net.setListeners(l); + net.addTrainingListeners(l); for(int i=0; i<3; i++ ){ //10 iterations total net.fit(iter); @@ -226,7 +226,7 @@ public class TestCheckpointListener extends BaseDL4JTest { .keepLastAndEvery(3, 3) .saveEveryNEpochs(2) .build(); - net.setListeners(l); + net.addTrainingListeners(l); for(int i=0; i<20; i++ ){ //40 iterations total net.fit(iter); @@ -272,7 +272,7 @@ public class TestCheckpointListener extends BaseDL4JTest { .keepAll() .saveEveryNEpochs(1) .build(); - net.setListeners(l); + net.addTrainingListeners(l); for(int i=0; i<3; i++ ){ net.fit(iter); @@ -294,7 +294,7 @@ public class TestCheckpointListener extends BaseDL4JTest { .saveEveryNEpochs(1) .deleteExisting(true) .build(); - net.setListeners(l); + net.addTrainingListeners(l); net.fit(iter); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java index a1933c247..fb500772d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestFailureListener.java @@ -58,7 +58,7 @@ public class TestFailureListener extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(new FailureTestingListener( + net.addTrainingListeners(new FailureTestingListener( // FailureTestingListener.FailureMode.OOM, FailureTestingListener.FailureMode.SYSTEM_EXIT_1, new FailureTestingListener.IterationEpochTrigger(false, 10))); @@ -84,7 +84,7 @@ public class TestFailureListener extends BaseDL4JTest { assertNotNull(username); assertFalse(username.isEmpty()); - net.setListeners(new FailureTestingListener( + net.addTrainingListeners(new FailureTestingListener( FailureTestingListener.FailureMode.SYSTEM_EXIT_1, new FailureTestingListener.Or( new FailureTestingListener.IterationEpochTrigger(false, 10000), @@ -112,7 +112,7 @@ public class TestFailureListener extends BaseDL4JTest { assertNotNull(hostname); assertFalse(hostname.isEmpty()); - net.setListeners(new FailureTestingListener( + net.addTrainingListeners(new FailureTestingListener( FailureTestingListener.FailureMode.ILLEGAL_STATE, new FailureTestingListener.And( new FailureTestingListener.HostNameTrigger(hostname), diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java index b335d43a6..f3d4f5dee 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/optimizer/listener/TestListeners.java @@ -77,17 +77,17 @@ public class TestListeners extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(new ScoreIterationListener(), new TestRoutingListener()); + net.addTrainingListeners(new ScoreIterationListener(), new TestRoutingListener()); for (Layer l : net.getLayers()) { - Collection layerListeners = l.getListeners(); + Collection layerListeners = l.getTrainingListeners(); assertEquals(2, layerListeners.size(), l.getClass().toString()); TrainingListener[] lArr = layerListeners.toArray(new TrainingListener[2]); assertTrue(lArr[0] instanceof ScoreIterationListener); assertTrue(lArr[1] instanceof TestRoutingListener); } - Collection netListeners = net.getListeners(); + Collection netListeners = net.getTrainingListeners(); assertEquals(2, netListeners.size()); TrainingListener[] lArr = netListeners.toArray(new TrainingListener[2]); assertTrue(lArr[0] instanceof ScoreIterationListener); @@ -101,17 +101,17 @@ public class TestListeners extends BaseDL4JTest { ComputationGraph cg = new ComputationGraph(gConf); cg.init(); - cg.setListeners(new ScoreIterationListener(), new TestRoutingListener()); + cg.addTrainingListeners(new ScoreIterationListener(), new TestRoutingListener()); for (Layer l : cg.getLayers()) { - Collection layerListeners = l.getListeners(); + Collection layerListeners = l.getTrainingListeners(); assertEquals(2, layerListeners.size()); lArr = layerListeners.toArray(new TrainingListener[2]); assertTrue(lArr[0] instanceof ScoreIterationListener); assertTrue(lArr[1] instanceof TestRoutingListener); } - netListeners = cg.getListeners(); + netListeners = cg.getTrainingListeners(); assertEquals(2, netListeners.size()); lArr = netListeners.toArray(new TrainingListener[2]); assertTrue(lArr[0] instanceof ScoreIterationListener); @@ -180,7 +180,7 @@ public class TestListeners extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(listeners); + net.addTrainingListeners(listeners); net.fit(iter); @@ -199,7 +199,7 @@ public class TestListeners extends BaseDL4JTest { listeners2.add(il2); } - net.setListeners(listeners2); + net.addTrainingListeners(listeners2); net.fit(iter); } @@ -216,7 +216,7 @@ public class TestListeners extends BaseDL4JTest { net.init(); TestListener tl = new TestListener(); - net.setListeners(tl); + net.addTrainingListeners(tl); DataSetIterator irisIter = new IrisDataSetIterator(50, 150); @@ -260,7 +260,7 @@ public class TestListeners extends BaseDL4JTest { tl = new TestListener(); ComputationGraph cg = net.toComputationGraph(); - cg.setListeners(tl); + cg.addTrainingListeners(tl); cg.fit(irisIter, 2); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java index 114d90887..2c214eeff 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/parallelism/RandomTests.java @@ -94,7 +94,7 @@ public class RandomTests extends BaseDL4JTest { // at the end of day, model params has to for (int i = 0; i < models.size(); i++) { - assertEquals(models.get(0).params(), models.get(i).params()); + assertEquals(models.get(0).getModelParams(), models.get(i).getModelParams()); } } @@ -119,7 +119,7 @@ public class RandomTests extends BaseDL4JTest { MultiLayerNetwork net2 = new MultiLayerNetwork(conf); net2.init(); - assertEquals(net1.params(), net2.params()); + assertEquals(net1.getModelParams(), net2.getModelParams()); NeuralNetConfiguration fromJson = NeuralNetConfiguration.fromJson(json); @@ -127,6 +127,6 @@ public class RandomTests extends BaseDL4JTest { MultiLayerNetwork net3 = new MultiLayerNetwork(fromJson); net3.init(); - assertEquals(net1.params(), net3.params()); + assertEquals(net1.getModelParams(), net3.getModelParams()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java index c52f4943f..6b2d882e3 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java @@ -63,7 +63,7 @@ public class TestSystemInfoPrintListener extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(systemInfoFilePrintListener); + net.addTrainingListeners(systemInfoFilePrintListener); DataSetIterator iter = new IrisDataSetIterator(10, 150); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index df6f1e0cb..773ccbae8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -87,7 +87,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new Nesterovs().stateSize(net.numParams()); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @@ -126,7 +126,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertEquals(new WeightDecay(0.2, false), TestUtils.getWeightDecayReg(l1)); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new RmsProp().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @@ -170,7 +170,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new RmsProp().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index d6c88b4d3..c75c11d11 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -89,7 +89,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new Nesterovs().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @@ -132,7 +132,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new RmsProp().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @@ -178,7 +178,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new RmsProp().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index bf14dba46..63ea30e49 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -90,7 +90,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); long numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new Nesterovs().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @@ -133,7 +133,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); long numParams = net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new RmsProp().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @@ -179,7 +179,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); long numParams = net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new RmsProp().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize).reshape(1,numParams), net.getUpdater().getStateViewArray()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index 4cc26f05a..010ac9733 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -94,7 +94,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertEquals(0.15, n.getLearningRate(), 1e-6); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new Nesterovs().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @@ -143,7 +143,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertEquals(1.5, l1.getGradientNormalizationThreshold(), 1e-5); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new RmsProp().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } @@ -194,7 +194,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertTrue(conf.getInputPreProcess(2) instanceof CnnToFeedForwardPreProcessor); int numParams = (int)net.numParams(); - assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.params()); + assertEquals(Nd4j.linspace(1, numParams, numParams, Nd4j.dataType()).reshape(1,numParams), net.getModelParams()); int updaterSize = (int) new RmsProp().stateSize(numParams); assertEquals(Nd4j.linspace(1, updaterSize, updaterSize, Nd4j.dataType()).reshape(1,numParams), net.getUpdater().getStateViewArray()); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index 6d73c1074..829fc8c2b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -97,7 +97,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { assertEquals(dt, in.dataType()); assertEquals(dt, outExp.dataType()); - assertEquals(dt, net.params().dataType()); + assertEquals(dt, net.getModelParams().dataType()); assertEquals(dt, net.getFlattenedGradients().dataType()); assertEquals(dt, net.getUpdater().getStateViewArray().dataType()); @@ -109,7 +109,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { List activations = net.feedForward(in); assertEquals(dt, net.getNetConfiguration().getDataType()); - assertEquals(dt, net.params().dataType()); + assertEquals(dt, net.getModelParams().dataType()); assertEquals( outExp, outAct, dtype); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index bd2f231d2..b1247b3c1 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -116,7 +116,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(dtype, in.dataType()); assertEquals(dtype, outExp.dataType()); - assertEquals(dtype, net.params().dataType()); + assertEquals(dtype, net.getModelParams().dataType()); assertEquals(dtype, net.getFlattenedGradients().dataType()); assertEquals(dtype, net.getUpdater().getStateViewArray().dataType()); @@ -126,7 +126,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(dtype, outAct.dataType()); assertEquals(dtype, net.getNetConfiguration().getDataType()); - assertEquals(dtype, net.params().dataType()); + assertEquals(dtype, net.getModelParams().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); assertTrue(eq, "Test for dtype: " + dtypeName + "\n" + outExp + " vs " + outAct); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index bf13cff1b..f00b9c437 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -98,7 +98,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(dtype, in.dataType()); assertEquals(dtype, outExp.dataType()); - assertEquals(dtype, net.params().dataType()); + assertEquals(dtype, net.getModelParams().dataType()); assertEquals(dtype, net.getFlattenedGradients().dataType()); assertEquals(dtype, net.getUpdater().getStateViewArray().dataType()); @@ -108,7 +108,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(dtype, outAct.dataType()); assertEquals(dtype, net.getNetConfiguration().getDataType()); - assertEquals(dtype, net.params().dataType()); + assertEquals(dtype, net.getModelParams().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); assertTrue( eq, "Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java index 72b55f9e6..b20ad6f00 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayer.java @@ -76,7 +76,7 @@ public class CustomLayer extends FeedForwardLayer { //For the most part, it's the same for each type of layer CustomLayerImpl myCustomLayer = new CustomLayerImpl(lconf, networkDataType); - myCustomLayer.setListeners(iterationListeners); //Set the iteration listeners, if any + myCustomLayer.addTrainingListeners(iterationListeners); //Set the iteration listeners, if any myCustomLayer.setIndex(layerIndex); //Integer index of the layer //Parameter view array: In Deeplearning4j, the network parameters for the entire network (all layers) are diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java index d233a5da3..14e13634b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/customlayer100a/CustomLayerImpl.java @@ -20,7 +20,6 @@ package org.deeplearning4j.regressiontest.customlayer100a; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -56,7 +55,7 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete INDArray firstHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2)); INDArray secondHalf = output.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); - IActivation activation1 = layerConf().getActivationFn(); + IActivation activation1 = getTypedLayerConfiguration().getActivationFn(); IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction(); //IActivation function instances modify the activation functions in-place @@ -75,7 +74,7 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { /* - The baockprop gradient method here is very similar to the BaseLayer backprop gradient implementation + The baockprop gradient method here is very similar to the BaseLayerConfiguration backprop gradient implementation The only major difference is the two activation functions we have added in this example. Note that epsilon is dL/da - i.e., the derivative of the loss function with respect to the activations. @@ -105,14 +104,14 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete INDArray epsilonFirstHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(0, columns / 2)); INDArray epsilonSecondHalf = epsilon.get(NDArrayIndex.all(), NDArrayIndex.interval(columns / 2, columns)); - IActivation activation1 = layerConf().getActivationFn(); + IActivation activation1 = getTypedLayerConfiguration().getActivationFn(); IActivation activation2 = ((CustomLayer) getLayerConfiguration()).getSecondActivationFunction(); //IActivation backprop method modifies the 'firstHalf' and 'secondHalf' arrays in-place, to contain dL/dz activation1.backprop(firstHalf, epsilonFirstHalf); activation2.backprop(secondHalf, epsilonSecondHalf); - //The remaining code for this method: just copy & pasted from BaseLayer.backpropGradient + //The remaining code for this method: just copy & pasted from BaseLayerConfiguration.backpropGradient // INDArray delta = epsilon.muli(activationDerivative); if (maskArray != null) { activationDerivative.muliColumnVector(maskArray); @@ -128,7 +127,7 @@ public class CustomLayerImpl extends BaseLayer { //Generic paramete ret.gradientForVariable().put(DefaultParamInitializer.WEIGHT_KEY, weightGrad); ret.gradientForVariable().put(DefaultParamInitializer.BIAS_KEY, biasGrad); - INDArray epsilonNext = paramsTable.get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose(); + INDArray epsilonNext = getParamTable().get(DefaultParamInitializer.WEIGHT_KEY).mmul(activationDerivative.transpose()).transpose(); return new Pair<>(ret, epsilonNext); } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index 03b8192f4..b4edb0ba8 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -190,7 +190,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { //Check score - double scoreDl4j = net.score(); + double scoreDl4j = net.getScore(); double scoreSd = map.get(lossMse.name()).getDouble(0) + sd.calcRegularizationScore(); assertEquals(scoreDl4j, scoreSd, 1e-6, testName); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java index 49a9c7fa1..e0eeef88d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/CrashReportingUtilTest.java @@ -104,7 +104,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.addListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); //Test net that hasn't been trained yet Exception e = new Exception(); @@ -161,7 +161,7 @@ public class CrashReportingUtilTest extends BaseDL4JTest { CrashReportingUtil.crashDumpOutputDirectory(dir); ComputationGraph cg = net.toComputationGraph(); - cg.setListeners(new ScoreIterationListener(1)); + cg.addTrainingListeners(new ScoreIterationListener(1)); //Test net that hasn't been trained yet CrashReportingUtil.writeMemoryCrashDump(cg, e); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java index 4415c5455..e941c75ee 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelGuesserTest.java @@ -156,7 +156,7 @@ public class ModelGuesserTest extends BaseDL4JTest { MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(tempFile.getAbsolutePath()); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); - assertEquals(net.params(), network.params()); + assertEquals(net.getModelParams(), network.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -173,7 +173,7 @@ public class ModelGuesserTest extends BaseDL4JTest { MultiLayerNetwork network = (MultiLayerNetwork) ModelGuesser.loadModelGuess(inputStream); Assertions.assertNotNull(network); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); - assertEquals(net.params(), network.params()); + assertEquals(net.getModelParams(), network.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java index 5124e15ac..495b403d5 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/util/ModelSerializerTest.java @@ -81,7 +81,7 @@ public class ModelSerializerTest extends BaseDL4JTest { MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(tempFile); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); - assertEquals(net.params(), network.params()); + assertEquals(net.getModelParams(), network.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -125,7 +125,7 @@ public class ModelSerializerTest extends BaseDL4JTest { MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(fis); assertEquals(network.getNetConfiguration().toJson(), net.getNetConfiguration().toJson()); - assertEquals(net.params(), network.params()); + assertEquals(net.getModelParams(), network.getModelParams()); assertEquals(net.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -151,7 +151,7 @@ public class ModelSerializerTest extends BaseDL4JTest { ComputationGraph network = ModelSerializer.restoreComputationGraph(tempFile); assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson()); - assertEquals(cg.params(), network.params()); + assertEquals(cg.getModelParams(), network.getModelParams()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -177,7 +177,7 @@ public class ModelSerializerTest extends BaseDL4JTest { ComputationGraph network = ModelSerializer.restoreComputationGraph(fis); assertEquals(network.getComputationGraphConfiguration().toJson(), cg.getComputationGraphConfiguration().toJson()); - assertEquals(cg.params(), network.params()); + assertEquals(cg.getModelParams(), network.getModelParams()); assertEquals(cg.getUpdater().getStateViewArray(), network.getUpdater().getStateViewArray()); } @@ -346,7 +346,7 @@ public class ModelSerializerTest extends BaseDL4JTest { //Also test reading both model and normalizer from stream (correctly) Pair pair = ModelSerializer.restoreMultiLayerNetworkAndNormalizer(new FileInputStream(tempFile), true); - assertEquals(net.params(), pair.getFirst().params()); + assertEquals(net.getModelParams(), pair.getFirst().getModelParams()); assertNotNull(pair.getSecond()); } @@ -395,7 +395,7 @@ public class ModelSerializerTest extends BaseDL4JTest { //Also test reading both model and normalizer from stream (correctly) Pair pair = ModelSerializer.restoreComputationGraphAndNormalizer(new FileInputStream(tempFile), true); - assertEquals(net.params(), pair.getFirst().params()); + assertEquals(net.getModelParams(), pair.getFirst().getModelParams()); assertNotNull(pair.getSecond()); } @@ -496,6 +496,6 @@ public class ModelSerializerTest extends BaseDL4JTest { assertTrue(entries.contains("otherData.bin")); ComputationGraph restoredNet = ModelSerializer.restoreComputationGraph(tempFile); - assertEquals(net.params(), restoredNet.params()); + assertEquals(net.getModelParams(), restoredNet.getModelParams()); } } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java index 11bb40d58..86ebdf3ea 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/TFOpLayer.java @@ -21,7 +21,6 @@ package org.deeplearning4j.nn.modelimport.keras.layers; import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; @@ -80,10 +79,6 @@ public class TFOpLayer extends LayerConfiguration { public void setNIn(InputType inputType, boolean override){} - @Override - public GradientNormalization getGradientNormalization(){return null;} - - @Override public org.deeplearning4j.nn.api.Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, @@ -91,14 +86,11 @@ public class TFOpLayer extends LayerConfiguration { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); TFOpLayerImpl tfOpLayerImpl = new TFOpLayerImpl(nodeDef, constants, lconf, networkDataType); - tfOpLayerImpl.setListeners(trainingListeners); + tfOpLayerImpl.addTrainingListeners(trainingListeners); tfOpLayerImpl.setIndex(layerIndex); return tfOpLayerImpl; } - @Override - public double getGradientNormalizationThreshold(){return 0.;} - @Override public List getRegularizationByParam(String paramName){return null;} diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java index 97ceac993..b2e5a15a2 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTM.java @@ -31,7 +31,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; @@ -448,8 +448,8 @@ public class KerasLSTM extends KerasLayer { FeedForwardLayer ffl; - if(this.layer instanceof BaseWrapperLayer){ - BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; + if(this.layer instanceof BaseWrapperLayerConfiguration){ + BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration)this.layer; ffl = (FeedForwardLayer)bwl.getUnderlying(); } else { ffl = (FeedForwardLayer) this.layer; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java index 35a1aed01..3c850ecfa 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnn.java @@ -33,7 +33,7 @@ import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.conf.layers.util.MaskZeroLayer; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; @@ -296,8 +296,8 @@ public class KerasSimpleRnn extends KerasLayer { } FeedForwardLayer ffl; - if(this.layer instanceof BaseWrapperLayer){ - BaseWrapperLayer bwl = (BaseWrapperLayer)this.layer; + if(this.layer instanceof BaseWrapperLayerConfiguration){ + BaseWrapperLayerConfiguration bwl = (BaseWrapperLayerConfiguration)this.layer; ffl = (FeedForwardLayer)bwl.getUnderlying(); } else { ffl = (FeedForwardLayer) this.layer; diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java index 12a00d4f7..0ee7ce776 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/KerasTestUtils.java @@ -20,7 +20,7 @@ package org.deeplearning4j.nn.modelimport.keras; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.samediff.AbstractSameDiffLayer; import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.L2Regularization; @@ -34,7 +34,7 @@ public class KerasTestUtils { private KerasTestUtils(){ } - public static double getL1(BaseLayer layer) { + public static double getL1(BaseLayerConfiguration layer) { List l = layer.getRegularization(); return getL1(l); } @@ -49,7 +49,7 @@ public class KerasTestUtils { return l1Reg.getL1().valueAt(0,0); } - public static double getL2(BaseLayer layer) { + public static double getL2(BaseLayerConfiguration layer) { List l = layer.getRegularization(); return getL2(l); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 2fea0bb82..1dad7c549 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -286,7 +286,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true); Layer outLayer = net.getOutputLayer(); assertTrue(outLayer instanceof org.deeplearning4j.nn.layers.LossLayer); - LossLayer llConf = (LossLayer) outLayer.getConfig(); + LossLayer llConf = (LossLayer) outLayer.getTrainingConfig(); assertEquals(new LossSparseMCXENT(), llConf.getLossFn()); } @@ -656,7 +656,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { MultiLayerNetwork net = importEndModelTest(modelPath, inputsOutputPath, true, true, true, true, false, null, null); Layer l = net.getLayer(0); - Convolution1DLayer c1d = (Convolution1DLayer) l.getConfig(); + Convolution1DLayer c1d = (Convolution1DLayer) l.getTrainingConfig(); assertEquals(ConvolutionMode.Causal, c1d.getConvolutionMode()); } } diff --git a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java index 49237b20c..bd89639b3 100644 --- a/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java +++ b/cavis-dnn/cavis-dnn-nlp/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTestsSmall.java @@ -209,6 +209,6 @@ public class Word2VecTestsSmall extends BaseDL4JTest { final MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); - assertTrue(net.params().equalsWithEps(restored.params(), 2e-3)); + assertTrue(net.getModelParams().equalsWithEps(restored.getModelParams(), 2e-3)); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayerConfiguration.java index e0f5d856b..1462661bb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/ILayerConfiguration.java @@ -23,41 +23,5 @@ package net.brutex.ai.dnn.api; public interface ILayerConfiguration { - /** - * Create and return an instance of a ILayerConfiguration. - * - * @param network the "holding" network for the instance - * @return the new layer instance - */ - ILayer instantiate(IModel network); - - - /** - * Defines the valid input type for this ILayer - * - * @return InputType - */ - org.deeplearning4j.nn.conf.inputs.InputType.Type getInputType(); - - - /** - * Defines the valid input type for this ILayer - * - * @return InputType - */ - org.deeplearning4j.nn.conf.inputs.InputType.Type getOutputType(); - - - /** - * Number of trainable parameter in this layer - * @return number of parameter - */ - long numParameters(); - - /** - * An implementation should provide a method to validate the network - * @return true if no errors found; false otherwise - */ - boolean isValid(); - + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java index 9f81fd3d8..3f84a7004 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/IModel.java @@ -24,6 +24,7 @@ package net.brutex.ai.dnn.api; import java.util.Collection; import java.util.Map; import lombok.NonNull; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; @@ -44,16 +45,17 @@ import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; * {@link #getNetConfiguration()} methods. **/ -public interface IModel { +public interface IModel extends ITrainableLayer { /** - * The param table + * The full param table for the model. Each layer may get a subset of its parameters. * - * @return + * @return full table of parameters */ - Map getParamTable(); + Map getParamTable(boolean backpropOnly); + void setParamTable(Map paramTable); /** * This method returns updater state (if applicable), null otherwise @@ -113,6 +115,11 @@ public interface IModel { */ T[] doEvaluation(MultiDataSetIterator iterator, T... evaluations); + /** + * Get the configuration of this model. + * + * @return the neural net configuration + */ NeuralNetConfiguration getNetConfiguration(); void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration); @@ -124,6 +131,7 @@ public interface IModel { /** * Get the number of parameters in this model + * * @return number of parameters */ long numParams(); @@ -148,11 +156,12 @@ public interface IModel { /** - * The score for the model + * The score for the model. No calculation occurs, this simply returns the score calculated before + * by the {@link #computeGradientAndScore(LayerWorkspaceMgr)} method. * * @return the score for the model */ - double score(); + double getScore(); /** @@ -165,7 +174,7 @@ public interface IModel { * * @return the parameters of the model */ - INDArray params(); + INDArray getModelParams(); /** @@ -243,15 +252,16 @@ public interface IModel { /** * Get a parameter array for a given parameter type key + * * @param param the key of the parameter * @return ndarray of parameters */ INDArray getParam(String param); - /** * Set the parameters for a given parameter type. + * * @param key the param type key to set * @param val the new parameters ndarray */ @@ -273,20 +283,19 @@ public interface IModel { /** * Get the TrainingListeners + * * @return training listener */ - Collection getListeners(); + Collection getTrainingListeners(); /** * Replace the TrainingListeners for this model + * * @param listeners new listeners */ - void setListeners(TrainingListener... listeners); - void setListeners(Collection listeners); + void addTrainingListeners(TrainingListener... listeners); + + void addTrainingListeners(Collection listeners); + - /** - * Add TrainingListeners to the model - * @param listener listener to add - */ - void addListeners(TrainingListener... listener); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java index b317e4ab0..02ae2d45f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java @@ -29,6 +29,12 @@ public interface INeuralNetworkConfiguration extends Serializable, Cloneable { INeuralNetworkConfiguration clone(); void init(); + + /** + * The model (if initiated) + * @return + */ + IModel getNet(); } /** /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java index 3e13e811a..8d6f778d0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/NN.java @@ -23,6 +23,7 @@ package net.brutex.ai.dnn.api; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; +import org.deeplearning4j.nn.conf.layers.DenseLayer; /** * A fluent API to configure and create artificial neural networks diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/Layer_Descriptions.md b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/Layer_Descriptions.md new file mode 100644 index 000000000..74343c891 --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/conf/layer/Layer_Descriptions.md @@ -0,0 +1,31 @@ +# Layer Descriptions # + +## abstract LayerConfiguration and Interface TrainingConfig ## + +Every layer configuration is inherited from LayerConfiguration (and some also from TrainableLayerConfiguration) + + +### NoParamLayer ### + +The following are examples of No ParamLayers. No parameter layers are not inheriting from BaseConfigurationLayer, +but directly from LayerConfiguration. + +* ActivationLayer +* SubsamplingLayer +* ZeroPadding1DLayer +* MaskLayer +* CroppingLayer +* GlobalPoolingLayer + +### SameDiffLayer ### + +### BaseWrapperLayer ### + +### FrozenLayer ### + +### LocalResponseNormalization ### + +### Bidirectional ### + +### TFOpLayer ### + diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java index 2b900a5ff..aa0465659 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/networks/ArtificialNeuralNetwork.java @@ -21,6 +21,8 @@ package net.brutex.ai.dnn.networks; +import java.util.Arrays; +import java.util.HashMap; import java.util.Map; import lombok.Getter; import lombok.NonNull; @@ -58,6 +60,62 @@ public abstract class ArtificialNeuralNetwork implements IModel { @NonNull private NeuralNetConfiguration netConfiguration; + @Getter + @Setter + private Map paramTable; + + /** + * Table of parameters by key, for backprop. For many models (dense layers, etc) - all parameters + * are backprop parameters + * + * @param backpropParamsOnly If true, return backprop params only. If false: return all params + * (equivalent to paramsTable()) + */ + @Override + public Map getParamTable(boolean backpropParamsOnly) { + return paramTable; + } + + + /** + * Set the parameters of the network. Note that the parameter keys must match the format as + * described in {@link #getParam(String)} and {@link #getParamTable()}. Note that the values of the + * parameters used as an argument to this method are copied - i.e., it is safe to later + * modify/reuse the values in the provided paramTable without this impacting the network. + * + * @param paramTable Parameters to set + */ + @Override + public void setParamTable(Map paramTable) { + Map currParamTable = getParamTable(); + if(currParamTable == null) { + currParamTable = paramTable; + } else if (!currParamTable.keySet().equals(paramTable.keySet())) { + throw new IllegalArgumentException( + "Cannot set param table: parameter keys do not match.\n" + "Current: " + + currParamTable.keySet() + "\nTo set: " + paramTable.keySet()); + } + + for (String s : paramTable.keySet()) { + INDArray curr = currParamTable.get(s); + INDArray toSet = paramTable.get(s); + if (!Arrays.equals(curr.shape(), toSet.shape())) { + throw new IllegalArgumentException( + "Cannot set parameter table: parameter \"" + s + "\" shapes " + + "do not match. Current = " + Arrays.toString(curr.shape()) + ", to set = " + + Arrays.toString(toSet.shape())); + } + } + + //Now that we've checked ALL params (to avoid leaving net in half-modified state) + for (String s : paramTable.keySet()) { + INDArray curr = currParamTable.get(s); + INDArray toSet = paramTable.get(s); + curr.assign(toSet); + } + } + + /** * Create a new network from configuration diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java index db65ca7bb..770512e4d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/earlystopping/trainer/BaseEarlyStoppingTrainer.java @@ -168,12 +168,12 @@ public abstract class BaseEarlyStoppingTrainer implements IEar if(pretrain){ //TODO support for non-first-layer pretraining if(model instanceof MultiLayerNetwork){ - lastScore = (((MultiLayerNetwork) model).getLayer(0)).score(); + lastScore = (((MultiLayerNetwork) model).getLayer(0)).getScore(); } else { - lastScore = (((ComputationGraph) model).getLayer(0)).score(); + lastScore = (((ComputationGraph) model).getLayer(0)).getScore(); } } else { - lastScore = model.score(); + lastScore = model.getScore(); } for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) { if (c.terminate(lastScore)) { @@ -341,11 +341,11 @@ public abstract class BaseEarlyStoppingTrainer implements IEar Collection listeners; if(model instanceof MultiLayerNetwork){ MultiLayerNetwork n = ((MultiLayerNetwork) model); - listeners = n.getListeners(); + listeners = n.getTrainingListeners(); n.setEpochCount(epochNum); } else if(model instanceof ComputationGraph){ ComputationGraph cg = ((ComputationGraph) model); - listeners = cg.getListeners(); + listeners = cg.getTrainingListeners(); cg.getComputationGraphConfiguration().setEpochCount(epochNum); } else { return; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java index d106f827f..0cccc2a4f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/gradientcheck/GradientCheckUtil.java @@ -23,6 +23,7 @@ package org.deeplearning4j.gradientcheck; import lombok.*; import lombok.experimental.Accessors; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.exception.ND4JArraySizeException; @@ -32,10 +33,8 @@ import org.nd4j.common.primitives.Pair; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.api.layers.IOutputLayer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; -import org.deeplearning4j.nn.conf.layers.BaseLayer; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.layers.BaseOutputLayer; @@ -83,12 +82,12 @@ public class GradientCheckUtil { IActivation afn = null; if(outputLayer instanceof BaseOutputLayer){ BaseOutputLayer o = (BaseOutputLayer)outputLayer; - lfn = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)o.layerConf()).getLossFn(); + lfn = ((org.deeplearning4j.nn.conf.layers.BaseOutputLayer)o.getTypedLayerConfiguration()).getLossFn(); afn = o.getLayerConfiguration().getActivationFn(); } else if(outputLayer instanceof LossLayer){ LossLayer o = (LossLayer) outputLayer; - lfn = o.layerConf().getLossFn(); - afn = o.layerConf().getActivationFn(); + lfn = o.getTypedLayerConfiguration().getLossFn(); + afn = o.getTypedLayerConfiguration().getActivationFn(); } if (lfn instanceof LossMCXENT && afn instanceof ActivationSoftmax && ((LossMCXENT) lfn).getSoftmaxClipEps() != 0) { @@ -211,17 +210,17 @@ public class GradientCheckUtil { + "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil"); } - if(netDataType != c.net.params().dataType()){ + if(netDataType != c.net.getModelParams().dataType()){ throw new IllegalStateException("Parameters datatype does not match network configuration datatype (" - + "is: " + c.net.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); + + "is: " + c.net.getModelParams().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); } //Check network configuration: int layerCount = 0; for (LayerConfiguration n : c.net.getNetConfiguration().getFlattenedLayerConfigurations()) { - if (n instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) n; + if (n instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bl = (BaseLayerConfiguration) n; IUpdater u = bl.getIUpdater(); if (u instanceof Sgd) { //Must have LR of 1.0 @@ -274,7 +273,7 @@ public class GradientCheckUtil { updater.update(c.net, gradAndScore.getFirst(), 0, 0, c.net.batchSize(), LayerWorkspaceMgr.noWorkspaces()); INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup(); //need dup: gradients are a *view* of the full gradient array (which will change every time backprop is done) - INDArray originalParams = c.net.params().dup(); //need dup: params are a *view* of full parameters + INDArray originalParams = c.net.getModelParams().dup(); //need dup: params are a *view* of full parameters val nParams = originalParams.length(); @@ -323,7 +322,7 @@ public class GradientCheckUtil { log.info("NOTE: parameters will be skipped due to config: {}", c.excludeParams); } - INDArray params = c.net.params(); //Assumption here: params is a view that we can modify in-place + INDArray params = c.net.getModelParams(); //Assumption here: params is a view that we can modify in-place for (long i = 0; i < nParams; ) { //Get param name if (i >= paramEnds[currParamNameIdx]) { @@ -438,9 +437,9 @@ public class GradientCheckUtil { + "is: " + netDataType + "). Double precision must be used for gradient checks. Create network with .dataType(DataType.DOUBLE) before using GradientCheckUtil"); } - if(netDataType != c.net.params().dataType()){ + if(netDataType != c.net.getModelParams().dataType()){ throw new IllegalStateException("Parameters datatype does not match network configuration datatype (" - + "is: " + c.net.params().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); + + "is: " + c.net.getModelParams().dataType() + "). If network datatype is set to DOUBLE, parameters must also be DOUBLE."); } //Check configuration @@ -451,8 +450,8 @@ public class GradientCheckUtil { continue; LayerVertex lv = (LayerVertex) gv; - if (lv.getLayerConfiguration() instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) lv.getLayerConfiguration(); + if (lv.getLayerConfiguration() instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bl = (BaseLayerConfiguration) lv.getLayerConfiguration(); IUpdater u = bl.getIUpdater(); if (u instanceof Sgd) { //Must have LR of 1.0 @@ -510,7 +509,7 @@ public class GradientCheckUtil { updater.update(gradAndScore.getFirst(), 0, 0, c.net.batchSize(), LayerWorkspaceMgr.noWorkspaces()); INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup(); //need dup: gradients are a *view* of the full gradient array (which will change every time backprop is done) - INDArray originalParams = c.net.params().dup(); //need dup: params are a *view* of full parameters + INDArray originalParams = c.net.getModelParams().dup(); //need dup: params are a *view* of full parameters val nParams = originalParams.length(); @@ -530,7 +529,7 @@ public class GradientCheckUtil { int totalNFailures = 0; double maxError = 0.0; MultiDataSet mds = new MultiDataSet(c.inputs, c.labels, c.inputMask, c.labelMask); - INDArray params = c.net.params(); //Assumption here: params is a view that we can modify in-place + INDArray params = c.net.getModelParams(); //Assumption here: params is a view that we can modify in-place for (long i = 0; i < nParams; i++) { //Get param name if (i >= paramEnds[currParamNameIdx]) { @@ -643,7 +642,7 @@ public class GradientCheckUtil { updater.update(layer, gradAndScore.getFirst(), 0, 0, layer.batchSize(), LayerWorkspaceMgr.noWorkspaces()); INDArray gradientToCheck = gradAndScore.getFirst().gradient().dup(); //need dup: gradients are a *view* of the full gradient array (which will change every time backprop is done) - INDArray originalParams = layer.params().dup(); //need dup: params are a *view* of full parameters + INDArray originalParams = layer.getParams().dup(); //need dup: params are a *view* of full parameters val nParams = originalParams.length(); @@ -660,7 +659,7 @@ public class GradientCheckUtil { double maxError = 0.0; int currParamNameIdx = 0; - INDArray params = layer.params(); //Assumption here: params is a view that we can modify in-place + INDArray params = layer.getParams(); //Assumption here: params is a view that we can modify in-place for (int i = 0; i < nParams; i++) { //Get param name if (i >= paramEnds[currParamNameIdx]) { @@ -675,13 +674,13 @@ public class GradientCheckUtil { //TODO add a 'score' method that doesn't calculate gradients... Nd4j.getRandom().setSeed(rngSeed); layer.computeGradientAndScore(mgr); - double scorePlus = layer.score(); + double scorePlus = layer.getScore(); //(w-epsilon): Do forward pass and score params.putScalar(i, origValue - epsilon); Nd4j.getRandom().setSeed(rngSeed); layer.computeGradientAndScore(mgr); - double scoreMinus = layer.score(); + double scoreMinus = layer.getScore(); //Reset original param value params.putScalar(i, origValue); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Trainable.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ITrainableLayer.java similarity index 90% rename from cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Trainable.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ITrainableLayer.java index 33f87a736..d9c85d1f3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Trainable.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ITrainableLayer.java @@ -20,16 +20,21 @@ package org.deeplearning4j.nn.api; +import java.util.Map; import org.nd4j.linalg.api.ndarray.INDArray; -import java.util.Map; +public interface ITrainableLayer { -public interface Trainable { + + Map getParamTable(); + Map getParamTable(boolean isBackprop); + + void setParamTable(Map paramTable); /** * @return Training configuration */ - TrainingConfig getConfig(); + ITraininableLayerConfiguration getTrainingConfig(); /** * @return Number of parameters @@ -39,14 +44,15 @@ public interface Trainable { /** * @return 1d parameter vector */ - INDArray params(); + INDArray getParams(); /** * The param table * * @return - */ + Map getParamTable(); + */ /** * Table of parameters by key, for backprop. For many models (dense layers, etc) - all parameters @@ -54,16 +60,15 @@ public interface Trainable { * * @param backpropParamsOnly If true, return backprop params only. If false: return all params * (equivalent to paramsTable()) - */ - Map getParamTable(boolean backpropParamsOnly); + Map getParamTable(boolean backpropParamsOnly); +*/ /** * Setter for the param table * * @param paramTable - */ - void setParamTable(Map paramTable); - + void setParamTable(Map paramTable); +*/ /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/TrainingConfig.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ITraininableLayerConfiguration.java similarity index 96% rename from cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/TrainingConfig.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ITraininableLayerConfiguration.java index 58f101260..40a3170b4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/TrainingConfig.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/ITraininableLayerConfiguration.java @@ -27,7 +27,7 @@ import org.nd4j.linalg.learning.regularization.Regularization; import java.util.List; -public interface TrainingConfig { +public interface ITraininableLayerConfiguration { /** * @return Name of the layer @@ -55,7 +55,7 @@ public interface TrainingConfig { boolean isPretrainParam(String paramName); /** - * Get the updater for the given parameter. Typically the same updater will be used for all updaters, but this + * Get the updater for the given parameter. Typically the same updater will be used for all parameters, but this * is not necessarily the case * * @param paramName Parameter name @@ -74,5 +74,4 @@ public interface TrainingConfig { double getGradientNormalizationThreshold(); void setDataType(DataType dataType); - } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java index 41051df53..7ff694e99 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Layer.java @@ -21,7 +21,7 @@ package org.deeplearning4j.nn.api; -import lombok.NonNull; +import java.util.Map; import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -29,12 +29,10 @@ import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; -import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; import java.io.Serializable; -import java.util.Collection; /** * A layer is the highest-level building block in deep learning. A layer is a container that usually @@ -46,7 +44,7 @@ import java.util.Collection; * * @see NVIDIA Deep Learning In A Nutshell */ -public interface Layer extends Serializable, Cloneable, Trainable, IModel { +public interface Layer extends Serializable, Cloneable, IModel { //IModel /** * Return the configuration of this layer @@ -234,6 +232,11 @@ public interface Layer extends Serializable, Cloneable, Trainable, IModel { LayerHelper getHelper(); + /** + * Get a reference to the network this layer is part of. + * @return + */ + IModel getNet(); enum Type { FEED_FORWARD, RECURRENT, CONVOLUTIONAL, CONVOLUTIONAL3D, SUBSAMPLING, UPSAMPLING, RECURSIVE, MULTILAYER, NORMALIZATION diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Updater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Updater.java index 2c01298cb..ae301a40c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Updater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/api/Updater.java @@ -40,7 +40,7 @@ public interface Updater extends Serializable { * @param viewArray View array * @param initialize Whether to initialize the array or not */ - void setStateViewArray(Trainable layer, INDArray viewArray, boolean initialize); + void setStateViewArray(ITrainableLayer layer, INDArray viewArray, boolean initialize); /** * @return the view array for this updater @@ -54,5 +54,5 @@ public interface Updater extends Serializable { * @param gradient * @param iteration */ - void update(Trainable layer, Gradient gradient, int iteration, int epoch, int miniBatchSize, LayerWorkspaceMgr workspaceMgr); + void update(ITrainableLayer layer, Gradient gradient, int iteration, int epoch, int miniBatchSize, LayerWorkspaceMgr workspaceMgr); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index afba61743..e5e94ef3c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -27,7 +27,7 @@ import org.deeplearning4j.nn.conf.graph.LayerVertex; import org.deeplearning4j.nn.conf.graph.MergeVertex; import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.LastTimeStep; @@ -209,7 +209,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { if (lv.getNetConfiguration() != null && lv.getLayerConfiguration() != null) { LayerConfiguration layer = lv.getLayerConfiguration(); - if (layer instanceof BaseLayer && ((BaseLayer) layer).getActivationFn() == null) { + if (layer instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration) layer).getActivationFn() == null) { String layerName = layer.getLayerName(); try { @@ -235,7 +236,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { if (activationFunction != null) { IActivation ia = Activation.fromString(activationFunction.asText()).getActivationFunction(); - ((BaseLayer) layer).setActivationFn(ia); + ((BaseLayerConfiguration) layer).setActivationFn(ia); } } catch (IOException e) { @@ -257,7 +258,8 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { * @return True if all is well and layer iteration shall continue. False else-wise. */ private static void handleLegacyWeightInitFromJson(String json, LayerConfiguration layer, ObjectMapper mapper, JsonNode vertices) { - if (layer instanceof BaseLayer && ((BaseLayer) layer).getWeightInitFn() == null) { + if (layer instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration) layer).getWeightInitFn() == null) { String layerName = layer.getLayerName(); try { @@ -289,7 +291,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { if (weightInit != null) { final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist); - ((BaseLayer) layer).setWeightInitFn(wi); + ((BaseLayerConfiguration) layer).setWeightInitFn(wi); } } catch (IOException e) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java index 8ff512612..a11c21adc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java @@ -31,13 +31,11 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Random; -import lombok.Builder; import lombok.Data; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.NonNull; import lombok.Setter; -import lombok.Singular; import lombok.experimental.SuperBuilder; import lombok.extern.slf4j.Slf4j; import net.brutex.ai.dnn.api.INeuralNetworkConfiguration; @@ -47,7 +45,7 @@ import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.serde.JsonMappers; @@ -520,7 +518,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l, ObjectMapper mapper, JsonNode confs, int layerCount) { - if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { + if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInitFn() == null) { try { JsonNode jsonNode = mapper.readTree(json); if (confs == null) { @@ -551,7 +549,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor if (weightInit != null) { final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) .getWeightInitFunction(dist); - ((BaseLayer) l).setWeightInitFn(wi); + ((BaseLayerConfiguration) l).setWeightInitFn(wi); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index cb87e885c..ed5a406b4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -20,7 +20,10 @@ package org.deeplearning4j.nn.conf; +import com.fasterxml.jackson.annotation.JsonIdentityInfo; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.ObjectIdGenerators; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.exc.InvalidTypeIdException; @@ -41,13 +44,14 @@ import lombok.experimental.SuperBuilder; import lombok.extern.jackson.Jacksonized; import lombok.extern.slf4j.Slf4j; import lombok.val; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; @@ -65,6 +69,7 @@ import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; import org.deeplearning4j.nn.conf.serde.JsonMappers; import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; +import org.deeplearning4j.nn.conf.weightnoise.WeightNoise; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.OutputLayerUtil; @@ -116,11 +121,15 @@ import org.nd4j.linalg.lossfunctions.impl.LossNegativeLogLikelihood; @EqualsAndHashCode(exclude = {"iterationCount", "epochCount"}) @Jacksonized @JsonIgnoreProperties(ignoreUnknown = true) +@JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id") + //The inner builder, that we can then extend ... @SuperBuilder //TODO fix access public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { + private IModel net; private static final int DEFAULT_TBPTT_LENGTH = 20; + private boolean initCalled = false; /** * Set constraints to be applied to all layers. Default: no constraints.
@@ -634,7 +643,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { //Also, pre 0.7.2: activation functions were Strings ("activationFunction" field), not classes ("activationFn") //Try to load the old format if necessary, and create the appropriate IActivation instance - if ((l instanceof BaseLayer) && ((BaseLayer) l).getActivationFn() == null) { + if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getActivationFn() == null) { try { JsonNode jsonNode = mapper.readTree(json); if (confs == null) { @@ -660,7 +669,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { if (activationFunction != null) { IActivation ia = Activation.fromString(activationFunction.asText()) .getActivationFunction(); - ((BaseLayer) l).setActivationFn(ia); + ((BaseLayerConfiguration) l).setActivationFn(ia); } } @@ -689,7 +698,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l, ObjectMapper mapper, JsonNode confs, int layerCount) { - if ((l instanceof BaseLayer) && ((BaseLayer) l).getWeightInitFn() == null) { + if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInitFn() == null) { try { JsonNode jsonNode = mapper.readTree(json); if (confs == null) { @@ -720,7 +729,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { if (weightInit != null) { final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) .getWeightInitFunction(dist); - ((BaseLayer) l).setWeightInitFn(wi); + ((BaseLayerConfiguration) l).setWeightInitFn(wi); } } @@ -825,8 +834,39 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { */ @Override public void init() { - getNetConfigurations().stream().forEach( conf -> conf.init()); //call init on all embedded configurations - innerConfigurations.add(0, this); //put this configuration at first place + if(initCalled) return; + initCalled=true; + + /** + * Run init() for each layer + */ + + getNetConfigurations().stream().forEach( conf -> { + conf.init(); //do not call on self + }); //call init on all embedded net configurations + innerConfigurations.add(0, this); //put this configuration at first place + + /** + * Inherit network wide configuration setting to those layer configurations + * that do not have an individual setting (nor a default) + */ + for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) { + if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivationFn()); + if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getIUpdater() ); + if(lconf.getIDropout() == null ) lconf.setIDropout( this.getIdropOut() ); + if(lconf.getWeightNoise() == null ) lconf.setWeightNoise( this.getWeightNoise()); + + // ... maybe more to set here ... + if(lconf instanceof BaseLayerConfiguration ) { // then we can set some additional config settings + BaseLayerConfiguration bconf = (BaseLayerConfiguration) lconf; + if(bconf.getBiasUpdater() == null) bconf.setBiasUpdater(this.getBiasUpdater()); + if(bconf.getGradientNormalization() == null) bconf.setGradientNormalization(this.getGradientNormalization()); + // ... maybe more to set here ... + } + } + + + getLayerConfigurations().stream().forEach( lconf -> lconf.setNetConfiguration(this)); //set this as net config for all layers (defined in here, not stacked @@ -1009,6 +1049,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { public List netWideVariables() { + return netWideVariables; } @@ -1131,7 +1172,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { * and getFlattenedLayerConfigurations().get(0); * @return */ - @Deprecated + @Deprecated @JsonIgnore public LayerConfiguration getFirstLayer() { log.warn("This getFirstLayer method is an ugly workaround and will be removed."); return getFlattenedLayerConfigurations().get(0); @@ -1155,5 +1196,12 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { } + public IModel getNeuralNet() { + return net; + } + + public void setNeuralNet(IModel model) { + this.net = model; + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java index f93c1619b..67f6ee365 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/graph/LayerVertex.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -153,7 +154,8 @@ public class LayerVertex extends GraphVertex { @Override public void setDataType(DataType dataType){ - layerConfiguration.setDataType(dataType); + if(layerConfiguration instanceof BaseLayerConfiguration) + ((BaseLayerConfiguration)layerConfiguration).setDataType(dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java index 378ae01a2..d7ee4b8ef 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ActivationLayer.java @@ -35,6 +35,7 @@ import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.learning.config.IUpdater; import java.util.Collection; import java.util.Map; @@ -74,6 +75,11 @@ public class ActivationLayer extends NoParamLayer { return clone; } + @Override + public IUpdater getIUpdater() { + return null; + } + @Override public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { @@ -81,7 +87,7 @@ public class ActivationLayer extends NoParamLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.ActivationLayer ret = new org.deeplearning4j.nn.layers.ActivationLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java index 311359f7f..72615eca8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java @@ -60,7 +60,7 @@ public class AutoEncoder extends BasePretrainNetwork { org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder ret = new org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java similarity index 96% rename from cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java rename to cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java index bf30e0f7a..121f9b38f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java @@ -21,6 +21,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.Distribution; @@ -31,6 +32,7 @@ import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.util.NetworkUtils; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.L1Regularization; import org.nd4j.linalg.learning.regularization.L2Regularization; @@ -46,27 +48,26 @@ import java.util.List; */ @Data @EqualsAndHashCode(callSuper = true) -@NoArgsConstructor -public abstract class BaseLayer extends LayerConfiguration implements Serializable, Cloneable { +@NoArgsConstructor(force = true) +public abstract class BaseLayerConfiguration extends LayerConfiguration implements ITraininableLayerConfiguration, Serializable, Cloneable { - protected IActivation activationFn; @NonNull protected IWeightInit weightInitFn; - protected double biasInit; - protected double gainInit; + protected double biasInit = 0.0; + protected double gainInit = 0.0; protected List regularization; protected List regularizationBias; protected IUpdater iUpdater; protected IUpdater biasUpdater; - protected IWeightNoise weightNoise; + private DataType dataType; + protected GradientNormalization gradientNormalization = GradientNormalization.None; //Clipping, rescale based on l2 norm, etc protected double gradientNormalizationThreshold = 1.0; //Threshold for l2 and element-wise gradient clipping - public BaseLayer(Builder builder) { + public BaseLayerConfiguration(Builder builder) { super(builder); this.layerName = builder.layerName; - this.activationFn = builder.activationFn; this.weightInitFn = builder.weightInitFn; this.biasInit = builder.biasInit; this.gainInit = builder.gainInit; @@ -77,6 +78,7 @@ public abstract class BaseLayer extends LayerConfiguration implements Serializab this.gradientNormalization = builder.gradientNormalization; this.gradientNormalizationThreshold = builder.gradientNormalizationThreshold; this.weightNoise = builder.weightNoise; + super.setActivationFn(builder.activationFn); } /** @@ -99,8 +101,8 @@ public abstract class BaseLayer extends LayerConfiguration implements Serializab } @Override - public BaseLayer clone() { - BaseLayer clone = (BaseLayer) super.clone(); + public BaseLayerConfiguration clone() { + BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone(); if (clone.iDropout != null) { clone.iDropout = clone.iDropout.clone(); } @@ -121,7 +123,7 @@ public abstract class BaseLayer extends LayerConfiguration implements Serializab } /** - * Get the updater for the given parameter. Typically the same updater will be used for all updaters, but this is + * Get the updater for the given parameter. Typically the same updater will be used for all parameters, but this is * not necessarily the case * * @param paramName Parameter name @@ -174,13 +176,13 @@ public abstract class BaseLayer extends LayerConfiguration implements Serializab * Bias initialization value, for layers with biases. Defaults to 0 * */ - protected double biasInit = Double.NaN; + protected double biasInit = 0.0; /** * Gain initialization value, for layers with ILayer Normalization. Defaults to 1 * */ - protected double gainInit = Double.NaN; + protected double gainInit = 1.0; /** * Regularization for the parameters (excluding biases). @@ -504,7 +506,6 @@ public abstract class BaseLayer extends LayerConfiguration implements Serializab this.setWeightNoise(weightNoise); return (T) this; } - - } -} + +} \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java index 68e3a0851..ab0044448 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java @@ -97,7 +97,7 @@ public class BatchNormalization extends FeedForwardLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.normalization.BatchNormalization ret = new org.deeplearning4j.nn.layers.normalization.BatchNormalization(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java index a25a10947..afe3fcc48 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CenterLossOutputLayer.java @@ -66,7 +66,7 @@ public class CenterLossOutputLayer extends BaseOutputLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); Layer ret = new org.deeplearning4j.nn.layers.training.CenterLossOutputLayer(lconf, networkDataType); - ret.setListeners(trainingListeners.toArray(new TrainingListener[]{})); + ret.addTrainingListeners(trainingListeners.toArray(new TrainingListener[]{})); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java index 5c3cede7e..79782d956 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java @@ -60,7 +60,7 @@ public class Cnn3DLossLayer extends FeedForwardLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer ret = new org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java index bcad7fb65..b4f93482d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java @@ -65,7 +65,7 @@ public class CnnLossLayer extends FeedForwardLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.convolution.CnnLossLayer ret = new org.deeplearning4j.nn.layers.convolution.CnnLossLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java index eeb023374..cf4fb5a1a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution1DLayer.java @@ -69,7 +69,7 @@ public class Convolution1DLayer extends ConvolutionLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.convolution.Convolution1DLayer ret = new org.deeplearning4j.nn.layers.convolution.Convolution1DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java index 28a03ed4e..99992463a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Convolution3D.java @@ -100,7 +100,7 @@ public class Convolution3D extends ConvolutionLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); Convolution3DLayer ret = new Convolution3DLayer(lconf, networkDataType); - ret.setListeners(iterationListeners); + ret.addTrainingListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index a09d33506..25ad6ba4b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -177,7 +177,7 @@ public class ConvolutionLayer extends FeedForwardLayer { org.deeplearning4j.nn.layers.convolution.ConvolutionLayer ret = new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java index d5b113b7f..d805561d0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution2D.java @@ -88,7 +88,7 @@ public class Deconvolution2D extends ConvolutionLayer { org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer ret = new org.deeplearning4j.nn.layers.convolution.Deconvolution2DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java index 99ed3137b..ea19c1148 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Deconvolution3D.java @@ -85,7 +85,7 @@ public class Deconvolution3D extends ConvolutionLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); Deconvolution3DLayer ret = new Deconvolution3DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java index fce42e8e5..bfd88a62d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java @@ -62,12 +62,13 @@ public class DenseLayer extends FeedForwardLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret = new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType); if(getWeightInitFn() == null) setWeightInitFn(new WeightInitXavier()); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java index 52eb89ecf..307604ce0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DepthwiseConvolution2D.java @@ -71,7 +71,7 @@ public class DepthwiseConvolution2D extends ConvolutionLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); DepthwiseConvolution2DLayer ret = new DepthwiseConvolution2DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java index 573b6c617..521dacd23 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DropoutLayer.java @@ -74,7 +74,7 @@ public class DropoutLayer extends FeedForwardLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.DropoutLayer ret = new org.deeplearning4j.nn.layers.DropoutLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java index 3ef26352b..36d719ddc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingLayer.java @@ -61,7 +61,7 @@ public class EmbeddingLayer extends FeedForwardLayer { org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer ret = new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index 133b0b6c1..2ec7b654c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -69,7 +69,7 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer ret = new org.deeplearning4j.nn.layers.feedforward.embedding.EmbeddingSequenceLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java index 3728e55bb..de733add8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/FeedForwardLayer.java @@ -28,13 +28,12 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.preprocessor.Cnn3DToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor; import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; -import org.deeplearning4j.nn.params.DefaultParamInitializer; @Data @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) -public abstract class FeedForwardLayer extends BaseLayer { +public abstract class FeedForwardLayer extends BaseLayerConfiguration { protected long nIn; protected long nOut; @@ -123,7 +122,7 @@ public abstract class FeedForwardLayer extends BaseLayer { @Getter @Setter - public abstract static class Builder> extends BaseLayer.Builder { + public abstract static class Builder> extends BaseLayerConfiguration.Builder { /** * Number of inputs for the layer (usually the size of the last layer).
Note that for Convolutional layers, diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java index 1cd9e6c91..6d95ae93b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GlobalPoolingLayer.java @@ -73,7 +73,7 @@ public class GlobalPoolingLayer extends NoParamLayer { org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer ret = new org.deeplearning4j.nn.layers.pooling.GlobalPoolingLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java index ac6242e9a..792d735c3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesBidirectionalLSTM.java @@ -83,7 +83,7 @@ public class GravesBidirectionalLSTM extends BaseRecurrentLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM ret = new org.deeplearning4j.nn.layers.recurrent.GravesBidirectionalLSTM(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java index bb84cedae..1cdd16dba 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java @@ -81,7 +81,7 @@ public class GravesLSTM extends AbstractLSTM { org.deeplearning4j.nn.layers.recurrent.GravesLSTM ret = new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java index 8474d3089..85c440c18 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LSTM.java @@ -77,7 +77,7 @@ public class LSTM extends AbstractLSTM { LayerValidation.assertNInNOutSet("LSTM", getLayerName(), layerIndex, getNIn(), getNOut()); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.recurrent.LSTM ret = new org.deeplearning4j.nn.layers.recurrent.LSTM(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java index bb98be57d..b0131b80d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java @@ -20,7 +20,9 @@ package org.deeplearning4j.nn.conf.layers; +import com.fasterxml.jackson.annotation.JsonIdentityInfo; import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.annotation.ObjectIdGenerators; import java.io.Serializable; import java.lang.reflect.Field; import java.util.ArrayList; @@ -33,9 +35,10 @@ import lombok.Getter; import lombok.NoArgsConstructor; import lombok.NonNull; import lombok.Setter; +import lombok.extern.slf4j.Slf4j; +import net.brutex.ai.dnn.api.ILayerConfiguration; import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -43,6 +46,7 @@ import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.buffer.DataType; @@ -58,17 +62,18 @@ import org.nd4j.linalg.learning.regularization.Regularization; @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") @Data @NoArgsConstructor +@JsonIdentityInfo(generator= ObjectIdGenerators.IntSequenceGenerator.class, property="@id") +@Slf4j +public abstract class LayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITraininableLayerConfiguration -public abstract class LayerConfiguration implements TrainingConfig, Serializable, Cloneable { - - protected String layerName; + protected String layerName = "noname"; @Getter protected List variables = new ArrayList<>(); public void addVariable(String s) {variables.add(s);} protected IDropout iDropout; protected List constraints; - + protected IWeightNoise weightNoise; /** * The type of the layer, basically defines the base class and its properties */ @@ -247,7 +252,7 @@ public abstract class LayerConfiguration implements TrainingConfig, Serializable public abstract boolean isPretrainParam(String paramName); /** - * Get the updater for the given parameter. Typically the same updater will be used for all + * Get the updater for the given parameter. Typically, the same updater will be used for all * parameters, but this is not necessarily the case * * @param paramName Parameter name @@ -258,12 +263,13 @@ public abstract class LayerConfiguration implements TrainingConfig, Serializable "Not supported: all layers with parameters should override this method"); } - @Getter - private IUpdater iUpdater; - @Override - public void setDataType(DataType dataType) { - //No-op for most layers + public IUpdater getIUpdater() { + throw new UnsupportedOperationException( + "Not supported: all layers with parameters should override this method"); + } + public void setIUpdater(IUpdater iUpdater) { + log.warn("Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), getLayerName()); } /** @@ -279,15 +285,15 @@ public abstract class LayerConfiguration implements TrainingConfig, Serializable this.variables.clear(); } - @Getter - public IActivation activationFn; + @Getter @Setter + private IActivation activationFn; @SuppressWarnings("unchecked") @Getter @Setter public abstract static class Builder> { - protected String layerName = null; + protected String layerName = "noname"; protected List allParamConstraints; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java index a125d4ffc..fcde1b127 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerValidation.java @@ -79,11 +79,11 @@ public class LayerValidation { List weightConstraints, List biasConstraints) { if (layer != null) { - if (layer instanceof BaseLayer) { - BaseLayer bLayer = (BaseLayer) layer; + if (layer instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bLayer = (BaseLayerConfiguration) layer; configureBaseLayer(layerName, bLayer, iDropout, regularization, regularizationBias); - } else if (layer instanceof FrozenLayer && ((FrozenLayer) layer).getInnerConfiguration() instanceof BaseLayer) { - BaseLayer bLayer = (BaseLayer) ((FrozenLayer) layer).getInnerConfiguration(); + } else if (layer instanceof FrozenLayer && ((FrozenLayer) layer).getInnerConfiguration() instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bLayer = (BaseLayerConfiguration) ((FrozenLayer) layer).getInnerConfiguration(); configureBaseLayer(layerName, bLayer, iDropout, regularization, regularizationBias); } else if (layer instanceof Bidirectional) { Bidirectional l = (Bidirectional) layer; @@ -128,7 +128,7 @@ public class LayerValidation { } } - private static void configureBaseLayer(String layerName, BaseLayer bLayer, IDropout iDropout, + private static void configureBaseLayer(String layerName, BaseLayerConfiguration bLayer, IDropout iDropout, List regularization, List regularizationBias) { if (regularization != null && !regularization.isEmpty()) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java index 77483640c..75397400b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocalResponseNormalization.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -79,7 +78,7 @@ public class LocalResponseNormalization extends LayerConfiguration { org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization ret = new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); @@ -130,16 +129,6 @@ public class LocalResponseNormalization extends LayerConfiguration { return false; //No params in LRN } - @Override - public GradientNormalization getGradientNormalization() { - return GradientNormalization.None; - } - - @Override - public double getGradientNormalizationThreshold() { - return 0; - } - @Override public LayerMemoryReport getMemoryReport(InputType inputType) { val actElementsPerEx = inputType.arrayElementsPerExample(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java index 2e89f7ee7..226d3255d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LossLayer.java @@ -60,7 +60,7 @@ public class LossLayer extends FeedForwardLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.LossLayer ret = new org.deeplearning4j.nn.layers.LossLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java index 7d0c181f8..57a58f42c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/NoParamLayer.java @@ -24,8 +24,10 @@ import lombok.NoArgsConstructor; import net.brutex.ai.dnn.api.LayerType; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.params.EmptyParamInitializer; +import org.nd4j.linalg.learning.config.IUpdater; import org.nd4j.linalg.learning.regularization.Regularization; import java.util.List; @@ -55,18 +57,17 @@ public abstract class NoParamLayer extends LayerConfiguration { return null; } - @Override - public GradientNormalization getGradientNormalization() { - return GradientNormalization.None; - } - - @Override - public double getGradientNormalizationThreshold() { - return 0; - } - @Override public boolean isPretrainParam(String paramName) { throw new UnsupportedOperationException(getClass().getSimpleName() + " does not contain parameters"); } + +/** +* + * @return +*/ + @Override + public IUpdater getIUpdater() { + return Updater.NONE.getIUpdaterWithDefaultConfig(); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java index 2616ed8d9..f024caec2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java @@ -56,7 +56,7 @@ public class OutputLayer extends BaseOutputLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.OutputLayer ret = new org.deeplearning4j.nn.layers.OutputLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java index e44f7f709..50647d0f1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java @@ -41,7 +41,7 @@ import java.util.Map; @NoArgsConstructor @ToString(callSuper = true) @EqualsAndHashCode(callSuper = true) -public class PReLULayer extends BaseLayer { +public class PReLULayer extends BaseLayerConfiguration { private long[] inputShape = null; private long[] sharedAxes = null; @@ -61,7 +61,7 @@ public class PReLULayer extends BaseLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.feedforward.PReLU ret = new org.deeplearning4j.nn.layers.feedforward.PReLU(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java index 1127d0be0..4742b9e5b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java @@ -64,7 +64,7 @@ public class RnnLossLayer extends FeedForwardLayer { org.deeplearning4j.nn.layers.recurrent.RnnLossLayer ret = new org.deeplearning4j.nn.layers.recurrent.RnnLossLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java index 629e70da6..5b59c5399 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnOutputLayer.java @@ -63,7 +63,7 @@ public class RnnOutputLayer extends BaseOutputLayer { org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer ret = new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java index 34bc03086..924c4cc2a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SeparableConvolution2D.java @@ -121,7 +121,7 @@ public class SeparableConvolution2D extends ConvolutionLayer { org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer ret = new org.deeplearning4j.nn.layers.convolution.SeparableConvolution2DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java index ff4082075..50f91781b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToBatchLayer.java @@ -71,7 +71,7 @@ public class SpaceToBatchLayer extends NoParamLayer { org.deeplearning4j.nn.layers.convolution.SpaceToBatch ret = new org.deeplearning4j.nn.layers.convolution.SpaceToBatch(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java index 110d127b0..462f3ab5e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SpaceToDepthLayer.java @@ -77,7 +77,7 @@ public class SpaceToDepthLayer extends NoParamLayer { org.deeplearning4j.nn.layers.convolution.SpaceToDepth ret = new org.deeplearning4j.nn.layers.convolution.SpaceToDepth(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java index 5d48dfa6b..be544fb2f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling1DLayer.java @@ -65,7 +65,7 @@ public class Subsampling1DLayer extends SubsamplingLayer { org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer ret = new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling1DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java index d201c88b2..123df419b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Subsampling3DLayer.java @@ -117,7 +117,7 @@ public class Subsampling3DLayer extends NoParamLayer { org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling3DLayer ret = new org.deeplearning4j.nn.layers.convolution.subsampling.Subsampling3DLayer(lconf, networkDataType); - ret.setListeners(iterationListeners); + ret.addTrainingListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index 32983b01c..bddd9fc30 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -138,7 +138,7 @@ public class SubsamplingLayer extends NoParamLayer { org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer ret = new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java index 6f7a7c091..a2d3c4fb8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling1D.java @@ -61,7 +61,7 @@ public class Upsampling1D extends BaseUpsamplingLayer { org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling1D ret = new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling1D(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java index 61693091a..48e86c848 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling2D.java @@ -67,7 +67,7 @@ public class Upsampling2D extends BaseUpsamplingLayer { org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D ret = new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling2D(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java index f4d5fa280..4d629e2fd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Upsampling3D.java @@ -67,7 +67,7 @@ public class Upsampling3D extends BaseUpsamplingLayer { new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D(lconf, networkDataType); - ret.setListeners(iterationListeners); + ret.addTrainingListeners(iterationListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java index aa0268be1..43f6e4ed1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding1DLayer.java @@ -70,7 +70,7 @@ public class ZeroPadding1DLayer extends NoParamLayer { org.deeplearning4j.nn.layers.convolution.ZeroPadding1DLayer ret = new org.deeplearning4j.nn.layers.convolution.ZeroPadding1DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java index 21d77ae03..cdabe2788 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPadding3DLayer.java @@ -57,7 +57,7 @@ public class ZeroPadding3DLayer extends NoParamLayer { org.deeplearning4j.nn.layers.convolution.ZeroPadding3DLayer ret = new org.deeplearning4j.nn.layers.convolution.ZeroPadding3DLayer(lconf, networkDataType); - ret.setListeners(iterationListeners); + ret.addTrainingListeners(iterationListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java index 0d0e85d56..4582f42c5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ZeroPaddingLayer.java @@ -74,7 +74,7 @@ public class ZeroPaddingLayer extends NoParamLayer { org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer ret = new org.deeplearning4j.nn.layers.convolution.ZeroPaddingLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java index 2124e9eb9..ef3cedabe 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping1D.java @@ -79,7 +79,7 @@ public class Cropping1D extends NoParamLayer { setNetConfiguration(conf); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); Cropping1DLayer ret = new Cropping1DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java index 604a269cb..d73d33950 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping2D.java @@ -95,7 +95,7 @@ public class Cropping2D extends NoParamLayer { setNetConfiguration(conf); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); Cropping2DLayer ret = new Cropping2DLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java index c22c8f429..a950ed633 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/convolutional/Cropping3D.java @@ -87,7 +87,7 @@ public class Cropping3D extends NoParamLayer { setNetConfiguration(conf); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); Cropping3DLayer ret = new Cropping3DLayer(lconf, networkDataType); - ret.setListeners(iterationListeners); + ret.addTrainingListeners(iterationListeners); ret.setIndex(layerIndex); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java index dc7e9b93d..703d95cea 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/ElementWiseMultiplicationLayer.java @@ -68,7 +68,7 @@ public class ElementWiseMultiplicationLayer extends org.deeplearning4j.nn.conf.l org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer ret = new org.deeplearning4j.nn.layers.feedforward.elementwise.ElementWiseMultiplicationLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java index 35a4cae8d..eb15350dc 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayer.java @@ -125,16 +125,6 @@ public class FrozenLayer extends LayerConfiguration { return null; } - @Override - public GradientNormalization getGradientNormalization() { - return innerConfiguration.getGradientNormalization(); - } - - @Override - public double getGradientNormalizationThreshold() { - return innerConfiguration.getGradientNormalizationThreshold(); - } - @Override public LayerMemoryReport getMemoryReport(InputType inputType) { return innerConfiguration.getMemoryReport(inputType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java index ae438958f..6abf467d3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/FrozenLayerWithBackprop.java @@ -26,7 +26,7 @@ import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.deeplearning4j.nn.params.FrozenLayerWithBackpropParamInitializer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -40,11 +40,15 @@ import java.util.List; @Data @EqualsAndHashCode(callSuper = false) -public class FrozenLayerWithBackprop extends BaseWrapperLayer { +public class FrozenLayerWithBackprop extends BaseWrapperLayerConfiguration { + /** + * Create a new Frozen Layer, that wraps another layer with backpropagation enabled. + * + * @param layer configuration of the layer to wrap + */ public FrozenLayerWithBackprop(@JsonProperty("layer") LayerConfiguration layer) { super(layer); - underlying = layer; } public NeuralNetConfiguration getInnerConf(NeuralNetConfiguration conf) { @@ -66,9 +70,10 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { boolean initializeParams, DataType networkDataType) { //Need to be able to instantiate a layer, from a config - for JSON -> net type situations - org.deeplearning4j.nn.api.Layer underlying = getUnderlying().instantiate(conf, trainingListeners, + org.deeplearning4j.nn.api.Layer newUnderlyingLayer = underlying.instantiate(conf, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType); + newUnderlyingLayer.setLayerConfiguration(underlying); //Fix a problem, where the embedded layer gets the conf of the frozen layer, rather than its own NeuralNetConfiguration nncUnderlying = underlying.getNetConfiguration(); if (nncUnderlying.netWideVariables() != null) { @@ -81,7 +86,7 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { } } - return new org.deeplearning4j.nn.layers.FrozenLayerWithBackprop(underlying); + return new org.deeplearning4j.nn.layers.FrozenLayerWithBackprop(newUnderlyingLayer); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java index ba85f879c..541c26914 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/misc/RepeatVector.java @@ -69,7 +69,7 @@ public class RepeatVector extends FeedForwardLayer { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.RepeatVector ret = new org.deeplearning4j.nn.layers.RepeatVector(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java index d2d4bec81..70bd048e6 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java @@ -27,7 +27,6 @@ import lombok.Setter; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; @@ -84,7 +83,7 @@ public class Yolo2OutputLayer extends LayerConfiguration { org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer ret = new org.deeplearning4j.nn.layers.objdetect.Yolo2OutputLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); @@ -135,17 +134,6 @@ public class Yolo2OutputLayer extends LayerConfiguration { public boolean isPretrainParam(String paramName) { return false; //No params } - - @Override - public GradientNormalization getGradientNormalization() { - return GradientNormalization.None; - } - - @Override - public double getGradientNormalizationThreshold() { - return 1.0; - } - @Override public LayerMemoryReport getMemoryReport(InputType inputType) { long numValues = inputType.arrayElementsPerExample(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java index 5eda741e4..573492f3a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/Bidirectional.java @@ -30,7 +30,7 @@ import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.BaseRecurrentLayer; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.layers.recurrent.BidirectionalLayer; import org.deeplearning4j.nn.params.BidirectionalParamInitializer; @@ -93,7 +93,7 @@ public class Bidirectional extends LayerConfiguration { */ public Bidirectional(@NonNull Mode mode, @NonNull LayerConfiguration layer) { if (!(layer instanceof BaseRecurrentLayer || layer instanceof LastTimeStep - || layer instanceof BaseWrapperLayer)) { + || layer instanceof BaseWrapperLayerConfiguration)) { throw new IllegalArgumentException("Cannot wrap a non-recurrent layer: " + "config must extend BaseRecurrentLayer or LastTimeStep " + "Got class: " + layer.getClass()); @@ -211,16 +211,6 @@ public class Bidirectional extends LayerConfiguration { return fwd.getUpdaterByParam(sub); } - @Override - public GradientNormalization getGradientNormalization() { - return fwd.getGradientNormalization(); - } - - @Override - public double getGradientNormalizationThreshold() { - return fwd.getGradientNormalizationThreshold(); - } - @Override public void setLayerName(String layerName) { this.layerName = layerName; @@ -254,7 +244,7 @@ public class Bidirectional extends LayerConfiguration { public Builder rnnLayer(LayerConfiguration layer) { if (!(layer instanceof BaseRecurrentLayer || layer instanceof LastTimeStep - || layer instanceof BaseWrapperLayer)) { + || layer instanceof BaseWrapperLayerConfiguration)) { throw new IllegalArgumentException("Cannot wrap a non-recurrent layer: " + "config must extend BaseRecurrentLayer or LastTimeStep " + "Got class: " + layer.getClass()); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java index a869999dc..a5dff218f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/LastTimeStep.java @@ -23,7 +23,7 @@ package org.deeplearning4j.nn.conf.layers.recurrent; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.deeplearning4j.nn.layers.recurrent.LastTimeStepLayer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -31,7 +31,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Collection; -public class LastTimeStep extends BaseWrapperLayer { +public class LastTimeStep extends BaseWrapperLayerConfiguration { private LastTimeStep() {} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java index bda494c1d..1d4c182aa 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/SimpleRnn.java @@ -60,7 +60,7 @@ public class SimpleRnn extends BaseRecurrentLayer { org.deeplearning4j.nn.layers.recurrent.SimpleRnn ret = new org.deeplearning4j.nn.layers.recurrent.SimpleRnn(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java index 7ab6370b7..73cddbf14 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/recurrent/TimeDistributed.java @@ -28,7 +28,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.deeplearning4j.nn.layers.recurrent.TimeDistributedLayer; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -39,7 +39,7 @@ import java.util.Collection; @Data @EqualsAndHashCode(callSuper = true) -public class TimeDistributed extends BaseWrapperLayer { +public class TimeDistributed extends BaseWrapperLayerConfiguration { private RNNFormat rnnDataFormat = RNNFormat.NCW; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java index cfec8d653..e9bded983 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java @@ -23,7 +23,7 @@ package org.deeplearning4j.nn.conf.layers.samediff; import lombok.Data; import lombok.EqualsAndHashCode; import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.api.TrainingConfig; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.graph.GraphVertex; @@ -46,7 +46,7 @@ import java.util.Map; @Data @EqualsAndHashCode(callSuper = false) -public abstract class SameDiffVertex extends GraphVertex implements TrainingConfig { +public abstract class SameDiffVertex extends GraphVertex implements ITraininableLayerConfiguration { private SDVertexParams vertexParams; private String name; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java index 7f11874e8..18f9cadc1 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/util/MaskZeroLayer.java @@ -25,7 +25,7 @@ import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.optimize.api.TrainingListener; import org.nd4j.linalg.api.buffer.DataType; @@ -36,7 +36,7 @@ import java.util.Collection; @Data @EqualsAndHashCode(callSuper = false) -public class MaskZeroLayer extends BaseWrapperLayer { +public class MaskZeroLayer extends BaseWrapperLayerConfiguration { private double maskingValue = 0.0; @@ -61,7 +61,7 @@ public class MaskZeroLayer extends BaseWrapperLayer { boolean initializeParams, DataType networkDataType) { NeuralNetConfiguration conf2 = conf.clone(); - conf2.setLayer(((BaseWrapperLayer) this).getUnderlying()); + conf2.setLayer(((BaseWrapperLayerConfiguration) this).getUnderlying()); org.deeplearning4j.nn.api.Layer underlyingLayer = underlying.instantiate(conf2, trainingListeners, layerIndex, layerParamsView, initializeParams, networkDataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java index 4e6a0c41c..85f06a40b 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java @@ -73,7 +73,7 @@ public class VariationalAutoencoder extends BasePretrainNetwork { org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret = new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java deleted file mode 100644 index 2495fbd56..000000000 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayer.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nn.conf.layers.wrapper; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.NonNull; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.GradientNormalization; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.inputs.InputType; -import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; -import org.deeplearning4j.nn.params.WrapperLayerParamInitializer; -import org.nd4j.linalg.learning.regularization.Regularization; - -import java.util.List; - -@Data -@EqualsAndHashCode(callSuper = false) -public abstract class BaseWrapperLayer extends LayerConfiguration { - - /** - * Set the net configuration for this configuration as well as for the underlying layer - * (if not null there) - * - * @param netConfiguration the neural net configuration - */ - @Override - public void setNetConfiguration(NeuralNetConfiguration netConfiguration) { - super.setNetConfiguration(netConfiguration); - if(getUnderlying().getNetConfiguration() == null) { - getUnderlying().setNetConfiguration( - netConfiguration); //also set netconf for underlying if not set - } - } - - protected LayerConfiguration underlying; - - protected BaseWrapperLayer(Builder builder) { - super(builder); - } - - protected BaseWrapperLayer() {} - - public BaseWrapperLayer(LayerConfiguration underlying) { - this.underlying = underlying; - this.setNetConfiguration(underlying.getNetConfiguration()); - } - - @Override - public ParamInitializer initializer() { - return WrapperLayerParamInitializer.getInstance(); - } - - @Override - public InputType getOutputType(int layerIndex, InputType inputType) { - return underlying.getOutputType(layerIndex, inputType); - } - - @Override - public void setNIn(InputType inputType, boolean override) { - underlying.setNIn(inputType, override); - } - - @Override - public InputPreProcessor getPreProcessorForInputType(InputType inputType) { - return underlying.getPreProcessorForInputType(inputType); - } - - @Override - public List getRegularizationByParam(String paramName){ - return underlying.getRegularizationByParam(paramName); - } - - @Override - public GradientNormalization getGradientNormalization() { - return underlying.getGradientNormalization(); - } - - @Override - public double getGradientNormalizationThreshold() { - return underlying.getGradientNormalizationThreshold(); - } - - @Override - public boolean isPretrainParam(String paramName) { - return underlying.isPretrainParam(paramName); - } - - @Override - public LayerMemoryReport getMemoryReport(InputType inputType) { - return underlying.getMemoryReport(inputType); - } - - @Override - public void setLayerName(String layerName) { - super.setLayerName(layerName); - if (underlying != null) { - //May be null at some points during JSON deserialization - underlying.setLayerName(layerName); - } - } -} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayerConfiguration.java new file mode 100644 index 000000000..74b71de1f --- /dev/null +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/wrapper/BaseWrapperLayerConfiguration.java @@ -0,0 +1,196 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +package org.deeplearning4j.nn.conf.layers.wrapper; + +import java.util.List; +import lombok.Data; +import lombok.EqualsAndHashCode; +import org.deeplearning4j.nn.api.ParamInitializer; +import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.InputPreProcessor; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.dropout.IDropout; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.LayerConfiguration; +import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; +import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; +import org.deeplearning4j.nn.params.WrapperLayerParamInitializer; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.learning.config.IUpdater; +import org.nd4j.linalg.learning.regularization.Regularization; + +@Data +@EqualsAndHashCode(callSuper = false) +public abstract class BaseWrapperLayerConfiguration extends LayerConfiguration { + + /** + * The configuration to of another layer to wrap + */ + protected LayerConfiguration underlying; + + protected BaseWrapperLayerConfiguration(Builder builder) { + super(builder); + } + + protected BaseWrapperLayerConfiguration() { + } + + public BaseWrapperLayerConfiguration(LayerConfiguration underlying) { + this.underlying = underlying; + this.setNetConfiguration(underlying.getNetConfiguration()); + } + + /** + * Set the net configuration for this configuration as well as for the underlying layer (if not + * null there) + * + * @param netConfiguration the neural net configuration + */ + @Override + public void setNetConfiguration(NeuralNetConfiguration netConfiguration) { + super.setNetConfiguration(netConfiguration); + if (underlying.getNetConfiguration() == null) { + underlying.setNetConfiguration( + netConfiguration); //also set netconf for underlying if not set + } + } + + /** + * @return + */ + @Override + public IActivation getActivationFn() { + return underlying.getActivationFn(); + } + + /** + * @return + */ + @Override + public IDropout getIDropout() { + return underlying.getIDropout(); + } + + /** + * @param activationFn + */ + @Override + public void setActivationFn(IActivation activationFn) { + underlying.setActivationFn(activationFn); + } + + /** + * @param iDropout + */ + @Override + public void setIDropout(IDropout iDropout) { + underlying.setIDropout(iDropout); + } + + /** + * @param weightNoise + */ + @Override + public void setWeightNoise(IWeightNoise weightNoise) { + underlying.setWeightNoise(weightNoise); + } + + /** + * @param s + */ + @Override + public void addVariable(String s) { + underlying.addVariable(s); + } + + /** + * Get the updater for the given parameter. Typically, the same updater will be used for all + * parameters, but this is not necessarily the case + * + * @param paramName Parameter name + * @return IUpdater for the parameter + */ + @Override + public IUpdater getUpdaterByParam(String paramName) { + return underlying.getUpdaterByParam(paramName); + } + + /** + * @param iUpdater + */ + @Override + public void setIUpdater(IUpdater iUpdater) { + underlying.setIUpdater(iUpdater); + } + + /** + * @return + */ + @Override + public IUpdater getIUpdater() { + return underlying.getIUpdater(); + } + + @Override + public ParamInitializer initializer() { + return WrapperLayerParamInitializer.getInstance(); + } + + @Override + public InputType getOutputType(int layerIndex, InputType inputType) { + return underlying.getOutputType(layerIndex, inputType); + } + + @Override + public void setNIn(InputType inputType, boolean override) { + underlying.setNIn(inputType, override); + } + + @Override + public InputPreProcessor getPreProcessorForInputType(InputType inputType) { + return underlying.getPreProcessorForInputType(inputType); + } + + @Override + public List getRegularizationByParam(String paramName) { + return underlying.getRegularizationByParam(paramName); + } + + @Override + public boolean isPretrainParam(String paramName) { + return underlying.isPretrainParam(paramName); + } + + @Override + public LayerMemoryReport getMemoryReport(InputType inputType) { + return underlying.getMemoryReport(inputType); + } + + @Override + public void setLayerName(String layerName) { + super.setLayerName(layerName); + if (underlying != null) { + //May be null at some points during JSON deserialization + underlying.setLayerName(layerName); + } + } + +} diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/misc/DummyConfig.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/misc/DummyConfig.java index 8cc5e6e20..c2e149c25 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/misc/DummyConfig.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/misc/DummyConfig.java @@ -21,7 +21,7 @@ package org.deeplearning4j.nn.conf.misc; import lombok.AllArgsConstructor; -import org.deeplearning4j.nn.api.TrainingConfig; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.conf.GradientNormalization; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.learning.config.IUpdater; @@ -31,7 +31,7 @@ import org.nd4j.linalg.learning.regularization.Regularization; import java.util.List; @AllArgsConstructor -public class DummyConfig implements TrainingConfig { +public class DummyConfig implements ITraininableLayerConfiguration { private final String name; @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java index 8469c6f62..34a888303 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ocnn/OCNNOutputLayer.java @@ -73,7 +73,7 @@ public class OCNNOutputLayer extends BaseOutputLayer { super(builder); this.hiddenSize = builder.hiddenLayerSize; this.nu = builder.nu; - this.activationFn = builder.activation; + setActivationFn( builder.activation) ; this.windowSize = builder.windowSize; this.initialRValue = builder.initialRValue; this.configureR = builder.configureR; @@ -88,7 +88,7 @@ public class OCNNOutputLayer extends BaseOutputLayer { @JsonProperty("configureR") boolean configureR) { this.hiddenSize = hiddenSize; this.nu = nu; - this.activationFn = activation; + setActivationFn( activation); this.windowSize = windowSize; this.initialRValue = initialRValue; this.configureR = configureR; @@ -107,13 +107,13 @@ public class OCNNOutputLayer extends BaseOutputLayer { org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer ret = new org.deeplearning4j.nn.layers.ocnn.OCNNOutputLayer(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); ret.setLayerConfiguration(lconf); - ret.setActivation(activationFn); + ret.setActivation(getActivationFn()); if (lastEpochSinceRUpdated == 0 && configureR) { paramTable.get(OCNNParamInitializer.R_KEY).putScalar(0, initialRValue); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java index c6a2cbb26..292b85c10 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java @@ -24,7 +24,7 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.Distribution; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.*; @@ -66,8 +66,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im protected boolean requiresIUpdaterFromLegacy(LayerConfiguration[] layers){ for(LayerConfiguration l : layers){ - if(l instanceof BaseLayer){ - BaseLayer bl = (BaseLayer)l; + if(l instanceof BaseLayerConfiguration){ + BaseLayerConfiguration bl = (BaseLayerConfiguration)l; if(bl.getIUpdater() == null && bl.initializer().numParams(bl) > 0){ return true; } @@ -87,7 +87,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im protected boolean requiresRegularizationFromLegacy(LayerConfiguration[] layers){ for(LayerConfiguration l : layers){ - if(l instanceof BaseLayer && ((BaseLayer)l).getRegularization() == null){ + if(l instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)l).getRegularization() == null){ return true; } } @@ -96,7 +97,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im protected boolean requiresWeightInitFromLegacy(LayerConfiguration[] layers){ for(LayerConfiguration l : layers){ - if(l instanceof BaseLayer && ((BaseLayer)l).getWeightInitFn() == null){ + if(l instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)l).getWeightInitFn() == null){ return true; } } @@ -105,7 +107,8 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im protected boolean requiresActivationFromLegacy(LayerConfiguration[] layers){ for(LayerConfiguration l : layers){ - if(l instanceof BaseLayer && ((BaseLayer)l).getActivationFn() == null){ + if(l instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)l).getActivationFn() == null){ return true; } } @@ -121,7 +124,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im return false; } - protected void handleUpdaterBackwardCompatibility(BaseLayer layer, ObjectNode on){ + protected void handleUpdaterBackwardCompatibility(BaseLayerConfiguration layer, ObjectNode on){ if(on != null && on.has("updater")){ String updaterName = on.get("updater").asText(); if(updaterName != null){ @@ -202,42 +205,43 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im } } - protected void handleL1L2BackwardCompatibility(BaseLayer baseLayer, ObjectNode on){ + protected void handleL1L2BackwardCompatibility(BaseLayerConfiguration baseLayerConfiguration, ObjectNode on){ if(on != null && (on.has("l1") || on.has("l2"))){ //Legacy format JSON - baseLayer.setRegularization(new ArrayList()); - baseLayer.setRegularizationBias(new ArrayList()); + baseLayerConfiguration.setRegularization(new ArrayList()); + baseLayerConfiguration.setRegularizationBias(new ArrayList()); if(on.has("l1")){ double l1 = on.get("l1").doubleValue(); if(l1 > 0.0){ - baseLayer.getRegularization().add(new L1Regularization(l1)); + baseLayerConfiguration.getRegularization().add(new L1Regularization(l1)); } } if(on.has("l2")){ double l2 = on.get("l2").doubleValue(); if(l2 > 0.0){ //Default to non-LR based WeightDecay, to match behaviour in 1.0.0-beta3 - baseLayer.getRegularization().add(new WeightDecay(l2, false)); + baseLayerConfiguration.getRegularization().add(new WeightDecay(l2, false)); } } if(on.has("l1Bias")){ double l1Bias = on.get("l1Bias").doubleValue(); if(l1Bias > 0.0){ - baseLayer.getRegularizationBias().add(new L1Regularization(l1Bias)); + baseLayerConfiguration.getRegularizationBias().add(new L1Regularization(l1Bias)); } } if(on.has("l2Bias")){ double l2Bias = on.get("l2Bias").doubleValue(); if(l2Bias > 0.0){ //Default to non-LR based WeightDecay, to match behaviour in 1.0.0-beta3 - baseLayer.getRegularizationBias().add(new WeightDecay(l2Bias, false)); + baseLayerConfiguration.getRegularizationBias().add(new WeightDecay(l2Bias, false)); } } } } - protected void handleWeightInitBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){ + protected void handleWeightInitBackwardCompatibility( + BaseLayerConfiguration baseLayerConfiguration, ObjectNode on){ if(on != null && on.has("weightInit") ){ //Legacy format JSON if(on.has("weightInit")){ @@ -250,7 +254,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class); } IWeightInit iwi = w.getWeightInitFunction(d); - baseLayer.setWeightInitFn(iwi); + baseLayerConfiguration.setWeightInitFn(iwi); } catch (Throwable t){ log.warn("Failed to infer weight initialization from legacy JSON format",t); } @@ -259,8 +263,9 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im } //Changed after 0.7.1 from "activationFunction" : "softmax" to "activationFn" : - protected void handleActivationBackwardCompatibility(BaseLayer baseLayer, ObjectNode on){ - if(baseLayer.getActivationFn() == null && on.has("activationFunction")){ + protected void handleActivationBackwardCompatibility( + BaseLayerConfiguration baseLayerConfiguration, ObjectNode on){ + if(baseLayerConfiguration.getActivationFn() == null && on.has("activationFunction")){ String afn = on.get("activationFunction").asText(); IActivation a = null; try { @@ -272,7 +277,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im | InvocationTargetException instantiationException){ log.error(instantiationException.getMessage()); } - baseLayer.setActivationFn(a); + baseLayerConfiguration.setActivationFn(a); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java index cf9282771..92399e037 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java @@ -26,7 +26,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.graph.GraphVertex; import org.deeplearning4j.nn.conf.graph.LayerVertex; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; @@ -118,20 +118,24 @@ public class ComputationGraphConfigurationDeserializer continue; } - if(attemptIUpdaterFromLegacy && layers[layerIdx] instanceof BaseLayer && ((BaseLayer)layers[layerIdx]).getIUpdater() == null){ - handleUpdaterBackwardCompatibility((BaseLayer)layers[layerIdx], (ObjectNode)next); + if(attemptIUpdaterFromLegacy && layers[layerIdx] instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)layers[layerIdx]).getIUpdater() == null){ + handleUpdaterBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next); } - if(requireLegacyRegularizationHandling && layers[layerIdx] instanceof BaseLayer && ((BaseLayer)layers[layerIdx]).getRegularization() == null){ - handleL1L2BackwardCompatibility((BaseLayer)layers[layerIdx], (ObjectNode)next); + if(requireLegacyRegularizationHandling && layers[layerIdx] instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)layers[layerIdx]).getRegularization() == null){ + handleL1L2BackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next); } - if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayer && ((BaseLayer)layers[layerIdx]).getWeightInitFn() == null){ - handleWeightInitBackwardCompatibility((BaseLayer)layers[layerIdx], (ObjectNode)next); + if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null){ + handleWeightInitBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next); } - if(requiresLegacyActivationHandling && layers[layerIdx] instanceof BaseLayer && ((BaseLayer)layers[layerIdx]).getActivationFn() == null){ - handleActivationBackwardCompatibility((BaseLayer)layers[layerIdx], (ObjectNode)next); + if(requiresLegacyActivationHandling && layers[layerIdx] instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)layers[layerIdx]).getActivationFn() == null){ + handleActivationBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next); } if(requiresLegacyLossHandling && layers[layerIdx] instanceof BaseOutputLayer && ((BaseOutputLayer)layers[layerIdx]).getLossFn() == null){ @@ -144,9 +148,9 @@ public class ComputationGraphConfigurationDeserializer double d = next.get("dropOut").asDouble(); if(!Double.isNaN(d)){ //Might be dropout or dropconnect... - if(layers[layerIdx] instanceof BaseLayer && confNode.has("useDropConnect") + if(layers[layerIdx] instanceof BaseLayerConfiguration && confNode.has("useDropConnect") && confNode.get("useDropConnect").asBoolean(false)){ - ((BaseLayer)layers[layerIdx]).setWeightNoise(new DropConnect(d)); + ((BaseLayerConfiguration)layers[layerIdx]).setWeightNoise(new DropConnect(d)); } else { layers[layerIdx].setIDropout(new Dropout(d)); } @@ -155,11 +159,12 @@ public class ComputationGraphConfigurationDeserializer } layerIdx++; } else if("org.deeplearning4j.nn.conf.graph.LayerVertex".equals(cls)){ - if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayer && ((BaseLayer)layers[layerIdx]).getWeightInitFn() == null) { + if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null) { //Post JSON format change for subclasses, but before WeightInit was made a class confNode = (ObjectNode) next.get("layerConf"); next = confNode.get("layer"); - handleWeightInitBackwardCompatibility((BaseLayer) layers[layerIdx], (ObjectNode) next); + handleWeightInitBackwardCompatibility((BaseLayerConfiguration) layers[layerIdx], (ObjectNode) next); } layerIdx++; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java index 17a474e78..633650b95 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java @@ -23,7 +23,7 @@ package org.deeplearning4j.nn.conf.serde; import org.apache.commons.io.IOUtils; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.dropout.Dropout; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.BaseOutputLayer; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; @@ -86,7 +86,8 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize for( int i=0; i (first/only child) -> updater if(on.has("layer")){ confNode = on; @@ -96,7 +97,7 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize } on = (ObjectNode) on.elements().next(); - handleUpdaterBackwardCompatibility((BaseLayer)layers[i], on); + handleUpdaterBackwardCompatibility((BaseLayerConfiguration)layers[i], on); } if(attemptIUpdaterFromLegacy) { @@ -106,9 +107,10 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize double d = on.get("dropOut").asDouble(); if (!Double.isNaN(d)) { //Might be dropout or dropconnect... - if (confNode != null && layers[i] instanceof BaseLayer && confNode.has("useDropConnect") + if (confNode != null && layers[i] instanceof BaseLayerConfiguration + && confNode.has("useDropConnect") && confNode.get("useDropConnect").asBoolean(false)) { - ((BaseLayer) layers[i]).setWeightNoise(new DropConnect(d)); + ((BaseLayerConfiguration) layers[i]).setWeightNoise(new DropConnect(d)); } else { if (d > 0.0) { layers[i].setIDropout(new Dropout(d)); @@ -133,16 +135,19 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize } } - if(requiresLegacyRegularizationHandling && layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getRegularization() == null) { - handleL1L2BackwardCompatibility((BaseLayer) layers[i], on); + if(requiresLegacyRegularizationHandling && layers[i] instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration) layers[i]).getRegularization() == null) { + handleL1L2BackwardCompatibility((BaseLayerConfiguration) layers[i], on); } - if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayer && ((BaseLayer) layers[i]).getWeightInitFn() == null) { - handleWeightInitBackwardCompatibility((BaseLayer) layers[i], on); + if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration) layers[i]).getWeightInitFn() == null) { + handleWeightInitBackwardCompatibility((BaseLayerConfiguration) layers[i], on); } - if(requiresLegacyActivationHandling && layers[i] instanceof BaseLayer && ((BaseLayer)layers[i]).getActivationFn() == null){ - handleActivationBackwardCompatibility((BaseLayer) layers[i], on); + if(requiresLegacyActivationHandling && layers[i] instanceof BaseLayerConfiguration + && ((BaseLayerConfiguration)layers[i]).getActivationFn() == null){ + handleActivationBackwardCompatibility((BaseLayerConfiguration) layers[i], on); } if(requiresLegacyLossHandling && layers[i] instanceof BaseOutputLayer && ((BaseOutputLayer)layers[i]).getLossFn() == null){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 34d9b8c50..cc0c13506 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -25,7 +25,6 @@ import lombok.NonNull; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import lombok.val; -import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.networks.ArtificialNeuralNetwork; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; @@ -690,7 +689,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali // now we init solver & optimizer if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners( + getTrainingListeners()).model(this).build(); solver.initOptimizer(); } } @@ -1159,7 +1159,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } else { if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners( + getTrainingListeners()).model(this).build(); } } @@ -2886,7 +2887,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali @Override public ComputationGraph clone() { ComputationGraph cg = new ComputationGraph(computationGraphConfiguration.clone()); - cg.init(params().dup(), false); + cg.init(getModelParams().dup(), false); if (solver != null) { //If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however ComputationGraphUpdater u = this.getUpdater(); @@ -2919,12 +2920,12 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali /** * Set the trainingListeners for the ComputationGraph (and all layers in the network) */ - public void setListeners(Collection listeners) { + public void addTrainingListeners(Collection listeners) { if (layers == null) init(); for (Layer l : layers) { - l.setListeners(listeners.toArray(new TrainingListener[]{})); + l.addTrainingListeners(listeners.toArray(new TrainingListener[]{})); } if (solver != null) { @@ -2962,7 +2963,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali /** * Set the trainingListeners for the ComputationGraph (and all layers in the network) */ - public void setListeners(TrainingListener... listeners) { + public void addTrainingListeners(TrainingListener... listeners) { List list = new ArrayList<>(); //Check: user might have done setListeners(null) thinking this would clear the current listeners. //This results in an TrainingListener[1] with a single null value -> results in a NPE later @@ -2972,7 +2973,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali list.add(i); } } - setListeners(list); + addTrainingListeners(list); } /** @@ -2980,26 +2981,11 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali * * @param listeners Listeners to add */ - @Override - public void addListeners(TrainingListener... listeners) { - if (this.trainingListeners == null) { - setListeners(listeners); - return; - } else { - List newListeners = new ArrayList<>(this.trainingListeners); //To avoid immutable list issues - Collections.addAll(newListeners, listeners); - setListeners(newListeners); - } - - if (solver != null) { - solver.setListeners(this.trainingListeners); - } - } /** * Get the trainingListeners for the ComputationGraph */ - public Collection getListeners() { + public Collection getTrainingListeners() { return trainingListeners; } @@ -3017,7 +3003,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali */ public ComputationGraphUpdater getUpdater(boolean initializeIfAbsent){ if (solver == null && initializeIfAbsent) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners( + getTrainingListeners()).model(this).build(); solver.getOptimizer().setUpdaterComputationGraph(new ComputationGraphUpdater(this)); } if(solver != null) { @@ -3031,7 +3018,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali */ public void setUpdater(ComputationGraphUpdater updater) { if (solver == null) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); + solver = new Solver.Builder().configure(getNetConfiguration()).listeners( + getTrainingListeners()).model(this).build(); } solver.getOptimizer().setUpdaterComputationGraph(updater); } @@ -3048,11 +3036,11 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } /** - * @deprecated To be removed. Use {@link #params()} + * @deprecated To be removed. Use {@link #getModelParams()} */ @Deprecated public INDArray params(boolean backwardOnly) { - return params(); + return getModelParams(); } /** @@ -3314,7 +3302,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } @Override - public double score() { + public double getScore() { return score; } @@ -3323,7 +3311,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } @Override - public INDArray params() { + public INDArray getModelParams() { return flattenedParams; } @@ -3410,7 +3398,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali @Override public Pair gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } @Override @@ -3743,7 +3731,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) + solver = new Solver.Builder().configure(getNetConfiguration()).listeners( + getTrainingListeners()).model(this) .build(); } } @@ -4511,8 +4500,8 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali } ret.append(StringUtils.repeat("-", totalLength)) - .append(String.format("\n%30s %,d", "Total Parameters: ", params().length())) - .append(String.format("\n%30s %,d", "Trainable Parameters: ", params().length() - frozenParams)) + .append(String.format("\n%30s %,d", "Total Parameters: ", getModelParams().length())) + .append(String.format("\n%30s %,d", "ITrainableLayer Parameters: ", getModelParams().length() - frozenParams)) .append(String.format("\n%30s %,d", "Frozen Parameters: ", frozenParams)) .append("\n") .append(StringUtils.repeat("=", totalLength)) @@ -4643,12 +4632,12 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali */ public ComputationGraph convertDataType(@NonNull DataType dataType){ Preconditions.checkState(dataType.isFPType(), "Invalid DataType: %s. Can only convert network to a floating point type", dataType); - if(dataType == params().dataType()){ + if(dataType == getModelParams().dataType()){ return this; } try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - INDArray newParams = params().castTo(dataType); + INDArray newParams = getModelParams().castTo(dataType); String jsonConfig = this.getComputationGraphConfiguration().toJson(); ComputationGraphConfiguration newConf = ComputationGraphConfiguration.fromJson(jsonConfig); newConf.setDataType(dataType); @@ -4875,7 +4864,7 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali return false; if (obj instanceof ComputationGraph) { ComputationGraph network = (ComputationGraph) obj; - boolean paramsEquals = network.params().equals(params()); + boolean paramsEquals = network.getModelParams().equals(getModelParams()); boolean confEquals = this.getComputationGraphConfiguration().equals(network.getComputationGraphConfiguration()); boolean updaterEquals = getUpdater().equals(network.getUpdater()); return paramsEquals && confEquals && updaterEquals; @@ -4922,4 +4911,22 @@ public class ComputationGraph extends ArtificialNeuralNetwork implements Seriali Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); System.gc(); } + + @Override + public ITraininableLayerConfiguration getTrainingConfig() { + throw new UnsupportedOperationException("Not supported"); + } + + /** + * @return 1d parameter vector + */ + @Override + public INDArray getParams() { + throw new RuntimeException("Not supported"); + } + + @Override + public boolean updaterDivideByMinibatch(String paramName) { + return false; + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java index 759f214bc..269e67ac0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseGraphVertex.java @@ -23,7 +23,7 @@ package org.deeplearning4j.nn.graph.vertex; import lombok.Data; import lombok.Getter; import lombok.Setter; -import org.deeplearning4j.nn.api.TrainingConfig; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.impl.LayerVertex; import org.nd4j.linalg.api.buffer.DataType; @@ -213,16 +213,16 @@ public abstract class BaseGraphVertex implements GraphVertex { @Override public long numParams(){ - return params() == null ? 0 : params().length(); + return getParams() == null ? 0 : getParams().length(); } @Override - public TrainingConfig getConfig() { + public ITraininableLayerConfiguration getTrainingConfig() { return null; } @Override - public INDArray params() { + public INDArray getParams() { return null; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java index 0d2a3a26d..d73315645 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/BaseWrapperVertex.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.graph.vertex; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.api.TrainingConfig; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.ndarray.INDArray; @@ -184,13 +184,13 @@ public abstract class BaseWrapperVertex implements GraphVertex { } @Override - public TrainingConfig getConfig() { - return underlying.getConfig(); + public ITraininableLayerConfiguration getTrainingConfig() { + return underlying.getTrainingConfig(); } @Override - public INDArray params() { - return underlying.params(); + public INDArray getParams() { + return underlying.getParams(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java index 96ac34c19..51bd7ee62 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/GraphVertex.java @@ -20,9 +20,9 @@ package org.deeplearning4j.nn.graph.vertex; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.api.Trainable; import org.deeplearning4j.nn.gradient.Gradient; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.common.primitives.Pair; @@ -31,7 +31,7 @@ import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import java.io.Serializable; import java.util.Map; -public interface GraphVertex extends Trainable, Serializable { +public interface GraphVertex extends ITrainableLayer, Serializable { /** Get the name/label of the GraphVertex */ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java index c0b5999ac..a3f45121a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/FrozenVertex.java @@ -21,16 +21,13 @@ package org.deeplearning4j.nn.graph.vertex.impl; import java.util.Map; -import lombok.AllArgsConstructor; + import lombok.EqualsAndHashCode; -import org.deeplearning4j.nn.api.TrainingConfig; -import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.conf.misc.DummyConfig; import org.deeplearning4j.nn.graph.vertex.BaseWrapperVertex; import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.learning.config.IUpdater; -import org.nd4j.linalg.learning.config.NoOp; @EqualsAndHashCode(callSuper = true, exclude = {"config"}) public class FrozenVertex extends BaseWrapperVertex { @@ -41,7 +38,7 @@ public class FrozenVertex extends BaseWrapperVertex { private transient DummyConfig config; @Override - public TrainingConfig getConfig(){ + public ITraininableLayerConfiguration getTrainingConfig(){ if (config == null) { config = new DummyConfig(getVertexName()); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java index 5f9ebdb35..a0df3e1bb 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.java @@ -24,7 +24,7 @@ import lombok.Data; import lombok.EqualsAndHashCode; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.api.TrainingConfig; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.RecurrentLayer; import org.deeplearning4j.nn.conf.InputPreProcessor; @@ -263,13 +263,13 @@ public class LayerVertex extends BaseGraphVertex { } @Override - public TrainingConfig getConfig(){ - return getLayer().getConfig(); + public ITraininableLayerConfiguration getTrainingConfig(){ + return getLayer().getTrainingConfig(); } @Override - public INDArray params(){ - return layer.params(); + public INDArray getParams(){ + return layer.getParams(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index edaa3fb80..e8501f312 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -20,23 +20,15 @@ package org.deeplearning4j.nn.layers; -import java.lang.ref.Cleaner; -import java.lang.ref.PhantomReference; -import java.lang.ref.Reference; -import java.lang.ref.WeakReference; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; -import lombok.AccessLevel; -import lombok.Data; -import lombok.Getter; -import lombok.NoArgsConstructor; -import lombok.NonNull; -import lombok.Setter; +import lombok.*; +import net.brutex.ai.dnn.api.IModel; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.api.TrainingConfig; import org.deeplearning4j.nn.api.layers.LayerConstraint; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -55,122 +47,71 @@ import org.nd4j.linalg.dataset.api.MultiDataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -/** - * A layer with input and output, no parameters or gradients - */ -@Data -@NoArgsConstructor +/** A layer with input and output, no parameters or gradients */ +@NoArgsConstructor(force = true) public abstract class AbstractLayer implements Layer { - @Setter(AccessLevel.NONE) - protected INDArray input; - protected INDArray preOutput; + private final @Getter List variables = new ArrayList<>(); + @Getter - @NonNull - protected LayerConf_T layerConfiguration; + @Setter(AccessLevel.MODULE) + protected INDArray + input; // TODO: this should be private, but too much code is still accessing input directly. + + protected INDArray preOutput; + /** The typed {@link LayerConfiguration}. */ + @Getter @NonNull protected LayerConf_T layerConfiguration; + protected boolean dropoutApplied = false; + @Getter @Setter @NonNull protected Collection trainingListeners = new ArrayList<>(); - @Deprecated public Collection getListeners() {return getTrainingListeners();} - @Deprecated public void setListeners(TrainingListener ... listeners) { setTrainingListeners(List.of(listeners));} - /** - * Set the {@link TrainingListener}s for this model. If any listeners have previously been set, - * they will be replaced by this method - * - * @param listeners - */ - @Deprecated - public void setListeners(Collection listeners) { - setTrainingListeners(listeners); - } - - protected int index = 0; - protected INDArray maskArray; - protected MaskState maskState; + protected @Getter @Setter int index = 0; + protected @Getter @Setter INDArray maskArray; + protected @Getter @Setter MaskState maskState; protected CacheMode cacheMode = CacheMode.NONE; protected boolean inputModificationAllowed = false; protected DataType dataType; - protected int iterationCount; - protected int epochCount; - private List variables = new ArrayList<>(); - public AbstractLayer(LayerConfiguration layerConfiguration, DataType dataType) { - this.layerConfiguration = (LayerConf_T) layerConfiguration; - if (layerConfiguration != null) { + protected @Getter @Setter int iterationCount; + protected @Getter @Setter int epochCount; + private @Getter @Setter IModel net; + + @Getter @Setter @NonNull private NeuralNetConfiguration netConfiguration; + + public AbstractLayer(@NonNull LayerConfiguration layerConf, @NonNull DataType dataType) { + //noinspection unchecked + this.layerConfiguration = (LayerConf_T) layerConf; + this.netConfiguration = layerConfiguration.getNetConfiguration(); + + if (layerConfiguration.getNetConfiguration() != null) { cacheMode = layerConfiguration.getNetConfiguration().getCacheMode(); } this.dataType = dataType; + this.net = layerConfiguration.getNetConfiguration().getNet(); + } + + public void addTrainingListeners(TrainingListener... listeners) { + trainingListeners.addAll(List.of(listeners)); + } + + public void addTrainingListeners(Collection listeners) { + trainingListeners.addAll(listeners); } - /** - * @param backpropOnly If true: return only parameters that are not exclusively used for layerwise - * pretraining - * @return Parameter table - */ @Override - public Map getParamTable(boolean backpropOnly) { - return null; - } - - public void setParamTable(Map map) { - throw new RuntimeException("Not implemented."); - } - /** - * @return 1D gradients view array - */ - @Override - public INDArray getGradientsViewArray() { - return null; + public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { + setInput(input, workspaceMgr); + return activate(training, workspaceMgr); } /** - * Creates and returns a copy of this object. The precise meaning of "copy" may depend on the - * class of the object. The general intent is that, for any object {@code x}, the expression: - *
- *
-   * x.clone() != x
- * will be true, and that the expression: - *
- *
-   * x.clone().getClass() == x.getClass()
- * will be {@code true}, but these are not absolute requirements. While it is typically the case - * that: - *
- *
-   * x.clone().equals(x)
- * will be {@code true}, this is not an absolute requirement. - *

- * By convention, the returned object should be obtained by calling {@code super.clone}. If a - * class and all of its superclasses (except {@code Object}) obey this convention, it will be the - * case that {@code x.clone().getClass() == x.getClass()}. - *

- * By convention, the object returned by this method should be independent of this object (which - * is being cloned). To achieve this independence, it may be necessary to modify one or more - * fields of the object returned by {@code super.clone} before returning it. Typically, this - * means copying any mutable objects that comprise the internal "deep structure" of the object - * being cloned and replacing the references to these objects with references to the copies. If a - * class contains only primitive fields or references to immutable objects, then it is usually the - * case that no fields in the object returned by {@code super.clone} need to be modified. - *

- * The method {@code clone} for class {@code Object} performs a specific cloning operation. First, - * if the class of this object does not implement the interface {@code Cloneable}, then a - * {@code CloneNotSupportedException} is thrown. Note that all arrays are considered to implement - * the interface {@code Cloneable} and that the return type of the {@code clone} method of an - * array type {@code T[]} is {@code T[]} where T is any reference or primitive type. Otherwise, - * this method creates a new instance of the class of this object and initializes all its fields - * with exactly the contents of the corresponding fields of this object, as if by assignment; the - * contents of the fields are not themselves cloned. Thus, this method performs a "shallow copy" - * of this object, not a "deep copy" operation. - *

- * The class {@code Object} does not itself implement the interface {@code Cloneable}, so calling - * the {@code clone} method on an object whose class is {@code Object} will result in throwing an - * exception at run time. + * Creates and returns a copy of this object. * * @return a clone of this instance. * @throws CloneNotSupportedException if the object's class does not support the {@code Cloneable} - * interface. Subclasses that override the {@code clone} method - * can also throw this exception to indicate that an instance - * cannot be cloned. + * interface. Subclasses that override the {@code clone} method can also throw this exception + * to indicate that an instance cannot be cloned. * @see Cloneable */ @Override @@ -178,83 +119,6 @@ public abstract class AbstractLayer impl return super.clone(); } - /** - * Called by the garbage collector on an object when garbage collection determines that there are - * no more references to the object. A subclass overrides the {@code finalize} method to dispose - * of system resources or to perform other cleanup. - *

- * The general contract of {@code finalize} is that it is invoked if and when the Java™ - * virtual machine has determined that there is no longer any means by which this object can be - * accessed by any thread that has not yet died, except as a result of an action taken by the - * finalization of some other object or class which is ready to be finalized. The {@code finalize} - * method may take any action, including making this object available again to other threads; the - * usual purpose of {@code finalize}, however, is to perform cleanup actions before the object is - * irrevocably discarded. For example, the finalize method for an object that represents an - * input/output connection might perform explicit I/O transactions to break the connection before - * the object is permanently discarded. - *

- * The {@code finalize} method of class {@code Object} performs no special action; it simply - * returns normally. Subclasses of {@code Object} may override this definition. - *

- * The Java programming language does not guarantee which thread will invoke the {@code finalize} - * method for any given object. It is guaranteed, however, that the thread that invokes finalize - * will not be holding any user-visible synchronization locks when finalize is invoked. If an - * uncaught exception is thrown by the finalize method, the exception is ignored and finalization - * of that object terminates. - *

- * After the {@code finalize} method has been invoked for an object, no further action is taken - * until the Java virtual machine has again determined that there is no longer any means by which - * this object can be accessed by any thread that has not yet died, including possible actions by - * other objects or classes which are ready to be finalized, at which point the object may be - * discarded. - *

- * The {@code finalize} method is never invoked more than once by a Java virtual machine for any - * given object. - *

- * Any exception thrown by the {@code finalize} method causes the finalization of this object to - * be halted, but is otherwise ignored. - * - * @throws Throwable the {@code Exception} raised by this method - * @apiNote Classes that embed non-heap resources have many options for cleanup of those - * resources. The class must ensure that the lifetime of each instance is longer than that of any - * resource it embeds. {@link Reference#reachabilityFence} can be used to ensure that objects - * remain reachable while resources embedded in the object are in use. - *

- * A subclass should avoid overriding the {@code finalize} method unless the subclass embeds - * non-heap resources that must be cleaned up before the instance is collected. Finalizer - * invocations are not automatically chained, unlike constructors. If a subclass overrides - * {@code finalize} it must invoke the superclass finalizer explicitly. To guard against - * exceptions prematurely terminating the finalize chain, the subclass should use a - * {@code try-finally} block to ensure {@code super.finalize()} is always invoked. For example, - *

{@code      @Override
-   *     protected void finalize() throws Throwable {
-   *         try {
-   *             ... // cleanup subclass state
-   *         } finally {
-   *             super.finalize();
-   *         }
-   *     }
-   * }
- * @jls 12.6 Finalization of Class Instances - * @see WeakReference - * @see PhantomReference - * @deprecated The finalization mechanism is inherently problematic. Finalization can lead to - * performance issues, deadlocks, and hangs. Errors in finalizers can lead to resource leaks; - * there is no way to cancel finalization if it is no longer necessary; and no ordering is - * specified among calls to {@code finalize} methods of different objects. Furthermore, there are - * no guarantees regarding the timing of finalization. The {@code finalize} method might be called - * on a finalizable object only after an indefinite delay, if at all. - *

- * Classes whose instances hold non-heap resources should provide a method to enable explicit - * release of those resources, and they should also implement {@link AutoCloseable} if - * appropriate. The {@link Cleaner} and {@link PhantomReference} provide more flexible and - * efficient ways to release resources when an object becomes unreachable. - */ - @Override - protected void finalize() throws Throwable { - super.finalize(); - } - /** * This method returns updater state (if applicable), null otherwise * @@ -281,9 +145,7 @@ public abstract class AbstractLayer impl * @param dataSet */ @Override - public void fit(DataSet dataSet) { - - } + public void fit(DataSet dataSet) {} /** * This method fits model with a given MultiDataSet @@ -291,9 +153,7 @@ public abstract class AbstractLayer impl * @param dataSet */ @Override - public void fit(MultiDataSet dataSet) { - - } + public void fit(MultiDataSet dataSet) {} /** * This method fits model with a given DataSetIterator @@ -301,9 +161,7 @@ public abstract class AbstractLayer impl * @param iterator */ @Override - public void fit(DataSetIterator iterator) { - - } + public void fit(DataSetIterator iterator) {} /** * This method fits model with a given MultiDataSetIterator @@ -311,9 +169,7 @@ public abstract class AbstractLayer impl * @param iterator */ @Override - public void fit(MultiDataSetIterator iterator) { - - } + public void fit(MultiDataSetIterator iterator) {} /** * This method executes evaluation of the model against given iterator and evaluation @@ -339,31 +195,9 @@ public abstract class AbstractLayer impl return null; } - /** - * @param netConfiguration - */ + /** Init the model */ @Override - public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { - - } - - /** - * Init the model - */ - @Override - public void init() { - - } - - /** - * This method ADDS additional TrainingListener to existing listeners - * - * @param listener - */ - @Override - public void addListeners(TrainingListener... listener) { - this.trainingListeners.addAll(List.of(listener)); - } + public void init() {} /** * Update layer weights and biases with gradient change @@ -371,20 +205,16 @@ public abstract class AbstractLayer impl * @param gradient */ @Override - public void update(Gradient gradient) { - - } + public void update(Gradient gradient) {} /** * Perform one update applying the gradient * - * @param gradient the gradient to apply + * @param gradient the gradient to apply * @param paramType */ @Override - public void update(INDArray gradient, String paramType) { - - } + public void update(INDArray gradient, String paramType) {} /** * Update the score @@ -392,9 +222,7 @@ public abstract class AbstractLayer impl * @param workspaceMgr */ @Override - public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { - - } + public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) {} /** * the number of parameters for the model @@ -407,15 +235,18 @@ public abstract class AbstractLayer impl return 0; } + @Override + public void setParam(String s, INDArray array) {} + /** - * Set the parameters for this model. This expects a linear ndarray which then be unpacked - * internally relative to the expected ordering of the model + * Get a parameter array for a given parameter type key * - * @param params the parameters for the model + * @param param the key of the parameter + * @return ndarray of parameters */ @Override - public void setParams(INDArray params) { - + public INDArray getParam(String param) { + return null; } /** @@ -425,21 +256,16 @@ public abstract class AbstractLayer impl * @param params a 1 x nParams row vector that is a view of the larger (MLN/CG) parameters array */ @Override - public void setParamsViewArray(INDArray params) { - - } + public void setParamsViewArray(INDArray params) {} /** * Set the gradients array as a view of the full (backprop) network parameters NOTE: this is * intended to be used internally in MultiLayerNetwork and ComputationGraph, not by users. * - * @param gradients a 1 x nParams row vector that is a view of the larger (MLN/CG) gradients - * array + * @param gradients a 1 x nParams row vector that is a view of the larger (MLN/CG) gradients array */ @Override - public void setBackpropGradientsViewArray(INDArray gradients) { - - } + public void setBackpropGradientsViewArray(INDArray gradients) {} /** * The current inputs batch size @@ -458,78 +284,28 @@ public abstract class AbstractLayer impl */ @Override public INDArray input() { - return null; + return this.input; } - /** - * Get a parameter array for a given parameter type key - * - * @param param the key of the parameter - * @return ndarray of parameters - */ + /** */ @Override - public INDArray getParam(String param) { - return null; - } - - - /** - * The param table - * - * @return - */ - @Override - public Map getParamTable() { - return null; - } - - /** - * Set the parameters for a given parameter type. - * - * @param key the param type key to set - * @param val the new parameters ndarray - */ - @Override - public void setParam(String key, INDArray val) { - - } - - /** - * - */ - @Override - public void close() { - - } + public void close() {} /** * Calculate the gradient relative to the error in the next layer * - * @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where - * C is cost function a=sigma(z) is activation. + * @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where C is + * cost function a=sigma(z) is activation. * @param workspaceMgr Workspace manager * @return Pair where Gradient is gradient for this layer, INDArray is epsilon - * (activation gradient) needed by next layer, but before element-wise multiply by sigmaPrime(z). - * So for standard feed-forward layer, if this layer is L, then return.getSecond() == dL/dIn = - * (w^(L)*(delta^(L))^T)^T. Note that the returned array should be placed in the - * {@link ArrayType#ACTIVATION_GRAD} workspace via the workspace manager + * (activation gradient) needed by next layer, but before element-wise multiply by + * sigmaPrime(z). So for standard feed-forward layer, if this layer is L, then + * return.getSecond() == dL/dIn = (w^(L)*(delta^(L))^T)^T. Note that the returned array should + * be placed in the {@link ArrayType#ACTIVATION_GRAD} workspace via the workspace manager */ @Override - public Pair backpropGradient(INDArray epsilon, - LayerWorkspaceMgr workspaceMgr) { - return null; - } - - /** - * Perform forward pass and return the activations array with the last set input - * - * @param training training or test mode - * @param workspaceMgr Workspace manager - * @return the activation (layer output) of the last specified input. Note that the returned array - * should be placed in the {@link ArrayType#ACTIVATIONS} workspace via the workspace manager - */ - @Override - public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) { + public Pair backpropGradient( + INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { return null; } @@ -543,13 +319,9 @@ public abstract class AbstractLayer impl return false; } - /** - * - */ + /** */ @Override - public void clearNoiseWeightParams() { - - } + public void clearNoiseWeightParams() {} public List variables() { return variables; @@ -562,33 +334,14 @@ public abstract class AbstractLayer impl return variables; } - /** - * The configuration for the neural network - * - * @return the configuration for the neural network - */ - @Override - public NeuralNetConfiguration getNetConfiguration() { - return layerConfiguration.getNetConfiguration(); - } - public void addVariable(String variable) { if (!variables.contains(variable)) { variables.add(variable); } } - /** - * Return the configuration of this layer - * - * @return the configuration - */ - @Override - public LayerConfiguration getLayerConfiguration() { - return layerConf(); - } - public void setLayerConfiguration(LayerConfiguration layerConfiguration) { + //noinspection unchecked this.layerConfiguration = (LayerConf_T) layerConfiguration; } @@ -601,57 +354,39 @@ public abstract class AbstractLayer impl this.cacheMode = mode; } - public LayerConf_T layerConf() { + public LayerConf_T getTypedLayerConfiguration() { return this.layerConfiguration; } @Override - public TrainingConfig getConfig() { - return layerConfiguration; + public ITraininableLayerConfiguration getTrainingConfig() { + return (ITraininableLayerConfiguration) getTypedLayerConfiguration(); } protected String layerId() { String name = this.layerConfiguration.getLayerName(); - return "(layer name: " + (name == null ? "\"\"" : name) + ", layer index: " + index - + ", layer type: " + - getClass().getSimpleName() + ")"; - } - - public INDArray getInput() { - return input; - } - - public int getEpochCount() { - return epochCount; - } - - public void setEpochCount(int epochCount) { - this.epochCount = epochCount; + return "(layer name: " + + (name == null ? "\"\"" : name) + + ", layer index: " + + index + + ", layer type: " + + getClass().getSimpleName() + + ")"; } @Override - public void setInput(INDArray input, LayerWorkspaceMgr workspaceMgr) { + public void setInput(@NonNull INDArray input, LayerWorkspaceMgr workspaceMgr) { this.input = workspaceMgr.leverageTo(ArrayType.INPUT, input); dropoutApplied = false; } - @Override - public int getIndex() { - return index; - } - - @Override - public void setIndex(int index) { - this.index = index; - } - /** * Returns the parameters of the neural network as a flattened row vector * * @return the parameters of the neural network */ @Override - public INDArray params() { + public INDArray getModelParams() { return null; } @@ -671,65 +406,60 @@ public abstract class AbstractLayer impl to.muliColumnVector(maskArray.castTo(to.dataType())); } - @Override - public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { - setInput(input, workspaceMgr); - return activate(training, workspaceMgr); - } - @Override public double calcRegularizationScore(boolean backpropParamsOnly) { return 0.0; } - @Deprecated public void clear() { input = null; maskArray = null; maskState = null; - if (layerConf().getIDropout() != null) { - layerConf().getIDropout().clear(); + if (getTypedLayerConfiguration().getIDropout() != null) { + getTypedLayerConfiguration().getIDropout().clear(); } } protected void applyDropOutIfNecessary(boolean training, LayerWorkspaceMgr workspaceMgr) { - if (training && !dropoutApplied && layerConf().getIDropout() != null) { + if (training && !dropoutApplied && getTypedLayerConfiguration().getIDropout() != null) { INDArray result; if (inputModificationAllowed) { result = input; } else { - result = workspaceMgr.createUninitialized(ArrayType.INPUT, input.dataType(), input.shape(), - input.ordering()); + result = + workspaceMgr.createUninitialized( + ArrayType.INPUT, input.dataType(), input.shape(), input.ordering()); } - input = layerConf().getIDropout() - .applyDropout(input, result, getIterationCount(), getEpochCount(), workspaceMgr); + input = + getTypedLayerConfiguration() + .getIDropout() + .applyDropout(input, result, getIterationCount(), getEpochCount(), workspaceMgr); dropoutApplied = true; } } protected INDArray backpropDropOutIfPresent(INDArray epsilon) { - if (layerConf().getIDropout() != null) { - layerConf().getIDropout().backprop(epsilon, epsilon, getIterationCount(), getEpochCount()); + if (getTypedLayerConfiguration().getIDropout() != null) { + getTypedLayerConfiguration() + .getIDropout() + .backprop(epsilon, epsilon, getIterationCount(), getEpochCount()); } return epsilon; } - @Override public Type type() { return Type.FEED_FORWARD; } - public void fit(INDArray input, LayerWorkspaceMgr workspaceMgr) { throw new UnsupportedOperationException("Not supported"); } - public Pair gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } @Override @@ -738,23 +468,13 @@ public abstract class AbstractLayer impl } @Override - public void setInputMiniBatchSize(int size) { - } + public void setInputMiniBatchSize(int size) {} @Override - public INDArray getMaskArray() { - return maskArray; - } - - @Override - public void setMaskArray(INDArray maskArray) { - this.maskArray = maskArray; - } - - @Override - public Pair feedForwardMaskArray(INDArray maskArray, - MaskState currentMaskState, int minibatchSize) { - //Most layers: CNN, dense, activation, etc - set mask array, mask state and then leave the mask unmodified + public Pair feedForwardMaskArray( + INDArray maskArray, MaskState currentMaskState, int minibatchSize) { + // Most layers: CNN, dense, activation, etc - set mask array, mask state and then leave the mask + // unmodified this.maskArray = maskArray; this.maskState = currentMaskState; @@ -762,28 +482,24 @@ public abstract class AbstractLayer impl return new Pair<>(maskArray, currentMaskState); } - public Gradient gradient() { throw new UnsupportedOperationException( "Not supported for this layer, or should be overridden for layers requiring it"); } - public void fit() { throw new UnsupportedOperationException( "Not supported for this layer, or should be overridden for layers requiring it"); } - - public double score() { + public double getScore() { throw new UnsupportedOperationException( "Not supported for this layer, or should be overridden for layers requiring it"); } - public void applyConstraints(int iteration, int epoch) { - if (layerConf().getConstraints() != null) { - for (LayerConstraint lc : layerConf().getConstraints()) { + if (getTypedLayerConfiguration().getConstraints() != null) { + for (LayerConstraint lc : getTypedLayerConfiguration().getConstraints()) { lc.applyConstraint(this, iteration, epoch); } } @@ -793,11 +509,13 @@ public abstract class AbstractLayer impl if (input == null) { if (backprop) { throw new IllegalStateException( - "Cannot perform backprop in layer " + getClass().getSimpleName() + "Cannot perform backprop in layer " + + getClass().getSimpleName() + ": layer input field is not set"); } else { throw new IllegalStateException( - "Cannot perform forward pass in layer " + getClass().getSimpleName() + "Cannot perform forward pass in layer " + + getClass().getSimpleName() + ": layer input field is not set"); } } @@ -810,14 +528,79 @@ public abstract class AbstractLayer impl @Override public LayerHelper getHelper() { - //Layers with helpers should override this method! + // Layers with helpers should override this method! return null; } @Override public boolean updaterDivideByMinibatch(String paramName) { - //Majority of params's gradients should be... Exception: batch norm mean/variance estimate + // Majority of params's gradients should be... Exception: batch norm mean/variance estimate return true; } + /** + * The AbstractLayer does not implement Params, ParamTable and GradientView. A RuntimeException + * will be triggered when calling this. + * + * @return + */ + @Override + public Map getParamTable() { + throw new RuntimeException("Not implemented"); + } + + /** + * * The AbstractLayer does not implement Params, ParamTable and GradientView. A RuntimeException + * * will be triggered when calling this. + * + * @param paramTable + */ + @Override + public void setParamTable(Map paramTable) { + throw new RuntimeException("Not implemented"); + } + + /** + * * The AbstractLayer does not implement Params, ParamTable and GradientView. A RuntimeException + * * will be triggered when calling this. + * + * @param isBackprop + * @return + */ + @Override + public Map getParamTable(boolean isBackprop) { + throw new RuntimeException("Not implemented"); + } + + /** + * * The AbstractLayer does not implement Params, ParamTable and GradientView. A RuntimeException + * * will be triggered when calling this. + * + * @return 1d parameter vector + */ + @Override + public INDArray getParams() { + // throw new RuntimeException("Not implemented"); + return null; + } + + /** + * Set the parameters for this model. This expects a linear ndarray which then be unpacked + * internally relative to the expected ordering of the model + * + * @param params the parameters for the model + */ + @Override + public void setParams(INDArray params) {} + + /** + * * The AbstractLayer does not implement Params, ParamTable and GradientView. A RuntimeException + * * will be triggered when calling this. + * + * @return 1D gradients view array + */ + @Override + public INDArray getGradientsViewArray() { + throw new RuntimeException("Not implemented"); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java index 7043275a0..48df25694 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ActivationLayer.java @@ -21,7 +21,7 @@ package org.deeplearning4j.nn.layers; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import java.util.Map; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -38,6 +38,7 @@ public class ActivationLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); INDArray temp = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, input, input.ordering()); - INDArray delta = layerConf().getActivationFn().backprop(temp, epsilon).getFirst(); //TODO handle activation function params + INDArray delta = getTypedLayerConfiguration().getActivationFn().backprop(temp, epsilon).getFirst(); //TODO handle activation function params if(delta == epsilon ){ //Edge case: identity activation + external errors -> no-op delta = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, delta); @@ -75,7 +76,7 @@ public class ActivationLayer extends AbstractLayer paramTable; + /** + * @param backpropOnly If true: return only parameters that are not exclusively used for layerwise + * pretraining + * @return Parameter table + */ + @Override + public Map getParamTable(boolean backpropOnly) { + return this.paramTable; + } + + /** + * @param map + */ + @Override + public void setParamTable(Map map) { + this.paramTable = map; + } @Override - public INDArray params() { + public INDArray getModelParams() { return null; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java index 68de26b7c..6363c77c5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java @@ -21,19 +21,16 @@ package org.deeplearning4j.nn.layers; import java.lang.reflect.Constructor; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; import lombok.Getter; import lombok.NonNull; +import lombok.Setter; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -58,14 +55,31 @@ import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.regularization.Regularization; -/** - * A layer with parameters - * - * @author Adam Gibson - */ +/** A layer with parameters */ @Slf4j -public abstract class BaseLayer - extends AbstractLayer { +public abstract class BaseLayer + extends AbstractLayer implements ITrainableLayer { + + protected double score = 0.0; + protected ConvexOptimizer optimizer; + protected Gradient gradient; + protected Solver solver; + protected Map weightNoiseParams = new HashMap<>(); + protected INDArray paramsFlattened; + protected INDArray gradientsFlattened; + + @Getter @Setter protected Map paramTable; + + @Getter protected transient Map gradientViews; + + /** + * we put this as a virtual function to access the models paramTable. @Getter @Setter private + * INDArray params; + */ + public BaseLayer(LayerConfiguration conf, DataType dataType) { + + super(conf, dataType); + } /** * This method executes evaluation of the model against given iterator and evaluation @@ -91,31 +105,9 @@ public abstract class BaseLayer weightNoiseParams = new HashMap<>(); - protected INDArray paramsFlattened; - protected INDArray gradientsFlattened; - /** - * Full table of parameters - */ - protected Map paramsTable; - @Getter protected transient Map gradientViews; - - public BaseLayer(LayerConfiguration conf, DataType dataType) { - super(conf, dataType); - } - - /** * and others even use \epsilon (epsilon) * http://web.cs.swarthmore.edu/~meeden/cs81/s10/BackPropDeriv.pdf * - * @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where - * C is cost function a=sigma(z) is activation. + * @param epsilon w^(L+1)*delta^(L+1). Or, equiv: dC/da, i.e., (dC/dz)*(dz/da) = dC/da, where C is + * cost function a=sigma(z) is activation. * @param workspaceMgr Workspace manager * @return */ @Override - public Pair backpropGradient(INDArray epsilon, - LayerWorkspaceMgr workspaceMgr) { + public Pair backpropGradient( + INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - //If this layer is layer L, then epsilon is (w^(L+1)*(d^(L+1))^T) (or equivalent) + // If this layer is layer L, then epsilon is (w^(L+1)*(d^(L+1))^T) (or equivalent) Pair zAndPreNorm = preOutputWithPreNorm(true, true, workspaceMgr); - INDArray z = zAndPreNorm.getFirst(); //Note: using preOutput(INDArray) can't be used as this does a setInput(input) and resets the 'appliedDropout' flag + INDArray z = + zAndPreNorm.getFirst(); // Note: using preOutput(INDArray) can't be used as this does a + // setInput(input) and resets the 'appliedDropout' flag INDArray preNorm = zAndPreNorm.getSecond(); - INDArray delta = layerConf().getActivationFn().backprop(z, epsilon) - .getFirst(); //TODO handle activation function params + INDArray delta = + getTypedLayerConfiguration() + .getActivationFn() + .backprop(z, epsilon) + .getFirst(); // TODO handle activation function params if (maskArray != null) { applyMask(delta); @@ -317,29 +264,39 @@ public abstract class BaseLayer(ret, epsilonNext); } - public void computeGradientAndScore(LayerWorkspaceMgr workspaceMgr) { - if (this.input == null) { + if (getInput() == null) { log.warn("There is no input for this layer '{}'", layerConfiguration); return; } @@ -358,18 +314,15 @@ public abstract class BaseLayer parameterList = layerConfiguration.getVariables(); //netWideVariables(); + if (params == null) { + log.warn( + "setParams(INDArray params, char order): params is null. Skipping setParams in Layer {}[{}] at index {}", + getLayerConfiguration().getLayerName(), + getClass().getSimpleName(), + getIndex()); + return; + } + List parameterList = layerConfiguration.getVariables(); // netWideVariables(); int length = 0; - for (String s : parameterList) { - length += getParam(s).length(); - } - if (params.length() != length) { - throw new IllegalArgumentException("Unable to set parameters: must be of length " + length - + ", got params of length " + params.length() + " - " + layerId()); - } + for (String s : parameterList) { + length += getParam(s).length(); + } + if (params.length() != length) { + throw new IllegalArgumentException( + "Unable to set parameters: must be of length " + + length + + ", got params of length " + + params.length() + + " - " + + layerId()); + } int idx = 0; Set paramKeySet = this.getParamTable().keySet(); for (String s : paramKeySet) { INDArray param = getParam(s); - INDArray get = params.get(NDArrayIndex.point(0), - NDArrayIndex.interval(idx, idx + param.length())); - if (param.length() != get.length()) { - throw new IllegalStateException( - "Parameter " + s + " should have been of length " + param.length() - + " but was " + get.length() + " - " + layerId()); - } - param.assign(get.reshape(order, - param.shape())); //Use assign due to backprop params being a view of a larger array + INDArray get = + params.get(NDArrayIndex.point(0), NDArrayIndex.interval(idx, idx + param.length())); + if (param.length() != get.length()) { + throw new IllegalStateException( + "Parameter " + + s + + " should have been of length " + + param.length() + + " but was " + + get.length() + + " - " + + layerId()); + } + param.assign( + get.reshape( + order, + param.shape())); // Use assign due to backprop params being a view of a larger array idx += param.length(); } } @Override public void setParamsViewArray(INDArray params) { - if (this.paramsTable != null && params.length() != numParams()) { - throw new IllegalArgumentException("Invalid input: expect params of length " + numParams() - + ", got params of length " + params.length() + " - " + layerId()); - } - + if (this.getParamTable() != null && params.length() != numParams()) { + throw new IllegalArgumentException( + "Invalid input: expect params of length " + + numParams() + + ", got params of length " + + params.length() + + " - " + + layerId()); + } this.paramsFlattened = params; } + @Override + public Map getParamTable(boolean isBackprop) { + return paramTable; + } + @Override public INDArray getGradientsViewArray() { return gradientsFlattened; @@ -450,15 +431,19 @@ public abstract class BaseLayer 0 && weightNoiseParams.containsKey(param)) { - //Re-use these weights for both forward pass and backprop - don't want to use 2 different params here - //These should be cleared during backprop + // Re-use these weights for both forward pass and backprop - don't want to use 2 different + // params here + // These should be cleared during backprop return weightNoiseParams.get(param); } else { try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - p = layerConf().getWeightNoise() - .getParameter(this, param, getIterationCount(), getEpochCount(), training, - workspaceMgr); + p = + lconf + .getWeightNoise() + .getParameter( + this, param, getIterationCount(), getEpochCount(), training, workspaceMgr); } } if (training) { - //Store for re-use in backprop + // Store for re-use in backprop weightNoiseParams.put(param, p); } } else { @@ -502,34 +491,45 @@ public abstract class BaseLayer preOutputWithPreNorm(boolean training, boolean forBackprop, - LayerWorkspaceMgr workspaceMgr) { + protected Pair preOutputWithPreNorm( + boolean training, boolean forBackprop, @NonNull LayerWorkspaceMgr workspaceMgr) { assertInputSet(forBackprop); applyDropOutIfNecessary(training, workspaceMgr); INDArray W = getParamWithNoise(DefaultParamInitializer.WEIGHT_KEY, training, workspaceMgr); INDArray b = getParamWithNoise(DefaultParamInitializer.BIAS_KEY, training, workspaceMgr); INDArray g = (hasLayerNorm() ? getParam(DefaultParamInitializer.GAIN_KEY) : null); - INDArray input = this.input.castTo(dataType); + INDArray input = getInput().castTo(dataType); - //Input validation: + // Input validation: if (input.rank() != 2 || input.columns() != W.rows()) { if (input.rank() != 2) { throw new DL4JInvalidInputException( "Input that is not a matrix; expected matrix (rank 2), got rank " - + input.rank() + " array with shape " + Arrays.toString(input.shape()) - + ". Missing preprocessor or wrong input type? " + layerId()); + + input.rank() + + " array with shape " + + Arrays.toString(input.shape()) + + ". Missing preprocessor or wrong input type? " + + layerId()); } throw new DL4JInvalidInputException( - "Input size (" + input.columns() + " columns; shape = " + Arrays.toString(input.shape()) + "Input size (" + + input.columns() + + " columns; shape = " + + Arrays.toString(input.shape()) + ") is invalid: does not match layer input size (layer # inputs = " - + W.size(0) + ") " + layerId()); + + W.size(0) + + ") " + + layerId()); } - INDArray ret = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, W.dataType(), - input.size(0), W.size(1)); - input.castTo(ret.dataType()).mmuli(W, - ret); //TODO Can we avoid this cast? (It sohuld be a no op if not required, however) + INDArray ret = + workspaceMgr.createUninitialized( + ArrayType.ACTIVATIONS, W.dataType(), input.size(0), W.size(1)); + input + .castTo(ret.dataType()) + .mmuli( + W, ret); // TODO Can we avoid this cast? (It sohuld be a no op if not required, however) INDArray preNorm = ret; if (hasLayerNorm()) { @@ -550,8 +550,8 @@ public abstract class BaseLayer e : paramsTable.entrySet()) { - List l = layerConf().getRegularizationByParam(e.getKey()); + for (Map.Entry e : getParamTable().entrySet()) { + List l = getTypedLayerConfiguration().getRegularizationByParam(e.getKey()); if (l == null || l.isEmpty()) { continue; } @@ -582,7 +589,7 @@ public abstract class BaseLayer linkedTable = new LinkedHashMap<>(); - for (Map.Entry entry : paramsTable.entrySet()) { + for (Map.Entry entry : getParamTable().entrySet()) { linkedTable.put(entry.getKey(), entry.getValue().dup()); } layer.setParamTable(linkedTable); @@ -591,10 +598,8 @@ public abstract class BaseLayer gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } @Override @@ -167,10 +166,10 @@ public abstract class BaseOutputLayer getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) { - ILossFunction lossFunction = layerConf().getLossFn(); + ILossFunction lossFunction = getTypedLayerConfiguration().getLossFn(); INDArray labels2d = getLabels2d(workspaceMgr, ArrayType.BP_WORKING_MEM); //INDArray delta = lossFunction.computeGradient(labels2d, preOut, layerConf().getActivationFunction(), maskArray); - INDArray delta = lossFunction.computeGradient(labels2d, preOut, layerConf().getActivationFn(), maskArray); + INDArray delta = lossFunction.computeGradient(labels2d, preOut, getTypedLayerConfiguration().getActivationFn(), maskArray); Gradient gradient = new DefaultGradient(); @@ -350,6 +349,6 @@ public abstract class BaseOutputLayer(zeroGradient, underlying.score()); + return new Pair<>(zeroGradient, underlying.getScore()); } @Override @@ -199,9 +199,9 @@ public class FrozenLayer extends BaseWrapperLayer { } @Override - public TrainingConfig getConfig(){ + public ITraininableLayerConfiguration getTrainingConfig(){ if (config == null) { - config = new DummyConfig(getUnderlying().getConfig().getLayerName()); + config = new DummyConfig(getUnderlying().getTrainingConfig().getLayerName()); } return config; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java index 425ec454f..9cf762798 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/FrozenLayerWithBackprop.java @@ -42,7 +42,7 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { public FrozenLayerWithBackprop(final Layer insideLayer) { super(insideLayer); - this.zeroGradient = new DefaultGradient(insideLayer.params()); + this.zeroGradient = new DefaultGradient(insideLayer.getParams()); } protected String layerId() { @@ -58,7 +58,7 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { INDArray backpropEpsilon = underlying.backpropGradient(epsilon, workspaceMgr).getSecond(); - //backprop might have already changed the gradient view (like BaseLayer and BaseOutputLayer do) + //backprop might have already changed the gradient view (like BaseLayerConfiguration and BaseOutputLayer do) //so we want to put it back to zeroes INDArray gradientView = underlying.getGradientsViewArray(); if(gradientView != null){ @@ -72,12 +72,6 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { return underlying.activate(false, workspaceMgr); } - @Override - public INDArray activate(INDArray input, boolean training, LayerWorkspaceMgr workspaceMgr) { - logTestMode(training); - return underlying.activate(input, false, workspaceMgr); - } - @Override public void fit() { if (!logFit) { @@ -112,7 +106,7 @@ public class FrozenLayerWithBackprop extends BaseWrapperLayer { "Gradients for the frozen layer are not set and will therefore will not be updated.Warning will be issued only once per instance"); logGradient = true; } - underlying.score(); + underlying.getScore(); //no op } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java index e13a06219..43bbc69d8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/LossLayer.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.layers.IOutputLayer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -33,7 +32,6 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.util.FeatureUtil; @@ -72,10 +70,10 @@ public class LossLayer extends BaseLayer gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } @Override @@ -135,8 +133,8 @@ public class LossLayer extends BaseLayer getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) { // delta calculation - ILossFunction lossFunction = layerConf().getLossFn(); - INDArray delta = lossFunction.computeGradient(getLabels2d(), preOut, layerConf().getActivationFn(), maskArray); + ILossFunction lossFunction = getTypedLayerConfiguration().getLossFn(); + INDArray delta = lossFunction.computeGradient(getLabels2d(), preOut, getTypedLayerConfiguration().getActivationFn(), maskArray); // grab the empty gradient Gradient gradient = new DefaultGradient(); @@ -172,7 +170,7 @@ public class LossLayer extends BaseLayer fwd = preOutput(false,true,workspaceMgr); - IActivation afn = layerConf().getActivationFn(); + IActivation afn = getTypedLayerConfiguration().getActivationFn(); INDArray delta = afn.backprop(fwd.getFirst(), epsilon).getFirst(); //TODO handle activation function params - org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = layerConf(); + org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = getTypedLayerConfiguration(); Conv1DConfig conf = Conv1DConfig.builder() .k(c.getKernelSize()[0]) .s(c.getStride()[0]) @@ -86,11 +85,11 @@ public class Convolution1DLayer extends ConvolutionLayer { getRnnDataFormat()); INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, input.dataType(), input.shape()); INDArray input = this.input.castTo(dataType); - if(layerConf().getRnnDataFormat() == RNNFormat.NWC) { + if(getTypedLayerConfiguration().getRnnDataFormat() == RNNFormat.NWC) { input = input.permute(0,2,1); //NHWC to NCHW } - if(layerConf().hasBias()) { + if(getTypedLayerConfiguration().hasBias()) { INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY); b = b.reshape(b.length()); inputArrs = new INDArray[]{input, w, b, delta}; @@ -106,7 +105,7 @@ public class Convolution1DLayer extends ConvolutionLayer { Nd4j.exec(op); Gradient retGradient = new DefaultGradient(); - if(layerConf().hasBias()) { + if(getTypedLayerConfiguration().hasBias()) { retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, gradientViews.get(ConvolutionParamInitializer.BIAS_KEY)); } retGradient.setGradientFor(ConvolutionParamInitializer.WEIGHT_KEY, gradientViews.get(ConvolutionParamInitializer.WEIGHT_KEY), 'c'); @@ -130,11 +129,11 @@ public class Convolution1DLayer extends ConvolutionLayer { assertInputSet(false); INDArray input = this.input.castTo(dataType); - if(layerConf().getRnnDataFormat() == RNNFormat.NWC) { + if(getTypedLayerConfiguration().getRnnDataFormat() == RNNFormat.NWC) { input = input.permute(0,2,1); //NHWC to NCHW } - org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = layerConf(); + org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = getTypedLayerConfiguration(); Conv1DConfig conf = Conv1DConfig.builder() .k(c.getKernelSize()[0]) .s(c.getStride()[0]) @@ -151,7 +150,7 @@ public class Convolution1DLayer extends ConvolutionLayer { INDArray[] inputs; - if(layerConf().hasBias()) { + if(getTypedLayerConfiguration().hasBias()) { INDArray b = getParam(ConvolutionParamInitializer.BIAS_KEY); b = b.reshape(b.length()); inputs = new INDArray[]{input, w, b}; @@ -193,18 +192,18 @@ public class Convolution1DLayer extends ConvolutionLayer { @Override public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { - INDArray reduced = ConvolutionUtils.cnn1dMaskReduction(maskArray, layerConf().getKernelSize()[0], - layerConf().getStride()[0], layerConf().getPadding()[0], layerConf().getDilation()[0], - layerConf().getConvolutionMode()); + INDArray reduced = ConvolutionUtils.cnn1dMaskReduction(maskArray, getTypedLayerConfiguration().getKernelSize()[0], + getTypedLayerConfiguration().getStride()[0], getTypedLayerConfiguration().getPadding()[0], getTypedLayerConfiguration().getDilation()[0], + getTypedLayerConfiguration().getConvolutionMode()); return new Pair<>(reduced, currentMaskState); } @Override - public org.deeplearning4j.nn.conf.layers.Convolution1DLayer layerConf() { + public org.deeplearning4j.nn.conf.layers.Convolution1DLayer getTypedLayerConfiguration() { return (org.deeplearning4j.nn.conf.layers.Convolution1DLayer)layerConfiguration; } private RNNFormat getRnnDataFormat(){ - return layerConf().getRnnDataFormat(); + return getTypedLayerConfiguration().getRnnDataFormat(); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution3DLayer.java index 0edb2ed3b..184c46723 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Convolution3DLayer.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.layers.convolution; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -67,7 +66,7 @@ public class Convolution3DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); INDArray weights = getParamWithNoise(Convolution3DParamInitializer.WEIGHT_KEY, true, workspaceMgr); - Convolution3D layerConfig = (Convolution3D) layerConf(); + Convolution3D layerConfig = (Convolution3D) getTypedLayerConfiguration(); boolean isNCDHW = layerConfig.getDataFormat() == Convolution3D.DataFormat.NCDHW; @@ -76,7 +75,7 @@ public class Convolution3DLayer extends ConvolutionLayer { int inH = (int) (isNCDHW ? input.size(3) : input.size(2)); int inW = (int) (isNCDHW ? input.size(4) : input.size(3)); - int outEpsChannels = (int) layerConf().getNIn(); + int outEpsChannels = (int) getTypedLayerConfiguration().getNIn(); int[] dilation = layerConfig.getDilation(); int[] kernel = layerConfig.getKernelSize(); @@ -165,7 +164,7 @@ public class Convolution3DLayer extends ConvolutionLayer { protected Pair preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { - Convolution3D layerConfig = (Convolution3D) layerConf(); + Convolution3D layerConfig = (Convolution3D) getTypedLayerConfiguration(); ConvolutionMode mode = layerConfig.getConvolutionMode(); boolean isNCDHW = layerConfig.getDataFormat() == Convolution3D.DataFormat.NCDHW; @@ -194,8 +193,8 @@ public class Convolution3DLayer extends ConvolutionLayer { int inH = (int) (isNCDHW ? input.size(3) : input.size(2)); int inW = (int) (isNCDHW ? input.size(4) : input.size(3)); - int outWeightChannels = (int)layerConf().getNOut(); - int inWeightChannels = (int)layerConf().getNIn(); + int outWeightChannels = (int) getTypedLayerConfiguration().getNOut(); + int inWeightChannels = (int) getTypedLayerConfiguration().getNIn(); if (inputChannels != inWeightChannels) { String layerName = layerConfiguration.getLayerName(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java index ffd36652a..c1a94338a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/ConvolutionLayer.java @@ -28,7 +28,6 @@ import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -82,7 +81,7 @@ public class ConvolutionLayer extends BaseLayer p = preOutput4d(true, true, workspaceMgr); INDArray z = p.getFirst(); - CNN2DFormat f = layerConf().getCnn2dDataFormat(); + CNN2DFormat f = getTypedLayerConfiguration().getCnn2dDataFormat(); if(f != CNN2DFormat.NCHW){ z = z.permute(0,3,1,2); //NHWC to NCHW } delta = afn.backprop(z, epsilon).getFirst(); //TODO handle activation function params - if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())) { + if (helper != null && (helperCountFail == 0 || !getTypedLayerConfiguration().isCudnnAllowFallback())) { INDArray helperDelta = delta; - if(layerConf().getCnn2dDataFormat() == CNN2DFormat.NHWC) + if(getTypedLayerConfiguration().getCnn2dDataFormat() == CNN2DFormat.NHWC) helperDelta = delta.permute(0,2,3,1); //NCHW to NHWC if(!hasBias() && !(helper instanceof MKLDNNConvHelper)){ //MKL-DNN supports no bias, CuDNN doesn't if(dummyBiasGrad == null){ try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - dummyBiasGrad = Nd4j.create(1, layerConf().getNOut()); + dummyBiasGrad = Nd4j.create(1, getTypedLayerConfiguration().getNOut()); } } biasGradView = dummyBiasGrad; @@ -177,8 +176,8 @@ public class ConvolutionLayer extends BaseLayer(preOutput, null); } @@ -413,7 +412,7 @@ public class ConvolutionLayer extends BaseLayer addiRowVector - if(layerConf().hasBias()){ + if(getTypedLayerConfiguration().hasBias()){ z.addiRowVector(bias); } @@ -499,7 +498,7 @@ public class ConvolutionLayer extends BaseLayer(maskArray, currentMaskState); } - INDArray outMask = ConvolutionUtils.cnn2dMaskReduction(maskArray, layerConf().getKernelSize(), layerConf().getStride(), - layerConf().getPadding(), layerConf().getDilation(), layerConf().getConvolutionMode()); + INDArray outMask = ConvolutionUtils.cnn2dMaskReduction(maskArray, getTypedLayerConfiguration().getKernelSize(), getTypedLayerConfiguration().getStride(), + getTypedLayerConfiguration().getPadding(), getTypedLayerConfiguration().getDilation(), getTypedLayerConfiguration().getConvolutionMode()); return new Pair<>(outMask, currentMaskState); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping1DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping1DLayer.java index ac29715e0..94f752a6e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping1DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping1DLayer.java @@ -20,14 +20,17 @@ package org.deeplearning4j.nn.layers.convolution; +import java.util.Map; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; + import org.deeplearning4j.nn.conf.layers.convolutional.Cropping1D; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; +import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.workspace.ArrayType; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.nd4j.linalg.api.buffer.DataType; @@ -98,4 +101,5 @@ public class Cropping1DLayer extends AbstractLayer { } } } + } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java index d72d2f3eb..83f17f216 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Cropping2DLayer.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -92,7 +91,7 @@ public class Cropping2DLayer extends AbstractLayer p = preOutput4d(true, true, workspaceMgr); delta = afn.backprop(p.getFirst(), epsilon).getFirst(); @@ -119,7 +118,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer { INDArray[] opInputs; INDArray[] opOutputs; - if(layerConf().hasBias()){ + if(getTypedLayerConfiguration().hasBias()){ INDArray bias = getParamWithNoise(DeconvolutionParamInitializer.BIAS_KEY, true, workspaceMgr); opInputs = new INDArray[]{input, weights, bias, delta}; opOutputs = new INDArray[]{outEps, weightGradViewOp, biasGradView}; @@ -137,7 +136,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer { Gradient retGradient = new DefaultGradient(); - if(layerConf().hasBias()){ + if(getTypedLayerConfiguration().hasBias()){ retGradient.setGradientFor(DeconvolutionParamInitializer.BIAS_KEY, biasGradView); } retGradient.setGradientFor(DeconvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c'); @@ -167,7 +166,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer { + " " + layerId()); } - CNN2DFormat format = layerConf().getCnn2dDataFormat(); + CNN2DFormat format = getTypedLayerConfiguration().getCnn2dDataFormat(); boolean nchw = format == CNN2DFormat.NCHW; int cDim = nchw ? 1 : 3; int hDim = nchw ? 2 : 1; @@ -199,9 +198,9 @@ public class Deconvolution2DLayer extends ConvolutionLayer { int kH = (int) weights.size(2); int kW = (int) weights.size(3); - int[] dilation = layerConf().getDilation(); - int[] kernel = layerConf().getKernelSize(); - int[] strides = layerConf().getStride(); + int[] dilation = getTypedLayerConfiguration().getDilation(); + int[] kernel = getTypedLayerConfiguration().getKernelSize(); + int[] strides = getTypedLayerConfiguration().getStride(); int[] pad; int[] outSize; @@ -210,7 +209,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer { pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {(int) input.size(hDim), (int) input.size(wDim)}, kernel, strides, dilation ); } else { - pad = layerConf().getPadding(); + pad = getTypedLayerConfiguration().getPadding(); outSize = ConvolutionUtils.getDeconvolutionOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation } @@ -235,7 +234,7 @@ public class Deconvolution2DLayer extends ConvolutionLayer { weights = weights.permute(2, 3, 1, 0); INDArray[] opInputs; - if (layerConf().hasBias()) { + if (getTypedLayerConfiguration().hasBias()) { opInputs = new INDArray[]{input, weights, bias}; } else { opInputs = new INDArray[]{input, weights}; @@ -262,10 +261,10 @@ public class Deconvolution2DLayer extends ConvolutionLayer { INDArray z = preOutput(training, false, workspaceMgr).getFirst(); - IActivation afn = layerConf().getActivationFn(); + IActivation afn = getTypedLayerConfiguration().getActivationFn(); if (helper != null && Shape.strideDescendingCAscendingF(z)) { - INDArray ret = helper.activate(z, layerConf().getActivationFn(), training); + INDArray ret = helper.activate(z, getTypedLayerConfiguration().getActivationFn(), training); if (ret != null) { return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution3DLayer.java index 302522ed3..a14414d03 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/Deconvolution3DLayer.java @@ -24,7 +24,6 @@ import lombok.val; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.Deconvolution3D; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; @@ -64,20 +63,20 @@ public class Deconvolution3DLayer extends BaseLayer { INDArray weights = getParamWithNoise(DeconvolutionParamInitializer.WEIGHT_KEY, true, workspaceMgr); - Convolution3D.DataFormat df = layerConf().getDataFormat(); - ConvolutionMode cm = layerConf().getConvolutionMode(); + Convolution3D.DataFormat df = getTypedLayerConfiguration().getDataFormat(); + ConvolutionMode cm = getTypedLayerConfiguration().getConvolutionMode(); - int[] dilation = layerConf().getDilation(); - int[] kernel = layerConf().getKernelSize(); - int[] strides = layerConf().getStride(); - int[] pad = layerConf().getPadding(); + int[] dilation = getTypedLayerConfiguration().getDilation(); + int[] kernel = getTypedLayerConfiguration().getKernelSize(); + int[] strides = getTypedLayerConfiguration().getStride(); + int[] pad = getTypedLayerConfiguration().getPadding(); INDArray biasGradView = gradientViews.get(DeconvolutionParamInitializer.BIAS_KEY); INDArray weightGradView = gradientViews.get(DeconvolutionParamInitializer.WEIGHT_KEY); INDArray outEps = workspaceMgr.create(ArrayType.ACTIVATION_GRAD, weights.dataType(), input.shape(), 'c'); - Integer sameMode = (layerConf().getConvolutionMode() == ConvolutionMode.Same) ? 1 : 0; + Integer sameMode = (getTypedLayerConfiguration().getConvolutionMode() == ConvolutionMode.Same) ? 1 : 0; int[] args = new int[] { kernel[0], kernel[1], kernel[2], strides[0], strides[1], strides[2], @@ -86,13 +85,13 @@ public class Deconvolution3DLayer extends BaseLayer { }; INDArray delta; - IActivation afn = layerConf().getActivationFn(); + IActivation afn = getTypedLayerConfiguration().getActivationFn(); INDArray preOutput = preOutput(true, workspaceMgr); delta = afn.backprop(preOutput, epsilon).getFirst(); INDArray[] opInputs; INDArray[] opOutputs; - if(layerConf().hasBias()){ + if(getTypedLayerConfiguration().hasBias()){ INDArray bias = getParamWithNoise(DeconvolutionParamInitializer.BIAS_KEY, true, workspaceMgr); opInputs = new INDArray[]{input, weights, bias, delta}; opOutputs = new INDArray[]{outEps, weightGradView, biasGradView}; @@ -110,7 +109,7 @@ public class Deconvolution3DLayer extends BaseLayer { Gradient retGradient = new DefaultGradient(); - if(layerConf().hasBias()){ + if(getTypedLayerConfiguration().hasBias()){ retGradient.setGradientFor(DeconvolutionParamInitializer.BIAS_KEY, biasGradView); } retGradient.setGradientFor(DeconvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c'); @@ -132,34 +131,34 @@ public class Deconvolution3DLayer extends BaseLayer { " [minibatchSize, inputHeight, inputWidth, inputDepth, channels]. " + layerId()); } - Convolution3D.DataFormat df = layerConf().getDataFormat(); - boolean ncdhw = layerConf().getDataFormat() == Convolution3D.DataFormat.NCDHW; + Convolution3D.DataFormat df = getTypedLayerConfiguration().getDataFormat(); + boolean ncdhw = getTypedLayerConfiguration().getDataFormat() == Convolution3D.DataFormat.NCDHW; int chDim = ncdhw ? 1 : 4; - if (input.size(chDim) != layerConf().getNIn() ) { + if (input.size(chDim) != getTypedLayerConfiguration().getNIn() ) { String layerName = getLayerConfiguration().getLayerName(); if (layerName == null) layerName = "(not named)"; throw new DL4JInvalidInputException("Cannot do forward pass in Deconvolution3D layer (layer name = " + layerName + ", layer index = " + index + "): input array channels does not match CNN layer configuration" + " (data input channels = " + input.size(chDim) + ", " + (ncdhw ? "[minibatch,channels,height,width,depth]=" : "[minibatch,height,width,depth,channels]=") - + Arrays.toString(input.shape()) + "; expected" + " input channels = " + layerConf().getNIn() + ") " + + Arrays.toString(input.shape()) + "; expected" + " input channels = " + getTypedLayerConfiguration().getNIn() + ") " + layerId()); } - int[] dilation = layerConf().getDilation(); - int[] kernel = layerConf().getKernelSize(); - int[] strides = layerConf().getStride(); + int[] dilation = getTypedLayerConfiguration().getDilation(); + int[] kernel = getTypedLayerConfiguration().getKernelSize(); + int[] strides = getTypedLayerConfiguration().getStride(); int[] pad; - ConvolutionMode cm = layerConf().getConvolutionMode(); + ConvolutionMode cm = getTypedLayerConfiguration().getConvolutionMode(); long[] outSize; int[] inSize = df == Convolution3D.DataFormat.NCDHW ? new int[]{(int)input.size(2), (int)input.size(3), (int)input.size(4)} : new int[]{(int)input.size(1), (int)input.size(2), (int)input.size(3)}; if (cm == ConvolutionMode.Same) { - outSize = ConvolutionUtils.getDeconvolution3DOutputSize(input, kernel, strides, null, dilation, cm, layerConf().getDataFormat()); //Also performs validation + outSize = ConvolutionUtils.getDeconvolution3DOutputSize(input, kernel, strides, null, dilation, cm, getTypedLayerConfiguration().getDataFormat()); //Also performs validation pad = ConvolutionUtils.getSameModeTopLeftPadding(ArrayUtil.toInts(outSize), inSize, kernel, strides, dilation ); } else { - pad = layerConf().getPadding(); - outSize = ConvolutionUtils.getDeconvolution3DOutputSize(input, kernel, strides, pad, dilation, cm, layerConf().getDataFormat()); //Also performs validation + pad = getTypedLayerConfiguration().getPadding(); + outSize = ConvolutionUtils.getDeconvolution3DOutputSize(input, kernel, strides, pad, dilation, cm, getTypedLayerConfiguration().getDataFormat()); //Also performs validation } long outH = outSize[0]; @@ -168,7 +167,7 @@ public class Deconvolution3DLayer extends BaseLayer { val miniBatch = input.size(0); - long[] outShape = df == Convolution3D.DataFormat.NCDHW ? new long[]{miniBatch, layerConf().getNOut(), outH, outW, outD} : new long[]{miniBatch, outH, outW, outD, layerConf().getNOut()}; + long[] outShape = df == Convolution3D.DataFormat.NCDHW ? new long[]{miniBatch, getTypedLayerConfiguration().getNOut(), outH, outW, outD} : new long[]{miniBatch, outH, outW, outD, getTypedLayerConfiguration().getNOut()}; INDArray output = workspaceMgr.create(ArrayType.ACTIVATIONS, input.dataType(), outShape, 'c'); int sameMode = (cm == ConvolutionMode.Same) ? 1 : 0; @@ -180,7 +179,7 @@ public class Deconvolution3DLayer extends BaseLayer { }; INDArray[] opInputs; - if (layerConf().hasBias()) { + if (getTypedLayerConfiguration().hasBias()) { opInputs = new INDArray[]{input, weights, bias}; } else { opInputs = new INDArray[]{input, weights}; @@ -207,7 +206,7 @@ public class Deconvolution3DLayer extends BaseLayer { INDArray z = preOutput(training, workspaceMgr); - IActivation afn = layerConf().getActivationFn(); + IActivation afn = getTypedLayerConfiguration().getActivationFn(); INDArray activation = afn.getActivation(z, training); return activation; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java index 888875129..2b39f70d2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/DepthwiseConvolution2DLayer.java @@ -25,7 +25,6 @@ import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -60,12 +59,12 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { @Override public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - CNN2DFormat format = layerConf().getCnn2dDataFormat(); + CNN2DFormat format = getTypedLayerConfiguration().getCnn2dDataFormat(); boolean nchw = format == CNN2DFormat.NCHW; if (input.rank() != 4) { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to Convolution layer with shape " + Arrays.toString(input.shape()) - + ". Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + ". " + + ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getCnn2dDataFormat().dimensionNames() + ". " + layerId()); } INDArray bias; @@ -82,16 +81,16 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { int kH = (int) depthWiseWeights.size(0); int kW = (int) depthWiseWeights.size(1); - int[] dilation = layerConf().getDilation(); - int[] kernel = layerConf().getKernelSize(); - int[] strides = layerConf().getStride(); + int[] dilation = getTypedLayerConfiguration().getDilation(); + int[] kernel = getTypedLayerConfiguration().getKernelSize(); + int[] strides = getTypedLayerConfiguration().getStride(); int[] pad; if (convolutionMode == ConvolutionMode.Same) { int[] outSize = ConvolutionUtils.getOutputSize( input, kernel, strides, null, convolutionMode, dilation, format); pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[]{inH, inW}, kernel, strides, dilation); } else { - pad = layerConf().getPadding(); + pad = getTypedLayerConfiguration().getPadding(); ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); } @@ -110,13 +109,13 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { }; INDArray delta; - IActivation afn = layerConf().getActivationFn(); + IActivation afn = getTypedLayerConfiguration().getActivationFn(); Pair p = preOutput4d(true, true, workspaceMgr); delta = afn.backprop(p.getFirst(), epsilon).getFirst(); INDArray[] inputs; INDArray[] outputs; - if (layerConf().hasBias()) { + if (getTypedLayerConfiguration().hasBias()) { bias = getParamWithNoise(DepthwiseConvolutionParamInitializer.BIAS_KEY, true, workspaceMgr); inputs = new INDArray[]{input, depthWiseWeights, bias, delta}; outputs = new INDArray[]{outEpsilon, weightGradView, biasGradView}; @@ -134,7 +133,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { Nd4j.getExecutioner().exec(op); Gradient retGradient = new DefaultGradient(); - if (layerConf().hasBias()) { + if (getTypedLayerConfiguration().hasBias()) { retGradient.setGradientFor(DepthwiseConvolutionParamInitializer.BIAS_KEY, biasGradView); } retGradient.setGradientFor(DepthwiseConvolutionParamInitializer.WEIGHT_KEY, weightGradView, 'c'); @@ -159,7 +158,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to DepthwiseConvolution2D (layer name = " + layerName + ", layer index = " + index + ") with shape " + Arrays.toString(input.shape()) + ". " - + "Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + "." + + "Expected rank 4 array with shape " + getTypedLayerConfiguration().getCnn2dDataFormat().dimensionNames() + "." + (input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : "") + " " + layerId()); @@ -167,7 +166,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); //no-op if correct dtype - CNN2DFormat format = layerConf().getCnn2dDataFormat(); + CNN2DFormat format = getTypedLayerConfiguration().getCnn2dDataFormat(); boolean nchw = format == CNN2DFormat.NCHW; long inDepth = depthWiseWeights.size(2); @@ -197,9 +196,9 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { int kH = (int) depthWiseWeights.size(0); int kW = (int) depthWiseWeights.size(1); - int[] dilation = layerConf().getDilation(); - int[] kernel = layerConf().getKernelSize(); - int[] strides = layerConf().getStride(); + int[] dilation = getTypedLayerConfiguration().getDilation(); + int[] kernel = getTypedLayerConfiguration().getKernelSize(); + int[] strides = getTypedLayerConfiguration().getStride(); int[] pad; int[] outSize; @@ -212,7 +211,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { pad = ConvolutionUtils.getSameModeTopLeftPadding( outSize, new int[]{(int) input.size(nchw ? 2 : 1), (int) input.size(nchw ? 3 : 2)}, kernel, strides, dilation); } else { - pad = layerConf().getPadding(); + pad = getTypedLayerConfiguration().getPadding(); outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); } @@ -231,7 +230,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { }; INDArray[] inputs; - if (layerConf().hasBias()) { + if (getTypedLayerConfiguration().hasBias()) { inputs = new INDArray[]{input, depthWiseWeights, bias}; } else { inputs = new INDArray[]{input, depthWiseWeights}; @@ -260,7 +259,7 @@ public class DepthwiseConvolution2DLayer extends ConvolutionLayer { INDArray z = preOutput(training, false, workspaceMgr).getFirst(); //String afn = conf.getLayer().getActivationFunction(); - IActivation afn = layerConf().getActivationFn(); + IActivation afn = getTypedLayerConfiguration().getActivationFn(); INDArray activation = afn.getActivation(z, training); return activation; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java index d205017bf..dc660bfc8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SeparableConvolution2DLayer.java @@ -25,7 +25,6 @@ import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -64,7 +63,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { if (input.rank() != 4) { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to SubsamplingLayer with shape " + Arrays.toString(input.shape()) - + ". Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + ". " + + ". Expected rank 4 array with shape " + getTypedLayerConfiguration().getCnn2dDataFormat().dimensionNames() + ". " + layerId()); } INDArray bias; @@ -75,7 +74,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { INDArray input = this.input.castTo(dataType); - CNN2DFormat format = layerConf().getCnn2dDataFormat(); + CNN2DFormat format = getTypedLayerConfiguration().getCnn2dDataFormat(); boolean nchw = format == CNN2DFormat.NCHW; long miniBatch = input.size(0); @@ -86,15 +85,15 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { int kH = (int) depthWiseWeights.size(2); int kW = (int) depthWiseWeights.size(3); - int[] dilation = layerConf().getDilation(); - int[] kernel = layerConf().getKernelSize(); - int[] strides = layerConf().getStride(); + int[] dilation = getTypedLayerConfiguration().getDilation(); + int[] kernel = getTypedLayerConfiguration().getKernelSize(); + int[] strides = getTypedLayerConfiguration().getStride(); int[] pad; if (convolutionMode == ConvolutionMode.Same) { int[] outSize = ConvolutionUtils.getOutputSize(input, kernel, strides, null, convolutionMode, dilation, format); //Also performs validation pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[] {inH, inW}, kernel, strides, dilation); } else { - pad = layerConf().getPadding(); + pad = getTypedLayerConfiguration().getPadding(); ConvolutionUtils.getOutputSize(input, kernel, strides, pad, convolutionMode, dilation, format); //Also performs validation } @@ -114,7 +113,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { }; INDArray delta; - IActivation afn = layerConf().getActivationFn(); + IActivation afn = getTypedLayerConfiguration().getActivationFn(); Pair p = preOutput4d(true, true, workspaceMgr); delta = afn.backprop(p.getFirst(), epsilon).getFirst(); @@ -126,7 +125,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { INDArray opPointWiseWeightGradView = pointWiseWeightGradView.permute(2, 3, 1, 0); CustomOp op; - if(layerConf().hasBias()){ + if(getTypedLayerConfiguration().hasBias()){ bias = getParamWithNoise(SeparableConvolutionParamInitializer.BIAS_KEY, true, workspaceMgr); op = DynamicCustomOp.builder("sconv2d_bp") @@ -146,7 +145,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { Nd4j.getExecutioner().exec(op); Gradient retGradient = new DefaultGradient(); - if(layerConf().hasBias()){ + if(getTypedLayerConfiguration().hasBias()){ retGradient.setGradientFor(ConvolutionParamInitializer.BIAS_KEY, biasGradView); } retGradient.setGradientFor(SeparableConvolutionParamInitializer.DEPTH_WISE_WEIGHT_KEY, depthWiseWeightGradView, 'c'); @@ -168,7 +167,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { getParamWithNoise(SeparableConvolutionParamInitializer.POINT_WISE_WEIGHT_KEY, training, workspaceMgr); INDArray input = this.input.castTo(dataType); - if(layerConf().getCnn2dDataFormat() == CNN2DFormat.NHWC) { + if(getTypedLayerConfiguration().getCnn2dDataFormat() == CNN2DFormat.NHWC) { input = input.permute(0,3,1,2).dup(); } @@ -183,7 +182,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { throw new DL4JInvalidInputException("Got rank " + input.rank() + " array as input to SeparableConvolution2D (layer name = " + layerName + ", layer index = " + index + ") with shape " + Arrays.toString(input.shape()) + ". " - + "Expected rank 4 array with shape " + layerConf().getCnn2dDataFormat().dimensionNames() + "." + + "Expected rank 4 array with shape " + getTypedLayerConfiguration().getCnn2dDataFormat().dimensionNames() + "." + (input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : "") @@ -200,7 +199,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { String s = "Cannot do forward pass in SeparableConvolution2D layer (layer name = " + layerName + ", layer index = " + index + "): input array channels does not match CNN layer configuration" - + " (data format = " + layerConf().getCnn2dDataFormat() + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]=" + + " (data format = " + getTypedLayerConfiguration().getCnn2dDataFormat() + ", data input channels = " + input.size(1) + ", [minibatch,inputDepth,height,width]=" + Arrays.toString(input.shape()) + "; expected" + " input channels = " + inDepth + ") " + layerId(); @@ -215,9 +214,9 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { int kH = (int) depthWiseWeights.size(2); int kW = (int) depthWiseWeights.size(3); - int[] dilation = layerConf().getDilation(); - int[] kernel = layerConf().getKernelSize(); - int[] strides = layerConf().getStride(); + int[] dilation = getTypedLayerConfiguration().getDilation(); + int[] kernel = getTypedLayerConfiguration().getKernelSize(); + int[] strides = getTypedLayerConfiguration().getStride(); int[] pad; int[] outSize; @@ -241,7 +240,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { strides, dilation); } else { - pad = layerConf().getPadding(); + pad = getTypedLayerConfiguration().getPadding(); outSize = ConvolutionUtils.getOutputSize( input, kernel, @@ -273,7 +272,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { pointWiseWeights = pointWiseWeights.permute(2, 3, 1, 0); INDArray[] opInputs; - if (layerConf().hasBias()) { + if (getTypedLayerConfiguration().hasBias()) { opInputs = new INDArray[]{input, depthWiseWeights, pointWiseWeights, bias}; } else { opInputs = new INDArray[]{input, depthWiseWeights, pointWiseWeights}; @@ -288,7 +287,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { .build(); Nd4j.getExecutioner().exec(op); - if(layerConf().getCnn2dDataFormat() == CNN2DFormat.NHWC) { + if(getTypedLayerConfiguration().getCnn2dDataFormat() == CNN2DFormat.NHWC) { output = output.permute(0,2,3,1); //NCHW to NHWC } @@ -307,7 +306,7 @@ public class SeparableConvolution2DLayer extends ConvolutionLayer { INDArray z = preOutput(training, false, workspaceMgr).getFirst(); //String afn = conf.getLayer().getActivationFunction(); - IActivation afn = layerConf().getActivationFn(); + IActivation afn = getTypedLayerConfiguration().getActivationFn(); INDArray activation = afn.getActivation(z, training); return activation; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java index fb824dfa3..1e5c7b270 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/SpaceToBatch.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.convolution; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -48,20 +47,20 @@ public class SpaceToBatch extends AbstractLayer feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { - INDArray reduced = ConvolutionUtils.cnn1dMaskReduction(maskArray, layerConf().getKernelSize()[0], - layerConf().getStride()[0], layerConf().getPadding()[0], layerConf().getDilation()[0], - layerConf().getConvolutionMode()); + INDArray reduced = ConvolutionUtils.cnn1dMaskReduction(maskArray, getTypedLayerConfiguration().getKernelSize()[0], + getTypedLayerConfiguration().getStride()[0], getTypedLayerConfiguration().getPadding()[0], getTypedLayerConfiguration().getDilation()[0], + getTypedLayerConfiguration().getConvolutionMode()); return new Pair<>(reduced, currentMaskState); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/Subsampling3DLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/Subsampling3DLayer.java index 01f1698f6..168d59357 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/Subsampling3DLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/Subsampling3DLayer.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.convolution.subsampling; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.ConvolutionMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.PoolingType; @@ -69,7 +68,7 @@ public class Subsampling3DLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - boolean isNCDHW = layerConf().getDataFormat() == Convolution3D.DataFormat.NCDHW; + boolean isNCDHW = getTypedLayerConfiguration().getDataFormat() == Convolution3D.DataFormat.NCDHW; long miniBatch = input.size(0); long inChannels = isNCDHW ? input.size(1) : input.size(4); @@ -77,9 +76,9 @@ public class Subsampling3DLayer extends AbstractLayer ret = null; try{ ret = helper.backpropGradient(input, epsilon, kernel, strides, pad, - layerConf().getPoolingType(), convolutionMode, dilation, dataFormat, workspaceMgr); + getTypedLayerConfiguration().getPoolingType(), convolutionMode, dilation, dataFormat, workspaceMgr); } catch (ND4JOpProfilerException e){ throw e; //NaN panic etc for debugging } catch (Exception e){ @@ -137,7 +136,7 @@ public class SubsamplingLayer extends AbstractLayer(maskArray, currentMaskState); } - INDArray outMask = ConvolutionUtils.cnn2dMaskReduction(maskArray, layerConf().getKernelSize(), layerConf().getStride(), - layerConf().getPadding(), layerConf().getDilation(), layerConf().getConvolutionMode()); + INDArray outMask = ConvolutionUtils.cnn2dMaskReduction(maskArray, getTypedLayerConfiguration().getKernelSize(), getTypedLayerConfiguration().getStride(), + getTypedLayerConfiguration().getPadding(), getTypedLayerConfiguration().getDilation(), getTypedLayerConfiguration().getConvolutionMode()); return super.feedForwardMaskArray(outMask, currentMaskState, minibatchSize); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.java index e6630ad48..ae5417fc8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling1D.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.convolution.upsampling; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.BaseUpsamplingLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -57,7 +56,7 @@ public class Upsampling1D extends Upsampling2D { public Pair backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - int[] size = ((BaseUpsamplingLayer) layerConf()).getSize(); + int[] size = ((BaseUpsamplingLayer) getTypedLayerConfiguration()).getSize(); epsilon = epsilon.reshape(epsilon.size(0), epsilon.size(1), epsilon.size(2), 1); // we replicate the error term times "size" so that backprop works properly on it epsilon = epsilon.repeat(3, size[0]); @@ -95,7 +94,7 @@ public class Upsampling1D extends Upsampling2D { @Override protected int[] getSize(){ - return ((org.deeplearning4j.nn.conf.layers.Upsampling1D) getLayerConfiguration()).getSize(); + return getLayerConfiguration().getSize(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java index bf0742870..cf9da710e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/upsampling/Upsampling2D.java @@ -24,7 +24,6 @@ import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.exception.DL4JInvalidInputException; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -87,12 +86,12 @@ public class Upsampling2D extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - boolean ncdhw = layerConf().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW; + boolean ncdhw = getTypedLayerConfiguration().getDataFormat() == org.deeplearning4j.nn.conf.layers.Convolution3D.DataFormat.NCDHW; // Assumes NCDHW order long miniBatch = input.size(0); long inChannels, inD, inH, inW; @@ -110,7 +109,7 @@ public class Upsampling3D extends AbstractLayer { - long[] axes = layerConf().getSharedAxes(); + long[] axes = getTypedLayerConfiguration().getSharedAxes(); public PReLU(LayerConfiguration conf, DataType dataType) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/autoencoder/AutoEncoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/autoencoder/AutoEncoder.java index b2264d5cb..5a65889f8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/autoencoder/AutoEncoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/feedforward/autoencoder/AutoEncoder.java @@ -20,7 +20,6 @@ package org.deeplearning4j.nn.layers.feedforward.autoencoder; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.layers.BasePretrainNetwork; import org.deeplearning4j.nn.params.PretrainParamInitializer; @@ -55,7 +54,7 @@ public class AutoEncoder extends BasePretrainNetwork 0 ? getCorruptedInput(input, corruptionLevel) : input; setInput(corruptedX, workspaceMgr); @@ -98,8 +97,8 @@ public class AutoEncoder extends BasePretrainNetwork backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { //If this layer is layer L, then epsilon for this layer is ((w^(L+1)*(delta^(L+1))^T))^T (or equivalent) INDArray z = preOutput(true, workspaceMgr); //Note: using preOutput(INDArray) can't be used as this does a setInput(input) and resets the 'appliedDropout' flag - INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //TODO handle activation function params + INDArray delta = getTypedLayerConfiguration().getActivationFn().backprop(z, epsilon).getFirst(); //TODO handle activation function params if (maskArray != null) { applyMask(delta); @@ -69,7 +68,7 @@ public class ElementWiseMultiplicationLayer extends BaseLayer Integer.MAX_VALUE) throw new ND4JArraySizeException(); @@ -126,7 +125,7 @@ public class EmbeddingLayer extends BaseLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); INDArray z = preOutput(true, workspaceMgr); - INDArray delta = layerConf().getActivationFn().backprop(z, epsilon).getFirst(); //Shape: [mb, vector, seqLength] + INDArray delta = getTypedLayerConfiguration().getActivationFn().backprop(z, epsilon).getFirst(); //Shape: [mb, vector, seqLength] - boolean ncw = layerConf().getOutputFormat() == RNNFormat.NCW; + boolean ncw = getTypedLayerConfiguration().getOutputFormat() == RNNFormat.NCW; if (maskArray != null) { if(ncw){ @@ -68,9 +67,9 @@ public class EmbeddingSequenceLayer extends BaseLayer [minibatch, nOut, seqLen] i.e., NWC -> NCW } return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, ret); @@ -176,7 +175,7 @@ public class EmbeddingSequenceLayer extends BaseLayer(Arrays.asList(listeners)); } @@ -618,7 +619,7 @@ public class BatchNormalization extends BaseLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - double k = layerConf().getK(); - double n = layerConf().getN(); - double alpha = layerConf().getAlpha(); - double beta = layerConf().getBeta(); + double k = getTypedLayerConfiguration().getK(); + double n = getTypedLayerConfiguration().getN(); + double alpha = getTypedLayerConfiguration().getAlpha(); + double beta = getTypedLayerConfiguration().getBeta(); int halfN = (int) n / 2; - if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())){ + if (helper != null && (helperCountFail == 0 || !getTypedLayerConfiguration().isCudnnAllowFallback())){ Pair ret = null; try { ret = helper.backpropGradient(input, epsilon, k, n, alpha, beta, workspaceMgr); @@ -120,7 +119,7 @@ public class LocalResponseNormalization //This is a memory exception - don't fallback to built-in implementation throw t; } - if(layerConf().isCudnnAllowFallback()){ + if(getTypedLayerConfiguration().isCudnnAllowFallback()){ helperCountFail++; log.warn("CuDNN LocalResponseNormalization backprop execution failed - falling back on built-in implementation",t); } else { @@ -132,7 +131,7 @@ public class LocalResponseNormalization } } - boolean nchw = layerConf().getDataFormat() == CNN2DFormat.NCHW; + boolean nchw = getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NCHW; int chDim = nchw ? 1 : 3; int hDim = nchw ? 2 : 1; int wDim = nchw ? 3 : 2; @@ -185,13 +184,13 @@ public class LocalResponseNormalization private Triple activateHelper(boolean training, LayerWorkspaceMgr workspaceMgr, boolean forBackprop){ assertInputSet(false); - double k = layerConf().getK(); - double n = layerConf().getN(); - double alpha = layerConf().getAlpha(); - double beta = layerConf().getBeta(); + double k = getTypedLayerConfiguration().getK(); + double n = getTypedLayerConfiguration().getN(); + double alpha = getTypedLayerConfiguration().getAlpha(); + double beta = getTypedLayerConfiguration().getBeta(); int halfN = (int) n / 2; - if (helper != null && (helperCountFail == 0 || !layerConf().isCudnnAllowFallback())){ + if (helper != null && (helperCountFail == 0 || !getTypedLayerConfiguration().isCudnnAllowFallback())){ INDArray activations = null; try { activations = helper.activate(input, training, k, n, alpha, beta, workspaceMgr); @@ -203,7 +202,7 @@ public class LocalResponseNormalization throw t; } - if(layerConf().isCudnnAllowFallback()){ + if(getTypedLayerConfiguration().isCudnnAllowFallback()){ helperCountFail++; log.warn("CuDNN LocalResponseNormalization backprop execution failed - falling back on built-in implementation",t); } else { @@ -215,7 +214,7 @@ public class LocalResponseNormalization } } - boolean nchw = layerConf().getDataFormat() == CNN2DFormat.NCHW; + boolean nchw = getTypedLayerConfiguration().getDataFormat() == CNN2DFormat.NCHW; int chDim = nchw ? 1 : 3; val channel = input.size(chDim); @@ -287,13 +286,13 @@ public class LocalResponseNormalization } @Override - public INDArray params() { + public INDArray getModelParams() { return null; } @Override public INDArray getParam(String param) { - return params(); + return getModelParams(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java index e5f0fbf1e..49a61f496 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/objdetect/Yolo2OutputLayer.java @@ -24,7 +24,6 @@ import lombok.*; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -86,19 +85,19 @@ public class Yolo2OutputLayer extends AbstractLayer Predicted WH in grid units (0 to 13 usually) INDArray predictedWHPreExp = input5.get(all(), all(), interval(2,4), all(), all()); INDArray predictedWH = Transforms.exp(predictedWHPreExp, true); - Broadcast.mul(predictedWH, layerConf().getBoundingBoxes().castTo(predictedWH.dataType()), predictedWH, 1, 2); //Box priors: [b, 2]; predictedWH: [mb, b, 2, h, w] + Broadcast.mul(predictedWH, getTypedLayerConfiguration().getBoundingBoxes().castTo(predictedWH.dataType()), predictedWH, 1, 2); //Box priors: [b, 2]; predictedWH: [mb, b, 2, h, w] //Apply sqrt to W/H in preparation for loss function INDArray predictedWHSqrt = Transforms.sqrt(predictedWH, true); @@ -236,11 +235,11 @@ public class Yolo2OutputLayer extends AbstractLayer gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), this.getScore()); } @Override @@ -617,7 +616,7 @@ public class Yolo2OutputLayer extends AbstractLayer getPredictedObjects(INDArray networkOutput, double threshold){ - return YoloUtils.getPredictedObjects(layerConf().getBoundingBoxes(), networkOutput, threshold, 0.0); + return YoloUtils.getPredictedObjects(getTypedLayerConfiguration().getBoundingBoxes(), networkOutput, threshold, 0.0); } /** @@ -651,7 +650,7 @@ public class Yolo2OutputLayer extends AbstractLayer getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) { - ILossFunction lossFunction = layerConf().getLossFn(); + ILossFunction lossFunction = getTypedLayerConfiguration().getLossFn(); INDArray labels2d = getLabels2d(workspaceMgr, ArrayType.BP_WORKING_MEM); - INDArray delta = lossFunction.computeGradient(labels2d, preOut, layerConf().getActivationFn(), maskArray); + INDArray delta = lossFunction.computeGradient(labels2d, preOut, getTypedLayerConfiguration().getActivationFn(), maskArray); org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer conf = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) getLayerConfiguration(); @@ -165,20 +164,20 @@ public class OCNNOutputLayer extends BaseOutputLayer sigmoid derivative - INDArray firstVertDerivV = layerConf().getActivationFn() + INDArray firstVertDerivV = getTypedLayerConfiguration().getActivationFn() .backprop(xTimesV.dup(),Nd4j.ones(input.dataType(), xTimesV.shape())) .getFirst().muliRowVector(getParam(W_KEY).neg()); firstVertDerivV = firstVertDerivV.muliColumnVector(delta) - .reshape('f',input.size(0),1,layerConf().getHiddenSize()); + .reshape('f',input.size(0),1, getTypedLayerConfiguration().getHiddenSize()); INDArray secondTermDerivV = input.reshape('f', input.size(0),getParam(V_KEY).size(0),1); @@ -251,7 +250,7 @@ public class OCNNOutputLayer extends BaseOutputLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - if (!layerConf().isCollapseDimensions() && epsilon.rank() != 2) { + if (!getTypedLayerConfiguration().isCollapseDimensions() && epsilon.rank() != 2) { val origShape = epsilon.shape(); //Don't collapse dims case: error should be [minibatch, vectorSize, 1] or [minibatch, channels, 1, 1] //Reshape it to 2d, to get rid of the 1s @@ -293,7 +292,7 @@ public class GlobalPoolingLayer extends AbstractLayer stateMap) { throw new UnsupportedOperationException("Not supported: cannot RnnTimeStep bidirectional layers therefore " + @@ -255,20 +272,14 @@ public class BidirectionalLayer implements RecurrentLayer { } @Override - public Collection getListeners() { - return fwd.getListeners(); + public Collection getTrainingListeners() { + return fwd.getTrainingListeners(); } @Override - public void setListeners(TrainingListener... listeners) { - fwd.setListeners(listeners); - bwd.setListeners(listeners); - } - - @Override - public void addListeners(TrainingListener... listener) { - fwd.addListeners(listener); - bwd.addListeners(listener); + public void addTrainingListeners(TrainingListener... listeners) { + fwd.addTrainingListeners(listeners); + bwd.addTrainingListeners(listeners); } @Override @@ -287,8 +298,8 @@ public class BidirectionalLayer implements RecurrentLayer { } @Override - public double score() { - return fwd.score() + bwd.score(); + public double getScore() { + return fwd.getScore() + bwd.getScore(); } @Override @@ -298,14 +309,10 @@ public class BidirectionalLayer implements RecurrentLayer { } @Override - public INDArray params() { + public INDArray getModelParams() { return paramsView; } - @Override - public TrainingConfig getConfig() { - return layerConfiguration; - } @Override public long numParams() { @@ -548,9 +555,9 @@ public class BidirectionalLayer implements RecurrentLayer { //No op } - public void setListeners(Collection listeners) { - fwd.setListeners(listeners.toArray(new TrainingListener[]{})); - bwd.setListeners(listeners.toArray(new TrainingListener[]{})); + public void addTrainingListeners(Collection listeners) { + fwd.addTrainingListeners(listeners.toArray(new TrainingListener[]{})); + bwd.addTrainingListeners(listeners.toArray(new TrainingListener[]{})); } @Override @@ -708,4 +715,11 @@ public class BidirectionalLayer implements RecurrentLayer { public void close(){ //No-op for individual layers } + /** + * @return 1d parameter vector + */ + @Override + public INDArray getParams() { + throw new RuntimeException("Not implemented."); + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java index ac5c57165..595dd0e2c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesBidirectionalLSTM.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.recurrent; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -76,14 +75,14 @@ public class GravesBidirectionalLSTM fwdPass.fwdPassOutput = permuteIfNWC(fwdPass.fwdPassOutput); final Pair forwardsGradient = LSTMHelpers.backpropGradientHelper(this, this.layerConfiguration.getNetConfiguration(), - this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + this.getTypedLayerConfiguration().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS, GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS, gradientViews, maskArray, true, - null, workspaceMgr, layerConf().isHelperAllowFallback()); + null, workspaceMgr, getTypedLayerConfiguration().isHelperAllowFallback()); @@ -91,14 +90,14 @@ public class GravesBidirectionalLSTM final Pair backwardsGradient = LSTMHelpers.backpropGradientHelper(this, this.layerConfiguration.getNetConfiguration(), - this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + this.getTypedLayerConfiguration().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, backPass, false, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS, GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS, gradientViews, maskArray, true, - null, workspaceMgr, layerConf().isHelperAllowFallback()); + null, workspaceMgr, getTypedLayerConfiguration().isHelperAllowFallback()); forwardsGradient.setSecond(permuteIfNWC(forwardsGradient.getSecond())); backwardsGradient.setSecond(permuteIfNWC(backwardsGradient.getSecond())); @@ -118,7 +117,7 @@ public class GravesBidirectionalLSTM final Gradient correctOrderedGradient = new DefaultGradient(); - for (final String key : paramsTable.keySet()) { + for (final String key : getParamTable().keySet()) { correctOrderedGradient.setGradientFor(key, combinedGradient.getGradientFor(key)); } @@ -156,22 +155,22 @@ public class GravesBidirectionalLSTM cachedPassForward = null; } else { - forwardsEval = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), + forwardsEval = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.getTypedLayerConfiguration().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_FORWARDS), training, null, null, forBackprop || (cacheMode != CacheMode.NONE && training), true, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_FORWARDS, maskArray, true, null, - forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); + forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, getTypedLayerConfiguration().isHelperAllowFallback()); - backwardsEval = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), + backwardsEval = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.getTypedLayerConfiguration().getGateActivationFn(), permuteIfNWC(this.input), getParam(GravesBidirectionalLSTMParamInitializer.RECURRENT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS), getParam(GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS), training, null, null, forBackprop || (cacheMode != CacheMode.NONE && training), false, GravesBidirectionalLSTMParamInitializer.INPUT_WEIGHT_KEY_BACKWARDS, maskArray, true, null, - forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); + forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, getTypedLayerConfiguration().isHelperAllowFallback()); forwardsEval.fwdPassOutput = permuteIfNWC(forwardsEval.fwdPassOutput); backwardsEval.fwdPassOutput = permuteIfNWC(backwardsEval.fwdPassOutput); @@ -216,10 +215,10 @@ public class GravesBidirectionalLSTM biasKey = GravesBidirectionalLSTMParamInitializer.BIAS_KEY_BACKWARDS; } - FwdPassReturn ret = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + FwdPassReturn ret = LSTMHelpers.activateHelper(this, this.layerConfiguration.getNetConfiguration(), this.getTypedLayerConfiguration().getGateActivationFn(), permuteIfNWC(this.input), getParam(recurrentKey), getParam(inputKey), getParam(biasKey), training, prevOutputActivations, prevMemCellState, forBackprop, forwards, inputKey, maskArray, true, - null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, layerConf().isHelperAllowFallback()); + null, forBackprop ? cacheMode : CacheMode.NONE, workspaceMgr, getTypedLayerConfiguration().isHelperAllowFallback()); ret.fwdPassOutput = permuteIfNWC(ret.fwdPassOutput); return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java index 5aedd780b..6626e927e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/GravesLSTM.java @@ -23,7 +23,6 @@ package org.deeplearning4j.nn.layers.recurrent; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.params.GravesLSTMParamInitializer; @@ -84,11 +83,11 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this, - this.layerConfiguration.getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + this.layerConfiguration.getNetConfiguration(), this.getTypedLayerConfiguration().getGateActivationFn(), permuteIfNWC(this.input), recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, GravesLSTMParamInitializer.INPUT_WEIGHT_KEY, GravesLSTMParamInitializer.RECURRENT_WEIGHT_KEY, GravesLSTMParamInitializer.BIAS_KEY, gradientViews, maskArray, true, null, - workspaceMgr, layerConf().isHelperAllowFallback()); + workspaceMgr, getTypedLayerConfiguration().isHelperAllowFallback()); weightNoiseParams.clear(); p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond()))); @@ -129,11 +128,11 @@ public class GravesLSTM extends BaseRecurrentLayer p = LSTMHelpers.backpropGradientHelper(this, - getNetConfiguration(), this.layerConf().getGateActivationFn(), permuteIfNWC(this.input), + getNetConfiguration(), this.getTypedLayerConfiguration().getGateActivationFn(), permuteIfNWC(this.input), recurrentWeights, inputWeights, permuteIfNWC(epsilon), truncatedBPTT, tbpttBackwardLength, fwdPass, true, LSTMParamInitializer.INPUT_WEIGHT_KEY, LSTMParamInitializer.RECURRENT_WEIGHT_KEY, LSTMParamInitializer.BIAS_KEY, gradientViews, null, false, helper, workspaceMgr, - layerConf().isHelperAllowFallback()); + getTypedLayerConfiguration().isHelperAllowFallback()); weightNoiseParams.clear(); p.setSecond(permuteIfNWC(backpropDropOutIfPresent(p.getSecond()))); @@ -140,7 +139,7 @@ public class LSTM extends BaseRecurrentLayer= endIdx; iTimeIndex--) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java index de9d75928..e734212e8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/RnnLossLayer.java @@ -25,7 +25,6 @@ import lombok.Setter; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.api.MaskState; import org.deeplearning4j.nn.api.layers.IOutputLayer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -64,7 +63,7 @@ public class RnnLossLayer extends BaseLayer { } } - org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); + org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) getTypedLayerConfiguration(); bl.validateInput(input); Map phMap = new HashMap<>(); @@ -104,7 +103,7 @@ public class SameDiffLayer extends AbstractLayer { if(maskArray != null){ phMap.put(MASK_KEY, maskArray); } else { - phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); + phMap.put(MASK_KEY, getTypedLayerConfiguration().onesMaskForInput(input)); } //Configure memory management for SameDiff instance - use DL4J workspaces @@ -176,7 +175,7 @@ public class SameDiffLayer extends AbstractLayer { sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr); - org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); + org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) getTypedLayerConfiguration(); bl.validateInput(input); Map phMap = new HashMap<>(); @@ -185,7 +184,7 @@ public class SameDiffLayer extends AbstractLayer { if(maskArray != null){ phMap.put(MASK_KEY, maskArray); } else { - phMap.put(MASK_KEY, layerConf().onesMaskForInput(input)); + phMap.put(MASK_KEY, getTypedLayerConfiguration().onesMaskForInput(input)); } List requiredGrads = new ArrayList<>(paramTable.size() + 1); @@ -215,7 +214,7 @@ public class SameDiffLayer extends AbstractLayer { * @return the parameters of the neural network */ @Override - public INDArray params() { + public INDArray getModelParams() { return params; } @@ -272,7 +271,7 @@ public class SameDiffLayer extends AbstractLayer { @Override public void setBackpropGradientsViewArray(INDArray gradients) { this.gradients = gradients; - this.gradTable = layerConf().initializer().getGradientsFromFlattened(this.getLayerConfiguration(), gradients); + this.gradTable = getTypedLayerConfiguration().initializer().getGradientsFromFlattened(this.getLayerConfiguration(), gradients); } @Override @@ -298,7 +297,7 @@ public class SameDiffLayer extends AbstractLayer { protected void doInit(){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); + org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) getTypedLayerConfiguration(); sameDiff = SameDiff.create(); //Use SingleThreadArrayHolder so we can use views (also don't nede multithreading here, DL4J is not thread safe) sameDiff.setArrayHolders(new SingleThreadArrayHolder(), new SingleThreadArrayHolder(), false); @@ -307,7 +306,7 @@ public class SameDiffLayer extends AbstractLayer { long[] inputShape = input.shape().clone(); inputShape[0] = -1; SDVariable inputVar = sameDiff.placeHolder(INPUT_KEY, dataType, inputShape); - Map paramShapes = layerConf().getLayerParams().getParamShapes(); + Map paramShapes = getTypedLayerConfiguration().getLayerParams().getParamShapes(); Map params = new LinkedHashMap<>(); for (String s : paramShapes.keySet()) { val ps = paramShapes.get(s); @@ -336,7 +335,7 @@ public class SameDiffLayer extends AbstractLayer { @Override public Pair feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) { - org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf(); + org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) getTypedLayerConfiguration(); this.maskArray = maskArray; this.maskState = currentMaskState; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java index d3cc93049..60d4d4c7d 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffOutputLayer.java @@ -25,7 +25,6 @@ import lombok.Setter; import lombok.val; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.layers.IOutputLayer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; @@ -97,7 +96,7 @@ public class SameDiffOutputLayer extends AbstractLayer phMap = new HashMap<>(); phMap.put(INPUT_KEY, input); - if(!activations && layerConf().labelsRequired() && labels != null) { + if(!activations && getTypedLayerConfiguration().labelsRequired() && labels != null) { phMap.put(LABELS_KEY, labels); } - String s = activations ? layerConf().activationsVertexName() : outputVar.name(); + String s = activations ? getTypedLayerConfiguration().activationsVertexName() : outputVar.name(); INDArray out = sameDiff.outputSingle(phMap, s); @@ -153,7 +152,7 @@ public class SameDiffOutputLayer extends AbstractLayer backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { assertInputSet(true); - Preconditions.checkState(!layerConf().labelsRequired() || labels != null, "Cannot execute backprop: Labels are not set. " + + Preconditions.checkState(!getTypedLayerConfiguration().labelsRequired() || labels != null, "Cannot execute backprop: Labels are not set. " + "If labels are not required for this SameDiff output layer, override SameDiffOutputLayer.labelsRequired()" + " to return false instead"); Gradient g = new DefaultGradient(); @@ -228,7 +227,7 @@ public class SameDiffOutputLayer extends AbstractLayer paramShapes = layerConf().getLayerParams().getParamShapes(); + Map paramShapes = getTypedLayerConfiguration().getLayerParams().getParamShapes(); Map params = new LinkedHashMap<>(); for (String s : paramShapes.keySet()) { val ps = paramShapes.get(s); @@ -341,7 +340,7 @@ public class SameDiffOutputLayer extends AbstractLayer gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } @Override @@ -146,7 +145,7 @@ public class CenterLossOutputLayer extends BaseOutputLayer getGradientsAndDelta(INDArray preOut, LayerWorkspaceMgr workspaceMgr) { - ILossFunction lossFunction = layerConf().getLossFn(); + ILossFunction lossFunction = getTypedLayerConfiguration().getLossFn(); INDArray labels2d = getLabels2d(workspaceMgr, ArrayType.BP_WORKING_MEM); if (labels2d.size(1) != preOut.size(1)) { throw new DL4JInvalidInputException( @@ -182,7 +181,7 @@ public class CenterLossOutputLayer extends BaseOutputLayer params; @Getter protected transient Map gradientViews; - protected double score = 0.0; protected ConvexOptimizer optimizer; protected Gradient gradient; - protected Collection trainingListeners = new ArrayList<>(); protected int index = 0; protected INDArray maskArray; protected Solver solver; - protected int[] encoderLayerSizes; protected int[] decoderLayerSizes; protected ReconstructionDistribution reconstructionDistribution; @@ -87,18 +84,15 @@ public class VariationalAutoencoder implements Layer { protected int numSamples; protected CacheMode cacheMode = CacheMode.NONE; protected DataType dataType; - protected boolean zeroedPretrainParamGradients = false; - protected Map weightNoiseParams = new HashMap<>(); - @Getter @Setter protected int iterationCount; @Getter @Setter protected int epochCount; - @Getter @Setter @NonNull private LayerConfiguration layerConfiguration; + private @Getter @Setter Collection trainingListeners; public VariationalAutoencoder(@NonNull LayerConfiguration layerConfiguration, DataType dataType) { this.layerConfiguration = layerConfiguration; @@ -119,6 +113,16 @@ public class VariationalAutoencoder implements Layer { .getNumSamples(); } + /** + * Get a reference to the network this layer is part of. + * + * @return + */ + @Override + public IModel getNet() { + throw new RuntimeException("Not implemented."); + } + protected org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder layerConf() { return (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder) layerConfiguration; } @@ -175,7 +179,7 @@ public class VariationalAutoencoder implements Layer { } @Override - public double score() { + public double getScore() { return score; } @@ -277,7 +281,7 @@ public class VariationalAutoencoder implements Layer { this.score += logPTheta / numSamples; //If we have any training listeners (for example, for UI StatsListener - pass on activations) - if (trainingListeners != null && !trainingListeners.isEmpty() && l == 0) { //Note: only doing this on the *first* sample + if (getTrainingConfig() != null && !getTrainingListeners().isEmpty() && l == 0) { //Note: only doing this on the *first* sample Map activations = new LinkedHashMap<>(); for (int i = 0; i < fwd.encoderActivations.length; i++) { activations.put("e" + i, fwd.encoderActivations[i]); @@ -288,9 +292,9 @@ public class VariationalAutoencoder implements Layer { } activations.put(VariationalAutoencoderParamInitializer.PXZ_PREFIX, reconstructionDistribution.generateAtMean(pxzDistributionPreOut)); - if (!trainingListeners.isEmpty()) { + if (!getTrainingListeners().isEmpty()) { try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - for (TrainingListener tl : trainingListeners) { + for (TrainingListener tl : getTrainingListeners()) { tl.onForwardPass(this, activations); } } @@ -495,7 +499,7 @@ public class VariationalAutoencoder implements Layer { } @Override - public INDArray params() { + public INDArray getModelParams() { return paramsFlattened; } @@ -510,8 +514,13 @@ public class VariationalAutoencoder implements Layer { } @Override - public TrainingConfig getConfig() { - return layerConfiguration; + public void setParamTable(Map paramTable) { + this.params = paramTable; + } + + @Override + public ITraininableLayerConfiguration getTrainingConfig() { + return (BaseLayerConfiguration) layerConfiguration; } @Override @@ -519,6 +528,24 @@ public class VariationalAutoencoder implements Layer { return numParams(false); } + /** + * @return 1d parameter vector + */ + @Override + public INDArray getParams() { + throw new RuntimeException("Not implemented."); + } + + @Override + public void setParams(INDArray params) { + if (params.length() != this.paramsFlattened.length()) { + throw new IllegalArgumentException("Cannot set parameters: expected parameters vector of length " + + this.paramsFlattened.length() + " but got parameters array of length " + params.length() + + " " + layerId()); + } + this.paramsFlattened.assign(params); + } + @Override public long numParams(boolean backwards) { int ret = 0; @@ -530,16 +557,6 @@ public class VariationalAutoencoder implements Layer { return ret; } - @Override - public void setParams(INDArray params) { - if (params.length() != this.paramsFlattened.length()) { - throw new IllegalArgumentException("Cannot set parameters: expected parameters vector of length " - + this.paramsFlattened.length() + " but got parameters array of length " + params.length() - + " " + layerId()); - } - this.paramsFlattened.assign(params); - } - @Override public void setParamsViewArray(INDArray params) { if (this.params != null && params.length() != numParams()) @@ -577,7 +594,7 @@ public class VariationalAutoencoder implements Layer { @Override public Pair gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } @Override @@ -695,7 +712,6 @@ public class VariationalAutoencoder implements Layer { return params.get(param); } - @Override public Map getParamTable(boolean backpropParamsOnly) { Map map = new LinkedHashMap<>(); @@ -712,11 +728,6 @@ public class VariationalAutoencoder implements Layer { return true; } - @Override - public void setParamTable(Map paramTable) { - this.params = paramTable; - } - @Override public void setParam(String key, INDArray val) { if (getParamTable().containsKey(key)) { @@ -844,15 +855,6 @@ public class VariationalAutoencoder implements Layer { return f.pzxMeanPreOut; } - @AllArgsConstructor - @Data - private static class VAEFwdHelper { - private INDArray[] encoderPreOuts; - private INDArray pzxMeanPreOut; - private INDArray[] encoderActivations; - } - - private VAEFwdHelper doForward(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) { assertInputSet(false); @@ -904,48 +906,8 @@ public class VariationalAutoencoder implements Layer { } @Override - public Collection getListeners() { - if (trainingListeners == null) { - return null; - } - - return new ArrayList<>(trainingListeners); - } - - @Override - public void setListeners(TrainingListener... listeners) { - setListeners(Arrays.asList(listeners)); - } - - public void setListeners(Collection listeners) { - if (trainingListeners == null) - trainingListeners = new ArrayList<>(); - else - trainingListeners.clear(); - if (trainingListeners == null) - trainingListeners = new ArrayList<>(); - else - trainingListeners.clear(); - - if (listeners != null && !listeners.isEmpty()) { - trainingListeners.addAll(listeners); - } - } - - - /** - * This method ADDS additional TrainingListener to existing listeners - * - * @param listeners - */ - @Override - public void addListeners(TrainingListener... listeners) { - if (this.trainingListeners == null) { - setListeners(listeners); - return; - } - - Collections.addAll(trainingListeners, listeners); + public int getIndex() { + return index; } @Override @@ -953,21 +915,11 @@ public class VariationalAutoencoder implements Layer { this.index = index; } - @Override - public int getIndex() { - return index; - } - @Override public void setInput(INDArray input, LayerWorkspaceMgr layerWorkspaceMgr) { this.input = input; } - @Override - public void setInputMiniBatchSize(int size) { - - } - @Override public int getInputMiniBatchSize() { if (input.size(0) > Integer.MAX_VALUE) @@ -976,8 +928,8 @@ public class VariationalAutoencoder implements Layer { } @Override - public void setMaskArray(INDArray maskArray) { - this.maskArray = maskArray; + public void setInputMiniBatchSize(int size) { + } @Override @@ -985,6 +937,11 @@ public class VariationalAutoencoder implements Layer { return maskArray; } + @Override + public void setMaskArray(INDArray maskArray) { + this.maskArray = maskArray; + } + @Override public boolean isPretrainLayer() { return true; @@ -1022,7 +979,8 @@ public class VariationalAutoencoder implements Layer { if (solver == null) { try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().model(this).configure(getNetConfiguration()).listeners(getListeners()).build(); + solver = new Solver.Builder().model(this).configure(getNetConfiguration()).listeners( + getTrainingListeners()).build(); } } this.optimizer = solver.getOptimizer(); @@ -1255,4 +1213,31 @@ public class VariationalAutoencoder implements Layer { public void close(){ //No-op for individual layers } + + /** + * Replace the TrainingListeners for this model + * + * @param listeners new listeners + */ + @Override + public void addTrainingListeners(TrainingListener... listeners) { + trainingListeners.addAll(List.of(listeners)); + } + +/** +* + * @param listeners + */ + @Override + public void addTrainingListeners(Collection listeners) { + trainingListeners.addAll(listeners); + } + + @AllArgsConstructor + @Data + private static class VAEFwdHelper { + private INDArray[] encoderPreOuts; + private INDArray pzxMeanPreOut; + private INDArray[] encoderActivations; + } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/wrapper/BaseWrapperLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/wrapper/BaseWrapperLayer.java index d27d9cfbb..497b08aaf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/wrapper/BaseWrapperLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/wrapper/BaseWrapperLayer.java @@ -24,11 +24,13 @@ import java.util.Collection; import java.util.Map; import lombok.Data; import lombok.NonNull; +import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.api.TrainingConfig; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.conf.CacheMode; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.layers.AbstractLayer; @@ -36,17 +38,104 @@ import org.deeplearning4j.nn.layers.LayerHelper; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.TrainingListener; +import org.jetbrains.annotations.NotNull; import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.dataset.api.DataSet; +import org.nd4j.linalg.dataset.api.MultiDataSet; +import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; @Data public abstract class BaseWrapperLayer extends AbstractLayer { protected Layer underlying; - public BaseWrapperLayer(@NonNull Layer underlying) { this.underlying = underlying; + this.setLayerConfiguration( underlying.getLayerConfiguration() ); + } + + @Override + public BaseLayerConfiguration getTypedLayerConfiguration() { + return (BaseLayerConfiguration) underlying.getLayerConfiguration(); + } + + /** + * This method returns updater state (if applicable), null otherwise + * + * @return + */ + @Override + public INDArray updaterState() { + return underlying.updaterState(); + } + + /** + * This method fits model with a given DataSet + * + * @param dataSet + */ + @Override + public void fit(DataSet dataSet) { +underlying.fit(dataSet); + } + + /** + * This method fits model with a given MultiDataSet + * + * @param dataSet + */ + @Override + public void fit(MultiDataSet dataSet) { +underlying.fit(dataSet); + } + + /** + * This method fits model with a given DataSetIterator + * + * @param iterator + */ + @Override + public void fit(DataSetIterator iterator) { +underlying.fit(iterator); + } + + /** + * This method fits model with a given MultiDataSetIterator + * + * @param iterator + */ + @Override + public void fit(MultiDataSetIterator iterator) { +underlying.fit(iterator); + } + + /** + * @param netConfiguration + */ + @Override + public void setNetConfiguration(@NonNull NeuralNetConfiguration netConfiguration) { +underlying.setNetConfiguration(netConfiguration); + } + + + /** + * Get a reference to the network this layer is part of. + * + * @return + */ + @Override + public IModel getNet() { + return underlying.getNet(); + } + + /** + * @return 1d parameter vector + */ + @Override + public INDArray getParams() { + return underlying.getParams(); } /** @@ -96,19 +185,15 @@ public abstract class BaseWrapperLayer extends AbstractLayer { return underlying.activate(input, training, workspaceMgr); } + @NotNull @Override - public Collection getListeners() { - return underlying.getListeners(); + public Collection getTrainingListeners() { + return underlying.getTrainingListeners(); } @Override - public void setListeners(TrainingListener... listeners) { - underlying.setListeners(listeners); - } - - @Override - public void addListeners(TrainingListener... listener) { - underlying.addListeners(listener); + public void addTrainingListeners(TrainingListener... listeners) { + underlying.addTrainingListeners(listeners); } @Override @@ -127,8 +212,8 @@ public abstract class BaseWrapperLayer extends AbstractLayer { } @Override - public double score() { - return underlying.score(); + public double getScore() { + return underlying.getScore(); } @Override @@ -137,8 +222,8 @@ public abstract class BaseWrapperLayer extends AbstractLayer { } @Override - public INDArray params() { - return underlying.params(); + public INDArray getModelParams() { + return underlying.getParams(); } @Override @@ -333,8 +418,8 @@ public abstract class BaseWrapperLayer extends AbstractLayer { } @Override - public TrainingConfig getConfig() { - return underlying.getConfig(); + public ITraininableLayerConfiguration getTrainingConfig() { + return underlying.getTrainingConfig(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java index 575ee27e9..2b27c0179 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/multilayer/MultiLayerNetwork.java @@ -20,28 +20,17 @@ package org.deeplearning4j.nn.multilayer; - -import java.io.File; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; +import com.fasterxml.jackson.annotation.JsonIdentityInfo; +import com.fasterxml.jackson.annotation.ObjectIdGenerators; +import java.io.*; +import java.util.*; import java.util.stream.Collectors; import lombok.Getter; import lombok.NonNull; import lombok.Setter; import lombok.extern.slf4j.Slf4j; import lombok.val; +import net.brutex.ai.dnn.api.IModel; import net.brutex.ai.dnn.networks.ArtificialNeuralNetwork; import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; @@ -49,30 +38,20 @@ import org.bytedeco.javacpp.Pointer; import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; import org.deeplearning4j.exception.DL4JException; import org.deeplearning4j.exception.DL4JInvalidInputException; -import org.deeplearning4j.nn.api.Classifier; -import org.deeplearning4j.nn.api.FwdPassType; -import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.MaskState; -import org.deeplearning4j.nn.api.ModelAdapter; -import org.deeplearning4j.nn.api.TrainingConfig; +import org.deeplearning4j.nn.api.*; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.api.layers.IOutputLayer; import org.deeplearning4j.nn.api.layers.RecurrentLayer; -import org.deeplearning4j.nn.conf.BackpropType; -import org.deeplearning4j.nn.conf.CNN2DFormat; -import org.deeplearning4j.nn.conf.CacheMode; -import org.deeplearning4j.nn.conf.InputPreProcessor; -import org.deeplearning4j.nn.conf.NeuralNetBaseBuilderConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.RNNFormat; -import org.deeplearning4j.nn.conf.WorkspaceMode; +import org.deeplearning4j.nn.conf.*; import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.FeedForwardLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional; import org.deeplearning4j.nn.gradient.DefaultGradient; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.layers.BaseLayer; import org.deeplearning4j.nn.layers.FrozenLayer; import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop; import org.deeplearning4j.nn.layers.LayerHelper; @@ -85,12 +64,7 @@ import org.deeplearning4j.optimize.Solver; import org.deeplearning4j.optimize.api.ConvexOptimizer; import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator; -import org.deeplearning4j.util.Convolution1DUtils; -import org.deeplearning4j.util.ConvolutionUtils; -import org.deeplearning4j.util.CrashReportingUtil; -import org.deeplearning4j.util.ModelSerializer; -import org.deeplearning4j.util.NetworkUtils; -import org.deeplearning4j.util.OutputLayerUtil; +import org.deeplearning4j.util.*; import org.jetbrains.annotations.NotNull; import org.nd4j.adapters.OutputAdapter; import org.nd4j.common.base.Preconditions; @@ -137,17 +111,18 @@ import org.nd4j.linalg.workspace.WorkspaceUtils; * above constitute what is known as a layer, and the transformative function is often referred to * as a unit. The intermediate states—often termed features—are used as the input into another * layer. - *

- * Through repetition of these steps, the artificial neural network learns multiple layers of + * + *

Through repetition of these steps, the artificial neural network learns multiple layers of * non-linear features, which it then combines in a final layer to create a prediction. - *

- * The neural network learns by generating an error signal that measures the difference between the - * predictions of the network and the desired values and then using this error signal to change the - * weights (or parameters) so that predictions get more accurate. + * + *

The neural network learns by generating an error signal that measures the difference between + * the predictions of the network and the desired values and then using this error signal to change + * the weights (or parameters) so that predictions get more accurate. */ @Slf4j -public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serializable, Classifier, - Layer { +@JsonIdentityInfo(generator = ObjectIdGenerators.IntSequenceGenerator.class, property = "@id") +public class MultiLayerNetwork extends ArtificialNeuralNetwork + implements Serializable, Classifier, Layer, ITrainableLayer { /** * Workspace for working memory for a single layer: forward pass and backward pass Note that this @@ -165,78 +140,79 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * every second layer */ protected static final String WS_LAYER_ACT_1 = "WS_LAYER_ACT_1"; + protected static final String WS_LAYER_ACT_2 = "WS_LAYER_ACT_2"; - /** - * Workspace for output methods that use OutputAdapter - */ + /** Workspace for output methods that use OutputAdapter */ protected static final String WS_OUTPUT_MEM = "WS_OUTPUT_MEM"; - /** - * Workspace for working memory in RNNs - opened and closed once per RNN time step - */ + /** Workspace for working memory in RNNs - opened and closed once per RNN time step */ protected static final String WS_RNN_LOOP_WORKING_MEM = "WS_RNN_LOOP_WORKING_MEM"; - protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = WorkspaceConfiguration.builder() - .initialSize(0) - .overallocationLimit(0.05) - .policyLearning(LearningPolicy.FIRST_LOOP) - .policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE) - .policyAllocation(AllocationPolicy.OVERALLOCATE) - .build(); - protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = WorkspaceConfiguration.builder() - .initialSize(0).overallocationLimit(0.05).policyReset(ResetPolicy.BLOCK_LEFT) - .policyAllocation(AllocationPolicy.OVERALLOCATE).policySpill(SpillPolicy.REALLOCATE) - .policyLearning(LearningPolicy.FIRST_LOOP).build(); - //the hidden neural network layers (including output layer) + protected static final WorkspaceConfiguration WS_ALL_LAYERS_ACT_CONFIG = + WorkspaceConfiguration.builder() + .initialSize(0) + .overallocationLimit(0.05) + .policyLearning(LearningPolicy.FIRST_LOOP) + .policyReset(ResetPolicy.BLOCK_LEFT) + .policySpill(SpillPolicy.REALLOCATE) + .policyAllocation(AllocationPolicy.OVERALLOCATE) + .build(); + protected static final WorkspaceConfiguration WS_RNN_LOOP_WORKING_MEM_CONFIG = + WorkspaceConfiguration.builder() + .initialSize(0) + .overallocationLimit(0.05) + .policyReset(ResetPolicy.BLOCK_LEFT) + .policyAllocation(AllocationPolicy.OVERALLOCATE) + .policySpill(SpillPolicy.REALLOCATE) + .policyLearning(LearningPolicy.FIRST_LOOP) + .build(); + // the hidden neural network layers (including output layer) protected Layer[] layers; - - //Current training data: input features and labels + // Current training data: input features and labels protected INDArray input, labels; protected boolean initCalled = false; protected Collection trainingListeners = new ArrayList<>(); protected Gradient gradient; protected double score; - @Setter - protected boolean initDone = false; - protected INDArray flattenedParams; //Params for all layers are a view/subset of this array + @Setter protected boolean initDone = false; + protected INDArray flattenedParams; // Params for all layers are a view/subset of this array + @Getter - protected transient INDArray flattenedGradients; //Gradients for all layers are a view/subset of this array - protected boolean clearTbpttState = true; //Mainly for unit testing (should be enabled otherwise) + protected transient INDArray + flattenedGradients; // Gradients for all layers are a view/subset of this array + + protected boolean clearTbpttState = true; // Mainly for unit testing (should be enabled otherwise) protected transient ThreadLocal lastEtlTime = new ThreadLocal<>(); protected INDArray mask; - protected int layerIndex; //For LayerConfiguration.get/setIndex() - protected transient Solver solver; //Used to call optimizers during backprop - //Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers - @Getter - protected transient Map helperWorkspaces = new HashMap<>(); + protected int layerIndex; // For LayerConfiguration.get/setIndex() + protected transient Solver solver; // Used to call optimizers during backprop + // Workspaces for CUDNN. Pass to LayerWorkspaceMgr for re-use in cudnn helpers + @Getter protected transient Map helperWorkspaces = new HashMap<>(); protected WorkspaceConfiguration WS_LAYER_WORKING_MEM_CONFIG; protected WorkspaceConfiguration WS_LAYER_ACT_X_CONFIG; - public MultiLayerNetwork(@NotNull NeuralNetConfiguration conf) { super(conf); - //Working memory: should learn over course of: (a) full forward pass, and (b) full backward pass - //Working memory should be opened once per layer and once per preprocessor, for each of forward and backward passes - int numWorkingMem = 2 * (conf.getFlattenedLayerConfigurations().size() - + conf.getInputPreProcessors().size()); + // Working memory: should learn over course of: (a) full forward pass, and (b) full backward + // pass + // Working memory should be opened once per layer and once per preprocessor, for each of forward + // and backward passes + int numWorkingMem = + 2 * (conf.getFlattenedLayerConfigurations().size() + conf.getInputPreProcessors().size()); WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); - WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig( - conf.getFlattenedLayerConfigurations().size()); - - init(); + WS_LAYER_ACT_X_CONFIG = + getLayerActivationWSConfig(conf.getFlattenedLayerConfigurations().size()); } public MultiLayerNetwork(@NotNull NeuralNetBaseBuilderConfiguration conf) { - this(( NeuralNetConfiguration) conf); + this((NeuralNetConfiguration) conf); } - /** * Initialize the network based on the configuration (a NeuralNetConfiguration in JSON format) and * parameters array * - * @param conf the configuration json + * @param conf the configuration json * @param params the parameters for the network */ public MultiLayerNetwork(String conf, INDArray params) { @@ -248,7 +224,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Initialize the network based on the configuration and parameters array * - * @param conf the configuration + * @param conf the configuration * @param params the parameters */ public MultiLayerNetwork(NeuralNetConfiguration conf, INDArray params) { @@ -270,8 +246,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } protected static WorkspaceConfiguration getLayerActivationWSConfig(int numLayers) { - //Activations memory: opened once per layer - for every second layer (preprocessors are within the loop). - //Technically we could set learning to numLayers / 2, but will set to numLayers for simplicity, and also to + // Activations memory: opened once per layer - for every second layer (preprocessors are within + // the loop). + // Technically we could set learning to numLayers / 2, but will set to numLayers for simplicity, + // and also to // account for a backward pass return WorkspaceConfiguration.builder() .initialSize(0) @@ -285,19 +263,29 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Restore a MultiLayerNetwork to a file, saved using {@link #save(File)} or - * {@link ModelSerializer} + * Restore a MultiLayerNetwork to a file, saved using {@link #save(File)} or {@link + * ModelSerializer} * - * @param f File to load the network from + * @param f File to load the network from * @param loadUpdater If true: load the updater if it is available (i.e., the state array for - * momentum/Adam/rmsprop etc) - use false if no further training is - * required, or true if further training will be undertaken + * momentum/Adam/rmsprop etc) - use false if no further training is required, or + * true if further training will be undertaken * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams) */ public static MultiLayerNetwork load(File f, boolean loadUpdater) throws IOException { return ModelSerializer.restoreMultiLayerNetwork(f, loadUpdater); } + /** + * Get a reference to this neural network. + * + * @return + */ + @Override + public IModel getNet() { + return this; + } + /** * Return the configuration of this layer * @@ -305,7 +293,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial */ @Override public LayerConfiguration getLayerConfiguration() { - //TODO + // TODO throw new RuntimeException( "getLayerConfiguration cannot be called on a MultiLayerNetwork. This function is here because of inheritance from Layer (which should be fixed)."); } @@ -358,9 +346,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial lastEtlTime.set(time); } - /** - * Perform layerwise pretraining for one epoch - see {@link #pretrain(DataSetIterator, int)} - */ + /** Perform layerwise pretraining for one epoch - see {@link #pretrain(DataSetIterator, int)} */ public void pretrain(DataSetIterator iter) { pretrain(iter, 1); } @@ -368,9 +354,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Perform layerwise unsupervised training on all pre-trainable layers in the network (VAEs, * Autoencoders, etc), for the specified number of epochs each. For example, if numEpochs=3, then - * layer 0 will be fit for 3 epochs, followed by layer 1 for 3 epochs, and so on.
Note that - * pretraining will be performed on one layer after the other. To perform unsupervised training on - * a single layer, use {@link #pretrainLayer(int, DataSetIterator)} + * layer 0 will be fit for 3 epochs, followed by layer 1 for 3 epochs, and so on.
+ * Note that pretraining will be performed on one layer after the other. To perform unsupervised + * training on a single layer, use {@link #pretrainLayer(int, DataSetIterator)} * * @param iter Training data */ @@ -384,32 +370,33 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - /** - * Fit for one epoch - see {@link #pretrainLayer(int, DataSetIterator, int)} - */ + /** Fit for one epoch - see {@link #pretrainLayer(int, DataSetIterator, int)} */ public void pretrainLayer(int layerIdx, DataSetIterator iter) { pretrainLayer(layerIdx, iter, 1); } /** * Perform layerwise unsupervised training on a single pre-trainable layer in the network (VAEs, - * Autoencoders, etc) for the specified number of epochs
If the specified layer index (0 to - * numLayers - 1) is not a pretrainable layer, this is a no-op. + * Autoencoders, etc) for the specified number of epochs
+ * If the specified layer index (0 to numLayers - 1) is not a pretrainable layer, this is a no-op. * - * @param layerIdx Index of the layer to train (0 to numLayers-1) - * @param iter Training data + * @param layerIdx Index of the layer to train (0 to numLayers-1) + * @param iter Training data * @param numEpochs Number of epochs to fit the specified layer for */ public void pretrainLayer(int layerIdx, DataSetIterator iter, int numEpochs) { - Preconditions.checkState(numEpochs > 0, "Number of epochs (%s) must be a positive number", - numEpochs); + Preconditions.checkState( + numEpochs > 0, "Number of epochs (%s) must be a positive number", numEpochs); if (flattenedGradients == null) { initGradientsView(); } if (layerIdx >= layers.length) { throw new IllegalArgumentException( - "Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + layers.length + "Cannot pretrain layer: layerIdx (" + + layerIdx + + ") >= numLayers (" + + layers.length + ")"); } @@ -419,8 +406,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } if (numEpochs > 1 && !iter.resetSupported()) { - throw new IllegalStateException("Cannot fit multiple epochs (" + numEpochs - + ") on an iterator that doesn't support resetting"); + throw new IllegalStateException( + "Cannot fit multiple epochs (" + + numEpochs + + ") on an iterator that doesn't support resetting"); } if (!iter.hasNext() && iter.resetSupported()) { @@ -447,8 +436,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Perform layerwise unsupervised training on a single pre-trainable layer in the network (VAEs, - * Autoencoders, etc)
If the specified layer index (0 to numLayers - 1) is not a pretrainable - * layer, this is a no-op. + * Autoencoders, etc)
+ * If the specified layer index (0 to numLayers - 1) is not a pretrainable layer, this is a no-op. * * @param layerIdx Index of the layer to train (0 to numLayers-1) * @param features Training data array @@ -462,7 +451,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } if (layerIdx >= layers.length) { throw new IllegalArgumentException( - "Cannot pretrain layer: layerIdx (" + layerIdx + ") >= numLayers (" + layers.length + "Cannot pretrain layer: layerIdx (" + + layerIdx + + ") >= numLayers (" + + layers.length + ")"); } @@ -470,11 +462,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .defaultWorkspace(WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + workspaceMgr = + LayerWorkspaceMgr.builder() + .defaultWorkspace(WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); } workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); @@ -483,15 +478,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return; } - //Do forward pass to the layer to be pretrained + // Do forward pass to the layer to be pretrained INDArray outputOfPrevLayer; if (layerIdx == 0) { outputOfPrevLayer = input; } else { - //Yes, this part of training - but we'll do forward psas as inference mode when doing layerwise training + // Yes, this part of training - but we'll do forward psas as inference mode when doing + // layerwise training // to effectively freeze earlier layers and not apply dropout etc - outputOfPrevLayer = outputOfLayerDetached(false, FwdPassType.STANDARD, layerIndex - 1, - features, null, null, null); + outputOfPrevLayer = + outputOfLayerDetached( + false, FwdPassType.STANDARD, layerIndex - 1, features, null, null, null); } try (MemoryWorkspace ws = workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { @@ -500,9 +497,13 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (input.size(0) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - outputOfPrevLayer = getNetConfiguration().getInputPreProcess(layerIdx) - .preProcess(outputOfPrevLayer, (int) input.size(0), - LayerWorkspaceMgr.noWorkspaces(helperWorkspaces)); + outputOfPrevLayer = + getNetConfiguration() + .getInputPreProcess(layerIdx) + .preProcess( + outputOfPrevLayer, + (int) input.size(0), + LayerWorkspaceMgr.noWorkspaces(helperWorkspaces)); } layer.fit(outputOfPrevLayer, workspaceMgr); @@ -511,9 +512,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial @Override public int batchSize() { - //In 99+% of cases, the input and labels dimension 0 size should be identical - //The only real exceptions: space to batch, and batch to space layers - //In those cases, we should base it on the labels size, as this impacts gradient calculation + // In 99+% of cases, the input and labels dimension 0 size should be identical + // The only real exceptions: space to batch, and batch to space layers + // In those cases, we should base it on the labels size, as this impacts gradient calculation if (input.size(0) > Integer.MAX_VALUE || labels.size(0) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } @@ -531,12 +532,13 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Get one parameter array for the network.
In MultiLayerNetwork, parameters are keyed like - * "0_W" and "0_b" to mean "weights of layer index 0" and "biases of layer index 0" respectively. - * Numbers increment sequentially, and the suffixes ("W", "b" etc.) depend on the layer type, and - * are defined in the relevant parameter initializers for each layer.
Note that the returned - * INDArrays are views of the underlying network parameters, so modifications of the returned - * arrays will impact the parameters of the network. + * Get one parameter array for the network.
+ * In MultiLayerNetwork, parameters are keyed like "0_W" and "0_b" to mean "weights of layer index + * 0" and "biases of layer index 0" respectively. Numbers increment sequentially, and the suffixes + * ("W", "b" etc.) depend on the layer type, and are defined in the relevant parameter + * initializers for each layer.
+ * Note that the returned INDArrays are views of the underlying network parameters, so + * modifications of the returned arrays will impact the parameters of the network. * * @param param the key of the parameter * @return The specified parameter array for the network @@ -544,7 +546,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial */ @Override public INDArray getParam(String param) { - //Get params for MultiLayerNetwork sub layers. + // Get params for MultiLayerNetwork sub layers. int idx = param.indexOf('_'); if (idx == -1) { throw new IllegalStateException( @@ -556,18 +558,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return layers[layerIdx].getParam(newKey); } - /** - * Returns a map of all parameters in the network as per {@link #getParamTable()}.
Optionally - * (with backpropParamsOnly=true) only the 'backprop' parameters are returned - that is, any - * parameters involved only in unsupervised layerwise pretraining not standard inference/backprop - * are excluded from the returned list. + * Returns a map of all parameters in the network as per {@link #getParamTable()}.
+ * Optionally (with backpropParamsOnly=true) only the 'backprop' parameters are returned - that + * is, any parameters involved only in unsupervised layerwise pretraining not standard + * inference/backprop are excluded from the returned list. * * @param backpropParamsOnly If true, return backprop params only. If false: return all params * @return Parameters for the network */ public Map paramTable(boolean backpropParamsOnly) { - //Get all parameters from all layers + // Get all parameters from all layers Map allParams = new LinkedHashMap<>(); for (int i = 0; i < layers.length; i++) { Map paramMap = layers[i].getParamTable(backpropParamsOnly); @@ -579,63 +580,25 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return allParams; } - /** - * Intended for internal use - */ + /** Intended for internal use */ @Override public boolean updaterDivideByMinibatch(String paramName) { int idx = paramName.indexOf('_'); int layerIdx = Integer.parseInt(paramName.substring(0, idx)); String subName = paramName.substring(idx + 1); - return getLayer(layerIdx).updaterDivideByMinibatch(subName); + return ((BaseLayer) getLayer(layerIdx)).updaterDivideByMinibatch(subName); } /** - * Set the parameters of the netowrk. Note that the parameter keys must match the format as - * described in {@link #getParam(String)} and {@link #getParamTable()}. Note that the values of the - * parameters used as an argument to this method are copied - i.e., it is safe to later - * modify/reuse the values in the provided paramTable without this impacting the network. - * - * @param paramTable Parameters to set - */ - @Override - public void setParamTable(Map paramTable) { - Map currParamTable = getParamTable(); - if (!currParamTable.keySet().equals(paramTable.keySet())) { - throw new IllegalArgumentException( - "Cannot set param table: parameter keys do not match.\n" + "Current: " - + currParamTable.keySet() + "\nTo set: " + paramTable.keySet()); - } - - for (String s : paramTable.keySet()) { - INDArray curr = currParamTable.get(s); - INDArray toSet = paramTable.get(s); - if (!Arrays.equals(curr.shape(), toSet.shape())) { - throw new IllegalArgumentException( - "Cannot set parameter table: parameter \"" + s + "\" shapes " - + "do not match. Current = " + Arrays.toString(curr.shape()) + ", to set = " - + Arrays.toString(toSet.shape())); - } - } - - //Now that we've checked ALL params (to avoid leaving net in half-modified state) - for (String s : paramTable.keySet()) { - INDArray curr = currParamTable.get(s); - INDArray toSet = paramTable.get(s); - curr.assign(toSet); - } - } - - /** - * Set the values of a single parameter. See {@link #setParamTable(Map)} and - * {@link #getParam(String)} for more details. + * Set the values of a single parameter. See {@link #setParamTable(Map)} and {@link + * #getParam(String)} for more details. * * @param key the key of the parameter to set * @param val the new values for the parameter */ @Override public void setParam(String key, INDArray val) { - //Set params for MultiLayerNetwork sub layers. + // Set params for MultiLayerNetwork sub layers. int idx = key.indexOf('_'); if (idx == -1) { throw new IllegalStateException( @@ -663,81 +626,108 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * network; if no parameters array is specified, parameters will be initialized randomly according * to the network configuration. * - * @param parameters Network parameter. May be null. If null: randomly initialize. + * @param parameters Network parameter. May be null. If null: randomly initialize. * @param cloneParametersArray Whether the parameter array (if any) should be cloned, or used - * directly + * directly */ public void init(INDArray parameters, boolean cloneParametersArray) { if (initCalled) { + log.trace( + "Initialisation in {} has already been called. Ignoring additional call to init().", + getClass().getSimpleName()); return; } + /** + * Initialize the neural network configuration first. This also triggers inheritance of + * configuration setting where needed. + */ + getNetConfiguration().setNeuralNet(this); + getNetConfiguration() + .init(); // we cannot do this in constructor, as the config might be attached later. + DataType netDtype = getNetConfiguration().getDataType(); if (parameters != null && parameters.dataType() != netDtype) { - Preconditions.checkState(parameters.rank() == 2 && parameters.size(0) == 1, + Preconditions.checkState( + parameters.rank() == 2 && parameters.size(0) == 1, "Invalid parameters array: should be rank 2 with shape [1,numParams]. Got %ndShape", parameters); if (cloneParametersArray) { - try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { parameters = parameters.castTo(netDtype); } } else { throw new IllegalStateException( - "Error initializing network: Network datatype is set to " + netDtype - + " but provided array has datatype " + parameters.dataType() - + " with cloneParametersArray argument" + - " set to false. Cannot initialize net with specified datatype array if that array does not match network datatype"); + "Error initializing network: Network datatype is set to " + + netDtype + + " but provided array has datatype " + + parameters.dataType() + + " with cloneParametersArray argument" + + " set to false. Cannot initialize net with specified datatype array if that array does not match network datatype"); } } - + /** Set default Training and Inference Workspace modes unless set already */ if (getNetConfiguration().getTrainingWorkspaceMode() == null) { getNetConfiguration().setTrainingWorkspaceMode(WorkspaceMode.NONE); } - if (getNetConfiguration().getInferenceWorkspaceMode() == null) { getNetConfiguration().setInferenceWorkspaceMode(WorkspaceMode.NONE); } - + /** set default Cache mode, unless set already */ if (getNetConfiguration().getCacheMode() == null) { getNetConfiguration().setCacheMode(CacheMode.NONE); } - OneTimeLogger.info(log, + OneTimeLogger.info( + log, // Todo: Why not SLF4J? "Starting MultiLayerNetwork with WorkspaceModes set to [training: {}; inference: {}], cacheMode set to [{}]", getNetConfiguration().getTrainingWorkspaceMode(), getNetConfiguration().getInferenceWorkspaceMode(), getNetConfiguration().getCacheMode()); int nLayers = getNetConfiguration().getFlattenedLayerConfigurations().size(); - if (nLayers < 1) { throw new IllegalStateException("Unable to create network: number of layers is less than 1"); } + /** Initialize the array of Layers for this network using the number of LayerConfigurations */ if (this.layers == null || this.layers[0] == null) { if (this.layers == null) { this.layers = new Layer[nLayers]; } - //First: Work out total length of params + // First: Work out total length of params long paramLength = 0; val nParamsPerLayer = new long[nLayers]; for (int i = 0; i < nLayers; i++) { - LayerConfiguration layer_conf = getNetConfiguration().getFlattenedLayerConfigurations().get(i); - layer_conf.setDataType(netDtype); + LayerConfiguration layer_conf = + getNetConfiguration().getFlattenedLayerConfigurations().get(i); + // Test if Layer type has parameters (is inherited from BaseLayerConfiguration rather then + // LayerConfiguration + if (layer_conf instanceof BaseLayerConfiguration) + ((BaseLayerConfiguration) layer_conf).setDataType(netDtype); + nParamsPerLayer[i] = layer_conf.initializer().numParams(layer_conf); paramLength += nParamsPerLayer[i]; } + log.debug( + "Neural Network {} is initializes with a total number of {} parameters from {} layers.", + getClass().getSimpleName(), + paramLength, + nLayers); - //Create parameters array, if required + // Create parameters array, if required boolean initializeParams; if (parameters != null) { if (!parameters.isRowVectorOrScalar()) { throw new IllegalArgumentException("Invalid parameters: should be a row vector"); } if (parameters.length() != paramLength) { - throw new IllegalArgumentException("Invalid parameters: expected length " + paramLength - + ", got length " + parameters.length()); + throw new IllegalArgumentException( + "Invalid parameters: expected length " + + paramLength + + ", got length " + + parameters.length()); } if (cloneParametersArray) { @@ -751,12 +741,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial flattenedParams = Nd4j.create(netDtype, 1, paramLength); initializeParams = true; } else { - //Edge case: 0 params in network + // Edge case: 0 params in network flattenedParams = null; initializeParams = false; } - //Set RNG seed, for repeatability between initializations when set + // Set RNG seed, for repeatability between initializations when set if (initializeParams) { Nd4j.getRandom().setSeed(getNetConfiguration().getSeed()); } @@ -766,33 +756,43 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial for (int i = 0; i < nLayers; i++) { INDArray paramsView; if (nParamsPerLayer[i] > 0) { - paramsView = flattenedParams.get(NDArrayIndex.interval(0, 0, true), - NDArrayIndex.interval(paramCountSoFar, paramCountSoFar + nParamsPerLayer[i])); + paramsView = + flattenedParams.get( + NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(paramCountSoFar, paramCountSoFar + nParamsPerLayer[i])); } else { paramsView = null; } paramCountSoFar += nParamsPerLayer[i]; @NonNull LayerConfiguration lc = getNetConfiguration().getFlattenedLayerConfigurations().get(i); - layers[i] = lc.instantiate(lc.getNetConfiguration(), trainingListeners, i, paramsView, initializeParams, + layers[i] = + lc.instantiate( + lc.getNetConfiguration(), + trainingListeners, + i, + paramsView, + initializeParams, netDtype); } initCalled = true; } - //Set parameters in MultiLayerNetwork.getNetConfiguration() for later use in BaseOptimizer.setupSearchState() etc + // Set parameters in MultiLayerNetwork.getNetConfiguration() for later use in + // BaseOptimizer.setupSearchState() etc getNetConfiguration().clearNetWideVariable(); List variables = getNetConfiguration().netWideVariables(false); for (int i = 0; i < layers.length; i++) { if (layers[i] == null) { throw new IllegalStateException( - "Encountered null layer during initialization for layer " + i + - ": " + layers[i].getClass().getSimpleName() - + " initialization " + - "returned null layer?"); + "Encountered null layer during initialization for layer " + + i + + ": " + + layers[i].getClass().getSimpleName() + + " initialization " + + "returned null layer?"); } - - for (String s : layers[i].getNetConfiguration().netWideVariables()) { + for (String s : layers[i].getLayerConfiguration().getVariables()) { variables.add(i + "_" + s); } } @@ -800,14 +800,18 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial // now we init solver & optimizer if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) - .build(); + solver = + new Solver.Builder() + .configure(getNetConfiguration()) + .listeners(this.getTrainingListeners()) + .model(this) + .build(); solver.initOptimizer(); } } - //Mark that input modification is allowed. - //TODO When is it safe to NOT skip the very first layer? It's not always safe... + // Mark that input modification is allowed. + // TODO When is it safe to NOT skip the very first layer? It's not always safe... // For example dropout + iterating over List that is used for multiple epochs... for (int i = 1; i < layers.length; i++) { layers[i].allowInputModification(true); @@ -817,11 +821,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * This method allows you to specificy GradientsAccumulator instance to be used with this - * model
+ * This method allows you to specificy GradientsAccumulator instance to be used with this model + *
*
* PLEASE NOTE: Do not use this method unless you understand how to use GradientsAccumulator & - * updates sharing.
PLEASE NOTE: Do not use this method on standalone model + * updates sharing.
+ * PLEASE NOTE: Do not use this method on standalone model * * @param accumulator Gradient accumulator to use for the network */ @@ -832,8 +837,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) - .build(); + solver = + new Solver.Builder() + .configure(getNetConfiguration()) + .listeners(this.getTrainingListeners()) + .model(this) + .build(); } } @@ -857,38 +866,49 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial int nLayers = layers.length; - //First: Work out total length of params + // First: Work out total length of params long paramLength = 0; val nParamsPerLayer = new long[nLayers]; for (int i = 0; i < nLayers; i++) { - LayerConfiguration layerConfiguration = getNetConfiguration().getFlattenedLayerConfigurations().get(i); - nParamsPerLayer[i] = layerConfiguration.initializer().numParams(layerConfiguration); //TODO better initialisation + LayerConfiguration layerConfiguration = + getNetConfiguration().getFlattenedLayerConfigurations().get(i); + nParamsPerLayer[i] = + layerConfiguration + .initializer() + .numParams(layerConfiguration); // TODO better initialisation paramLength += nParamsPerLayer[i]; } if (paramLength > 0) { - flattenedGradients = Nd4j.create(flattenedParams.dataType(), new long[]{1, paramLength}, - 'f'); //No need to initialize, as each layer will do it each iteration anyway + flattenedGradients = + Nd4j.create( + flattenedParams.dataType(), + new long[] {1, paramLength}, + 'f'); // No need to initialize, as each layer will do it each iteration anyway } long paramsSoFar = 0; for (int i = 0; i < layers.length; i++) { if (nParamsPerLayer[i] == 0) { - continue; //This layer doesn't have any parameters... + continue; // This layer doesn't have any parameters... } - INDArray thisLayerGradView = flattenedGradients.get(NDArrayIndex.interval(0, 0, true), - NDArrayIndex.interval(paramsSoFar, paramsSoFar + nParamsPerLayer[i])); + INDArray thisLayerGradView = + flattenedGradients.get( + NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(paramsSoFar, paramsSoFar + nParamsPerLayer[i])); layers[i].setBackpropGradientsViewArray(thisLayerGradView); paramsSoFar += nParamsPerLayer[i]; } } } - protected INDArray activationFromPrevLayer(int curr, INDArray input, boolean training, - LayerWorkspaceMgr mgr) { + protected INDArray activationFromPrevLayer( + int curr, INDArray input, boolean training, LayerWorkspaceMgr mgr) { if (getNetConfiguration().getInputPreProcess(curr) != null) { - input = getNetConfiguration().getInputPreProcess(curr) - .preProcess(input, getInputMiniBatchSize(), mgr); + input = + getNetConfiguration() + .getInputPreProcess(curr) + .preProcess(input, getInputMiniBatchSize(), mgr); } INDArray ret = layers[curr].activate(input, training, mgr); @@ -897,12 +917,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Calculate activation for few layers at once. Suitable for autoencoder partial activation. - *

- * In example: in 10-layer deep autoencoder, layers 0 - 4 inclusive are used for encoding part, + * + *

In example: in 10-layer deep autoencoder, layers 0 - 4 inclusive are used for encoding part, * and layers 5-9 inclusive are used for decoding part. * * @param from first layer to be activated, inclusive - * @param to last layer to be activated, inclusive + * @param to last layer to be activated, inclusive * @return the activation from the last layer */ public INDArray activateSelectedLayers(int from, int to, INDArray input) { @@ -917,7 +937,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } try { - LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces(helperWorkspaces); //TODO + LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces(helperWorkspaces); // TODO INDArray res = input; for (int l = from; l <= to; l++) { @@ -936,8 +956,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * activations of layer 0, and so on. * * @param train Training: if true, perform forward pass/inference at training time. Usually, - * inference is performed with train = false. This impacts whether dropout etc is - * applied or not. + * inference is performed with train = false. This impacts whether dropout etc is applied or + * not. * @return The list of activations for each layer, including the input */ public List feedForward(INDArray input, boolean train) { @@ -946,16 +966,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Compute activations from input to output of the output layer. As per - * {@link #feedForward(INDArray, boolean)} but using the inputs that have previously been set - * using {@link #setInput(INDArray)} + * Compute activations from input to output of the output layer. As per {@link + * #feedForward(INDArray, boolean)} but using the inputs that have previously been set using + * {@link #setInput(INDArray)} * * @return the list of activations for each layer */ public List feedForward(boolean train) { try { - return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layers.length - 1, - input, mask, null, true); + return ffToLayerActivationsDetached( + train, FwdPassType.STANDARD, false, layers.length - 1, input, mask, null, true); } catch (OutOfMemoryError e) { CrashReportingUtil.writeMemoryCrashDump(this, e); throw e; @@ -963,21 +983,21 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Perform feed-forward, optionally (not) clearing the layer input arrays.
Note: when using - * clearInputs=false, there can be some performance and memory overhead: this is because the - * arrays are defined outside of workspaces (which are enabled by default) - otherwise, - * old/invalidated arrays could still be accessed after calling this method. Consequently: Don't - * use clearInputs=false unless you have a use case that requires them to remain after - * feed-forward has been completed + * Perform feed-forward, optionally (not) clearing the layer input arrays.
+ * Note: when using clearInputs=false, there can be some performance and memory overhead: this is + * because the arrays are defined outside of workspaces (which are enabled by default) - + * otherwise, old/invalidated arrays could still be accessed after calling this method. + * Consequently: Don't use clearInputs=false unless you have a use case that requires them to + * remain after feed-forward has been completed * - * @param train training mode (true) or test mode (false) + * @param train training mode (true) or test mode (false) * @param clearInputs If false: don't clear the layer inputs * @return Activations from feed-forward */ public List feedForward(boolean train, boolean clearInputs) { try { - return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layers.length - 1, - input, mask, null, clearInputs); + return ffToLayerActivationsDetached( + train, FwdPassType.STANDARD, false, layers.length - 1, input, mask, null, clearInputs); } catch (OutOfMemoryError e) { CrashReportingUtil.writeMemoryCrashDump(this, e); throw e; @@ -985,21 +1005,20 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Compute the activations from the input to the specified layer.
To compute activations for - * all layers, use feedForward(...) methods
Note: output list includes the original input. So - * list.get(0) is always the original input, and list.get(i+1) is the activations of the ith - * layer. + * Compute the activations from the input to the specified layer.
+ * To compute activations for all layers, use feedForward(...) methods
+ * Note: output list includes the original input. So list.get(0) is always the original input, and + * list.get(i+1) is the activations of the ith layer. * * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. - * feedForwardToLayer(i,input) will return the activations for layers 0..i - * (inclusive) - * @param input Input to the network + * feedForwardToLayer(i,input) will return the activations for layers 0..i (inclusive) + * @param input Input to the network * @return list of activations. */ public List feedForwardToLayer(int layerNum, INDArray input) { try { - return ffToLayerActivationsDetached(false, FwdPassType.STANDARD, false, layerNum, input, mask, - null, true); + return ffToLayerActivationsDetached( + false, FwdPassType.STANDARD, false, layerNum, input, mask, null, true); } catch (OutOfMemoryError e) { CrashReportingUtil.writeMemoryCrashDump(this, e); throw e; @@ -1007,24 +1026,22 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Compute the activations from the input to the specified layer.
To compute activations for - * all layers, use feedForward(...) methods
Note: output list includes the original input. So - * list.get(0) is always the original input, and list.get(i+1) is the activations of the ith - * layer. + * Compute the activations from the input to the specified layer.
+ * To compute activations for all layers, use feedForward(...) methods
+ * Note: output list includes the original input. So list.get(0) is always the original input, and + * list.get(i+1) is the activations of the ith layer. * * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. - * feedForwardToLayer(i,input) will return the activations for layers 0..i - * (inclusive) - * @param input Input to the network - * @param train true for training, false for test (i.e., false if using network after - * training) + * feedForwardToLayer(i,input) will return the activations for layers 0..i (inclusive) + * @param input Input to the network + * @param train true for training, false for test (i.e., false if using network after training) * @return list of activations. */ public List feedForwardToLayer(int layerNum, INDArray input, boolean train) { try { int layerVertexIdx = layers[layerNum].getIndex(); - return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerVertexIdx, input, - mask, null, true); + return ffToLayerActivationsDetached( + train, FwdPassType.STANDARD, false, layerVertexIdx, input, mask, null, true); } catch (OutOfMemoryError e) { CrashReportingUtil.writeMemoryCrashDump(this, e); throw e; @@ -1033,30 +1050,33 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Compute the activations from the input to the specified layer, using the currently set input - * for the network.
To compute activations for all layers, use feedForward(...) methods
+ * for the network.
+ * To compute activations for all layers, use feedForward(...) methods
* Note: output list includes the original input. So list.get(0) is always the original input, and * list.get(i+1) is the activations of the ith layer. * * @param layerNum Index of the last layer to calculate activations for. Layers are zero-indexed. - * feedForwardToLayer(i,input) will return the activations for layers 0..i - * (inclusive) - * @param train true for training, false for test (i.e., false if using network after - * training) + * feedForwardToLayer(i,input) will return the activations for layers 0..i (inclusive) + * @param train true for training, false for test (i.e., false if using network after training) * @return list of activations. */ public List feedForwardToLayer(int layerNum, boolean train) { try { - return ffToLayerActivationsDetached(train, FwdPassType.STANDARD, false, layerNum, input, mask, - null, true); + return ffToLayerActivationsDetached( + train, FwdPassType.STANDARD, false, layerNum, input, mask, null, true); } catch (OutOfMemoryError e) { CrashReportingUtil.writeMemoryCrashDump(this, e); throw e; } } - protected void validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, + protected void validateArrayWorkspaces( + @NonNull LayerWorkspaceMgr mgr, + @NonNull INDArray array, + @NonNull ArrayType arrayType, int layerIdx, - boolean isPreprocessor, String op) { + boolean isPreprocessor, + String op) { try { mgr.validateArrayLocation(arrayType, array, false, layerIdx > 0); } catch (ND4JWorkspaceException e) { @@ -1068,11 +1088,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial clazz = layers[layerIdx].getClass().getName(); } throw new IllegalStateException( - op + ": array (" + arrayType + ") workspace validation failed (" + - (isPreprocessor ? "preprocessor" : "layer ") + layerIdx + (layerName != null ? - " - layer name \"" + - layerName + "\"" : "") + " - class: " + clazz - + ") - array is defined in incorrect workspace", e); + op + + ": array (" + + arrayType + + ") workspace validation failed (" + + (isPreprocessor ? "preprocessor" : "layer ") + + layerIdx + + (layerName != null ? " - layer name \"" + layerName + "\"" : "") + + " - class: " + + clazz + + ") - array is defined in incorrect workspace", + e); } } @@ -1081,46 +1107,54 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * workspace. Note that no workspace should be active externally when calling this method (an * exception will be thrown if a workspace is open externally) * - * @param train Training mode (true) or test/inference mode (false) - * @param fwdPassType Type of forward pass to perform (STANDARD or - * RNN_ACTIVATE_WITH_STORED_STATE only) - * @param storeLastForTBPTT ONLY used if fwdPassType == - * FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE - * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use - * numLayers-1 - * @param input Input to the network - * @param fMask Feature mask array. May be null. - * @param lMask Label mask array. May be null. - * @param clearInputs Whether the layer inputs should be cleared + * @param train Training mode (true) or test/inference mode (false) + * @param fwdPassType Type of forward pass to perform (STANDARD or RNN_ACTIVATE_WITH_STORED_STATE + * only) + * @param storeLastForTBPTT ONLY used if fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE + * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use numLayers-1 + * @param input Input to the network + * @param fMask Feature mask array. May be null. + * @param lMask Label mask array. May be null. + * @param clearInputs Whether the layer inputs should be cleared * @return List of activations (including the input), detached from any workspace */ - protected synchronized List ffToLayerActivationsDetached(boolean train, + protected synchronized List ffToLayerActivationsDetached( + boolean train, @NonNull FwdPassType fwdPassType, - boolean storeLastForTBPTT, int layerIndex, @NonNull INDArray input, - INDArray fMask, INDArray lMask, boolean clearInputs) { + boolean storeLastForTBPTT, + int layerIndex, + @NonNull INDArray input, + INDArray fMask, + INDArray lMask, + boolean clearInputs) { setInput(input); setLayerMaskArrays(fMask, lMask); - //Verify that no workspace is open externally + // Verify that no workspace is open externally WorkspaceUtils.assertNoWorkspacesOpen( "Expected no workspace active in ffToLayerActivationsDetached"); LayerWorkspaceMgr workspaceMgr; - WorkspaceMode wsm = (train ? getNetConfiguration().getTrainingWorkspaceMode() - : getNetConfiguration().getInferenceWorkspaceMode()); + WorkspaceMode wsm = + (train + ? getNetConfiguration().getTrainingWorkspaceMode() + : getNetConfiguration().getInferenceWorkspaceMode()); if (wsm == WorkspaceMode.NONE) { workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .noWorkspaceFor(ArrayType.ACTIVATIONS) - .with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + workspaceMgr = + LayerWorkspaceMgr.builder() + .noWorkspaceFor(ArrayType.ACTIVATIONS) + .with(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); if (input.isAttached()) { - //Don't leverage out of async DataSetIterator workspaces + // Don't leverage out of async DataSetIterator workspaces workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); } @@ -1131,17 +1165,26 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); List out = new ArrayList<>(); - out.add(workspaceMgr.leverageTo(ArrayType.INPUT, - input)); //Should be unnecessary (and no op), if layer is implemented correctly + out.add( + workspaceMgr.leverageTo( + ArrayType.INPUT, + input)); // Should be unnecessary (and no op), if layer is implemented correctly for (int i = 0; i <= layerIndex; i++) { - try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered( - ArrayType.FF_WORKING_MEM)) { + try (MemoryWorkspace wsFFWorking = + workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { if (getNetConfiguration().getInputPreProcess(i) != null) { - input = getNetConfiguration().getInputPreProcess(i) - .preProcess(input, getInputMiniBatchSize(), workspaceMgr); - //Validation: Exception if invalid (bad preprocessor implementation) - validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, + input = + getNetConfiguration() + .getInputPreProcess(i) + .preProcess(input, getInputMiniBatchSize(), workspaceMgr); + // Validation: Exception if invalid (bad preprocessor implementation) + validateArrayWorkspaces( + workspaceMgr, + input, + ArrayType.ACTIVATIONS, + i, + true, "Feed forward to layer (inference)"); } @@ -1149,15 +1192,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial input = layers[i].activate(input, train, workspaceMgr); } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) { if (layers[i] instanceof RecurrentLayer) { - input = ((RecurrentLayer) layers[i]).rnnActivateUsingStoredState(input, train, - storeLastForTBPTT, workspaceMgr); + input = + ((RecurrentLayer) layers[i]) + .rnnActivateUsingStoredState(input, train, storeLastForTBPTT, workspaceMgr); } else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) { RecurrentLayer rl = (RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying(); input = rl.rnnActivateUsingStoredState(input, train, storeLastForTBPTT, workspaceMgr); } else if (layers[i] instanceof MultiLayerNetwork) { - List temp = ((MultiLayerNetwork) layers[i]).rnnActivateUsingStoredState(input, - train, storeLastForTBPTT); + List temp = + ((MultiLayerNetwork) layers[i]) + .rnnActivateUsingStoredState(input, train, storeLastForTBPTT); input = temp.get(temp.size() - 1); } else { input = layers[i].activate(input, train, workspaceMgr); @@ -1167,8 +1212,13 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial "Forward pass type not supported for this method: " + fwdPassType); } - //Validation: Exception if invalid (bad layer implementation) - validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, + // Validation: Exception if invalid (bad layer implementation) + validateArrayWorkspaces( + workspaceMgr, + input, + ArrayType.ACTIVATIONS, + i, + false, "Feed forward to layer (inference)"); out.add(input); @@ -1184,25 +1234,29 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Feed-forward through the network at training time - returning a list of all activations in a * workspace (WS_ALL_LAYERS_ACT) if workspaces are enabled for training; or detached if no - * workspaces are used.
Note: if using workspaces for training, this method requires that - * WS_ALL_LAYERS_ACT is open externally.
If using NO workspaces, requires that no external - * workspace is open
Note that this method does NOT clear the inputs to each layer - instead, - * they are in the WS_ALL_LAYERS_ACT workspace for use in later backprop. + * workspaces are used.
+ * Note: if using workspaces for training, this method requires that WS_ALL_LAYERS_ACT is open + * externally.
+ * If using NO workspaces, requires that no external workspace is open
+ * Note that this method does NOT clear the inputs to each layer - instead, they are in the + * WS_ALL_LAYERS_ACT workspace for use in later backprop. * - * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use - * numLayers-1 - * @param fwdPassType Type of forward pass to perform (STANDARD or - * RNN_ACTIVATE_WITH_STORED_STATE only) - * @param storeLastForTBPTT ONLY used if fwdPassType == - * FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE - * @param input Input to network - * @param fMask Feature mask array. May be null - * @param lMask Label mask aray. May be null. + * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use numLayers-1 + * @param fwdPassType Type of forward pass to perform (STANDARD or RNN_ACTIVATE_WITH_STORED_STATE + * only) + * @param storeLastForTBPTT ONLY used if fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE + * @param input Input to network + * @param fMask Feature mask array. May be null + * @param lMask Label mask aray. May be null. * @return */ - protected synchronized List ffToLayerActivationsInWs(int layerIndex, - @NonNull FwdPassType fwdPassType, boolean storeLastForTBPTT, - @NonNull INDArray input, INDArray fMask, INDArray lMask) { + protected synchronized List ffToLayerActivationsInWs( + int layerIndex, + @NonNull FwdPassType fwdPassType, + boolean storeLastForTBPTT, + @NonNull INDArray input, + INDArray fMask, + INDArray lMask) { setInput(input); setLayerMaskArrays(fMask, lMask); @@ -1212,44 +1266,55 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial "Expected no workspace active in ffToLayerActivationsInWs when training workspace is set to NONE"); workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + workspaceMgr = + LayerWorkspaceMgr.builder() + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); if (input.isAttached()) { - //Don't leverage out of async DataSetIterator workspaces + // Don't leverage out of async DataSetIterator workspaces workspaceMgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); } if (getNetConfiguration().getCacheMode() != CacheMode.NONE) { - //For now: store cache mode activations in activations workspace + // For now: store cache mode activations in activations workspace workspaceMgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); - workspaceMgr.setWorkspace(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, - WS_LAYER_WORKING_MEM_CONFIG); + workspaceMgr.setWorkspace( + ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG); } - WorkspaceUtils.assertOpenAndActive(WS_ALL_LAYERS_ACT, + WorkspaceUtils.assertOpenAndActive( + WS_ALL_LAYERS_ACT, "ffToLayerActivationsInWs method requires workspace WS_ALL_LAYERS_ACT to be open"); } workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); List out = new ArrayList<>(); - out.add(workspaceMgr.leverageTo(ArrayType.INPUT, input)); //Probably unnecessary usually + out.add(workspaceMgr.leverageTo(ArrayType.INPUT, input)); // Probably unnecessary usually boolean traceLog = log.isTraceEnabled(); for (int i = 0; i <= layerIndex; i++) { - try (MemoryWorkspace wsFFWorking = workspaceMgr.notifyScopeEntered( - ArrayType.FF_WORKING_MEM)) { + try (MemoryWorkspace wsFFWorking = + workspaceMgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { if (getNetConfiguration().getInputPreProcess(i) != null) { - input = getNetConfiguration().getInputPreProcess(i) - .preProcess(input, getInputMiniBatchSize(), workspaceMgr); - //Validation: Exception if invalid (bad preprocessor implementation) - validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, true, + input = + getNetConfiguration() + .getInputPreProcess(i) + .preProcess(input, getInputMiniBatchSize(), workspaceMgr); + // Validation: Exception if invalid (bad preprocessor implementation) + validateArrayWorkspaces( + workspaceMgr, + input, + ArrayType.ACTIVATIONS, + i, + true, "Feed forward to layer (training)"); } @@ -1261,15 +1326,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial input = layers[i].activate(input, true, workspaceMgr); } else if (fwdPassType == FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE) { if (layers[i] instanceof RecurrentLayer) { - input = ((RecurrentLayer) layers[i]).rnnActivateUsingStoredState(input, true, - storeLastForTBPTT, workspaceMgr); + input = + ((RecurrentLayer) layers[i]) + .rnnActivateUsingStoredState(input, true, storeLastForTBPTT, workspaceMgr); } else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) { RecurrentLayer rl = (RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying(); input = rl.rnnActivateUsingStoredState(input, true, storeLastForTBPTT, workspaceMgr); } else if (layers[i] instanceof MultiLayerNetwork) { - List temp = ((MultiLayerNetwork) layers[i]).rnnActivateUsingStoredState(input, - true, storeLastForTBPTT); + List temp = + ((MultiLayerNetwork) layers[i]) + .rnnActivateUsingStoredState(input, true, storeLastForTBPTT); input = temp.get(temp.size() - 1); } else { input = layers[i].activate(input, true, workspaceMgr); @@ -1283,10 +1350,27 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial throw new IllegalStateException("LayerConfiguration " + i + " returned null activations"); } - //Validation: Exception if invalid (bad layer implementation) - validateArrayWorkspaces(workspaceMgr, input, ArrayType.ACTIVATIONS, i, false, + // Validation: Exception if invalid (bad layer implementation) + validateArrayWorkspaces( + workspaceMgr, + input, + ArrayType.ACTIVATIONS, + i, + false, "Feed forward to layer (training)"); - validateArrayWorkspaces(workspaceMgr, layers[i].input(), ArrayType.INPUT, i, false, + if (layers[i].input() == null) { + log.error( + "Input for layer {} at index {} cannot be null.", + layers[i].getLayerConfiguration().getLayerName(), + i); + throw new RuntimeException("Layer input is null."); + } + validateArrayWorkspaces( + workspaceMgr, + layers[i].input(), + ArrayType.INPUT, + i, + false, "Feed forward to layer (training)"); out.add(input); @@ -1302,92 +1386,110 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Provide the output of the specified layer, detached from any workspace. This is most commonly - * used at inference/test time, and is more memory efficient than - * {@link #ffToLayerActivationsDetached(boolean, FwdPassType, boolean, int, INDArray, INDArray, - * INDArray, boolean)} and - * {@link #ffToLayerActivationsInWs(int, FwdPassType, boolean, INDArray, INDArray, INDArray)}.
+ * used at inference/test time, and is more memory efficient than {@link + * #ffToLayerActivationsDetached(boolean, FwdPassType, boolean, int, INDArray, INDArray, INDArray, + * boolean)} and {@link #ffToLayerActivationsInWs(int, FwdPassType, boolean, INDArray, INDArray, + * INDArray)}.
* This method clears all layer inputs. - *

- * NOTE: in general, no workspaces should be activated externally for this method! This method + * + *

NOTE: in general, no workspaces should be activated externally for this method! This method * handles the workspace activation as required * - * @param train Training mode (true) or test/inference mode (false) - * @param fwdPassType Type of forward pass to perform (STANDARD, RNN_TIMESTEP or - * RNN_ACTIVATE_WITH_STORED_STATE) - * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use - * numLayers-1 - * @param input Input to the network - * @param featureMask Input/feature mask array. May be null. - * @param labelsMask Labels mask array. May be null + * @param train Training mode (true) or test/inference mode (false) + * @param fwdPassType Type of forward pass to perform (STANDARD, RNN_TIMESTEP or + * RNN_ACTIVATE_WITH_STORED_STATE) + * @param layerIndex Index (inclusive) to stop forward pass at. For all layers, use numLayers-1 + * @param input Input to the network + * @param featureMask Input/feature mask array. May be null. + * @param labelsMask Labels mask array. May be null * @param outputWorkspace Optional - if provided, outputs should be placed in this workspace. - * NOTE: this workspace must be open + * NOTE: this workspace must be open * @return Output of the specified layer, detached from any workspace */ - protected INDArray outputOfLayerDetached(boolean train, @NonNull FwdPassType fwdPassType, - int layerIndex, @NonNull INDArray input, - INDArray featureMask, INDArray labelsMask, MemoryWorkspace outputWorkspace) { + protected INDArray outputOfLayerDetached( + boolean train, + @NonNull FwdPassType fwdPassType, + int layerIndex, + @NonNull INDArray input, + INDArray featureMask, + INDArray labelsMask, + MemoryWorkspace outputWorkspace) { setInput(input); setLayerMaskArrays(featureMask, labelsMask); - /* - Idea here: we want to minimize memory, and return only the final array - Approach to do this: keep activations in memory only as long as we need them. - In MultiLayerNetwork, the output activations of layer X are used as input to layer X+1 - Which means: the workspace for layer X has to be open for both layers X and X+1 forward pass. + /* + Idea here: we want to minimize memory, and return only the final array + Approach to do this: keep activations in memory only as long as we need them. + In MultiLayerNetwork, the output activations of layer X are used as input to layer X+1 + Which means: the workspace for layer X has to be open for both layers X and X+1 forward pass. - Here, we'll use two workspaces for activations: - 1. For even index layers, activations WS that opens on start of even layer fwd pass, closes at end of odd layer fwd pass - 2. For odd index layers, activations WS that opens on start of odd layer fwd pass, closes at end of even layer fwd pass + Here, we'll use two workspaces for activations: + 1. For even index layers, activations WS that opens on start of even layer fwd pass, closes at end of odd layer fwd pass + 2. For odd index layers, activations WS that opens on start of odd layer fwd pass, closes at end of even layer fwd pass - Additionally, we'll reconfigure the workspace manager for the *final* layer, so that we don't have to detach - */ + Additionally, we'll reconfigure the workspace manager for the *final* layer, so that we don't have to detach + */ if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace) { - WorkspaceUtils.assertNoWorkspacesOpen("Expected no workspace active in outputOfLayerDetached", - true); + WorkspaceUtils.assertNoWorkspacesOpen( + "Expected no workspace active in outputOfLayerDetached", true); } else { - Preconditions.checkState(outputWorkspace.isScopeActive(), - "Workspace \"" + outputWorkspace.getId() + - "\" was provided for the network/layer outputs. When provided, this workspace must be opened before " - + - "calling the output method; furthermore, closing the workspace is the responsibility of the user"); + Preconditions.checkState( + outputWorkspace.isScopeActive(), + "Workspace \"" + + outputWorkspace.getId() + + "\" was provided for the network/layer outputs. When provided, this workspace must be opened before " + + "calling the output method; furthermore, closing the workspace is the responsibility of the user"); } LayerWorkspaceMgr mgrEven; LayerWorkspaceMgr mgrOdd; - WorkspaceMode wsm = train ? getNetConfiguration().getTrainingWorkspaceMode() - : getNetConfiguration().getInferenceWorkspaceMode(); + WorkspaceMode wsm = + train + ? getNetConfiguration().getTrainingWorkspaceMode() + : getNetConfiguration().getInferenceWorkspaceMode(); if (wsm == WorkspaceMode.NONE) { mgrEven = LayerWorkspaceMgr.noWorkspaces(); mgrOdd = mgrEven; - //Check for external workspace - doesn't make sense to have one with workspace mode NONE + // Check for external workspace - doesn't make sense to have one with workspace mode NONE if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) { - throw new IllegalStateException("Workspace \"" + outputWorkspace.getId() + - "\" was provided for the network/layer outputs, however " + (train ? "training" - : "inference") + - " workspace mode is set to NONE. Cannot put output activations into the specified workspace if" - + - "workspaces are disabled for the network. use getNetConfiguration().setTraining/InferenceWorkspaceMode(WorkspaceMode.ENABLED)"); + throw new IllegalStateException( + "Workspace \"" + + outputWorkspace.getId() + + "\" was provided for the network/layer outputs, however " + + (train ? "training" : "inference") + + " workspace mode is set to NONE. Cannot put output activations into the specified workspace if" + + "workspaces are disabled for the network. use getNetConfiguration().setTraining/InferenceWorkspaceMode(WorkspaceMode.ENABLED)"); } } else { - mgrEven = LayerWorkspaceMgr.builder() - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) - .with(ArrayType.INPUT, WS_LAYER_ACT_2, - WS_LAYER_ACT_X_CONFIG) //Inputs should always be in the previous WS - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + mgrEven = + LayerWorkspaceMgr.builder() + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) + .with( + ArrayType.INPUT, + WS_LAYER_ACT_2, + WS_LAYER_ACT_X_CONFIG) // Inputs should always be in the previous WS + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); - mgrOdd = LayerWorkspaceMgr.builder() - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) - .with(ArrayType.INPUT, WS_LAYER_ACT_1, - WS_LAYER_ACT_X_CONFIG) //Inputs should always be in the previous WS - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + mgrOdd = + LayerWorkspaceMgr.builder() + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) + .with( + ArrayType.INPUT, + WS_LAYER_ACT_1, + WS_LAYER_ACT_X_CONFIG) // Inputs should always be in the previous WS + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); } mgrEven.setHelperWorkspacePointers(helperWorkspaces); mgrOdd.setHelperWorkspacePointers(helperWorkspaces); @@ -1407,64 +1509,74 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); } - //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) - //Hence: put inputs in working memory + // Edge case: for first layer with dropout, inputs can't be in previous workspace (as it + // hasn't been opened yet) + // Hence: put inputs in working memory if (i == 0 && wsm != WorkspaceMode.NONE) { mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG); } - try (MemoryWorkspace wsFFWorking = mgr.notifyScopeEntered( - ArrayType.FF_WORKING_MEM)) { //Working memory: opened/closed once per layer - //Activations workspaces: opened/closed every second layer. - //So mgrEven (WS_LAYER_ACT_1) open at start of 0, 2, 4, 8; closed at end of 1, 3, 5, 7 etc - //and mgrOdd (WS_LAYER_ACT_2) opened at start of 1, 3, 5, 7; closed at end of 2, 4, 6, 8 etc + try (MemoryWorkspace wsFFWorking = + mgr.notifyScopeEntered( + ArrayType.FF_WORKING_MEM)) { // Working memory: opened/closed once per layer + // Activations workspaces: opened/closed every second layer. + // So mgrEven (WS_LAYER_ACT_1) open at start of 0, 2, 4, 8; closed at end of 1, 3, 5, 7 + // etc + // and mgrOdd (WS_LAYER_ACT_2) opened at start of 1, 3, 5, 7; closed at end of 2, 4, 6, 8 + // etc temp = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS); - //Note that because we're opening activation workspaces not in a simple nested order, we'll manually - // override the previous workspace setting. Otherwise, when we close these workspaces, the "current" + // Note that because we're opening activation workspaces not in a simple nested order, + // we'll manually + // override the previous workspace setting. Otherwise, when we close these workspaces, the + // "current" // workspace may be set to the incorrect one temp.setPreviousWorkspace(initialWorkspace); if (i == 0 && input.isAttached()) { - //Don't leverage out of async DataSetIterator workspaces + // Don't leverage out of async DataSetIterator workspaces mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); } if (getNetConfiguration().getInputPreProcess(i) != null) { - input = getNetConfiguration().getInputPreProcess(i) - .preProcess(input, getInputMiniBatchSize(), mgr); - //Validation: Exception if invalid (bad preprocessor implementation) - validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, - "Output of layer (inference)"); + input = + getNetConfiguration() + .getInputPreProcess(i) + .preProcess(input, getInputMiniBatchSize(), mgr); + // Validation: Exception if invalid (bad preprocessor implementation) + validateArrayWorkspaces( + mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)"); } if (i == layerIndex) { if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) { - //Place activations in user-specified workspace - mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), + // Place activations in user-specified workspace + mgr.setWorkspace( + ArrayType.ACTIVATIONS, + outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration()); } else { - //Final activations: should be detached + // Final activations: should be detached mgr.setScopedOutFor(ArrayType.ACTIVATIONS); } } if (fwdPassType == FwdPassType.STANDARD) { - //Standard feed-forward case - if (i > 0 && ConvolutionUtils.layerHasConvolutionLayout( - layers[i - 1].getLayerConfiguration()) + // Standard feed-forward case + if (i > 0 + && ConvolutionUtils.layerHasConvolutionLayout(layers[i - 1].getLayerConfiguration()) && ConvolutionUtils.layerHasConvolutionLayout(layers[i].getLayerConfiguration())) { - CNN2DFormat preLayerFormat = ConvolutionUtils.getFormatForLayer( - layers[i - 1].getLayerConfiguration()); - CNN2DFormat currLayerFormat = ConvolutionUtils.getFormatForLayer( - layers[i].getLayerConfiguration()); + CNN2DFormat preLayerFormat = + ConvolutionUtils.getFormatForLayer(layers[i - 1].getLayerConfiguration()); + CNN2DFormat currLayerFormat = + ConvolutionUtils.getFormatForLayer(layers[i].getLayerConfiguration()); if (preLayerFormat != currLayerFormat) { - //NHWC case + // NHWC case if (preLayerFormat == CNN2DFormat.NCHW) { input = input.permute(0, 3, 1, 2); } - //NCHW case + // NCHW case else if (preLayerFormat == CNN2DFormat.NHWC) { input = input.permute(0, 2, 3, 1); @@ -1475,26 +1587,25 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } input = layers[i].activate(input, train, mgr); - } else if (i > 0 && Convolution1DUtils.hasRnnDataFormat( - layers[i - 1].getLayerConfiguration()) + } else if (i > 0 + && Convolution1DUtils.hasRnnDataFormat(layers[i - 1].getLayerConfiguration()) && Convolution1DUtils.hasRnnDataFormat(layers[i].getLayerConfiguration())) { - RNNFormat preLayerFormat = Convolution1DUtils.getRnnFormatFromLayer( - layers[i - 1].getLayerConfiguration()); - RNNFormat currLayerFormat = Convolution1DUtils.getRnnFormatFromLayer( - layers[i].getLayerConfiguration()); - //permute for next layer + RNNFormat preLayerFormat = + Convolution1DUtils.getRnnFormatFromLayer(layers[i - 1].getLayerConfiguration()); + RNNFormat currLayerFormat = + Convolution1DUtils.getRnnFormatFromLayer(layers[i].getLayerConfiguration()); + // permute for next layer if (preLayerFormat != currLayerFormat) { input = input.permute(0, 2, 1); } input = layers[i].activate(input, train, mgr); - } else { input = layers[i].activate(input, train, mgr); } } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) { - //rnnTimeStep case + // rnnTimeStep case if (layers[i] instanceof RecurrentLayer) { input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr); } else if (layers[i] instanceof BaseWrapperLayer @@ -1511,9 +1622,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial "Unsupported forward pass type for this method: " + fwdPassType); } layers[i].clear(); - //Validation: Exception if invalid (bad layer implementation) - validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, - "Output of layer (inference)"); + // Validation: Exception if invalid (bad layer implementation) + validateArrayWorkspaces( + mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)"); if (wsActCloseNext != null) { wsActCloseNext.close(); @@ -1526,11 +1637,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); } - //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) - //Hence: put inputs in working memory -> set back to default for next use of workspace mgr + // Edge case: for first layer with dropout, inputs can't be in previous workspace (as it + // hasn't been opened yet) + // Hence: put inputs in working memory -> set back to default for next use of workspace mgr if (i == 0 && wsm != WorkspaceMode.NONE) { - mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, - WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS + mgr.setWorkspace( + ArrayType.INPUT, + WS_LAYER_ACT_2, + WS_LAYER_ACT_X_CONFIG); // Inputs should always be in the previous WS } } } catch (Throwable t2) { @@ -1549,9 +1663,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } if (temp != null) { - //Should only be non-null on exception + // Should only be non-null on exception while (temp.isScopeActive()) { - //For safety, should never occur in theory: a single close() call may not be sufficient, if + // For safety, should never occur in theory: a single close() call may not be sufficient, + // if // workspace scope was borrowed and not properly closed when exception occurred try { temp.close(); @@ -1579,9 +1694,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial WorkspaceUtils.assertNoWorkspacesOpen( "Expected no workspace active at the end of outputOfLayerDetached", true); } else { - Preconditions.checkState(outputWorkspace.isScopeActive(), - "Expected output workspace to still be open" + - "at end of outputOfLayerDetached, but it is closed. This suggests an implementation or layer workspace problem"); + Preconditions.checkState( + outputWorkspace.isScopeActive(), + "Expected output workspace to still be open" + + "at end of outputOfLayerDetached, but it is closed. This suggests an implementation or layer workspace problem"); } } @@ -1624,8 +1740,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Compute the activations from the input to the output layer, given mask arrays (that may be * null) The masking arrays are used in situations such an one-to-many and many-to-one rucerrent * neural network (RNN) designs, as well as for supporting time series of varying lengths within - * the same minibatch for RNNs. Other than mask arrays, this is equivalent to calling - * {@link #feedForward(INDArray, boolean)} with train = false + * the same minibatch for RNNs. Other than mask arrays, this is equivalent to calling {@link + * #feedForward(INDArray, boolean)} with train = false */ public List feedForward(INDArray input, INDArray featuresMask, INDArray labelsMask) { setLayerMaskArrays(featuresMask, labelsMask); @@ -1641,14 +1757,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial @Override public Pair gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } /** * Clone the MultiLayerNetwork * * @return A cloned MultiLayerNetwork with a copy of the configuration, parameters and updater - * identical to the current network. + * identical to the current network. */ @Override public MultiLayerNetwork clone() { @@ -1657,10 +1773,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } NeuralNetConfiguration conf = this.getNetConfiguration().clone(); MultiLayerNetwork ret = new MultiLayerNetwork(conf); - ret.init(this.params().dup(), false); + ret.init(this.getModelParams().dup(), false); if (solver != null) { - //If solver is null: updater hasn't been initialized -> getUpdater call will force initialization, however + // If solver is null: updater hasn't been initialized -> getUpdater call will force + // initialization, however Updater u = this.getUpdater(); INDArray updaterState = u.getStateViewArray(); if (updaterState != null) { @@ -1669,7 +1786,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } if (hasAFrozenLayer()) { - //correct layers to frozen layers + // correct layers to frozen layers Layer[] clonedLayers = ret.getLayers(); for (int i = 0; i < layers.length; i++) { if (layers[i] instanceof FrozenLayer) { @@ -1691,84 +1808,28 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * @deprecated To be removed. Use {@link #params()} instead + * @deprecated To be removed. Use {@link #getModelParams()} instead */ @Deprecated public INDArray params(boolean backwardOnly) { - return params(); + return getModelParams(); } /** * Returns a 1 x m vector where the vector is composed of a flattened vector of all of the - * parameters in the network.
See {@link #getParam(String)} and {@link #getParamTable()} for a - * more useful/interpretable representation of the parameters.
Note that the parameter vector - * is not a copy, and changes to the returned INDArray will impact the network parameters. + * parameters in the network.
+ * See {@link #getParam(String)} and {@link #getParamTable()} for a more useful/interpretable + * representation of the parameters.
+ * Note that the parameter vector is not a copy, and changes to the returned INDArray will impact + * the network parameters. * * @return the parameters for this neural net */ @Override - public INDArray params() { + public INDArray getModelParams() { return flattenedParams; } - /** - * The param table - * - * @return - */ - @Override - public Map getParamTable() { - return null; - } - - /** - * Table of parameters by key, for backprop. For many models (dense layers, etc) - all parameters - * are backprop parameters - * - * @param backpropParamsOnly If true, return backprop params only. If false: return all params - * (equivalent to paramsTable()) - */ - @Override - public Map getParamTable(boolean backpropParamsOnly) { - return null; - } - - /** - * Set the parameters for this model. This expects a linear ndarray which then be unpacked - * internally relative to the expected ordering of the model.
See also: - * {@link #setParamTable(Map)} and {@link #setParam(String, INDArray)} - * - * @param params the parameters for the model - */ - @Override - public void setParams(INDArray params) { - if (flattenedParams == params) { - return; //No op - } - - if (flattenedParams != null && params.length() == flattenedParams.length()) { - if (params != flattenedParams) { - flattenedParams.assign(params); - } - } else { - if (flattenedParams == null) { - flattenedParams = params.dup(); - } - int idx = 0; - for (int i = 0; i < getLayers().length; i++) { - Layer layer = getLayer(i); - long range = layer.numParams(); - if (range <= 0) { - continue; //Some layers: no parameters (subsampling, etc) - } - INDArray get = params.get(NDArrayIndex.interval(0, 0, true), - NDArrayIndex.interval(idx, range + idx)); - layer.setParams(get); - idx += range; - } - } - } - @Override public void setParamsViewArray(INDArray params) { throw new UnsupportedOperationException("Not yet implemented"); @@ -1786,14 +1847,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (layer.numParams() == 0) { continue; } - layer.setBackpropGradientsViewArray(gradients.get(NDArrayIndex.interval(0, 0, true), - NDArrayIndex.interval(paramsSoFar, paramsSoFar + layer.numParams()))); + layer.setBackpropGradientsViewArray( + gradients.get( + NDArrayIndex.interval(0, 0, true), + NDArrayIndex.interval(paramsSoFar, paramsSoFar + layer.numParams()))); paramsSoFar += layer.numParams(); } } @Override - public TrainingConfig getConfig() { + public ITraininableLayerConfiguration getTrainingConfig() { throw new UnsupportedOperationException("Not supported"); } @@ -1807,14 +1870,58 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (!isInitCalled()) { init(); } - return flattenedParams == null ? 0 : flattenedParams.length(); //Maybe nul for 0 params net + return flattenedParams == null ? 0 : flattenedParams.length(); // Maybe nul for 0 params net + } + + /** + * @return 1d parameter vector + */ + @Override + public INDArray getParams() { + throw new RuntimeException("Calling getParams on the MultiLazerNetwork !?"); + } + + /** + * Set the parameters for this model. This expects a linear ndarray which then be unpacked + * internally relative to the expected ordering of the model.
+ * See also: {@link #setParamTable(Map)} and {@link #setParam(String, INDArray)} + * + * @param params the parameters for the model + */ + @Override + public void setParams(INDArray params) { + if (flattenedParams == params) { + return; // No op + } + + if (flattenedParams != null && params.length() == flattenedParams.length()) { + if (params != flattenedParams) { + flattenedParams.assign(params); + } + } else { + if (flattenedParams == null) { + flattenedParams = params.dup(); + } + int idx = 0; + for (int i = 0; i < getLayers().length; i++) { + Layer layer = getLayer(i); + long range = layer.numParams(); + if (range <= 0) { + continue; // Some layers: no parameters (subsampling, etc) + } + INDArray get = + params.get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.interval(idx, range + idx)); + layer.setParams(get); + idx += range; + } + } } /** * Returns the number of parameters in the network * * @param backwards If true: exclude any parameters uned only in unsupervised layerwise training - * (such as the decoder parameters in an autoencoder) + * (such as the decoder parameters in an autoencoder) * @return The number of parameters */ @Override @@ -1843,15 +1950,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Perform minibatch training on all minibatches in the DataSetIterator, for the specified number * of epochs. Equvalent to calling {@link #fit(DataSetIterator)} numEpochs times in a loop * - * @param iterator Training data (DataSetIterator). Iterator must support resetting + * @param iterator Training data (DataSetIterator). Iterator must support resetting * @param numEpochs Number of training epochs, >= 1 */ public void fit(@NonNull DataSetIterator iterator, int numEpochs) { - Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", - numEpochs); - Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), - "Cannot perform multiple epochs training using" + - "iterator thas does not support resetting (iterator.resetSupported() returned false)"); + Preconditions.checkArgument( + numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs); + Preconditions.checkArgument( + numEpochs == 1 || iterator.resetSupported(), + "Cannot perform multiple epochs training using" + + "iterator thas does not support resetting (iterator.resetSupported() returned false)"); for (int i = 0; i < numEpochs; i++) { fit(iterator); @@ -1859,9 +1967,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.
Note that - * this method does not do layerwise pretraining.
For pretraining use method pretrain.. - * {@link #pretrain(DataSetIterator)}
+ * Perform minibatch training on all minibatches in the DataSetIterator for 1 epoch.
+ * Note that this method does not do layerwise pretraining.
+ * For pretraining use method pretrain.. {@link #pretrain(DataSetIterator)}
* * @param iterator Training data (DataSetIterator) */ @@ -1876,12 +1984,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } private synchronized void fitHelper(DataSetIterator iterator) { - // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate + // we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where + // appropriate DataSetIterator iter; boolean destructable = false; if (iterator.asyncSupported()) { - iter = new AsyncDataSetIterator(iterator, - Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true); + iter = + new AsyncDataSetIterator( + iterator, Math.min(Nd4j.getAffinityManager().getNumberOfDevices() * 2, 2), true); destructable = true; } else { iter = iterator; @@ -1895,20 +2005,26 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM - // as these should be closed by the time updaters are executed - //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this - .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .build(); + workspaceMgr = + LayerWorkspaceMgr.builder() + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_BP_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + // Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or + // FF/BP_WORKING_MEM + // as these should be closed by the time updaters are executed + // Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this + .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .build(); } workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); @@ -1933,8 +2049,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial boolean hasMaskArrays = next.hasMaskArrays(); if (getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT) { - doTruncatedBPTT(next.getFeatures(), next.getLabels(), next.getFeaturesMaskArray(), - next.getLabelsMaskArray(), workspaceMgr); + doTruncatedBPTT( + next.getFeatures(), + next.getLabels(), + next.getFeaturesMaskArray(), + next.getLabelsMaskArray(), + workspaceMgr); } else { if (hasMaskArrays) { setLayerMaskArrays(next.getFeaturesMaskArray(), next.getLabelsMaskArray()); @@ -1945,12 +2065,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) - .build(); + solver = + new Solver.Builder() + .configure(getNetConfiguration()) + .listeners(this.getTrainingListeners()) + .model(this) + .build(); } } - //TODO CACHE + // TODO CACHE solver.optimize(workspaceMgr); } @@ -1981,16 +2105,15 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Calculate parameter gradients and input activation gradients given the input and labels, and * optionally mask arrays * - * @param features Features for gradient calculation - * @param label Labels for gradient - * @param fMask Features mask array (may be null) + * @param features Features for gradient calculation + * @param label Labels for gradient + * @param fMask Features mask array (may be null) * @param labelMask Label mask array (may be null) * @return A pair of gradient arrays: parameter gradients (in Gradient object) and input - * activation gradients + * activation gradients */ - public Pair calculateGradients(@NonNull INDArray features, - @NonNull INDArray label, - INDArray fMask, INDArray labelMask) { + public Pair calculateGradients( + @NonNull INDArray features, @NonNull INDArray label, INDArray fMask, INDArray labelMask) { try { return calculateGradientsHelper(features, label, fMask, labelMask); } catch (OutOfMemoryError e) { @@ -1999,9 +2122,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - private Pair calculateGradientsHelper(INDArray features, INDArray label, - INDArray fMask, - INDArray labelMask) { + private Pair calculateGradientsHelper( + INDArray features, INDArray label, INDArray fMask, INDArray labelMask) { setInput(features); setLabels(label); setLayerMaskArrays(fMask, labelMask); @@ -2010,42 +2132,51 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { mgr = LayerWorkspaceMgr.noWorkspaces(); } else { - mgr = LayerWorkspaceMgr.builder() - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + mgr = + LayerWorkspaceMgr.builder() + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_BP_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); if (getNetConfiguration().getCacheMode() != null) { - //For now: store cache mode activations in activations workspace + // For now: store cache mode activations in activations workspace mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); } } mgr.setHelperWorkspacePointers(helperWorkspaces); - //Calculate activations (which are stored in each layer, and used in backprop) + // Calculate activations (which are stored in each layer, and used in backprop) try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) { - //First: do a feed-forward through the network - //Note that we don't actually need to do the full forward pass through the output layer right now; but we do + // First: do a feed-forward through the network + // Note that we don't actually need to do the full forward pass through the output layer right + // now; but we do // need the input to the output layer to be set (such that backprop can be done) - List activations = ffToLayerActivationsInWs(layers.length - 2, FwdPassType.STANDARD, - false, input, mask, fMask); + List activations = + ffToLayerActivationsInWs( + layers.length - 2, FwdPassType.STANDARD, false, input, mask, fMask); if (!trainingListeners.isEmpty()) { - //TODO: We possibly do want output layer activations in some cases here... + // TODO: We possibly do want output layer activations in some cases here... for (TrainingListener tl : trainingListeners) { tl.onForwardPass(this, activations); } } INDArray inputToOutputLayer = activations.get(activations.size() - 1); if (getNetConfiguration().getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = getNetConfiguration().getInputPreProcess(layers.length - 1) - .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); - //Validate activations location + inputToOutputLayer = + getNetConfiguration() + .getInputPreProcess(layers.length - 1) + .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); + // Validate activations location } getOutputLayer().setInput(inputToOutputLayer, mgr); @@ -2062,18 +2193,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * network learning) (b) backpropGradient (layer method, for when MultiLayerNetwork is used as a * layer) * - * @param epsilon Errors (technically errors .* activations). Not used if - * withOutputLayer = true - * @param withOutputLayer if true: assume last layer is output layer, and calculate errors - * based on labels. In this case, the epsilon input is not used - * (may/should be null). If false: calculate backprop gradients + * @param epsilon Errors (technically errors .* activations). Not used if withOutputLayer = true + * @param withOutputLayer if true: assume last layer is output layer, and calculate errors based + * on labels. In this case, the epsilon input is not used (may/should be null). If false: + * calculate backprop gradients * @param returnInputActGrad If true: terun the input activation gradients (detached). False: - * don't return + * don't return * @return Gradients and the error (epsilon) at the input */ - protected Pair calcBackpropGradients(INDArray epsilon, - boolean withOutputLayer, boolean tbptt, - boolean returnInputActGrad) { + protected Pair calcBackpropGradients( + INDArray epsilon, boolean withOutputLayer, boolean tbptt, boolean returnInputActGrad) { if (flattenedGradients == null) { initGradientsView(); } @@ -2087,63 +2216,82 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial mgrEven = LayerWorkspaceMgr.noWorkspaces(); mgrOdd = mgrEven; WorkspaceUtils.assertNoWorkspacesOpen( - "Expected no workspace active in calcBackpropGradients when " + - "training workspace is set to none"); + "Expected no workspace active in calcBackpropGradients when " + + "training workspace is set to none"); } else { - /* - Workspaces for backprop in MLN share some features with outputOfLayerDetached, in terms of the - "two alternating workspaces" idea (but for activation gradients here, instead of activations there). + /* + Workspaces for backprop in MLN share some features with outputOfLayerDetached, in terms of the + "two alternating workspaces" idea (but for activation gradients here, instead of activations there). - Workspace design for backprop: - First: we calculate all activations, and ensure they are in WS_ALL_LAYERS_ACT. We assume this is done - EXTERNALLY to this method - Then: we iterate backwards over layers. + Workspace design for backprop: + First: we calculate all activations, and ensure they are in WS_ALL_LAYERS_ACT. We assume this is done + EXTERNALLY to this method + Then: we iterate backwards over layers. - Activations gradient workspaces: opened/closed every second layer. - mgrEven (WS_LAYER_ACT_1) activation grad WS opens at start of 8, 4, 2, 0; closed at end of 7, 5, 3, 1 etc - mgrOdd (WS_LAYER_ACT_2) activation grad WS opens at start of 7, 3, 5, 1; closed at end of 6, 4, 2, 0 etc + Activations gradient workspaces: opened/closed every second layer. + mgrEven (WS_LAYER_ACT_1) activation grad WS opens at start of 8, 4, 2, 0; closed at end of 7, 5, 3, 1 etc + mgrOdd (WS_LAYER_ACT_2) activation grad WS opens at start of 7, 3, 5, 1; closed at end of 6, 4, 2, 0 etc - */ + */ - mgrEven = LayerWorkspaceMgr.builder() - //Activations in context of backprop (preOut methods etc) are not used outside of the layer itself - .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, - WS_ALL_LAYERS_ACT_CONFIG) //Usually not required here. Exception: OutputLayer dropout - .with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + mgrEven = + LayerWorkspaceMgr.builder() + // Activations in context of backprop (preOut methods etc) are not used outside of the + // layer itself + .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.INPUT, + WS_ALL_LAYERS_ACT, + WS_ALL_LAYERS_ACT_CONFIG) // Usually not required here. Exception: OutputLayer + // dropout + .with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_1, WS_LAYER_ACT_X_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_BP_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); - mgrOdd = LayerWorkspaceMgr.builder() - //Activations in context of backprop (preOut methods etc) are not used outside of the layer itself - .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, - WS_ALL_LAYERS_ACT_CONFIG) //Usually not required here. Exception: OutputLayer dropout - .with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + mgrOdd = + LayerWorkspaceMgr.builder() + // Activations in context of backprop (preOut methods etc) are not used outside of the + // layer itself + .with(ArrayType.ACTIVATIONS, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.INPUT, + WS_ALL_LAYERS_ACT, + WS_ALL_LAYERS_ACT_CONFIG) // Usually not required here. Exception: OutputLayer + // dropout + .with(ArrayType.ACTIVATION_GRAD, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_BP_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); if (epsilon == null) { - //If epsilon is non-null: external errors use case -> inputs are already detached - WorkspaceUtils.assertOpenActiveAndCurrent(WS_ALL_LAYERS_ACT, - "calcBackpropGradients method requires workspace WS_ALL_LAYERS_ACT" + - " to be open when workspaces are used"); + // If epsilon is non-null: external errors use case -> inputs are already detached + WorkspaceUtils.assertOpenActiveAndCurrent( + WS_ALL_LAYERS_ACT, + "calcBackpropGradients method requires workspace WS_ALL_LAYERS_ACT" + + " to be open when workspaces are used"); } } mgrEven.setHelperWorkspacePointers(helperWorkspaces); mgrOdd.setHelperWorkspacePointers(helperWorkspaces); - //calculate and apply the backward gradient for every layer + // calculate and apply the backward gradient for every layer /* * Skip the output layer for the indexing and just loop backwards updating the coefficients for each layer. * (when withOutputLayer == true) @@ -2154,7 +2302,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * This interpretation transpose a few things to get mini batch because ND4J is rows vs columns organization for params */ int numLayers = getnLayers(); - //Store gradients is a list; used to ensure iteration order in DefaultGradient linked hash map. i.e., layer 0 first instead of output layer + // Store gradients is a list; used to ensure iteration order in DefaultGradient linked hash map. + // i.e., layer 0 first instead of output layer LinkedList> gradientList = new LinkedList<>(); Pair currPair = null; @@ -2191,55 +2340,78 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial outputLayer.setLabels(labels); } - //Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers + // Open activation gradients WS *then* BP working memory, so BP working memory is opened + // last for use in layers wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD); - try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered( - ArrayType.BP_WORKING_MEM)) { + try (MemoryWorkspace wsBPWorking = + workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) { - //Note that because we're opening activation workspaces not in a simple nested order, we'll manually - // override the previous workspace setting. Otherwise, when we close these workspaces, the "current" + // Note that because we're opening activation workspaces not in a simple nested order, + // we'll manually + // override the previous workspace setting. Otherwise, when we close these workspaces, the + // "current" // workspace may be set to the incorrect one wsActGradTemp.setPreviousWorkspace(initialWorkspace); wsBPWorking.setPreviousWorkspace(initialWorkspace); - INDArray eps = (i == layers.length - 1 ? epsilon - : currPair.getRight()); //eps is null for OutputLayer + INDArray eps = + (i == layers.length - 1 + ? epsilon + : currPair.getRight()); // eps is null for OutputLayer if (!tbptt) { - //Standard case + // Standard case currPair = layers[i].backpropGradient(eps, workspaceMgr); } else { - //TBPTT gradient + // TBPTT gradient if (layers[i] instanceof RecurrentLayer) { - currPair = ((RecurrentLayer) layers[i]).tbpttBackpropGradient(currPair.getSecond(), - getNetConfiguration().getTbpttBackLength(), workspaceMgr); + currPair = + ((RecurrentLayer) layers[i]) + .tbpttBackpropGradient( + currPair.getSecond(), + getNetConfiguration().getTbpttBackLength(), + workspaceMgr); } else { currPair = layers[i].backpropGradient(currPair.getSecond(), workspaceMgr); } } if (currPair.getSecond() != null) { - //Edge case: may be null for Embedding layer, for example - validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, + // Edge case: may be null for Embedding layer, for example + validateArrayWorkspaces( + workspaceMgr, + currPair.getSecond(), + ArrayType.ACTIVATION_GRAD, i, - false, "Backprop"); + false, + "Backprop"); } - for (Map.Entry entry : currPair.getFirst().gradientForVariable() - .entrySet()) { + for (Map.Entry entry : + currPair.getFirst().gradientForVariable().entrySet()) { String origName = entry.getKey(); multiGradientKey = i + "_" + origName; - gradientList.addLast(new Triple<>(multiGradientKey, entry.getValue(), - currPair.getFirst().flatteningOrderForVariable(origName))); + gradientList.addLast( + new Triple<>( + multiGradientKey, + entry.getValue(), + currPair.getFirst().flatteningOrderForVariable(origName))); } if (getNetConfiguration().getInputPreProcess(i) != null) { - currPair = new Pair<>(currPair.getFirst(), - this.getNetConfiguration().getInputPreProcess(i) - .backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr)); + currPair = + new Pair<>( + currPair.getFirst(), + this.getNetConfiguration() + .getInputPreProcess(i) + .backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr)); if (i > 0 && currPair.getSecond() != null) { - validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, + validateArrayWorkspaces( + workspaceMgr, + currPair.getSecond(), + ArrayType.ACTIVATION_GRAD, i, - true, "Backprop"); + true, + "Backprop"); } } @@ -2278,7 +2450,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } if (wsActGradTemp != null) { - //Should only be non-null on exception + // Should only be non-null on exception try { wsActGradTemp.close(); } catch (Throwable t2) { @@ -2302,18 +2474,19 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { WorkspaceUtils.assertNoWorkspacesOpen( - "Expected no workspace active in calcBackpropGradients when " + - "training workspace is set to none"); + "Expected no workspace active in calcBackpropGradients when " + + "training workspace is set to none"); } else { if (epsilon == null) { - //If epsilon != null: external errors use case (inputs are detached instead) - WorkspaceUtils.assertOpenActiveAndCurrent(WS_ALL_LAYERS_ACT, - "calcBackpropGradients: WS_ALL_LAYERS_ACT is no" + - " longer the currently open/active workspace"); + // If epsilon != null: external errors use case (inputs are detached instead) + WorkspaceUtils.assertOpenActiveAndCurrent( + WS_ALL_LAYERS_ACT, + "calcBackpropGradients: WS_ALL_LAYERS_ACT is no" + + " longer the currently open/active workspace"); } } - //Add gradients to Gradients (map), in correct order + // Add gradients to Gradients (map), in correct order for (Triple triple : gradientList) { gradient.setGradientFor(triple.getFirst(), triple.getSecond(), triple.getThird()); } @@ -2321,19 +2494,25 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return new Pair<>(gradient, currPair.getSecond()); } - protected void doTruncatedBPTT(INDArray input, INDArray labels, INDArray featuresMaskArray, - INDArray labelsMaskArray, LayerWorkspaceMgr workspaceMgr) { + protected void doTruncatedBPTT( + INDArray input, + INDArray labels, + INDArray featuresMaskArray, + INDArray labelsMaskArray, + LayerWorkspaceMgr workspaceMgr) { if (input.rank() != 3 || labels.rank() != 3) { log.warn( "Cannot do truncated BPTT with non-3d inputs or labels. Expect input with shape [miniBatchSize,nIn,timeSeriesLength], got " - + Arrays.toString(input.shape()) + "\tand labels with shape " + + Arrays.toString(input.shape()) + + "\tand labels with shape " + Arrays.toString(labels.shape())); return; } if (input.size(2) != labels.size(2)) { log.warn( "Input and label time series have different lengths: {} input length, {} label length", - input.size(2), labels.size(2)); + input.size(2), + labels.size(2)); return; } @@ -2342,7 +2521,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial val timeSeriesLength = input.size(2); long nSubsets = timeSeriesLength / fwdLen; if (timeSeriesLength % fwdLen != 0) { - nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20) + nSubsets++; // Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, + // 1 of size 20) } rnnClearPreviousState(); @@ -2357,8 +2537,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (startTimeIdx > Integer.MAX_VALUE || endTimeIdx > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - INDArray[] subsets = getSubsetsForTbptt((int) startTimeIdx, (int) endTimeIdx, input, labels, - featuresMaskArray, labelsMaskArray); + INDArray[] subsets = + getSubsetsForTbptt( + (int) startTimeIdx, + (int) endTimeIdx, + input, + labels, + featuresMaskArray, + labelsMaskArray); setInput(subsets[0]); setLabels(subsets[1]); @@ -2366,13 +2552,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) - .build(); + solver = + new Solver.Builder() + .configure(getNetConfiguration()) + .listeners(this.getTrainingListeners()) + .model(this) + .build(); } } solver.optimize(workspaceMgr); - //Finally, update the state of the RNN layers: + // Finally, update the state of the RNN layers: updateRnnStateWithTBPTTState(); } @@ -2380,30 +2570,36 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial clearLayerMaskArrays(); } - private INDArray[] getSubsetsForTbptt(int startTimeIdx, int endTimeIdx, INDArray input, + private INDArray[] getSubsetsForTbptt( + int startTimeIdx, + int endTimeIdx, + INDArray input, INDArray labels, - INDArray fMask, INDArray lMask) { + INDArray fMask, + INDArray lMask) { INDArray[] out = new INDArray[4]; - out[0] = input.get(NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); - out[1] = labels.get(NDArrayIndex.all(), NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + out[0] = + input.get( + NDArrayIndex.all(), + NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + out[1] = + labels.get( + NDArrayIndex.all(), + NDArrayIndex.all(), + NDArrayIndex.interval(startTimeIdx, endTimeIdx)); if (fMask != null) { - out[2] = fMask.get(NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + out[2] = fMask.get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx)); } if (lMask != null) { - out[3] = lMask.get(NDArrayIndex.all(), - NDArrayIndex.interval(startTimeIdx, endTimeIdx)); + out[3] = lMask.get(NDArrayIndex.all(), NDArrayIndex.interval(startTimeIdx, endTimeIdx)); } return out; } - /** - * Intended for internal/developer use - */ + /** Intended for internal/developer use */ public void updateRnnStateWithTBPTTState() { for (int i = 0; i < layers.length; i++) { if (layers[i] instanceof RecurrentLayer) { @@ -2420,43 +2616,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * * @return listeners set for this network */ - public Collection getListeners() { + public Collection getTrainingListeners() { return trainingListeners; } - @Override - public void setListeners(TrainingListener ... listeners) { - if (layers == null) { - init(); - } - for (Layer layer : layers) { - layer.setListeners(listeners); - } - - if (solver != null) { - solver.setListeners(List.of(listeners)); - } - - this.trainingListeners.clear(); - if (listeners != null) { - this.trainingListeners.addAll(List.of(listeners)); - } - } - /** * @param listeners */ @Override - public void setListeners(Collection listeners) { - setListeners(listeners.toArray(new TrainingListener[]{})); - } - - /** - * @deprecated Use {@link #getListeners()} - */ - @Deprecated - public Collection getTrainingListeners() { - return trainingListeners; + public void addTrainingListeners(Collection listeners) { + this.addTrainingListeners(listeners.toArray(new TrainingListener[] {})); } /** @@ -2465,7 +2634,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * @param listeners */ @Override - public void addListeners(TrainingListener... listeners) { + public void addTrainingListeners(TrainingListener... listeners) { Collections.addAll(trainingListeners, listeners); // fixme this is wrong, since it removes existing listeners from the solver @@ -2476,8 +2645,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Usable only for classification networks in conjunction with OutputLayer. Cannot be used with - * RnnOutputLayer, CnnLossLayer, or networks used for regression.
To get the raw output - * activations of the output layer, use {@link #output(INDArray)} or similar.
+ * RnnOutputLayer, CnnLossLayer, or networks used for regression.
+ * To get the raw output activations of the output layer, use {@link #output(INDArray)} or + * similar.
*
* Equivalent to argmax(this.output(input)): Returns the predicted class indices corresponding to * the predictions for each example in the features array. @@ -2493,7 +2663,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial throw new ND4JArraySizeException(); } - Preconditions.checkState(output.rank() == 2, + Preconditions.checkState( + output.rank() == 2, "predict(INDArray) method can only be used on rank 2 output - got array with rank %s", output.rank()); return output.argMax(1).toIntVector(); @@ -2505,7 +2676,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial */ @Override public List predict(org.nd4j.linalg.dataset.api.DataSet dataSet) { - Preconditions.checkState(dataSet.getLabelNamesList() != null, + Preconditions.checkState( + dataSet.getLabelNamesList() != null, "This method can only be used when the DataSet contains a label name list"); int[] intRet = predict(dataSet.getFeatures()); List ret = new ArrayList<>(); @@ -2518,26 +2690,28 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Fit the model for one iteration on the provided data * - * @param data the examples to classify (one example in each row) + * @param data the examples to classify (one example in each row) * @param labels the example labels(a binary outcome matrix) */ @Override public void fit(INDArray data, INDArray labels) { + if (!initCalled) init(); fit(data, labels, null, null); } /** * Fit the model for one iteration on the provided data * - * @param features the examples to classify (one example in each row) - * @param labels the example labels(a binary outcome matrix) + * @param features the examples to classify (one example in each row) + * @param labels the example labels(a binary outcome matrix) * @param featuresMask The mask array for the features (used for variable length time series, - * etc). May be null. - * @param labelsMask The mask array for the labels (used for variable length time series, etc). - * May be null. + * etc). May be null. + * @param labelsMask The mask array for the labels (used for variable length time series, etc). + * May be null. */ - public synchronized void fit(INDArray features, INDArray labels, INDArray featuresMask, - INDArray labelsMask) { + public synchronized void fit( + INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) { + if (!initCalled) init(); try { fitHelper(features, labels, featuresMask, labelsMask); } catch (OutOfMemoryError e) { @@ -2546,10 +2720,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - private void fitHelper(INDArray features, INDArray labels, INDArray featuresMask, - INDArray labelsMask) { + private void fitHelper( + INDArray features, INDArray labels, INDArray featuresMask, INDArray labelsMask) { + if (!initCalled) init(); if (numParams() == 0) { - //No op: can't fit a network with 0 parameters + // No op: can't fit a network with 0 parameters return; } @@ -2562,14 +2737,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (getNetConfiguration().getTrainingWorkspaceMode() == null) { workspaceMgr = LayerWorkspaceMgr.noWorkspaces(); } else { - workspaceMgr = LayerWorkspaceMgr.builder() - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - //Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or FF/BP_WORKING_MEM - // these should be closed by the time updaters are executed - //Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this - .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .build(); + workspaceMgr = + LayerWorkspaceMgr.builder() + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + // Note for updater working memory, we have the option to re-use WS_ALL_LAYERS_ACT or + // FF/BP_WORKING_MEM + // these should be closed by the time updaters are executed + // Generally, WS_ALL_LAYERS_ACT will be the larger of the two, so we'll use this + .with(ArrayType.UPDATER_WORKING_MEM, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .build(); } workspaceMgr.setHelperWorkspacePointers(helperWorkspaces); @@ -2578,11 +2755,15 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } else { if (solver == null) { try (MemoryWorkspace wsO = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) - .build(); + solver = + new Solver.Builder() + .configure(getNetConfiguration()) + .listeners(this.getTrainingListeners()) + .model(this) + .build(); } } - //TODO CACHE WORKSPACE, IF USED??? + // TODO CACHE WORKSPACE, IF USED??? solver.optimize(workspaceMgr); } @@ -2603,7 +2784,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial */ @Override public void fit(org.nd4j.linalg.dataset.api.DataSet data) { - fit(data.getFeatures(), data.getLabels(), data.getFeaturesMaskArray(), + if (!initCalled) init(); + fit( + data.getFeatures(), + data.getLabels(), + data.getFeaturesMaskArray(), data.getLabelsMaskArray()); } @@ -2611,10 +2796,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Fit the model for one iteration on the provided data * * @param examples the examples to classify (one example in each row) - * @param labels the labels for each example (the number of labels must match + * @param labels the labels for each example (the number of labels must match */ @Override public void fit(INDArray examples, int[] labels) { + if (!initCalled) init(); org.deeplearning4j.nn.conf.layers.OutputLayer layerConf = (org.deeplearning4j.nn.conf.layers.OutputLayer) getOutputLayer().getLayerConfiguration(); @@ -2630,8 +2816,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * * @param input Input to the network * @param train whether the output is test or train. This mainly affect hyper parameters such as - * dropout and batch normalization, which have different behaviour for test vs. - * train + * dropout and batch normalization, which have different behaviour for test vs. train * @return The network predictions - i.e., the activations of the final layer */ public INDArray output(INDArray input, TrainingMode train) { @@ -2644,8 +2829,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * * @param input Input to the network * @param train whether the output is test or train. This mainly affect hyper parameters such as - * dropout and batch normalization, which have different behaviour for test vs. - * train + * dropout and batch normalization, which have different behaviour for test vs. train * @return The network predictions - i.e., the activations of the final layer */ public INDArray output(INDArray input, boolean train) { @@ -2657,54 +2841,64 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * situations such as one-to-many and many-to-one recurrent neural network (RNN) designs, as well * as for supporting time series of varying lengths within the same minibatch. */ - public INDArray output(INDArray input, boolean train, INDArray featuresMask, - INDArray labelsMask) { + public INDArray output( + INDArray input, boolean train, INDArray featuresMask, INDArray labelsMask) { return output(input, train, featuresMask, labelsMask, null); } /** - * Get the network output, which is optionally placed in the specified memory workspace.
If no - * memory workspace is provided, the output will be detached (not in any workspace).
If a - * memory workspace is provided, the output activation array (i.e., the INDArray returned by this - * method) will be placed in the specified workspace. This workspace must be opened by the user - * before calling this method - and the user is responsible for (a) closing this workspace, and - * (b) ensuring the output array is not used out of scope (i.e., not used after closing the + * Get the network output, which is optionally placed in the specified memory workspace.
+ * If no memory workspace is provided, the output will be detached (not in any workspace).
+ * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by + * this method) will be placed in the specified workspace. This workspace must be opened by the + * user before calling this method - and the user is responsible for (a) closing this workspace, + * and (b) ensuring the output array is not used out of scope (i.e., not used after closing the * workspace to which it belongs - as this is likely to cause either an exception when used, or a * crash). * - * @param input Input to the network - * @param train True for train, false otherwise + * @param input Input to the network + * @param train True for train, false otherwise * @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling - * this method. + * this method. * @return The output/activations from the network (either detached or in the specified workspace - * if provided) + * if provided) */ public INDArray output(INDArray input, boolean train, MemoryWorkspace outputWorkspace) { return output(input, train, null, null, outputWorkspace); } /** - * Get the network output, which is optionally placed in the specified memory workspace.
If no - * memory workspace is provided, the output will be detached (not in any workspace).
If a - * memory workspace is provided, the output activation array (i.e., the INDArray returned by this - * method) will be placed in the specified workspace. This workspace must be opened by the user - * before calling this method - and the user is responsible for (a) closing this workspace, and - * (b) ensuring the output array is not used out of scope (i.e., not used after closing the + * Get the network output, which is optionally placed in the specified memory workspace.
+ * If no memory workspace is provided, the output will be detached (not in any workspace).
+ * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by + * this method) will be placed in the specified workspace. This workspace must be opened by the + * user before calling this method - and the user is responsible for (a) closing this workspace, + * and (b) ensuring the output array is not used out of scope (i.e., not used after closing the * workspace to which it belongs - as this is likely to cause either an exception when used, or a * crash). * - * @param input Input to the network - * @param train True for train, false otherwise + * @param input Input to the network + * @param train True for train, false otherwise * @param outputWorkspace May be null. If not null: the workspace MUST be opened before calling - * this method. + * this method. * @return The output/activations from the network (either detached or in the specified workspace - * if provided) + * if provided) */ - public synchronized INDArray output(INDArray input, boolean train, INDArray featuresMask, - INDArray labelsMask, MemoryWorkspace outputWorkspace) { + public synchronized INDArray output( + INDArray input, + boolean train, + INDArray featuresMask, + INDArray labelsMask, + MemoryWorkspace outputWorkspace) { try { - return outputOfLayerDetached(train, FwdPassType.STANDARD, layers.length - 1, input, - featuresMask, labelsMask, outputWorkspace); + return outputOfLayerDetached( + train, + FwdPassType.STANDARD, + layers.length - 1, + input, + featuresMask, + labelsMask, + outputWorkspace); } catch (OutOfMemoryError e) { CrashReportingUtil.writeMemoryCrashDump(this, e); throw e; @@ -2713,24 +2907,32 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * This method uses provided OutputAdapter to return custom object built from INDArray - *

- * PLEASE NOTE: This method uses dedicated Workspace for output generation to avoid redundant + * + *

PLEASE NOTE: This method uses dedicated Workspace for output generation to avoid redundant * allocations * - * @param inputs Input arrays to the netwonk - * @param inputMasks Optional input mask arrays (may be null) - * @param labelMasks Optional label mask arrays (may be null + * @param inputs Input arrays to the netwonk + * @param inputMasks Optional input mask arrays (may be null) + * @param labelMasks Optional label mask arrays (may be null * @param outputAdapter OutputAdapter instance - * @param T extends Object + * @param T extends Object * @return T instance produced by OutputAdapter */ - public synchronized T output(@NonNull INDArray inputs, INDArray inputMasks, - INDArray labelMasks, @NonNull OutputAdapter outputAdapter) { - try (val ws = Nd4j.getWorkspaceManager() - .getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM)) { + public synchronized T output( + @NonNull INDArray inputs, + INDArray inputMasks, + INDArray labelMasks, + @NonNull OutputAdapter outputAdapter) { + try (val ws = + Nd4j.getWorkspaceManager() + .getAndActivateWorkspace(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM)) { if (outputAdapter instanceof ModelAdapter) { - return ((ModelAdapter) outputAdapter).apply(this, new INDArray[]{inputs}, - new INDArray[]{inputMasks}, new INDArray[]{labelMasks}); + return ((ModelAdapter) outputAdapter) + .apply( + this, + new INDArray[] {inputs}, + new INDArray[] {inputMasks}, + new INDArray[] {labelMasks}); } else { return outputAdapter.apply(output(inputs, false, inputMasks, labelMasks, ws)); } @@ -2739,8 +2941,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Perform inference on the provided input/features - i.e., perform forward pass using the - * provided input/features and return the output of the final layer. Equivalent to - * {@link #output(INDArray, boolean)} with train=false - i.e., this method is used for inference. + * provided input/features and return the output of the final layer. Equivalent to {@link + * #output(INDArray, boolean)} with train=false - i.e., this method is used for inference. * * @param input Input to the network * @return The network predictions - i.e., the activations of the final layer @@ -2751,13 +2953,14 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Generate the output for all examples/batches in the input iterator, and concatenate them into a - * single array. See {@link #output(INDArray)}
NOTE 1: The output array can require a - * considerable amount of memory for iterators with a large number of examples
NOTE 2: This - * method cannot be used for variable length time series outputs, as this would require padding - * arrays for some outputs, or returning a mask array (which cannot be done with this method). For - * variable length time series applications, use one of the other output methods. This method also - * cannot be used with fully convolutional networks with different output sizes (for example, - * segmentation on different input image sizes). + * single array. See {@link #output(INDArray)}
+ * NOTE 1: The output array can require a considerable amount of memory for iterators with a large + * number of examples
+ * NOTE 2: This method cannot be used for variable length time series outputs, as this would + * require padding arrays for some outputs, or returning a mask array (which cannot be done with + * this method). For variable length time series applications, use one of the other output + * methods. This method also cannot be used with fully convolutional networks with different + * output sizes (for example, segmentation on different input image sizes). * * @param iterator Data to pass through the network * @return output for all examples in the iterator, concatenated into a @@ -2780,31 +2983,34 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (firstOutputShape == null) { firstOutputShape = output.shape(); } else { - //Validate that shapes are the same (may not be, for some RNN variable length time series applications) + // Validate that shapes are the same (may not be, for some RNN variable length time series + // applications) long[] currShape = output.shape(); - Preconditions.checkState(firstOutputShape.length == currShape.length, - "Error during forward pass:" + - "different minibatches have different output array ranks - first minibatch shape %s, last minibatch shape %s", - firstOutputShape, currShape); - for (int i = 1; i < currShape.length; - i++) { //Skip checking minibatch dimension, fine if this varies - Preconditions.checkState(firstOutputShape[i] == currShape[i], - "Current output shape does not match first" + - " output array shape at position %s: all dimensions must match other than the first dimension.\n" - + - " For variable length output size/length use cases such as for RNNs with multiple sequence lengths," - + - " use one of the other (non iterator) output methods. First batch output shape: %s, current batch output shape: %s", - i, firstOutputShape, currShape); + Preconditions.checkState( + firstOutputShape.length == currShape.length, + "Error during forward pass:" + + "different minibatches have different output array ranks - first minibatch shape %s, last minibatch shape %s", + firstOutputShape, + currShape); + for (int i = 1; + i < currShape.length; + i++) { // Skip checking minibatch dimension, fine if this varies + Preconditions.checkState( + firstOutputShape[i] == currShape[i], + "Current output shape does not match first" + + " output array shape at position %s: all dimensions must match other than the first dimension.\n" + + " For variable length output size/length use cases such as for RNNs with multiple sequence lengths," + + " use one of the other (non iterator) output methods. First batch output shape: %s, current batch output shape: %s", + i, + firstOutputShape, + currShape); } } } return Nd4j.concat(0, outList.toArray(new INDArray[outList.size()])); } - /** - * Equivalent to {@link #output(DataSetIterator, boolean)} with train=false - */ + /** Equivalent to {@link #output(DataSetIterator, boolean)} with train=false */ public INDArray output(DataSetIterator iterator) { return output(iterator, false); } @@ -2812,7 +3018,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Perform inference and then calculate the F1 score of the output(input) vs. the labels. * - * @param input the input to perform inference with + * @param input the input to perform inference with * @param labels the true labels * @return the score for the given input,label pairs */ @@ -2836,8 +3042,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Sets the input and labels and calculates the score (value of the output layer loss function - * plus l1/l2 if applicable) for the prediction with respect to the true labels
This is - * equivalent to {@link #score(DataSet, boolean)} with training==false. + * plus l1/l2 if applicable) for the prediction with respect to the true labels
+ * This is equivalent to {@link #score(DataSet, boolean)} with training==false. * * @param data the data to score * @return the score for the given input,label pairs @@ -2851,10 +3057,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Sets the input and labels and calculates the score (value of the output layer loss function * plus l1/l2 if applicable) for the prediction with respect to the true labels
* - * @param data data to calculate score for + * @param data data to calculate score for * @param training If true: score during training. If false: score at test time. This can affect - * the application of certain features, such as dropout and dropconnect (which are - * applied at training time only) + * the application of certain features, such as dropout and dropconnect (which are applied at + * training time only) * @return the score (value of the loss function) */ public double score(DataSet data, boolean training) { @@ -2874,40 +3080,54 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (!(getOutputLayer() instanceof IOutputLayer)) { throw new IllegalStateException( - "Cannot calculate score if final layer is not an instance of IOutputLayer. " + - "Final layer is of type: " + getOutputLayer().getClass()); + "Cannot calculate score if final layer is not an instance of IOutputLayer. " + + "Final layer is of type: " + + getOutputLayer().getClass()); } - WorkspaceMode wsm = (training ? getNetConfiguration().getTrainingWorkspaceMode() - : getNetConfiguration().getInferenceWorkspaceMode()); + WorkspaceMode wsm = + (training + ? getNetConfiguration().getTrainingWorkspaceMode() + : getNetConfiguration().getInferenceWorkspaceMode()); LayerWorkspaceMgr mgr; if (wsm == WorkspaceMode.NONE) { mgr = LayerWorkspaceMgr.noWorkspaces(); } else { - mgr = LayerWorkspaceMgr.builder() - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - //TODO we can probably optimize this - .noWorkspaceFor(ArrayType.ACTIVATIONS) - .noWorkspaceFor(ArrayType.INPUT) - .build(); + mgr = + LayerWorkspaceMgr.builder() + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + // TODO we can probably optimize this + .noWorkspaceFor(ArrayType.ACTIVATIONS) + .noWorkspaceFor(ArrayType.INPUT) + .build(); } mgr.setHelperWorkspacePointers(helperWorkspaces); - INDArray inputToOutputLayer = outputOfLayerDetached(training, FwdPassType.STANDARD, - layers.length - 2, data.getFeatures(), - data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null); + INDArray inputToOutputLayer = + outputOfLayerDetached( + training, + FwdPassType.STANDARD, + layers.length - 2, + data.getFeatures(), + data.getFeaturesMaskArray(), + data.getLabelsMaskArray(), + null); if (data.getFeatures().size(0) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } IOutputLayer ol = (IOutputLayer) getOutputLayer(); if (getNetConfiguration().getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = getNetConfiguration().getInputPreProcess(layers.length - 1) - .preProcess(inputToOutputLayer, (int) data.getFeatures().size(0), mgr); + inputToOutputLayer = + getNetConfiguration() + .getInputPreProcess(layers.length - 1) + .preProcess(inputToOutputLayer, (int) data.getFeatures().size(0), mgr); } - ol.setInput(inputToOutputLayer, mgr); //Feedforward doesn't include output layer for efficiency + ol.setInput(inputToOutputLayer, mgr); // Feedforward doesn't include output layer for efficiency ol.setLabels(data.getLabels()); double score; try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { @@ -2939,14 +3159,15 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Calculate the score for each example in a DataSet individually. Unlike {@link #score(DataSet)} * and {@link #score(DataSet, boolean)} this method does not average/sum over examples. This * method allows for examples to be scored individually (at test time only), which may be useful - * for example for autoencoder architectures and the like.
Each row of the output (assuming - * addRegularizationTerms == true) is equivalent to calling score(DataSet) with a single example. + * for example for autoencoder architectures and the like.
+ * Each row of the output (assuming addRegularizationTerms == true) is equivalent to calling + * score(DataSet) with a single example. * - * @param data The data to score + * @param data The data to score * @param addRegularizationTerms If true: add l1/l2 regularization terms (if any) to the score. If - * false: don't add regularization terms + * false: don't add regularization terms * @return An INDArray (column vector) of size input.numRows(); the ith entry is the score (loss - * value) of the ith example + * value) of the ith example */ public INDArray scoreExamples(DataSet data, boolean addRegularizationTerms) { try { @@ -2958,13 +3179,19 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } private INDArray scoreExamplesHelper(DataSet data, boolean addRegularizationTerms) { - INDArray inputLast = outputOfLayerDetached(false, FwdPassType.STANDARD, layers.length - 2, - data.getFeatures(), - data.getFeaturesMaskArray(), data.getLabelsMaskArray(), null); + INDArray inputLast = + outputOfLayerDetached( + false, + FwdPassType.STANDARD, + layers.length - 2, + data.getFeatures(), + data.getFeaturesMaskArray(), + data.getLabelsMaskArray(), + null); setLabels(data.getLabels()); setLayerMaskArrays(data.getFeaturesMaskArray(), data.getLabelsMaskArray()); - //TODO we might want workspaces here? + // TODO we might want workspaces here? LayerWorkspaceMgr mgr = LayerWorkspaceMgr.noWorkspaces(); INDArray out; @@ -2975,9 +3202,10 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (data.getFeatures().size(0) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - inputLast = getNetConfiguration().getInputPreProcess(layers.length - 1) - .preProcess(inputLast, - (int) data.getFeatures().size(0), mgr); + inputLast = + getNetConfiguration() + .getInputPreProcess(layers.length - 1) + .preProcess(inputLast, (int) data.getFeatures().size(0), mgr); } ol.setLabels(data.getLabels()); ol.setInput(inputLast, mgr); @@ -3010,13 +3238,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * @return the score of the model (relative to the objective function) */ @Override - public double score() { + public double getScore() { return score; } - /** - * Intended for developer/internal use - */ + /** Intended for developer/internal use */ public void setScore(double score) { this.score = score; } @@ -3031,71 +3257,80 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (!(getOutputLayer() instanceof IOutputLayer)) { throw new DL4JException( "Cannot calculate gradient and score with respect to labels: final layer is not an IOutputLayer. " - + - "Final layer class: " + getOutputLayer().getClass() - + ". To calculate gradients and fit a network " + - "using backpropagation, the final layer must be an output layer"); + + "Final layer class: " + + getOutputLayer().getClass() + + ". To calculate gradients and fit a network " + + "using backpropagation, the final layer must be an output layer"); } - //Note: Workspace manager is only ose here for score calculation... other workspace managers are used in the + // Note: Workspace manager is only ose here for score calculation... other workspace managers + // are used in the // various FF/backprop methds LayerWorkspaceMgr mgr; if (getNetConfiguration().getTrainingWorkspaceMode() == WorkspaceMode.NONE) { mgr = LayerWorkspaceMgr.noWorkspaces(); } else { - mgr = LayerWorkspaceMgr.builder() - .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) - .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .with(ArrayType.RNN_BP_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, - WS_RNN_LOOP_WORKING_MEM_CONFIG) - .build(); + mgr = + LayerWorkspaceMgr.builder() + .with(ArrayType.INPUT, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.ACTIVATIONS, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG) + .with(ArrayType.FF_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with(ArrayType.BP_WORKING_MEM, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_FF_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .with( + ArrayType.RNN_BP_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM, + WS_RNN_LOOP_WORKING_MEM_CONFIG) + .build(); if (getNetConfiguration().getCacheMode() != null) { - //For now: store cache mode activations in activations workspace + // For now: store cache mode activations in activations workspace mgr.setWorkspace(ArrayType.FF_CACHE, WS_ALL_LAYERS_ACT, WS_ALL_LAYERS_ACT_CONFIG); } } boolean tbptt = getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT; - FwdPassType fwdType = (tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE - : FwdPassType.STANDARD); + FwdPassType fwdType = + (tbptt ? FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE : FwdPassType.STANDARD); synchronizeIterEpochCounts(); - //Calculate activations (which are stored in each layer, and used in backprop) + // Calculate activations (which are stored in each layer, and used in backprop) try (MemoryWorkspace ws = mgr.notifyScopeEntered(ArrayType.ACTIVATIONS)) { - //First: do a feed-forward through the network - //Note that we don't actually need to do the full forward pass through the output layer right now; but we do + // First: do a feed-forward through the network + // Note that we don't actually need to do the full forward pass through the output layer right + // now; but we do // need the input to the output layer to be set (such that backprop can be done) - List activations = ffToLayerActivationsInWs(layers.length - 2, fwdType, tbptt, - input, mask, null); + List activations = + ffToLayerActivationsInWs(layers.length - 2, fwdType, tbptt, input, mask, null); if (!trainingListeners.isEmpty()) { - //TODO: We possibly do want output layer activations in some cases here... + // TODO: We possibly do want output layer activations in some cases here... for (TrainingListener tl : trainingListeners) { tl.onForwardPass(this, activations); } } INDArray inputToOutputLayer = activations.get(activations.size() - 1); if (getNetConfiguration().getInputPreProcess(layers.length - 1) != null) { - inputToOutputLayer = getNetConfiguration().getInputPreProcess(layers.length - 1) - .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); - //Validate activations location + inputToOutputLayer = + getNetConfiguration() + .getInputPreProcess(layers.length - 1) + .preProcess(inputToOutputLayer, getInputMiniBatchSize(), mgr); + // Validate activations location } getOutputLayer().setInput(inputToOutputLayer, mgr); - //Then: compute gradients + // Then: compute gradients Pair pair = calcBackpropGradients(null, true, false, false); this.gradient = (pair == null ? null : pair.getFirst()); - //Calculate score + // Calculate score try (MemoryWorkspace wsFF = mgr.notifyScopeEntered(ArrayType.FF_WORKING_MEM)) { double r = calcRegularizationScore(true); score = ((IOutputLayer) getOutputLayer()).computeScore(r, true, mgr); } - //Listeners + // Listeners if (!trainingListeners.isEmpty()) { try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { for (TrainingListener tl : trainingListeners) { @@ -3105,13 +3340,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - //Clear the post noise/dropconnect parameters on the output layer + // Clear the post noise/dropconnect parameters on the output layer getOutputLayer().clearNoiseWeightParams(); } - /** - * Clear the inputs. Clears optimizer state. - */ + /** Clear the inputs. Clears optimizer state. */ public void clear() { for (Layer layer : layers) { layer.clear(); @@ -3147,15 +3380,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return ret; } - - /** - * See {@link #setParams(INDArray)} - */ + /** See {@link #setParams(INDArray)} */ public void setParameters(INDArray params) { setParams(params); } - public INDArray getLabels() { return labels; } @@ -3215,9 +3444,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } public Layer getLayer(int i) { - Preconditions.checkArgument(i >= 0 && i < layers.length, - "Invalid layer index: layer index must be 0" + - " to %s (inclusive), got index %s", layers.length - 1, i); + Preconditions.checkArgument( + i >= 0 && i < layers.length, + "Invalid layer index: layer index must be 0" + " to %s (inclusive), got index %s", + layers.length - 1, + i); return layers[i]; } @@ -3268,19 +3499,18 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial throw new UnsupportedOperationException("Not supported"); } - //========== - //LayerConfiguration methods + // ========== + // LayerConfiguration methods @Override - public Pair feedForwardMaskArray(INDArray maskArray, - MaskState currentMaskState, - int minibatchSize) { + public Pair feedForwardMaskArray( + INDArray maskArray, MaskState currentMaskState, int minibatchSize) { if (maskArray == null) { for (int i = 0; i < layers.length; i++) { layers[i].feedForwardMaskArray(null, null, minibatchSize); } } else { - //Do a forward pass through each preprocessor and layer + // Do a forward pass through each preprocessor and layer for (int i = 0; i < layers.length; i++) { InputPreProcessor preProcessor = getNetConfiguration().getInputPreProcess(i); @@ -3321,23 +3551,19 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return Type.MULTILAYER; } - /** - * Equivalent to {@link #output(INDArray)} using the input set via {@link #setInput(INDArray)} - */ + /** Equivalent to {@link #output(INDArray)} using the input set via {@link #setInput(INDArray)} */ public INDArray activate(TrainingMode training) { return output(input, training == TrainingMode.TRAIN); } - /** - * Equivalent to {@link #output(INDArray, TrainingMode)} - */ + /** Equivalent to {@link #output(INDArray, TrainingMode)} */ public INDArray activate(INDArray input, TrainingMode training) { return output(input, training == TrainingMode.TRAIN); } @Override - public Pair backpropGradient(INDArray epsilon, - LayerWorkspaceMgr workspaceMgr) { + public Pair backpropGradient( + INDArray epsilon, LayerWorkspaceMgr workspaceMgr) { if (getOutputLayer() instanceof IOutputLayer) { throw new UnsupportedOperationException( "Cannot calculate gradients based on epsilon with OutputLayer"); @@ -3408,7 +3634,6 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } // Update layerwise gradient view setBackpropGradientsViewArray(gradient.gradient()); - } @Override @@ -3445,50 +3670,61 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * If this MultiLayerNetwork contains one or more RNN layers: conduct forward pass (prediction) * but using previous stored state for any RNN layers. The activations for the final step are also - * stored in the RNN layers for use next time rnnTimeStep() is called.
This method can be used - * to generate output one or more steps at a time instead of always having to do forward pass from - * t=0. Example uses are for streaming data, and for generating samples from network output one - * step at a time (where samples are then fed back into the network as input)
If no previous - * state is present in RNN layers (i.e., initially or after calling rnnClearPreviousState()), the - * default initialization (usually 0) is used.
Supports mini-batch (i.e., multiple - * predictions/forward pass in parallel) as well as for single examples.
+ * stored in the RNN layers for use next time rnnTimeStep() is called.
+ * This method can be used to generate output one or more steps at a time instead of always having + * to do forward pass from t=0. Example uses are for streaming data, and for generating samples + * from network output one step at a time (where samples are then fed back into the network as + * input)
+ * If no previous state is present in RNN layers (i.e., initially or after calling + * rnnClearPreviousState()), the default initialization (usually 0) is used.
+ * Supports mini-batch (i.e., multiple predictions/forward pass in parallel) as well as for single + * examples.
* * @param input Input to network. May be for one or multiple time steps. For single time step: - * input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. - * miniBatchSize=1 for single example.
For multiple time steps: - * [miniBatchSize,inputSize,inputTimeSeriesLength] + * input has shape [miniBatchSize,inputSize] or [miniBatchSize,inputSize,1]. miniBatchSize=1 + * for single example.
+ * For multiple time steps: [miniBatchSize,inputSize,inputTimeSeriesLength] * @return Output activations. If output is RNN layer (such as RnnOutputLayer): if input has shape - * [miniBatchSize,inputSize] i.e., is 2d, output has shape [miniBatchSize,outputSize] (i.e., also - * 2d).
Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using - * RnnOutputLayer. + * [miniBatchSize,inputSize] i.e., is 2d, output has shape [miniBatchSize,outputSize] (i.e., + * also 2d).
+ * Otherwise output is 3d [miniBatchSize,outputSize,inputTimeSeriesLength] when using + * RnnOutputLayer. * @see #rnnTimeStep(INDArray, MemoryWorkspace) For outputting the activations in the specified - * workspace + * workspace */ public INDArray rnnTimeStep(INDArray input) { return rnnTimeStep(input, null); } /** - * See {@link #rnnTimeStep(INDArray)} for details
If no memory workspace is provided, the - * output will be detached (not in any workspace).
If a memory workspace is provided, the - * output activation array (i.e., the INDArray returned by this method) will be placed in the - * specified workspace. This workspace must be opened by the user before calling this method - and - * the user is responsible for (a) closing this workspace, and (b) ensuring the output array is - * not used out of scope (i.e., not used after closing the workspace to which it belongs - as this - * is likely to cause either an exception when used, or a crash). + * See {@link #rnnTimeStep(INDArray)} for details
+ * If no memory workspace is provided, the output will be detached (not in any workspace).
+ * If a memory workspace is provided, the output activation array (i.e., the INDArray returned by + * this method) will be placed in the specified workspace. This workspace must be opened by the + * user before calling this method - and the user is responsible for (a) closing this workspace, + * and (b) ensuring the output array is not used out of scope (i.e., not used after closing the + * workspace to which it belongs - as this is likely to cause either an exception when used, or a + * crash). * - * @param input Input activations + * @param input Input activations * @param outputWorkspace Output workspace. May be null * @return The output/activations from the network (either detached or in the specified workspace - * if provided) + * if provided) */ public INDArray rnnTimeStep(INDArray input, MemoryWorkspace outputWorkspace) { try { boolean inputIs2d = input.rank() == 2; - INDArray out = outputOfLayerDetached(false, FwdPassType.RNN_TIMESTEP, layers.length - 1, - input, null, null, outputWorkspace); + INDArray out = + outputOfLayerDetached( + false, + FwdPassType.RNN_TIMESTEP, + layers.length - 1, + input, + null, + null, + outputWorkspace); if (inputIs2d && out.rank() == 3 && layers[layers.length - 1].type() == Type.RECURRENT) { - //Return 2d output with shape [miniBatchSize,nOut] + // Return 2d output with shape [miniBatchSize,nOut] // instead of 3d output with shape [miniBatchSize,nOut,1] return out.tensorAlongDimension(0, 1, 0); } @@ -3540,9 +3776,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial r.rnnSetPreviousState(state); } - /** - * Clear the previous state of the RNN layers (if any). - */ + /** Clear the previous state of the RNN layers (if any). */ public void rnnClearPreviousState() { if (layers == null) { return; @@ -3560,21 +3794,28 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:
(a) - * like rnnTimeStep does forward pass using stored state for RNN layers, and
(b) unlike - * rnnTimeStep does not modify the RNN layer state
Therefore multiple calls to this method - * with the same input should have the same output.
Typically used during training only. Use - * rnnTimeStep for prediction/forward pass at test time. + * Similar to rnnTimeStep and feedForward() methods. Difference here is that this method:
+ * (a) like rnnTimeStep does forward pass using stored state for RNN layers, and
+ * (b) unlike rnnTimeStep does not modify the RNN layer state
+ * Therefore multiple calls to this method with the same input should have the same output.
+ * Typically used during training only. Use rnnTimeStep for prediction/forward pass at test time. * - * @param input Input to network - * @param training Whether training or not + * @param input Input to network + * @param training Whether training or not * @param storeLastForTBPTT set to true if used as part of truncated BPTT training * @return Activations for each layer (including input, as per feedforward() etc) */ - public List rnnActivateUsingStoredState(INDArray input, boolean training, - boolean storeLastForTBPTT) { - return ffToLayerActivationsDetached(training, FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, - storeLastForTBPTT, layers.length - 1, input, mask, null, false); + public List rnnActivateUsingStoredState( + INDArray input, boolean training, boolean storeLastForTBPTT) { + return ffToLayerActivationsDetached( + training, + FwdPassType.RNN_ACTIVATE_WITH_STORED_STATE, + storeLastForTBPTT, + layers.length - 1, + input, + mask, + null, + false); } /** @@ -3586,12 +3827,15 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return getUpdater(true); } - /** - * Set the updater for the MultiLayerNetwork - */ + /** Set the updater for the MultiLayerNetwork */ public void setUpdater(Updater updater) { if (solver == null) { - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this).build(); + solver = + new Solver.Builder() + .configure(getNetConfiguration()) + .listeners(this.getTrainingListeners()) + .model(this) + .build(); } solver.getOptimizer().setUpdater(updater); } @@ -3599,9 +3843,13 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial public Updater getUpdater(boolean initializeIfReq) { if (solver == null && initializeIfReq) { synchronized (this) { - if (solver == null) { //May have been created while waiting for lock - solver = new Solver.Builder().configure(getNetConfiguration()).listeners(getListeners()).model(this) - .build(); + if (solver == null) { // May have been created while waiting for lock + solver = + new Solver.Builder() + .configure(getNetConfiguration()) + .listeners(this.getTrainingListeners()) + .model(this) + .build(); solver.getOptimizer().setUpdater(UpdaterCreator.getUpdater(this)); } } @@ -3615,17 +3863,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Set the mask arrays for features and labels. Mask arrays are typically used in situations such * as one-to-many and many-to-one learning with recurrent neural networks, as well as for - * supporting time series of varying lengths within the same minibatch.
For example, with RNN - * data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and outputs of shape - * [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have shape - * [miniBatchSize,timeSeriesLength] and contain values 0 or 1 at each element (to specify whether - * a given input/example is present - or merely padding - at a given time step).
- * NOTE: This method is not usually used directly. Instead, methods such as - * {@link #feedForward(INDArray, INDArray, INDArray)} and - * {@link #output(INDArray, boolean, INDArray, INDArray)} handle setting of masking internally. + * supporting time series of varying lengths within the same minibatch.
+ * For example, with RNN data sets with input of shape [miniBatchSize,nIn,timeSeriesLength] and + * outputs of shape [miniBatchSize,nOut,timeSeriesLength], the features and mask arrays will have + * shape [miniBatchSize,timeSeriesLength] and contain values 0 or 1 at each element (to specify + * whether a given input/example is present - or merely padding - at a given time step).
+ * NOTE: This method is not usually used directly. Instead, methods such as {@link + * #feedForward(INDArray, INDArray, INDArray)} and {@link #output(INDArray, boolean, INDArray, + * INDArray)} handle setting of masking internally. * * @param featuresMaskArray Mask array for features (input) - * @param labelsMaskArray Mask array for labels (output) + * @param labelsMaskArray Mask array for labels (output) * @see #clearLayerMaskArrays() */ public void setLayerMaskArrays(INDArray featuresMaskArray, INDArray labelsMaskArray) { @@ -3634,29 +3882,28 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (featuresMaskArray.size(0) > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - //New approach: use feedForwardMaskArray method + // New approach: use feedForwardMaskArray method feedForwardMaskArray(featuresMaskArray, MaskState.Active, (int) featuresMaskArray.size(0)); + /* + //feedforward layers below a RNN layer: need the input (features) mask array + //Reason: even if the time series input is zero padded, the output from the dense layers are + // non-zero (i.e., activationFunction(0*weights + bias) != 0 in general) + //This assumes that the time series input is masked - i.e., values are 0 at the padded time steps, + // so we don't need to do anything for the recurrent layer - /* - //feedforward layers below a RNN layer: need the input (features) mask array - //Reason: even if the time series input is zero padded, the output from the dense layers are - // non-zero (i.e., activationFunction(0*weights + bias) != 0 in general) - //This assumes that the time series input is masked - i.e., values are 0 at the padded time steps, - // so we don't need to do anything for the recurrent layer + //Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order + // as is done for 3d -> 2d time series reshaping + INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray); - //Now, if mask array is 2d -> need to reshape to 1d (column vector) in the exact same order - // as is done for 3d -> 2d time series reshaping - INDArray reshapedFeaturesMask = TimeSeriesUtils.reshapeTimeSeriesMaskToVector(featuresMaskArray); + for( int i=0; i See {@link #setLayerMaskArrays(INDArray, INDArray)} - * for details on mask arrays. + * Remove the mask arrays from all layers.
+ * See {@link #setLayerMaskArrays(INDArray, INDArray)} for details on mask arrays. */ public void clearLayerMaskArrays() { for (Layer layer : layers) { @@ -3720,7 +3967,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * @deprecated To be removed - use {@link #evaluateROC(DataSetIterator, int)} to enforce selection - * of appropriate ROC/threshold configuration + * of appropriate ROC/threshold configuration */ @Deprecated public T evaluateROC(DataSetIterator iterator) { @@ -3731,23 +3978,23 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Evaluate the network (must be a binary classifier) on the specified data, using the {@link ROC} * class * - * @param iterator Data to evaluate on + * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROC} - see that class for - * details. + * details. * @return ROC evaluation on the given dataset */ public T evaluateROC(DataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(); if (getNetConfiguration().isValidateOutputLayerConfig()) { - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), - ROC.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation( + outputLayer.getLayerConfiguration(), ROC.class); } return (T) doEvaluation(iterator, new org.deeplearning4j.eval.ROC(rocThresholdSteps))[0]; } /** * @deprecated To be removed - use {@link #evaluateROCMultiClass(DataSetIterator, int)} to enforce - * selection of appropriate ROC/threshold configuration + * selection of appropriate ROC/threshold configuration */ @Deprecated public T evaluateROCMultiClass(DataSetIterator iterator) { @@ -3757,19 +4004,19 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Evaluate the network on the specified data, using the {@link ROCMultiClass} class * - * @param iterator Data to evaluate on + * @param iterator Data to evaluate on * @param rocThresholdSteps Number of threshold steps to use with {@link ROCMultiClass} * @return Multi-class ROC evaluation on the given dataset */ - public T evaluateROCMultiClass(DataSetIterator iterator, - int rocThresholdSteps) { + public T evaluateROCMultiClass( + DataSetIterator iterator, int rocThresholdSteps) { Layer outputLayer = getOutputLayer(); if (getNetConfiguration().isValidateOutputLayerConfig()) { - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), - ROCMultiClass.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation( + outputLayer.getLayerConfiguration(), ROCMultiClass.class); } - return (T) doEvaluation(iterator, - new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0]; + return (T) + doEvaluation(iterator, new org.deeplearning4j.eval.ROCMultiClass(rocThresholdSteps))[0]; } /** @@ -3786,8 +4033,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - public T[] doEvaluationHelper(DataSetIterator iterator, - T... evaluations) { + public T[] doEvaluationHelper( + DataSetIterator iterator, T... evaluations) { if (!iterator.hasNext() && iterator.resetSupported()) { iterator.reset(); } @@ -3796,22 +4043,26 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial iterator.asyncSupported() ? new AsyncDataSetIterator(iterator, 2, true) : iterator; WorkspaceMode cMode = getNetConfiguration().getTrainingWorkspaceMode(); - getNetConfiguration().setTrainingWorkspaceMode( - getNetConfiguration().getInferenceWorkspaceMode()); + getNetConfiguration() + .setTrainingWorkspaceMode(getNetConfiguration().getInferenceWorkspaceMode()); - //First: let's determine if we should do 'split feed forward' for long time series - //The idea: RNN 20k time steps. Train using TBPTT length 100 -> 200 segments of length 100. If we naively - // just use .output(INDArray) here, then our memory requirements are 200x larger than if we did the same + // First: let's determine if we should do 'split feed forward' for long time series + // The idea: RNN 20k time steps. Train using TBPTT length 100 -> 200 segments of length 100. If + // we naively + // just use .output(INDArray) here, then our memory requirements are 200x larger than if we did + // the same // evaluation in segments... - //Only do this if TBPTT is enabled - if not, it means we can train without TBPTT and hence should be able + // Only do this if TBPTT is enabled - if not, it means we can train without TBPTT and hence + // should be able // to test without splitting also - boolean useRnnSegments = (getNetConfiguration().getBackpropType() - == BackpropType.TruncatedBPTT); + boolean useRnnSegments = + (getNetConfiguration().getBackpropType() == BackpropType.TruncatedBPTT); MemoryWorkspace outputWs; if (getNetConfiguration().getInferenceWorkspaceMode() == WorkspaceMode.ENABLED) { - outputWs = Nd4j.getWorkspaceManager() - .getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM); + outputWs = + Nd4j.getWorkspaceManager() + .getWorkspaceForCurrentThread(WS_ALL_LAYERS_ACT_CONFIG, WS_OUTPUT_MEM); } else { outputWs = new DummyWorkspace(); } @@ -3830,10 +4081,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial List meta = next.getExampleMetaData(); if (!useRnnSegments) { - //Standard/non-RNN case: + // Standard/non-RNN case: try (MemoryWorkspace ws = outputWs.notifyScopeEntered()) { - INDArray out = outputOfLayerDetached(false, FwdPassType.STANDARD, layers.length - 1, - features, fMask, lMask, ws); + INDArray out = + outputOfLayerDetached( + false, FwdPassType.STANDARD, layers.length - 1, features, fMask, lMask, ws); try (MemoryWorkspace wsO = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { for (T evaluation : evaluations) { @@ -3844,12 +4096,13 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } else { rnnClearPreviousState(); - //Get subset of features and labels: + // Get subset of features and labels: val fwdLen = getNetConfiguration().getTbpttFwdLength(); val tsLength = features.size(2); long nSubsets = tsLength / fwdLen; if (tsLength % fwdLen != 0) { - nSubsets++; //Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size 100, 1 of size 20) + nSubsets++; // Example: 100 fwdLen with timeSeriesLength=120 -> want 2 subsets (1 of size + // 100, 1 of size 20) } for (int i = 0; i < nSubsets; i++) { val startTimeIdx = i * fwdLen; @@ -3858,8 +4111,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial if (endTimeIdx > Integer.MAX_VALUE) { throw new ND4JArraySizeException(); } - INDArray[] subsets = getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, - fMask, lMask); + INDArray[] subsets = + getSubsetsForTbptt(startTimeIdx, (int) endTimeIdx, features, labels, fMask, lMask); setLayerMaskArrays(subsets[2], subsets[3]); @@ -3874,7 +4127,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } } - //Clear inputs, masks etc. Important to avoid leaking invalidated/out of scope arrays between iterations + // Clear inputs, masks etc. Important to avoid leaking invalidated/out of scope arrays between + // iterations clearLayersStates(); } @@ -3893,7 +4147,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * * @param iterator Data to undertake evaluation on * @return Evaluation object, summarizing the results of the evaluation on the provided - * DataSetIterator + * DataSetIterator */ public Evaluation evaluate(DataSetIterator iterator, List labelsList) { return evaluate(iterator, labelsList, 1); @@ -3924,8 +4178,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial fit(ds); } else { throw new DL4JInvalidInputException( - "MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." + - "Please consider use of ComputationGraph"); + "MultiLayerNetwork can't handle MultiDataSet with more than 1 features or labels array." + + "Please consider use of ComputationGraph"); } } @@ -3934,15 +4188,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * number of epochs. Equvalent to calling {@link #fit(MultiDataSetIterator)} numEpochs times in a * loop * - * @param iterator Training data (DataSetIterator). Iterator must support resetting + * @param iterator Training data (DataSetIterator). Iterator must support resetting * @param numEpochs Number of training epochs, >= 1 */ public void fit(@NonNull MultiDataSetIterator iterator, int numEpochs) { - Preconditions.checkArgument(numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", - numEpochs); - Preconditions.checkArgument(numEpochs == 1 || iterator.resetSupported(), - "Cannot perform multiple epochs training using" + - "iterator has does not support resetting (iterator.resetSupported() returned false)"); + Preconditions.checkArgument( + numEpochs > 0, "Number of epochs much be > 0. Got numEpochs = %s", numEpochs); + Preconditions.checkArgument( + numEpochs == 1 || iterator.resetSupported(), + "Cannot perform multiple epochs training using" + + "iterator has does not support resetting (iterator.resetSupported() returned false)"); for (int i = 0; i < numEpochs; i++) { fit(iterator); @@ -3950,9 +4205,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Perform minibatch training on all minibatches in the MultiDataSetIterator.
Note: The - * MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array (as - * MultiLayerNetwork only supports 1 input and 1 output) + * Perform minibatch training on all minibatches in the MultiDataSetIterator.
+ * Note: The MultiDataSets in the MultiDataSetIterator must have exactly 1 input and output array + * (as MultiLayerNetwork only supports 1 input and 1 output) * * @param iterator Training data (DataSetIterator). Iterator must support resetting */ @@ -3970,11 +4225,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * Evaluate the network (for classification) on the provided data set, with top N accuracy in * addition to standard accuracy. For 'standard' accuracy evaluation only, use topN = 1 * - * @param iterator Iterator (data) to evaluate on + * @param iterator Iterator (data) to evaluate on * @param labelsList List of labels. May be null. - * @param topN N value for top N accuracy evaluation + * @param topN N value for top N accuracy evaluation * @return Evaluation object, summarizing the results of the evaluation on the provided - * DataSetIterator + * DataSetIterator */ public Evaluation evaluate(DataSetIterator iterator, List labelsList, int topN) { if (layers == null || !(getOutputLayer() instanceof IOutputLayer)) { @@ -3984,13 +4239,13 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial try { labelsList = iterator.getLabels(); } catch (Throwable t) { - } //Ignore, maybe UnsupportedOperationException etc + } // Ignore, maybe UnsupportedOperationException etc } Layer outputLayer = getOutputLayer(); if (getNetConfiguration().isValidateOutputLayerConfig()) { - OutputLayerUtil.validateOutputLayerForClassifierEvaluation(outputLayer.getLayerConfiguration(), - Evaluation.class); + OutputLayerUtil.validateOutputLayerForClassifierEvaluation( + outputLayer.getLayerConfiguration(), Evaluation.class); } Evaluation e = new org.deeplearning4j.eval.Evaluation(labelsList, topN); @@ -4036,10 +4291,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial List lines = new ArrayList<>(); if (inputType == null) { - lines.add(new String[]{"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape"}); + lines.add(new String[] {"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape"}); } else { - lines.add(new String[]{"LayerName (LayerType)", "nIn,nOut", "TotalParams", "ParamsShape", - "InputShape", "OutputShape"}); + lines.add( + new String[] { + "LayerName (LayerType)", + "nIn,nOut", + "TotalParams", + "ParamsShape", + "InputShape", + "OutputShape" + }); } int[] maxLength = new int[inputType == null ? 4 : 6]; String[] header = lines.get(0); @@ -4070,8 +4332,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial inputType = preProcessor.getOutputType(inputType); inShape += "--> " + inputType.toString(); } - outType = currentLayer.getLayerConfiguration() - .getOutputType(currentLayer.getIndex(), inputType); + outType = + currentLayer.getLayerConfiguration().getOutputType(currentLayer.getIndex(), inputType); outShape = outType.toString(); inputType = outType; } @@ -4084,8 +4346,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } else { try { in = String.valueOf(((FeedForwardLayer) currentLayer.getLayerConfiguration()).getNIn()); - out = String.valueOf( - ((FeedForwardLayer) currentLayer.getLayerConfiguration()).getNOut()); + out = + String.valueOf(((FeedForwardLayer) currentLayer.getLayerConfiguration()).getNOut()); } catch ( Exception e) { // Some layers, like PReLU, are just BaseLayers (but have parameters) } @@ -4099,17 +4361,24 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } if (currentLayer instanceof FrozenLayer) { frozenParams += currentLayer.numParams(); - classNameArr = ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName() - .split("\\."); + classNameArr = + ((FrozenLayer) currentLayer).getInsideLayer().getClass().getName().split("\\."); className = "Frozen " + classNameArr[classNameArr.length - 1]; } String[] line; if (inputType == null) { - line = new String[]{name + " (" + className + ")", in + "," + out, paramCount, paramShape}; + line = new String[] {name + " (" + className + ")", in + "," + out, paramCount, paramShape}; } else { - line = new String[]{name + " (" + className + ")", in + "," + out, paramCount, paramShape, - inShape, outShape}; + line = + new String[] { + name + " (" + className + ")", + in + "," + out, + paramCount, + paramShape, + inShape, + outShape + }; } for (int i = 0; i < line.length; i++) { maxLength[i] = Math.max(maxLength[i], line[i] == null ? 0 : line[i].length()); @@ -4133,8 +4402,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial sbFormat.append("\n"); String format = sbFormat.toString(); - ret.append(StringUtils.repeat("=", totalLength)) - .append("\n"); + ret.append(StringUtils.repeat("=", totalLength)).append("\n"); boolean first = true; for (String[] line : lines) { @@ -4147,9 +4415,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } ret.append(StringUtils.repeat("-", totalLength)); - ret.append(String.format("\n%30s %,d", "Total Parameters: ", params().length())); + ret.append(String.format("\n%30s %,d", "Total Parameters: ", getModelParams().length())); ret.append( - String.format("\n%30s %,d", "Trainable Parameters: ", params().length() - frozenParams)); + String.format( + "\n%30s %,d", + "ITrainableLayer Parameters: ", getModelParams().length() - frozenParams)); ret.append(String.format("\n%30s %,d", "Frozen Parameters: ", frozenParams)); ret.append("\n"); ret.append(StringUtils.repeat("=", totalLength)); @@ -4162,9 +4432,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * minibatch size. Note that when using workspaces or CuDNN, the network should be trained for * some iterations so that the memory workspaces have time to initialize. Without this, the memory * requirements during training may be underestimated. - *

- * Note also that this is the same information that is generated during an OOM crash when training - * or performing inference. + * + *

Note also that this is the same information that is generated during an OOM crash when + * training or performing inference. * * @param minibatch Minibatch size to estimate memory for * @param inputType Input type to the network @@ -4174,9 +4444,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial return CrashReportingUtil.generateMemoryStatus(this, minibatch, inputType); } - /** - * This method just makes sure there's no state preserved within layers - */ + /** This method just makes sure there's no state preserved within layers */ public void clearLayersStates() { for (Layer layer : layers) { layer.clear(); @@ -4186,14 +4454,15 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Increment the epoch count (in the underlying {@link NeuralNetConfiguration} by 1). Note that - * this is done automatically when using iterator-based fitting methods, such as - * {@link #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, + * this is done automatically when using iterator-based fitting methods, such as {@link + * #fit(DataSetIterator)}. However, when using non-iterator fit methods (DataSet, * INDArray/INDArray etc), the network has no way to know when one epoch ends and another starts. - * In such situations, this method can be used to increment the epoch counter.
Note that the - * epoch counter is used for situations such as some learning rate schedules, and the like. - *

- * The current epoch count can be obtained using - * {@code NeuralNetConfiguration.getLayerwiseConfiguration().getEpochCount()} + * In such situations, this method can be used to increment the epoch counter.
+ * Note that the epoch counter is used for situations such as some learning rate schedules, and + * the like. + * + *

The current epoch count can be obtained using {@code + * NeuralNetConfiguration.getLayerwiseConfiguration().getEpochCount()} */ public void incrementEpochCount() { getNetConfiguration().setEpochCount(getNetConfiguration().getEpochCount() + 1); @@ -4201,7 +4470,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } protected void synchronizeIterEpochCounts() { - //TODO: this is necessary for some schedules - but the redundant values are a little ugly... + // TODO: this is necessary for some schedules - but the redundant values are a little ugly... int currIter = getIterationCount(); int currEpoch = getEpochCount(); for (Layer l : layers) { @@ -4226,9 +4495,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Save the MultiLayerNetwork to a file. Restore using {@link #load(File, boolean)}. * - * @param f File to save the network to + * @param f File to save the network to * @param saveUpdater If true: save the updater (i.e., the state array for momentum/Adam/rmsprop - * etc), which should usually be saved if further training is required + * etc), which should usually be saved if further training is required * @see ModelSerializer ModelSerializer for more details (and saving/loading via streams) * @see #save(File, boolean) */ @@ -4255,14 +4524,16 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * @return The network, set to use the specified datatype for the parameters and activations */ public MultiLayerNetwork convertDataType(@NonNull DataType dataType) { - Preconditions.checkState(dataType.isFPType(), - "Invalid DataType: %s. Can only convert network to a floating point type", dataType); - if (dataType == params().dataType()) { + Preconditions.checkState( + dataType.isFPType(), + "Invalid DataType: %s. Can only convert network to a floating point type", + dataType); + if (dataType == getModelParams().dataType()) { return this; } try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - INDArray newParams = params().castTo(dataType); + INDArray newParams = getModelParams().castTo(dataType); String jsonConfig = getNetConfiguration().toJson(); NeuralNetConfiguration newConf = NeuralNetConfiguration.fromJson(jsonConfig); newConf.setDataType(dataType); @@ -4297,9 +4568,9 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Set the learning rate schedule for all layers in the network to the specified schedule. This * schedule will replace any/all existing schedules, and also any fixed learning rate values.
- * Note that the iteration/epoch counts will not be reset. Use - * {@link NeuralNetConfiguration#setIterationCount(int)} and - * {@link NeuralNetConfiguration#setEpochCount(int)} if this is required + * Note that the iteration/epoch counts will not be reset. Use {@link + * NeuralNetConfiguration#setIterationCount(int)} and {@link + * NeuralNetConfiguration#setEpochCount(int)} if this is required * * @param newLr New learning rate schedule for all layers * @see #setLearningRate(ISchedule) @@ -4320,7 +4591,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial * to be set to a new LR * * @param layerNumber Number of the layer to set the LR for - * @param newLr New learning rate for a single layer + * @param newLr New learning rate for a single layer * @see #setLearningRate(ISchedule) * @see #setLearningRate(int, double) */ @@ -4331,13 +4602,15 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Set the learning rate schedule for a single layer in the network to the specified value.
* Note also that {@link #setLearningRate(ISchedule)} should also be used in preference, when all - * layers need to be set to a new LR schedule.
This schedule will replace any/all existing - * schedules, and also any fixed learning rate values.
Note also that the iteration/epoch - * counts will not be reset. Use {@link NeuralNetConfiguration#setIterationCount(int)} and - * {@link NeuralNetConfiguration#setEpochCount(int)} if this is required + * layers need to be set to a new LR schedule.
+ * This schedule will replace any/all existing schedules, and also any fixed learning rate values. + *
+ * Note also that the iteration/epoch counts will not be reset. Use {@link + * NeuralNetConfiguration#setIterationCount(int)} and {@link + * NeuralNetConfiguration#setEpochCount(int)} if this is required * * @param layerNumber Number of the layer to set the LR schedule for - * @param newLr New learning rate for a single layer + * @param newLr New learning rate for a single layer * @see #setLearningRate(ISchedule) * @see #setLearningRate(int, double) */ @@ -4358,11 +4631,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Return the layer size (number of units) for the specified layer.
Note that the meaning of - * the "layer size" can depend on the type of layer. For example:
- DenseLayer, OutputLayer, - * recurrent layers: number of units (nOut configuration option)
- ConvolutionLayer: the - * channels (number of channels)
- Subsampling layers, global pooling layers, etc: size of 0 - * is always returned
+ * Return the layer size (number of units) for the specified layer.
+ * Note that the meaning of the "layer size" can depend on the type of layer. For example:
+ * - DenseLayer, OutputLayer, recurrent layers: number of units (nOut configuration option)
+ * - ConvolutionLayer: the channels (number of channels)
+ * - Subsampling layers, global pooling layers, etc: size of 0 is always returned
* * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive * @return Size of the layer @@ -4370,8 +4643,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial public int layerSize(int layer) { if (layer < 0 || layer > layers.length) { throw new IllegalArgumentException( - "Invalid layer index: " + layer + ". LayerConfiguration index must be between 0 and " - + (layers.length - 1) + " inclusive"); + "Invalid layer index: " + + layer + + ". LayerConfiguration index must be between 0 and " + + (layers.length - 1) + + " inclusive"); } LayerConfiguration conf = layers[layer].getLayerConfiguration(); if (conf == null || !(conf instanceof FeedForwardLayer)) { @@ -4386,12 +4662,12 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } /** - * Return the input size (number of inputs) for the specified layer.
Note that the meaning of - * the "input size" can depend on the type of layer. For example:
- DenseLayer, OutputLayer, - * etc: the feature vector size (nIn configuration option)
- Recurrent layers: the feature - * vector size per time step (nIn configuration option)
- ConvolutionLayer: the - * channels (number of channels)
- Subsampling layers, global pooling layers, etc: size of 0 - * is always returned
+ * Return the input size (number of inputs) for the specified layer.
+ * Note that the meaning of the "input size" can depend on the type of layer. For example:
+ * - DenseLayer, OutputLayer, etc: the feature vector size (nIn configuration option)
+ * - Recurrent layers: the feature vector size per time step (nIn configuration option)
+ * - ConvolutionLayer: the channels (number of channels)
+ * - Subsampling layers, global pooling layers, etc: size of 0 is always returned
* * @param layer Index of the layer to get the size of. Must be in range 0 to nLayers-1 inclusive * @return Size of the layer @@ -4399,8 +4675,11 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial public int layerInputSize(int layer) { if (layer < 0 || layer > layers.length) { throw new IllegalArgumentException( - "Invalid layer index: " + layer + ". LayerConfiguration index must be between 0 and " - + (layers.length - 1) + " inclusive"); + "Invalid layer index: " + + layer + + ". LayerConfiguration index must be between 0 and " + + (layers.length - 1) + + " inclusive"); } LayerConfiguration conf = layers[layer].getLayerConfiguration(); if (conf == null || !(conf instanceof FeedForwardLayer)) { @@ -4416,42 +4695,34 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial /** * Indicates whether some other object is "equal to" this one. - *

- * The {@code equals} method implements an equivalence relation on non-null object references: + * + *

The {@code equals} method implements an equivalence relation on non-null object references: + * *

    - *
  • It is reflexive: for any non-null reference value - * {@code x}, {@code x.equals(x)} should return - * {@code true}. - *
  • It is symmetric: for any non-null reference values - * {@code x} and {@code y}, {@code x.equals(y)} - * should return {@code true} if and only if - * {@code y.equals(x)} returns {@code true}. - *
  • It is transitive: for any non-null reference values - * {@code x}, {@code y}, and {@code z}, if - * {@code x.equals(y)} returns {@code true} and - * {@code y.equals(z)} returns {@code true}, then - * {@code x.equals(z)} should return {@code true}. - *
  • It is consistent: for any non-null reference values - * {@code x} and {@code y}, multiple invocations of - * {@code x.equals(y)} consistently return {@code true} - * or consistently return {@code false}, provided no - * information used in {@code equals} comparisons on the - * objects is modified. - *
  • For any non-null reference value {@code x}, - * {@code x.equals(null)} should return {@code false}. + *
  • It is reflexive: for any non-null reference value {@code x}, {@code x.equals(x)} + * should return {@code true}. + *
  • It is symmetric: for any non-null reference values {@code x} and {@code y}, {@code + * x.equals(y)} should return {@code true} if and only if {@code y.equals(x)} returns {@code + * true}. + *
  • It is transitive: for any non-null reference values {@code x}, {@code y}, and + * {@code z}, if {@code x.equals(y)} returns {@code true} and {@code y.equals(z)} returns + * {@code true}, then {@code x.equals(z)} should return {@code true}. + *
  • It is consistent: for any non-null reference values {@code x} and {@code y}, + * multiple invocations of {@code x.equals(y)} consistently return {@code true} or + * consistently return {@code false}, provided no information used in {@code equals} + * comparisons on the objects is modified. + *
  • For any non-null reference value {@code x}, {@code x.equals(null)} should return {@code + * false}. *
- *

- * The {@code equals} method for class {@code Object} implements - * the most discriminating possible equivalence relation on objects; - * that is, for any non-null reference values {@code x} and - * {@code y}, this method returns {@code true} if and only - * if {@code x} and {@code y} refer to the same object - * ({@code x == y} has the value {@code true}). - *

- * Note that it is generally necessary to override the {@code hashCode} - * method whenever this method is overridden, so as to maintain the - * general contract for the {@code hashCode} method, which states - * that equal objects must have equal hash codes. + * + *

The {@code equals} method for class {@code Object} implements the most discriminating + * possible equivalence relation on objects; that is, for any non-null reference values {@code x} + * and {@code y}, this method returns {@code true} if and only if {@code x} and {@code y} refer to + * the same object ({@code x == y} has the value {@code true}). + * + *

Note that it is generally necessary to override the {@code hashCode} method whenever this + * method is overridden, so as to maintain the general contract for the {@code hashCode} method, + * which states that equal objects must have equal hash codes. * * @param obj the reference object with which to compare. * @return {@code true} if this object is the same as the obj argument; {@code false} otherwise. @@ -4465,9 +4736,8 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial } if (obj instanceof MultiLayerNetwork) { MultiLayerNetwork network = (MultiLayerNetwork) obj; - boolean paramsEquals = network.params().equals(params()); - boolean confEquals = getNetConfiguration().equals( - network.getNetConfiguration()); + boolean paramsEquals = network.getModelParams().equals(getModelParams()); + boolean confEquals = getNetConfiguration().equals(network.getNetConfiguration()); boolean updaterEquals = getUpdater().equals(network.getUpdater()); return paramsEquals && confEquals && updaterEquals; } @@ -4481,15 +4751,17 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException { val mln = ModelSerializer.restoreMultiLayerNetwork(ois, true); - this.setNetConfiguration( mln.getNetConfiguration().clone() ); + this.setNetConfiguration(mln.getNetConfiguration().clone()); this.init(); this.flattenedParams.assign(mln.flattenedParams); - int numWorkingMem = 2 * (getNetConfiguration().getFlattenedLayerConfigurations().size() - + getNetConfiguration().getInputPreProcessors().size()); + int numWorkingMem = + 2 + * (getNetConfiguration().getFlattenedLayerConfigurations().size() + + getNetConfiguration().getInputPreProcessors().size()); WS_LAYER_WORKING_MEM_CONFIG = getLayerWorkingMemWSConfig(numWorkingMem); - WS_LAYER_ACT_X_CONFIG = getLayerActivationWSConfig( - getNetConfiguration().getFlattenedLayerConfigurations().size()); + WS_LAYER_ACT_X_CONFIG = + getLayerActivationWSConfig(getNetConfiguration().getFlattenedLayerConfigurations().size()); if (mln.getUpdater() != null && mln.getUpdater(false).getStateViewArray() != null) { this.getUpdater(true).getStateViewArray().assign(mln.getUpdater(false).getStateViewArray()); @@ -4503,7 +4775,7 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial */ @Override public void close() { - //Close the INDArray and dealloc + // Close the INDArray and dealloc if (flattenedParams.closeable()) { flattenedParams.close(); } @@ -4533,5 +4805,4 @@ public class MultiLayerNetwork extends ArtificialNeuralNetwork implements Serial public String toString() { return getNetConfiguration().toString(); } - } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java index 3de33be57..32b05a04c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java @@ -22,9 +22,7 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.api.AbstractParamInitializer; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.PReLULayer; import org.deeplearning4j.nn.weights.IWeightInit; @@ -99,7 +97,7 @@ public class PReLUParamInitializer extends AbstractParamInitializer { @Override public Map init(LayerConfiguration conf, INDArray paramsView, boolean initializeParams) { - if (!(conf instanceof BaseLayer)) + if (!(conf instanceof BaseLayerConfiguration)) throw new IllegalArgumentException("unsupported layer type: " + conf.getClass().getName()); Map params = Collections.synchronizedMap(new LinkedHashMap()); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/WrapperLayerParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/WrapperLayerParamInitializer.java index 7cb7059c8..5744e70ad 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/WrapperLayerParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/WrapperLayerParamInitializer.java @@ -21,10 +21,8 @@ package org.deeplearning4j.nn.params; import org.deeplearning4j.nn.api.AbstractParamInitializer; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; -import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayer; +import org.deeplearning4j.nn.conf.layers.wrapper.BaseWrapperLayerConfiguration; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.List; @@ -99,8 +97,8 @@ public class WrapperLayerParamInitializer extends AbstractParamInitializer { } private LayerConfiguration underlying(LayerConfiguration layer){ - while (layer instanceof BaseWrapperLayer) { - layer = ((BaseWrapperLayer)layer).getUnderlying(); + while (layer instanceof BaseWrapperLayerConfiguration) { + layer = ((BaseWrapperLayerConfiguration)layer).getUnderlying(); } return layer; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java index 73a31b96b..5d68bd890 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java @@ -42,7 +42,7 @@ import org.deeplearning4j.nn.conf.WorkspaceMode; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.dropout.Dropout; import org.deeplearning4j.nn.conf.dropout.IDropout; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerValidation; @@ -146,7 +146,7 @@ public class FineTuneConfiguration { if (layerConfiguration != null) { //As per NeuralNetConfiguration.configureLayer and LayerValidation.configureBaseLayer: only copy dropout to base layers // this excludes things like subsampling and activation layers - if (dropout != null && layerConfiguration instanceof BaseLayer) { + if (dropout != null && layerConfiguration instanceof BaseLayerConfiguration) { IDropout d = dropout.orElse(null); if (d != null) { d = d.clone(); //Clone to avoid shared state between layers @@ -158,8 +158,8 @@ public class FineTuneConfiguration { } } - if (layerConfiguration != null && layerConfiguration instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) layerConfiguration; + if (layerConfiguration != null && layerConfiguration instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bl = (BaseLayerConfiguration) layerConfiguration; if (activationFn != null) { bl.setActivationFn(activationFn); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java index 8cc50854b..663420f0a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java @@ -322,7 +322,7 @@ public class TransferLearning { if (numParams > 0) { params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); org.deeplearning4j.nn.api.Layer someLayer = layerConf.instantiate(layerConf.getNetConfiguration(), null, 0, params, true, dataType); - appendParams.add(someLayer.params()); + appendParams.add(someLayer.getParams()); appendConfs.add(someLayer.getLayerConfiguration()); } else { appendConfs.add(layerConf); @@ -400,9 +400,9 @@ public class TransferLearning { for (int i = 0; i < origModel.getnLayers(); i++) { if (origModel.getLayer(i).numParams() > 0) { //dup only if params are there - editedParams.add(origModel.getLayer(i).params().dup()); + editedParams.add(origModel.getLayer(i).getParams().dup()); } else { - editedParams.add(origModel.getLayer(i).params()); + editedParams.add(origModel.getLayer(i).getParams()); } } //apply changes to nout/nin if any in sorted order and save to editedParams @@ -467,7 +467,7 @@ public class TransferLearning { long numParams = layerImpl.initializer().numParams(layerConf); INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf.getNetConfiguration(), null, 0, params, true, dataType); - editedParams.set(layerNum, someLayer.params()); + editedParams.set(layerNum, someLayer.getParams()); } @@ -485,7 +485,7 @@ public class TransferLearning { long numParams = layerImpl.initializer().numParams(layerConf); INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); org.deeplearning4j.nn.api.Layer someLayer = layerImpl.instantiate(layerConf.getNetConfiguration(), null, 0, params, true, dataType); - editedParams.set(layerNum, someLayer.params()); + editedParams.set(layerNum, someLayer.getParams()); if (layerNum + 1 < editedConfs.size()) { layerConf = editedConfs.get(layerNum + 1); @@ -498,7 +498,7 @@ public class TransferLearning { if (numParams > 0) { params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); someLayer = layerImpl.instantiate(layerConf.getNetConfiguration(), null, 0, params, true, dataType); - editedParams.set(layerNum + 1, someLayer.params()); + editedParams.set(layerNum + 1, someLayer.getParams()); } } } @@ -979,11 +979,11 @@ public class TransferLearning { continue; //some layers have no params if (editedVertices.contains(layerName)) continue; //keep the changed params - INDArray origParams = origGraph.getLayer(layerName).params(); + INDArray origParams = origGraph.getLayer(layerName).getParams(); layer.setParams(origParams.dup()); //copy over origGraph params } } else { - newGraph.setParams(origGraph.params()); + newGraph.setParams(origGraph.getModelParams()); } //Freeze layers as necessary. Note: we can't simply say "everything before frozen layer X needs to be frozen diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java index effc48ad4..bd6cc18a3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearningHelper.java @@ -295,7 +295,7 @@ public class TransferLearningHelper { unFrozenSubsetMLN.init(); //copy over params for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) { - unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).setParams(origMLN.getLayer(i).params()); + unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).setParams(origMLN.getLayer(i).getParams()); } //unFrozenSubsetMLN.setListeners(origMLN.getListeners()); } @@ -413,7 +413,7 @@ public class TransferLearningHelper { for (GraphVertex aVertex : unFrozenSubsetGraph.getVertices()) { if (!aVertex.hasLayer()) continue; - origGraph.getVertex(aVertex.getVertexName()).getLayer().setParams(aVertex.getLayer().params()); + origGraph.getVertex(aVertex.getVertexName()).getLayer().setParams(aVertex.getLayer().getParams()); } } @@ -421,13 +421,13 @@ public class TransferLearningHelper { for (GraphVertex aVertex : unFrozenSubsetGraph.getVertices()) { if (!aVertex.hasLayer()) continue; - aVertex.getLayer().setParams(origGraph.getLayer(aVertex.getVertexName()).params()); + aVertex.getLayer().setParams(origGraph.getLayer(aVertex.getVertexName()).getParams()); } } private void copyParamsFromSubsetMLNToOrig() { for (int i = frozenInputLayer + 1; i < origMLN.getnLayers(); i++) { - origMLN.getLayer(i).setParams(unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).params()); + origMLN.getLayer(i).setParams(unFrozenSubsetMLN.getLayer(i - frozenInputLayer - 1).getParams()); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java index dfcef372c..cec9da44a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.updater; import lombok.Getter; import net.brutex.ai.dnn.api.IModel; -import org.deeplearning4j.nn.api.Trainable; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.gradient.DefaultGradient; @@ -47,7 +47,7 @@ import java.util.*; public abstract class BaseMultiLayerUpdater implements Updater { protected final T network; - protected Map layersByName; + protected Map layersByName; protected final List updaterBlocks; protected INDArray updaterStateViewArray; protected boolean initializedMinibatchDivision; @@ -64,19 +64,19 @@ public abstract class BaseMultiLayerUpdater implements Updater */ public BaseMultiLayerUpdater(T network, INDArray updaterState) { this.network = network; - Trainable[] layers = getOrderedLayers(); //May also include vertices + ITrainableLayer[] layers = getOrderedLayers(); //May also include vertices int updaterStateSize = 0; //Iterate through layers, and variables for each layer. //While the updater configuration is the same: combine into one op, rather than doing a lot of smaller // (yet identical) ops. - Trainable lastLayer = null; + ITrainableLayer lastLayer = null; String lastVariable = null; UpdaterBlock currentBlock = null; updaterBlocks = new ArrayList<>(); - INDArray paramsView = network.params(); + INDArray paramsView = network.getModelParams(); INDArray gradientView = getFlattenedGradientsView(); int paramsViewSoFar = 0; int currentUpdaterOffset = 0; @@ -87,8 +87,8 @@ public abstract class BaseMultiLayerUpdater implements Updater for (int j = 0; j < variables.size(); j++) { String var = variables.get(j); long paramSizeThisVariable = layerParamTable.get(var).length(); - IUpdater u = layers[i].getConfig().getUpdaterByParam(var); - Preconditions.checkNotNull(u, "Updater for parameter %s, layer \"%s\" was null", var, layers[i].getConfig().getLayerName()); + IUpdater u = layers[i].getTrainingConfig().getUpdaterByParam(var); + Preconditions.checkNotNull(u, "Updater for parameter %s, layer \"%s\" was null", var, layers[i].getTrainingConfig().getLayerName()); int updaterStateSizeThisVariable = (int) u.stateSize(paramSizeThisVariable); INDArray gradientViewSubset = null; @@ -145,7 +145,7 @@ public abstract class BaseMultiLayerUpdater implements Updater updaterRequiresInit = false; } else if (updaterStateSize > 0) { //May be 0 if all SGD or NONE updaters, for example - updaterStateViewArray = Nd4j.createUninitialized(network.params().dataType(), new long[] {1, updaterStateSize}, Nd4j.order()); + updaterStateViewArray = Nd4j.createUninitialized(network.getModelParams().dataType(), new long[] {1, updaterStateSize}, Nd4j.order()); updaterRequiresInit = true; } @@ -183,7 +183,7 @@ public abstract class BaseMultiLayerUpdater implements Updater * @return Array of layers, in the correct order (i.e., same order as the parameter/gradient/updater flattening * order - input to output for MultiLayerNetwork, or topological order for ComputationGraph) */ - protected abstract Trainable[] getOrderedLayers(); + protected abstract ITrainableLayer[] getOrderedLayers(); /** * @return The flattened gradient view array for the model @@ -220,7 +220,7 @@ public abstract class BaseMultiLayerUpdater implements Updater } @Override - public void setStateViewArray(Trainable layer, INDArray viewArray, boolean initialize) { + public void setStateViewArray(ITrainableLayer layer, INDArray viewArray, boolean initialize) { this.setStateViewArray(viewArray); } @@ -241,7 +241,7 @@ public abstract class BaseMultiLayerUpdater implements Updater } @Override - public void update(Trainable layer, Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr) { + public void update(ITrainableLayer layer, Gradient gradient, int iteration, int epoch, int batchSize, LayerWorkspaceMgr workspaceMgr) { update(gradient, iteration, epoch, batchSize, workspaceMgr); } @@ -266,9 +266,9 @@ public abstract class BaseMultiLayerUpdater implements Updater //Split up the gradients on a per-layer basis, for pre-apply Map layerGradients = new HashMap<>(); - Trainable[] layers = getOrderedLayers(); + ITrainableLayer[] layers = getOrderedLayers(); if (layers.length == 1 && isSingleLayerUpdater()) { - layerGradients.put(layers[0].getConfig().getLayerName(), gradient); + layerGradients.put(layers[0].getTrainingConfig().getLayerName(), gradient); } else { for (Map.Entry gradientPair : gradient.gradientForVariable().entrySet()) { String key = gradientPair.getKey(); @@ -296,7 +296,7 @@ public abstract class BaseMultiLayerUpdater implements Updater //PRE apply (gradient clipping, etc): done on a per-layer basis for (Map.Entry entry : layerGradients.entrySet()) { String layerName = entry.getKey(); - Trainable layer = layersByName.get(layerName); + ITrainableLayer layer = layersByName.get(layerName); preApply(layer, layerGradients.get(layerName), iteration); } @@ -350,7 +350,7 @@ public abstract class BaseMultiLayerUpdater implements Updater long paramsSoFar = 0; long currentStart = 0; long currentEnd = 0; - for(Trainable t : getOrderedLayers()){ + for(ITrainableLayer t : getOrderedLayers()){ Set layerParams = t.getParamTable(false).keySet(); Map paramTable = t.getParamTable(false); for(String s : layerParams) { @@ -389,18 +389,18 @@ public abstract class BaseMultiLayerUpdater implements Updater * @param gradient Gradient to update * @param iteration The current iteration (i.e., number of parameter updates so far) */ - public void preApply(Trainable layer, Gradient gradient, int iteration) { + public void preApply(ITrainableLayer layer, Gradient gradient, int iteration) { - if (layer.getConfig() == null || layer.numParams() == 0) { + if (layer.getTrainingConfig() == null || layer.numParams() == 0) { //ILayer does not have parameters -> no gradient return; } - GradientNormalization normalization = layer.getConfig().getGradientNormalization(); + GradientNormalization normalization = layer.getTrainingConfig().getGradientNormalization(); if (normalization == null || normalization == GradientNormalization.None) return; //no op - final double threshold = layer.getConfig().getGradientNormalizationThreshold(); + final double threshold = layer.getTrainingConfig().getGradientNormalizationThreshold(); INDArray layerGradientView = layer.getGradientsViewArray(); switch (normalization) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java index f27e7dcfa..3dafbb3f9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/LayerUpdater.java @@ -22,7 +22,7 @@ package org.deeplearning4j.nn.updater; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Trainable; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,8 +46,8 @@ public class LayerUpdater extends BaseMultiLayerUpdater { } @Override - protected Trainable[] getOrderedLayers() { - return new Trainable[] {network}; + protected ITrainableLayer[] getOrderedLayers() { + return new ITrainableLayer[] {network}; } @Override @@ -57,7 +57,7 @@ public class LayerUpdater extends BaseMultiLayerUpdater { @Override protected INDArray getParams() { - return network.params(); + return network.getParams(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/MultiLayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/MultiLayerUpdater.java index f43aa85d2..1027f5003 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/MultiLayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/MultiLayerUpdater.java @@ -22,8 +22,8 @@ package org.deeplearning4j.nn.updater; import lombok.Getter; import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Trainable; import org.deeplearning4j.nn.api.Updater; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.nd4j.linalg.api.ndarray.INDArray; @@ -49,9 +49,9 @@ public class MultiLayerUpdater extends BaseMultiLayerUpdater } @Override - protected Trainable[] getOrderedLayers() { + protected ITrainableLayer[] getOrderedLayers() { Layer[] layers = network.getLayers(); - Trainable[] t = new Trainable[layers.length]; + ITrainableLayer[] t = new ITrainableLayer[layers.length]; System.arraycopy(layers, 0, t, 0, layers.length); return t; } @@ -66,7 +66,7 @@ public class MultiLayerUpdater extends BaseMultiLayerUpdater @Override protected INDArray getParams() { - return network.params(); + return network.getModelParams(); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterBlock.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterBlock.java index 3366a48f9..7b496468f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterBlock.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterBlock.java @@ -22,18 +22,11 @@ package org.deeplearning4j.nn.updater; import lombok.AllArgsConstructor; import lombok.Data; -import lombok.val; -import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.api.Trainable; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.BaseLayer; -import org.deeplearning4j.nn.layers.FrozenLayer; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.regularization.Regularization; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.ArrayList; import java.util.List; @@ -56,7 +49,7 @@ public class UpdaterBlock { @AllArgsConstructor @Data public static class ParamState { - private final Trainable layer; + private final ITrainableLayer layer; private final String paramName; private final int paramOffsetStart; private final int paramOffsetEnd; @@ -89,7 +82,7 @@ public class UpdaterBlock { if (gradientUpdater == null) { ParamState varState = layersAndVariablesInBlock.get(0); String varName = varState.getParamName(); - gradientUpdater = varState.getLayer().getConfig().getUpdaterByParam(varName).instantiate(updaterView, + gradientUpdater = varState.getLayer().getTrainingConfig().getUpdaterByParam(varName).instantiate(updaterView, updaterViewRequiresInitialization); //UpdaterUtils.getGradientUpdater(varState.getLayer(), varState.getParamName()); } } @@ -97,7 +90,7 @@ public class UpdaterBlock { public boolean isPretrainUpdaterBlock() { //All in block should be the same layer, and all be pretrain params ParamState vs = layersAndVariablesInBlock.get(0); - return vs.getLayer().getConfig().isPretrainParam(vs.getParamName()); + return vs.getLayer().getTrainingConfig().isPretrainParam(vs.getParamName()); } public boolean skipDueToPretrainConfig( boolean isLayerUpdater) { @@ -148,7 +141,7 @@ public class UpdaterBlock { //Second: apply learning rate policy. Note that by definition we have the same LR policy for every single // variable in the block - Trainable l0 = layersAndVariablesInBlock.get(0).getLayer(); + ITrainableLayer l0 = layersAndVariablesInBlock.get(0).getLayer(); if (l0.numParams() == 0) { //No params for this layer return; @@ -194,10 +187,10 @@ public class UpdaterBlock { * @param gradientView Gradient view array for the layer + param * @param paramsView Parameter view array for the layer + param */ - protected void applyRegularization(Regularization.ApplyStep step, Trainable layer, String paramName, INDArray gradientView, INDArray paramsView, int iter, int epoch, double lr) { + protected void applyRegularization(Regularization.ApplyStep step, ITrainableLayer layer, String paramName, INDArray gradientView, INDArray paramsView, int iter, int epoch, double lr) { //TODO: do this for multiple contiguous params/layers (fewer, larger ops) - List l = layer.getConfig().getRegularizationByParam(paramName); + List l = layer.getTrainingConfig().getRegularizationByParam(paramName); if(l != null && !l.isEmpty()){ for(Regularization r : l){ if(r.applyStep() == step){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterCreator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterCreator.java index 14850eafb..11573daa0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterCreator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterCreator.java @@ -28,13 +28,18 @@ import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater; /** * - * - * @author Adam Gibson + * Create an {@link org.deeplearning4j.nn.api.Updater} based on the provided {@link IModel}. */ public class UpdaterCreator { private UpdaterCreator() {} + /** + * Create an Updater for a given model type. This is either {@link ComputationGraphUpdater} or + * {@link MultiLayerUpdater} or a {@link LayerUpdater}. + * @param layer + * @return + */ public static org.deeplearning4j.nn.api.Updater getUpdater(IModel layer) { if (layer instanceof MultiLayerNetwork) { return new MultiLayerUpdater((MultiLayerNetwork) layer); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterUtils.java index 73bd5410e..14a2a54de 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/UpdaterUtils.java @@ -20,16 +20,16 @@ package org.deeplearning4j.nn.updater; -import org.deeplearning4j.nn.api.Trainable; -import org.deeplearning4j.nn.api.TrainingConfig; +import org.deeplearning4j.nn.api.ITrainableLayer; +import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.nd4j.linalg.learning.config.IUpdater; public class UpdaterUtils { - public static boolean updaterConfigurationsEquals(Trainable layer1, String param1, Trainable layer2, String param2) { - TrainingConfig l1 = layer1.getConfig(); - TrainingConfig l2 = layer2.getConfig(); + public static boolean updaterConfigurationsEquals(ITrainableLayer layer1, String param1, ITrainableLayer layer2, String param2) { + ITraininableLayerConfiguration l1 = layer1.getTrainingConfig(); + ITraininableLayerConfiguration l2 = layer2.getTrainingConfig(); IUpdater u1 = l1.getUpdaterByParam(param1); IUpdater u2 = l2.getUpdaterByParam(param2); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.java index 1c39f52e1..952258bcf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/graph/ComputationGraphUpdater.java @@ -20,7 +20,7 @@ package org.deeplearning4j.nn.updater.graph; -import org.deeplearning4j.nn.api.Trainable; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; import org.deeplearning4j.nn.updater.BaseMultiLayerUpdater; @@ -31,7 +31,7 @@ import java.util.HashMap; public class ComputationGraphUpdater extends BaseMultiLayerUpdater { - protected Trainable[] orderedLayers; + protected ITrainableLayer[] orderedLayers; public ComputationGraphUpdater(ComputationGraph graph) { this(graph, null); @@ -41,14 +41,14 @@ public class ComputationGraphUpdater extends BaseMultiLayerUpdater(); - Trainable[] layers = getOrderedLayers(); - for (Trainable l : layers) { - layersByName.put(l.getConfig().getLayerName(), l); + ITrainableLayer[] layers = getOrderedLayers(); + for (ITrainableLayer l : layers) { + layersByName.put(l.getTrainingConfig().getLayerName(), l); } } @Override - protected Trainable[] getOrderedLayers() { + protected ITrainableLayer[] getOrderedLayers() { if (orderedLayers != null) { return orderedLayers; } @@ -57,7 +57,7 @@ public class ComputationGraphUpdater extends BaseMultiLayerUpdater pair) { INDArray gradient = pair.getFirst().gradient(conf.netWideVariables()); - INDArray params = model.params().dup(); //Need dup here: params returns an array that isn't a copy (hence changes to this are problematic for line search methods) + INDArray params = model.getModelParams().dup(); //Need dup here: params returns an array that isn't a copy (hence changes to this are problematic for line search methods) searchState.put(GRADIENT_KEY, gradient); searchState.put(SCORE_KEY, pair.getSecond()); searchState.put(PARAMS_KEY, params); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java index 3a8fa9bdc..80a94c6e6 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/LBFGS.java @@ -71,7 +71,7 @@ public class LBFGS extends BaseOptimizer { @Override public void postStep(INDArray gradient) { INDArray previousParameters = (INDArray) searchState.get("oldparams"); - INDArray parameters = model.params(); + INDArray parameters = model.getModelParams(); INDArray previousGradient = (INDArray) searchState.get(GRADIENT_KEY); LinkedList rho = (LinkedList) searchState.get("rho"); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java index ee7070f01..e0de12fe9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/StochasticGradientDescent.java @@ -54,7 +54,7 @@ public class StochasticGradientDescent extends BaseOptimizer { log.info("Applying external updates before FF..."); // we'll just fire off params update process - accumulator.applyUpdate(stepFunction, model.params(), Nd4j.createUninitialized(model.params().shape(), model.params().ordering()), false); + accumulator.applyUpdate(stepFunction, model.getModelParams(), Nd4j.createUninitialized(model.getModelParams().shape(), model.getModelParams().ordering()), false); } } @@ -62,7 +62,7 @@ public class StochasticGradientDescent extends BaseOptimizer { Gradient gradient = pair.getFirst(); - INDArray params = model.params(); + INDArray params = model.getModelParams(); // if optimizer has GradientsAccumulator defined - go for it if (accumulator != null) { @@ -87,7 +87,7 @@ public class StochasticGradientDescent extends BaseOptimizer { // if there's no update available - just go on then } else { - // if accumulator isn't used - we just to for direct updates application + // if accumulator isn't used - we just go for direct updates application stepFunction.step(params, gradient.gradient()); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java index 7684caa6a..16e8a97e7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java @@ -172,7 +172,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist public static long getOptimalBufferSize(IModel model, int numWorkers, int queueSize) { - return getOptimalBufferSize(model.params().length(), numWorkers, queueSize); + return getOptimalBufferSize(model.getModelParams().length(), numWorkers, queueSize); } @Override diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java index 56f7d3b7f..b2e10ece5 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/CrashReportingUtil.java @@ -204,7 +204,7 @@ public class CrashReportingUtil { StringBuilder sb = genericMemoryStatus(); int bytesPerElement; - switch (isMLN ? mln.params().dataType() : cg.params().dataType()){ + switch (isMLN ? mln.getModelParams().dataType() : cg.getModelParams().dataType()){ case DOUBLE: bytesPerElement = 8; break; @@ -260,7 +260,7 @@ public class CrashReportingUtil { } long sumMem = 0; - long nParams = net.params().length(); + long nParams = net.getModelParams().length(); sb.append("\n----- Network Information -----\n") .append(f("Network # Parameters", nParams)) .append(fBytes("Parameter Memory", bytesPerElement * nParams)); @@ -334,9 +334,9 @@ public class CrashReportingUtil { //Listener info: Collection listeners; if(isMLN){ - listeners = mln.getListeners(); + listeners = mln.getTrainingListeners(); } else { - listeners = cg.getListeners(); + listeners = cg.getTrainingListeners(); } sb.append("\n----- Network Training Listeners -----\n"); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java index e763d30bf..8649bcd19 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java @@ -152,10 +152,10 @@ public class ModelSerializer { ZipEntry coefficients = new ZipEntry(COEFFICIENTS_BIN); zipfile.putNextEntry(coefficients); DataOutputStream dos = new DataOutputStream(new BufferedOutputStream(zipfile)); - INDArray params = model.params(); + INDArray params = model.getModelParams(); if(params != null) { try { - Nd4j.write(model.params(), dos); + Nd4j.write(model.getModelParams(), dos); } finally { dos.flush(); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java index 900a516cd..f19dd8a47 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/NetworkUtils.java @@ -22,11 +22,11 @@ package org.deeplearning4j.util; import lombok.extern.slf4j.Slf4j; import net.brutex.ai.dnn.api.IModel; -import org.deeplearning4j.nn.api.Trainable; +import org.deeplearning4j.nn.api.ITrainableLayer; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.InputPreProcessor; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.graph.vertex.GraphVertex; @@ -86,7 +86,7 @@ public class NetworkUtils { ComputationGraph cg = new ComputationGraph(conf); cg.init(); - cg.setParams(net.params()); + cg.setParams(net.getModelParams()); //Also copy across updater state: INDArray updaterState = net.getUpdater().getStateViewArray(); @@ -123,8 +123,8 @@ public class NetworkUtils { private static void setLearningRate(MultiLayerNetwork net, int layerNumber, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) { LayerConfiguration l = net.getLayer(layerNumber).getLayerConfiguration(); - if (l instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) l; + if (l instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bl = (BaseLayerConfiguration) l; IUpdater u = bl.getIUpdater(); if (u != null && u.hasLearningRate()) { if (newLrSchedule != null) { @@ -205,8 +205,8 @@ public class NetworkUtils { LayerConfiguration l = net.getLayer(layerNumber).getLayerConfiguration(); int iter = net.getIterationCount(); int epoch = net.getEpochCount(); - if (l instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) l; + if (l instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bl = (BaseLayerConfiguration) l; IUpdater u = bl.getIUpdater(); if (u != null && u.hasLearningRate()) { double d = u.getLearningRate(iter, epoch); @@ -245,8 +245,8 @@ public class NetworkUtils { private static void setLearningRate(ComputationGraph net, String layerName, double newLr, ISchedule newLrSchedule, boolean refreshUpdater) { LayerConfiguration l = net.getLayer(layerName).getLayerConfiguration(); - if (l instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) l; + if (l instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bl = (BaseLayerConfiguration) l; IUpdater u = bl.getIUpdater(); if (u != null && u.hasLearningRate()) { if (newLrSchedule != null) { @@ -327,8 +327,8 @@ public class NetworkUtils { LayerConfiguration l = net.getLayer(layerName).getLayerConfiguration(); int iter = net.getComputationGraphConfiguration().getIterationCount(); int epoch = net.getComputationGraphConfiguration().getEpochCount(); - if (l instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) l; + if (l instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bl = (BaseLayerConfiguration) l; IUpdater u = bl.getIUpdater(); if (u != null && u.hasLearningRate()) { double d = u.getLearningRate(iter, epoch); @@ -499,7 +499,7 @@ public class NetworkUtils { } - private static int getId(Trainable trainable){ + private static int getId(ITrainableLayer trainable){ if(trainable instanceof GraphVertex){ GraphVertex gv = (GraphVertex)trainable; return gv.getVertexIndex(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java index 76f06b556..1829fbd40 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/util/OutputLayerUtil.java @@ -20,6 +20,7 @@ package org.deeplearning4j.util; +import lombok.NonNull; import org.deeplearning4j.exception.DL4JInvalidConfigException; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.conf.layers.objdetect.Yolo2OutputLayer; @@ -148,7 +149,7 @@ public class OutputLayerUtil { return lf instanceof LossMCXENT || lf instanceof LossBinaryXENT; } - public static boolean activationExceedsZeroOneRange(IActivation activation, boolean isLossLayer){ + public static boolean activationExceedsZeroOneRange(@NonNull IActivation activation, boolean isLossLayer){ if(OUTSIDE_ZERO_ONE_RANGE.contains(activation.getClass())){ //Note: we're intentionally excluding identity here, for situations like dense(softmax) -> loss(identity) @@ -174,8 +175,8 @@ public class OutputLayerUtil { //Check that the activation function provides probabilities. This can't catch everything, but should catch a few // of the common mistakes users make - if(outputLayer instanceof BaseLayer){ - BaseLayer bl = (BaseLayer)outputLayer; + if(outputLayer instanceof BaseLayerConfiguration){ + BaseLayerConfiguration bl = (BaseLayerConfiguration)outputLayer; boolean isOutputLayer = outputLayer instanceof OutputLayer || outputLayer instanceof RnnOutputLayer || outputLayer instanceof CenterLossOutputLayer; if(activationExceedsZeroOneRange(bl.getActivationFn(), !isOutputLayer)){ diff --git a/cavis-dnn/cavis-dnn-nn/src/main/resources/simplelogger.properties b/cavis-dnn/cavis-dnn-nn/src/main/resources/simplelogger.properties index 93090cbc4..51c081db4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/resources/simplelogger.properties +++ b/cavis-dnn/cavis-dnn-nn/src/main/resources/simplelogger.properties @@ -19,4 +19,7 @@ # # -org.slf4j.simpleLogger.defaultLogLevel = trace \ No newline at end of file +org.slf4j.simpleLogger.defaultLogLevel = debug + +org.slf4j.simpleLogger.log.org.deeplearning4j.optimize.listeners = info +org.slf4j.simplelogger.log.org.nd4j.linalg.dataset = info \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java index 9ca79badc..0b7ce4627 100644 --- a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java +++ b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java @@ -32,21 +32,27 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; +import org.deeplearning4j.nn.conf.layers.OutputLayer; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInitXavier; +import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.Pair; +import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationLReLU; +import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.learning.config.Adam; +import org.nd4j.linalg.lossfunctions.LossFunctions; class dnnTest { @Test void testFFLayer() { - int numFeatures = 128; - int batchSize = 10; - int numRows = 1000; + int numFeatures = 6; + int batchSize = 5; + int numRows = 100; AtomicInteger cnt = new AtomicInteger(0); FloatsDataSetIterator iterator = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize); @@ -55,40 +61,52 @@ class dnnTest { NeuralNetConfiguration conf = NeuralNetConfiguration.builder().build(); /** - * NeuralNetConfiguration confxx = NeuralNetConfiguration.builder() - * .seed(42) - * .updater(UPDATER) - * .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) - * .gradientNormalizationThreshold(GRADIENT_THRESHOLD) - * .weightInit(WeightInit.XAVIER) - * .activation(Activation.IDENTITY) - * .list(genLayers()) - * .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) - * // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, X_DIM, CHANNELS)) - * .build(); + * NeuralNetConfiguration confxx = NeuralNetConfiguration.builder() .seed(42) .updater(UPDATER) + * .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + * .gradientNormalizationThreshold(GRADIENT_THRESHOLD) .weightInit(WeightInit.XAVIER) + * .activation(Activation.IDENTITY) .list(genLayers()) .inputType(InputType.convolutional(X_DIM, + * Y_DIM, CHANNELS)) // .inputPreProcessor("CNN1", new FeedForwardToCnnPreProcessor(Y_DIM, + * X_DIM, CHANNELS)) .build(); */ /** - * new DenseLayer.Builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), - * new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), - * new DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), - * new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), - * new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), - * new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), - * new DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) + * new + * DenseLayer.Builder().nIn(INPUT).nOut(X_DIM*Y_DIM*CHANNELS).weightInit(WeightInit.NORMAL).build(), + * new ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), new + * DenseLayer.Builder().nIn(X_DIM*Y_DIM*CHANNELS).nOut(X_DIM*Y_DIM).build(), new + * ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), new + * DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM).build(), new + * ActivationLayer.Builder(new ActivationLReLU(0.2)).build(), new + * DenseLayer.Builder().nIn(X_DIM*Y_DIM).nOut(X_DIM*Y_DIM*CHANNELS).activation(Activation.TANH) */ - NN.net() - .seed(42) - .updater( Adam.builder().learningRate(0.0002).beta1(0.5).build() ) - .gradientNormalization( GradientNormalization.RenormalizeL2PerLayer) - .gradientNormalizationThreshold( 100 ) - .weightInitFn( new WeightInitXavier() ) - .activationFn( new ActivationIdentity() ) - .inputType( InputType.convolutional( 28, 28, 1)) - .layer( new DenseLayer.Builder().nIn(10).nOut(20).build() ) - .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build() ) - ; + NeuralNetConfiguration network = + NN.net() + .seed(42) + .updater(Adam.builder().learningRate(0.0002).beta1(0.5).build()) + .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) + .gradientNormalizationThreshold(100) + .weightInitFn(new WeightInitXavier()) + .activationFn(new ActivationSigmoid()) + // .inputType(InputType.convolutional(28, 28, 1)) + .layer(new DenseLayer.Builder().nIn(6).nOut(20).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DenseLayer.Builder().nIn(20).nOut(40).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DenseLayer.Builder().nIn(40).nOut(12).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new DenseLayer.Builder().nIn(12).nOut(8).build()) + .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) + .layer(new OutputLayer.Builder(LossFunctions.LossFunction.SQUARED_LOSS).activation(Activation.SIGMOID).nOut(6).build()) + .build(); + MultiLayerNetwork net = new MultiLayerNetwork(network); + net.addTrainingListeners(new ScoreToChartListener("dnnTest")); + FloatsDataSetIterator dset = new FloatsDataSetIterator(floatIterable(numRows, numFeatures), batchSize); + + for (int i = 0; i < 2000000; i++) { + net.fit(dset); + System.out.println("Score: " + net.getScore()); + } } protected static Iterable> floatIterable(final int totalRows, final int numColumns) { @@ -108,8 +126,8 @@ class dnnTest { float[] features = new float[numColumns]; float[] labels = new float[numColumns]; for (int i = 0; i < numColumns; i++) { - features[i] = (float) i; - labels[i] = RandomUtils.nextFloat(0, 5); + features[i] = RandomUtils.nextFloat(0, 3); + labels[i] = (float) features[i] + 1; } return Pair.makePair(features, labels); } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainer.java index 7af10085b..6808d4145 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper-parameterserver/src/main/java/org/deeplearning4j/parallelism/parameterserver/ParameterServerTrainer.java @@ -58,7 +58,7 @@ public class ParameterServerTrainer extends DefaultTrainer { log.info("Sending parameters"); //send the updated params - parameterServerClient.pushNDArray(getModel().params()); + parameterServerClient.pushNDArray(getModel().getModelParams()); } @Override @@ -77,7 +77,7 @@ public class ParameterServerTrainer extends DefaultTrainer { log.info("About to send params in"); //send the updated params - parameterServerClient.pushNDArray(getModel().params()); + parameterServerClient.pushNDArray(getModel().getModelParams()); log.info("Sent params"); } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java index 73261f155..25a364b36 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/EarlyStoppingParallelTrainer.java @@ -91,16 +91,16 @@ public class EarlyStoppingParallelTrainer implements IEarlySto // adjust UI listeners AveragingTrainingListener trainerListener = new AveragingTrainingListener(this); if (model instanceof MultiLayerNetwork) { - Collection listeners = ((MultiLayerNetwork) model).getListeners(); + Collection listeners = ((MultiLayerNetwork) model).getTrainingListeners(); Collection newListeners = new LinkedList<>(listeners); newListeners.add(trainerListener); - model.setListeners(newListeners.toArray(new TrainingListener[]{})); + model.addTrainingListeners(newListeners.toArray(new TrainingListener[]{})); } else if (model instanceof ComputationGraph) { - Collection listeners = ((ComputationGraph) model).getListeners(); + Collection listeners = ((ComputationGraph) model).getTrainingListeners(); Collection newListeners = new LinkedList<>(listeners); newListeners.add(trainerListener); - model.setListeners(newListeners.toArray(new TrainingListener[]{})); + model.addTrainingListeners(newListeners.toArray(new TrainingListener[]{})); } this.wrapper = new ParallelWrapper.Builder<>(model).workers(workers).prefetchBuffer(prefetchBuffer) @@ -327,7 +327,7 @@ public class EarlyStoppingParallelTrainer implements IEarlySto @Override public void iterationDone(IModel model, int iteration, int epoch) { //Check per-iteration termination conditions - double latestScore = model.score(); + double latestScore = model.getScore(); trainer.setLatestScore(latestScore); for (IterationTerminationCondition c : esConfig.getIterationTerminationConditions()) { if (c.terminate(latestScore)) { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java index 0c1515109..571002280 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/InplaceParallelInference.java @@ -185,7 +185,7 @@ public class InplaceParallelInference extends ParallelInference { isMLN = sourceModel instanceof MultiLayerNetwork; // we clone params only if we're not on the same device - val params = rootDevice ? sourceModel.params() : sourceModel.params().unsafeDuplication(true); + val params = rootDevice ? sourceModel.getModelParams() : sourceModel.getModelParams().unsafeDuplication(true); // and moving it to specified device (only if NOT root if (!rootDevice) diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java index 9d2c76a23..242a9f731 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java @@ -462,7 +462,7 @@ public class ParallelInference { this.replicatedModel.init(); synchronized (locker) { - this.replicatedModel.setParams(protoModel.params().unsafeDuplication(true)); + this.replicatedModel.setParams(protoModel.getModelParams().unsafeDuplication(true)); Nd4j.getExecutioner().commit(); } @@ -476,7 +476,7 @@ public class ParallelInference { this.replicatedModel.init(); synchronized (locker) { - this.replicatedModel.setParams(protoModel.params().unsafeDuplication(true)); + this.replicatedModel.setParams(protoModel.getModelParams().unsafeDuplication(true)); Nd4j.getExecutioner().commit(); } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java index 5a880872d..921b9b49e 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java @@ -345,8 +345,8 @@ public class ParallelWrapper implements AutoCloseable { List params = new ArrayList<>(); for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) { - params.add(zoo[cnt].getModel().params()); - score += zoo[cnt].getModel().score(); + params.add(zoo[cnt].getModel().getModelParams()); + score += zoo[cnt].getModel().getScore(); } Nd4j.averageAndPropagate(null, params); @@ -956,11 +956,11 @@ public class ParallelWrapper implements AutoCloseable { List modelListeners = null; if (model instanceof MultiLayerNetwork) { - modelListeners = new ArrayList<>(((MultiLayerNetwork) model).getListeners()); - model.setListeners(new TrainingListener[]{}); + modelListeners = new ArrayList<>(((MultiLayerNetwork) model).getTrainingListeners()); + model.addTrainingListeners(new TrainingListener[]{}); } else if (model instanceof ComputationGraph) { - modelListeners = new ArrayList<>(((ComputationGraph) model).getListeners()); - model.setListeners(new TrainingListener[]{}); + modelListeners = new ArrayList<>(((ComputationGraph) model).getTrainingListeners()); + model.addTrainingListeners(new TrainingListener[]{}); } if (modelListeners != null && !modelListeners.isEmpty()) { diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.java index 663cb148c..6bafcd4cd 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/factory/SymmetricTrainerContext.java @@ -74,6 +74,6 @@ public class SymmetricTrainerContext implements TrainerContext { @Override public void finalizeTraining(IModel originalModel, IModel... models) { // we CAN avarage here, but for now we'll just push first model params to original model - originalModel.setParams(models[0].params()); + originalModel.setParams(models[0].getModelParams()); } } diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java index 2a1cf4d4e..522b3548a 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/trainer/DefaultTrainer.java @@ -147,7 +147,7 @@ public class DefaultTrainer extends Thread implements Trainer { if (replicatedModel instanceof MultiLayerNetwork) { - replicatedModel.setParams(model.params().unsafeDuplication(true)); + replicatedModel.setParams(model.getModelParams().unsafeDuplication(true)); Updater updater = ((MultiLayerNetwork) model).getUpdater(); INDArray view = updater.getStateViewArray(); @@ -161,7 +161,7 @@ public class DefaultTrainer extends Thread implements Trainer { updater.setStateViewArray((MultiLayerNetwork) replicatedModel, viewD, false); } } else if (replicatedModel instanceof ComputationGraph) { - replicatedModel.setParams(model.params().unsafeDuplication(true)); + replicatedModel.setParams(model.getModelParams().unsafeDuplication(true)); ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater(); INDArray view = updater.getStateViewArray(); @@ -278,7 +278,7 @@ public class DefaultTrainer extends Thread implements Trainer { } configureListeners(uuid, oldListeners, replicatedListeners); - this.replicatedModel.setListeners(replicatedListeners.toArray(new TrainingListener[]{})); + this.replicatedModel.addTrainingListeners(replicatedListeners.toArray(new TrainingListener[]{})); } @Override @@ -305,7 +305,7 @@ public class DefaultTrainer extends Thread implements Trainer { // we replicate original model params & updater state, just in case it's pre-trained model try { modelLock.writeLock().lock(); - replicatedModel.setParams(originalModel.params().unsafeDuplication(true)); + replicatedModel.setParams(originalModel.getModelParams().unsafeDuplication(true)); Updater updaterReplica = ((MultiLayerNetwork) replicatedModel).getUpdater(); Updater updaterOrigina = ((MultiLayerNetwork) originalModel).getUpdater(); @@ -338,7 +338,7 @@ public class DefaultTrainer extends Thread implements Trainer { // we replicate original model params & updater state, just in case it's pre-trained model try { modelLock.writeLock().lock(); - replicatedModel.setParams(originalModel.params().unsafeDuplication(true)); + replicatedModel.setParams(originalModel.getModelParams().unsafeDuplication(true)); ComputationGraphUpdater updaterReplica = ((ComputationGraph) replicatedModel).getUpdater(); ComputationGraphUpdater updaterOrigina = ((ComputationGraph) originalModel).getUpdater(); @@ -389,7 +389,7 @@ public class DefaultTrainer extends Thread implements Trainer { Nd4j.getExecutioner().commit(); // we ensure memory is updated on host side - Nd4j.getAffinityManager().ensureLocation(replicatedModel.params(), + Nd4j.getAffinityManager().ensureLocation(replicatedModel.getModelParams(), AffinityManager.Location.HOST); if (replicatedModel instanceof MultiLayerNetwork) { @@ -427,7 +427,7 @@ public class DefaultTrainer extends Thread implements Trainer { Nd4j.getExecutioner().commit(); // we ensure memory is updated on host side - Nd4j.getAffinityManager().ensureLocation(replicatedModel.params(), + Nd4j.getAffinityManager().ensureLocation(replicatedModel.getModelParams(), AffinityManager.Location.HOST); ComputationGraphUpdater updaterReplica = ((ComputationGraph) replicatedModel).getUpdater(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java index e64e6d06f..d952ebf4a 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/InplaceParallelInferenceTest.java @@ -65,7 +65,7 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { for (val m : models) { assertNotNull(m); - assertEquals(net.params(), m.params()); + assertEquals(net.getModelParams(), m.getModelParams()); } val conf2 = NeuralNetConfiguration.builder() @@ -80,7 +80,7 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { val net2 = new ComputationGraph(conf2); net2.init(); - assertNotEquals(net.params(), net2.params()); + assertNotEquals(net.getModelParams(), net2.getModelParams()); pi.updateModel(net2); @@ -90,7 +90,7 @@ public class InplaceParallelInferenceTest extends BaseDL4JTest { for (val m : models2) { assertNotNull(m); - assertEquals(net2.params(), m.params()); + assertEquals(net2.getModelParams(), m.getModelParams()); } } finally { pi.shutdown(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java index 5f1ac9a7a..cdf908911 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelInferenceTest.java @@ -790,7 +790,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { // model can be null for some of the workers yet, due to race condition if (m != null) { Thread.sleep(500); - assertEquals( net.params(), m.params(), "Failed at model [" + cnt0 + "]"); + assertEquals( net.getModelParams(), m.getModelParams(), "Failed at model [" + cnt0 + "]"); passed = true; } cnt0++; @@ -818,7 +818,7 @@ public class ParallelInferenceTest extends BaseDL4JTest { cnt0 = 0; for (val m:modelsAfter) { assertNotNull( m, "Failed at model [" + cnt0 + "]"); - assertEquals( net2.params(), m.params(), "Failed at model [" + cnt0++ + "]"); + assertEquals( net2.getModelParams(), m.getModelParams(), "Failed at model [" + cnt0++ + "]"); } inf.shutdown(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java index b74262dd2..471cafbfd 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/ParallelWrapperTest.java @@ -26,7 +26,6 @@ import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.eval.Evaluation; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -112,7 +111,7 @@ public class ParallelWrapperTest extends BaseDL4JTest { .build(); log.info("Train model...."); - model.setListeners(new ScoreIterationListener(100)); + model.addTrainingListeners(new ScoreIterationListener(100)); long timeX = System.currentTimeMillis(); // optionally you might want to use MultipleEpochsIterator instead of manually iterating/resetting over your iterator diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java index 799f2dfd7..4389d8f68 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestListeners.java @@ -96,7 +96,7 @@ public class TestListeners extends BaseDL4JTest { model.init(); StatsStorage ss = new InMemoryStatsStorage(); - model.setListeners(new TestListener(), new StatsListener(ss)); + model.addTrainingListeners(new TestListener(), new StatsListener(ss)); testListenersForModel(model, null); @@ -119,7 +119,7 @@ public class TestListeners extends BaseDL4JTest { model.init(); StatsStorage ss = new InMemoryStatsStorage(); - model.setListeners(new TestListener(), new StatsListener(ss)); + model.addTrainingListeners(new TestListener(), new StatsListener(ss)); testListenersForModel(model, null); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java index a3b97339b..a003d99fb 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStopping.java @@ -107,7 +107,7 @@ public class TestParallelEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(50, 600); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -140,7 +140,7 @@ public class TestParallelEarlyStopping extends BaseDL4JTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(1)); + net.addTrainingListeners(new ScoreIterationListener(1)); DataSetIterator irisIter = new IrisDataSetIterator(10, 150); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); diff --git a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java index 9d8fe7c70..66a9b76c4 100644 --- a/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java +++ b/cavis-dnn/cavis-dnn-parallelwrapper/src/test/java/org/deeplearning4j/parallelism/TestParallelEarlyStoppingUI.java @@ -67,7 +67,7 @@ public class TestParallelEarlyStoppingUI extends BaseDL4JTest { // it's important that the UI can report results from parallel training // there's potential for StatsListener to fail if certain properties aren't set in the model StatsStorage statsStorage = new InMemoryStatsStorage(); - net.setListeners(new StatsListener(statsStorage)); + net.addTrainingListeners(new StatsListener(statsStorage)); uiServer.attach(statsStorage); DataSetIterator irisIter = new IrisDataSetIterator(50, 500); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java index 84c7cf753..4849ee142 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/graph/SparkComputationGraph.java @@ -370,7 +370,7 @@ public class SparkComputationGraph extends SparkListenable { */ public double calculateScore(JavaRDD data, boolean average, int minibatchSize) { JavaRDD> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGDataSet(conf.toJson(), - sc.broadcast(network.params()), minibatchSize)); + sc.broadcast(network.getModelParams()), minibatchSize)); //Reduce to a single tuple, with example count + sum of scores Tuple2 countAndSumScores = rdd.reduce(new LongDoubleReduceFunction()); @@ -405,7 +405,7 @@ public class SparkComputationGraph extends SparkListenable { */ public double calculateScoreMultiDataSet(JavaRDD data, boolean average, int minibatchSize) { JavaRDD> rdd = data.mapPartitions(new ScoreFlatMapFunctionCGMultiDataSet(conf.toJson(), - sc.broadcast(network.params()), minibatchSize)); + sc.broadcast(network.getModelParams()), minibatchSize)); //Reduce to a single tuple, with example count + sum of scores Tuple2 countAndSumScores = rdd.reduce(new LongDoubleReduceFunction()); if (average) { @@ -476,7 +476,7 @@ public class SparkComputationGraph extends SparkListenable { */ public JavaDoubleRDD scoreExamplesMultiDataSet(JavaRDD data, boolean includeRegularizationTerms, int batchSize) { - return data.mapPartitionsToDouble(new ScoreExamplesFunction(sc.broadcast(network.params()), + return data.mapPartitionsToDouble(new ScoreExamplesFunction(sc.broadcast(network.getModelParams()), sc.broadcast(conf.toJson()), includeRegularizationTerms, batchSize)); } @@ -527,7 +527,7 @@ public class SparkComputationGraph extends SparkListenable { * @return Network output given the input, by key */ public JavaPairRDD feedForwardWithKey(JavaPairRDD featuresData, int batchSize) { - return featuresData.mapPartitionsToPair(new GraphFeedForwardWithKeyFunction(sc.broadcast(network.params()), + return featuresData.mapPartitionsToPair(new GraphFeedForwardWithKeyFunction(sc.broadcast(network.getModelParams()), sc.broadcast(conf.toJson()), batchSize)); } @@ -554,7 +554,7 @@ public class SparkComputationGraph extends SparkListenable { */ public JavaPairRDD scoreExamplesMultiDataSet(JavaPairRDD data, boolean includeRegularizationTerms, int batchSize) { - return data.mapPartitionsToPair(new ScoreExamplesWithKeyFunction(sc.broadcast(network.params()), + return data.mapPartitionsToPair(new ScoreExamplesWithKeyFunction(sc.broadcast(network.getModelParams()), sc.broadcast(conf.toJson()), includeRegularizationTerms, batchSize)); } @@ -820,7 +820,7 @@ public class SparkComputationGraph extends SparkListenable { */ public T[] doEvaluation(JavaRDD data, int evalNumWorkers, int evalBatchSize, T... emptyEvaluations) { IEvaluateFlatMapFunction evalFn = new IEvaluateFlatMapFunction<>(true, sc.broadcast(conf.toJson()), - SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, emptyEvaluations); + SparkUtils.asByteArrayBroadcast(sc, network.getModelParams()), evalNumWorkers, evalBatchSize, emptyEvaluations); JavaRDD evaluations = data.mapPartitions(evalFn); return evaluations.treeAggregate(null, new IEvaluateAggregateFunction(), new IEvaluateAggregateFunction()); @@ -844,7 +844,7 @@ public class SparkComputationGraph extends SparkListenable { public T[] doEvaluationMDS(JavaRDD data, int evalNumWorkers, int evalBatchSize, T... emptyEvaluations) { Preconditions.checkArgument(evalNumWorkers > 0, "Invalid number of evaulation workers: require at least 1 - got %s", evalNumWorkers); IEvaluateMDSFlatMapFunction evalFn = new IEvaluateMDSFlatMapFunction<>(sc.broadcast(conf.toJson()), - SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, emptyEvaluations); + SparkUtils.asByteArrayBroadcast(sc, network.getModelParams()), evalNumWorkers, evalBatchSize, emptyEvaluations); JavaRDD evaluations = data.mapPartitions(evalFn); return evaluations.treeAggregate(null, new IEvaluateAggregateFunction(), new IEvaluateAggregateFunction()); @@ -906,7 +906,7 @@ public class SparkComputationGraph extends SparkListenable { protected IEvaluation[] doEvaluation(JavaRDD data, int evalNumWorkers, int evalBatchSize, DataSetLoader loader, MultiDataSetLoader mdsLoader, IEvaluation... emptyEvaluations){ IEvaluateMDSPathsFlatMapFunction evalFn = new IEvaluateMDSPathsFlatMapFunction(sc.broadcast(conf.toJson()), - SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, loader, mdsLoader, + SparkUtils.asByteArrayBroadcast(sc, network.getModelParams()), evalNumWorkers, evalBatchSize, loader, mdsLoader, BroadcastHadoopConfigHolder.get(sc), emptyEvaluations); Preconditions.checkArgument(evalNumWorkers > 0, "Invalid number of evaulation workers: require at least 1 - got %s", evalNumWorkers); JavaRDD evaluations = data.mapPartitions(evalFn); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java index 2a0c7b655..890c62c3d 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/multilayer/SparkDl4jMultiLayer.java @@ -430,7 +430,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { * @see MultiLayerNetwork#scoreExamples(DataSet, boolean) */ public JavaDoubleRDD scoreExamples(JavaRDD data, boolean includeRegularizationTerms, int batchSize) { - return data.mapPartitionsToDouble(new ScoreExamplesFunction(sc.broadcast(network.params()), + return data.mapPartitionsToDouble(new ScoreExamplesFunction(sc.broadcast(network.getModelParams()), sc.broadcast(conf.toJson()), includeRegularizationTerms, batchSize)); } @@ -466,7 +466,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { */ public JavaPairRDD scoreExamples(JavaPairRDD data, boolean includeRegularizationTerms, int batchSize) { - return data.mapPartitionsToPair(new ScoreExamplesWithKeyFunction(sc.broadcast(network.params()), + return data.mapPartitionsToPair(new ScoreExamplesWithKeyFunction(sc.broadcast(network.getModelParams()), sc.broadcast(conf.toJson()), includeRegularizationTerms, batchSize)); } @@ -494,7 +494,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { */ public JavaPairRDD feedForwardWithMaskAndKey(JavaPairRDD> featuresDataAndMask, int batchSize) { return featuresDataAndMask - .mapPartitionsToPair(new FeedForwardWithKeyFunction(sc.broadcast(network.params()), + .mapPartitionsToPair(new FeedForwardWithKeyFunction(sc.broadcast(network.getModelParams()), sc.broadcast(conf.toJson()), batchSize)); } @@ -708,7 +708,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { */ public T[] doEvaluation(JavaRDD data, int evalNumWorkers, int evalBatchSize, T... emptyEvaluations) { IEvaluateFlatMapFunction evalFn = new IEvaluateFlatMapFunction<>(false, sc.broadcast(conf.toJson()), - SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, emptyEvaluations); + SparkUtils.asByteArrayBroadcast(sc, network.getModelParams()), evalNumWorkers, evalBatchSize, emptyEvaluations); JavaRDD evaluations = data.mapPartitions(evalFn); return evaluations.treeAggregate(null, new IEvaluateAggregateFunction(), new IEvaluationReduceFunction()); } @@ -771,7 +771,7 @@ public class SparkDl4jMultiLayer extends SparkListenable { protected IEvaluation[] doEvaluation(JavaRDD data, int evalNumWorkers, int evalBatchSize, DataSetLoader loader, MultiDataSetLoader mdsLoader, IEvaluation... emptyEvaluations){ Configuration config = sc.hadoopConfiguration(); IEvaluateMDSPathsFlatMapFunction evalFn = new IEvaluateMDSPathsFlatMapFunction(sc.broadcast(conf.toJson()), - SparkUtils.asByteArrayBroadcast(sc, network.params()), evalNumWorkers, evalBatchSize, loader, mdsLoader, + SparkUtils.asByteArrayBroadcast(sc, network.getModelParams()), evalNumWorkers, evalBatchSize, loader, mdsLoader, BroadcastHadoopConfigHolder.get(sc), emptyEvaluations); Preconditions.checkArgument(evalNumWorkers > 0, "Invalid number of evaulation workers: require at least 1 - got %s", evalNumWorkers); JavaRDD evaluations = data.mapPartitions(evalFn); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java index d3fb3355f..38a7e5bd8 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingMaster.java @@ -276,7 +276,7 @@ public class ParameterAveragingTrainingMaster @Override public ParameterAveragingTrainingWorker getWorkerInstance(SparkDl4jMultiLayer network) { NetBroadcastTuple tuple = new NetBroadcastTuple(network.getNetwork().getNetConfiguration(), - network.getNetwork().params(), network.getNetwork().getUpdater().getStateViewArray()); + network.getNetwork().getModelParams(), network.getNetwork().getUpdater().getStateViewArray()); if (collectTrainingStats) stats.logBroadcastStart(); @@ -293,7 +293,7 @@ public class ParameterAveragingTrainingMaster @Override public ParameterAveragingTrainingWorker getWorkerInstance(SparkComputationGraph graph) { NetBroadcastTuple tuple = new NetBroadcastTuple(graph.getNetwork().getComputationGraphConfiguration(), - graph.getNetwork().params(), graph.getNetwork().getUpdater().getStateViewArray()); + graph.getNetwork().getModelParams(), graph.getNetwork().getUpdater().getStateViewArray()); if (collectTrainingStats) stats.logBroadcastStart(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java index 2322ba5c2..3b3b9f9b3 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/main/java/org/deeplearning4j/spark/impl/paramavg/ParameterAveragingTrainingWorker.java @@ -172,9 +172,9 @@ public class ParameterAveragingTrainingWorker extends BaseTrainingWorker irisData = getIris(); @@ -130,7 +130,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MSE).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -169,7 +169,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); @@ -215,7 +215,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); @@ -252,7 +252,7 @@ public class TestEarlyStoppingSpark extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build()) .build(); MultiLayerNetwork net = new MultiLayerNetwork(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java index f0e1fefb1..1a196af4f 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/TestEarlyStoppingSparkCompGraph.java @@ -78,7 +78,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); @@ -132,7 +132,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MSE).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); EarlyStoppingModelSaver saver = new InMemoryModelSaver<>(); @@ -173,7 +173,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); @@ -221,7 +221,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); @@ -260,7 +260,7 @@ public class TestEarlyStoppingSparkCompGraph extends BaseSparkTest { .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in") .setOutputs("0").build(); ComputationGraph net = new ComputationGraph(conf); - net.setListeners(new ScoreIterationListener(5)); + net.addTrainingListeners(new ScoreIterationListener(5)); JavaRDD irisData = getIris(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java index 3c3e2c46f..97f6a2c89 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/customlayer/layer/CustomLayer.java @@ -56,7 +56,7 @@ public class CustomLayer extends FeedForwardLayer { boolean initializeParams, DataType networkDataType) { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); CustomLayerImpl ret = new CustomLayerImpl(lconf, networkDataType); - ret.setListeners(trainingListeners); + ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java index 109add55d..20727ed03 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/graph/TestSparkComputationGraph.java @@ -40,7 +40,6 @@ import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.graph.ComputationGraph; import org.deeplearning4j.nn.weights.WeightInit; -import org.deeplearning4j.optimize.api.TrainingListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.spark.BaseSparkTest; import org.deeplearning4j.spark.api.RDDTrainingApproach; @@ -272,9 +271,9 @@ public class TestSparkComputationGraph extends BaseSparkTest { sparkNet3.fit(rdd); - INDArray p1 = sparkNet1.getNetwork().params(); - INDArray p2 = sparkNet2.getNetwork().params(); - INDArray p3 = sparkNet3.getNetwork().params(); + INDArray p1 = sparkNet1.getNetwork().getModelParams(); + INDArray p2 = sparkNet2.getNetwork().getModelParams(); + INDArray p3 = sparkNet3.getNetwork().getModelParams(); sparkNet1.getTrainingMaster().deleteTempFiles(sc); sparkNet2.getTrainingMaster().deleteTempFiles(sc); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java index 6b22acca7..7a638199c 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/multilayer/TestMiscFunctions.java @@ -239,7 +239,7 @@ public class TestMiscFunctions extends BaseSparkTest { JavaPairRDD reconstr = rdd.mapPartitionsToPair(new VaeReconstructionProbWithKeyFunction( - sc.broadcast(net.params()), sc.broadcast(mlc.toJson()), true, 16, 128)); + sc.broadcast(net.getModelParams()), sc.broadcast(mlc.toJson()), true, 16, 128)); Map l = reconstr.collectAsMap(); @@ -282,7 +282,7 @@ public class TestMiscFunctions extends BaseSparkTest { JavaPairRDD reconstrErrors = rdd.mapPartitionsToPair(new VaeReconstructionErrorWithKeyFunction( - sc.broadcast(net.params()), sc.broadcast(mlc.toJson()), 16)); + sc.broadcast(net.getModelParams()), sc.broadcast(mlc.toJson()), 16)); Map l = reconstrErrors.collectAsMap(); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java index 277c4a133..7ba980f62 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestCompareParameterAveragingSparkVsSingleMachine.java @@ -191,7 +191,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { MultiLayerNetwork net = new MultiLayerNetwork(getConf(12345, new RmsProp(0.5))); net.init(); - INDArray initialParams = net.params().dup(); + INDArray initialParams = net.getModelParams().dup(); for (int i = 0; i < seeds.length; i++) { DataSet ds = getOneDataSet(miniBatchSize, seeds[i]); @@ -199,13 +199,13 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { net.setUpdater(null); net.fit(ds); } - INDArray finalParams = net.params().dup(); + INDArray finalParams = net.getModelParams().dup(); //Do training on Spark with one executor, for 3 separate minibatches TrainingMaster tm = getTrainingMaster(1, miniBatchSize, saveUpdater); SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConf(12345, new RmsProp(0.5)), tm); sparkNet.setCollectTrainingStats(true); - INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); + INDArray initialSparkParams = sparkNet.getNetwork().getModelParams().dup(); for (int i = 0; i < seeds.length; i++) { List list = getOneDataSetAsIndividalExamples(miniBatchSize, seeds[i]); @@ -214,7 +214,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { sparkNet.fit(rdd); } - INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); + INDArray finalSparkParams = sparkNet.getNetwork().getModelParams().dup(); assertEquals(initialParams, initialSparkParams); assertNotEquals(initialParams, finalParams); @@ -245,7 +245,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { ComputationGraph net = new ComputationGraph(getGraphConf(12345, new RmsProp(0.5))); net.init(); - INDArray initialParams = net.params().dup(); + INDArray initialParams = net.getModelParams().dup(); for (int i = 0; i < seeds.length; i++) { DataSet ds = getOneDataSet(miniBatchSize, seeds[i]); @@ -253,14 +253,14 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { net.setUpdater(null); net.fit(ds); } - INDArray finalParams = net.params().dup(); + INDArray finalParams = net.getModelParams().dup(); //Do training on Spark with one executor, for 3 separate minibatches TrainingMaster tm = getTrainingMaster(1, miniBatchSize, saveUpdater); SparkComputationGraph sparkNet = new SparkComputationGraph(sc, getGraphConf(12345, new RmsProp(0.5)), tm); sparkNet.setCollectTrainingStats(true); - INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); + INDArray initialSparkParams = sparkNet.getNetwork().getModelParams().dup(); for (int i = 0; i < seeds.length; i++) { List list = getOneDataSetAsIndividalExamples(miniBatchSize, seeds[i]); @@ -269,7 +269,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { sparkNet.fit(rdd); } - INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); + INDArray finalSparkParams = sparkNet.getNetwork().getModelParams().dup(); assertEquals(initialParams, initialSparkParams); assertNotEquals(initialParams, finalParams); @@ -304,7 +304,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { MultiLayerNetwork net = new MultiLayerNetwork(getConf(12345, new Sgd(0.5))); net.init(); - INDArray initialParams = net.params().dup(); + INDArray initialParams = net.getModelParams().dup(); // executioner.addToWatchdog(initialParams, "initialParams"); @@ -314,7 +314,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { net.setUpdater(null); net.fit(ds); } - INDArray finalParams = net.params().dup(); + INDArray finalParams = net.getModelParams().dup(); //Do training on Spark with one executor, for 3 separate minibatches // TrainingMaster tm = getTrainingMaster(1, miniBatchSizePerWorker, saveUpdater); @@ -325,7 +325,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { .rddTrainingApproach(RDDTrainingApproach.Export).build(); SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConf(12345, new Sgd(0.5)), tm); sparkNet.setCollectTrainingStats(true); - INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); + INDArray initialSparkParams = sparkNet.getNetwork().getModelParams().dup(); // executioner.addToWatchdog(initialSparkParams, "initialSparkParams"); @@ -339,7 +339,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { // System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); sparkNet.getSparkTrainingStats().statsAsString(); - INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); + INDArray finalSparkParams = sparkNet.getNetwork().getModelParams().dup(); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Spark) params: " @@ -353,7 +353,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { double sparkScore = sparkNet.getScore(); assertTrue(sparkScore > 0.0); - assertEquals(net.score(), sparkScore, 1e-3); + assertEquals(net.getScore(), sparkScore, 1e-3); } finally { sc.stop(); } @@ -386,7 +386,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { MultiLayerNetwork net = new MultiLayerNetwork(getConfCNN(12345, new Sgd(0.5))); net.init(); - INDArray initialParams = net.params().dup(); + INDArray initialParams = net.getModelParams().dup(); for (int i = 0; i < seeds.length; i++) { DataSet ds = getOneDataSetCNN(miniBatchSizePerWorker * nWorkers, seeds[i]); @@ -394,7 +394,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { net.setUpdater(null); net.fit(ds); } - INDArray finalParams = net.params().dup(); + INDArray finalParams = net.getModelParams().dup(); //Do training on Spark with one executor, for 3 separate minibatches ParameterAveragingTrainingMaster tm = new ParameterAveragingTrainingMaster.Builder(1) @@ -403,7 +403,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { .rddTrainingApproach(RDDTrainingApproach.Export).build(); SparkDl4jMultiLayer sparkNet = new SparkDl4jMultiLayer(sc, getConfCNN(12345, new Sgd(0.5)), tm); sparkNet.setCollectTrainingStats(true); - INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); + INDArray initialSparkParams = sparkNet.getNetwork().getModelParams().dup(); for (int i = 0; i < seeds.length; i++) { List list = @@ -416,7 +416,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { // System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); sparkNet.getSparkTrainingStats().statsAsString(); - INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); + INDArray finalSparkParams = sparkNet.getNetwork().getModelParams().dup(); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Spark) params: " @@ -429,7 +429,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { double sparkScore = sparkNet.getScore(); assertTrue(sparkScore > 0.0); - assertEquals(net.score(), sparkScore, 1e-3); + assertEquals(net.getScore(), sparkScore, 1e-3); } finally { sc.stop(); } @@ -464,7 +464,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { ComputationGraph net = new ComputationGraph(getGraphConf(12345, new Sgd(0.5))); net.init(); - INDArray initialParams = net.params().dup(); + INDArray initialParams = net.getModelParams().dup(); // executioner.addToWatchdog(initialParams, "initialParams"); for (int i = 0; i < seeds.length; i++) { @@ -473,14 +473,14 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { net.setUpdater(null); net.fit(ds); } - INDArray finalParams = net.params().dup(); + INDArray finalParams = net.getModelParams().dup(); // executioner.addToWatchdog(finalParams, "finalParams"); //Do training on Spark with one executor, for 3 separate minibatches TrainingMaster tm = getTrainingMaster(1, miniBatchSizePerWorker, saveUpdater); SparkComputationGraph sparkNet = new SparkComputationGraph(sc, getGraphConf(12345, new Sgd(0.5)), tm); sparkNet.setCollectTrainingStats(true); - INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); + INDArray initialSparkParams = sparkNet.getNetwork().getModelParams().dup(); // executioner.addToWatchdog(initialSparkParams, "initialSparkParams"); @@ -494,7 +494,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { // System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); sparkNet.getSparkTrainingStats().statsAsString(); - INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); + INDArray finalSparkParams = sparkNet.getNetwork().getModelParams().dup(); // executioner.addToWatchdog(finalSparkParams, "finalSparkParams"); float[] fp = finalParams.data().asFloat(); @@ -512,7 +512,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { double sparkScore = sparkNet.getScore(); assertTrue(sparkScore > 0.0); - assertEquals(net.score(), sparkScore, 1e-3); + assertEquals(net.getScore(), sparkScore, 1e-3); } finally { sc.stop(); } @@ -545,7 +545,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { ComputationGraph net = new ComputationGraph(getGraphConfCNN(12345, new Sgd(0.5))); net.init(); - INDArray initialParams = net.params().dup(); + INDArray initialParams = net.getModelParams().dup(); for (int i = 0; i < seeds.length; i++) { DataSet ds = getOneDataSetCNN(miniBatchSizePerWorker * nWorkers, seeds[i]); @@ -553,13 +553,13 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { net.setUpdater(null); net.fit(ds); } - INDArray finalParams = net.params().dup(); + INDArray finalParams = net.getModelParams().dup(); //Do training on Spark with one executor, for 3 separate minibatches TrainingMaster tm = getTrainingMaster(1, miniBatchSizePerWorker, saveUpdater); SparkComputationGraph sparkNet = new SparkComputationGraph(sc, getGraphConfCNN(12345, new Sgd(0.5)), tm); sparkNet.setCollectTrainingStats(true); - INDArray initialSparkParams = sparkNet.getNetwork().params().dup(); + INDArray initialSparkParams = sparkNet.getNetwork().getModelParams().dup(); for (int i = 0; i < seeds.length; i++) { List list = @@ -572,7 +572,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { // System.out.println(sparkNet.getSparkTrainingStats().statsAsString()); sparkNet.getSparkTrainingStats().statsAsString(); - INDArray finalSparkParams = sparkNet.getNetwork().params().dup(); + INDArray finalSparkParams = sparkNet.getNetwork().getModelParams().dup(); // System.out.println("Initial (Local) params: " + Arrays.toString(initialParams.data().asFloat())); // System.out.println("Initial (Spark) params: " + Arrays.toString(initialSparkParams.data().asFloat())); @@ -584,7 +584,7 @@ public class TestCompareParameterAveragingSparkVsSingleMachine { double sparkScore = sparkNet.getScore(); assertTrue(sparkScore > 0.0); - assertEquals(net.score(), sparkScore, 1e-3); + assertEquals(net.getScore(), sparkScore, 1e-3); } finally { sc.stop(); } diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java index 42fc1112c..e4a720a51 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-core/src/test/java/org/deeplearning4j/spark/impl/paramavg/TestSparkMultiLayerParameterAveraging.java @@ -38,7 +38,7 @@ import org.deeplearning4j.nn.api.Layer; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.layers.BaseLayer; +import org.deeplearning4j.nn.conf.layers.BaseLayerConfiguration; import org.deeplearning4j.nn.conf.layers.BatchNormalization; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -191,11 +191,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { new ParameterAveragingTrainingMaster(true, numExecutors(), 1, 5, 1, 0)); MultiLayerNetwork networkCopy = sparkNetCopy.fit(data); - INDArray expectedParams = networkCopy.params(); + INDArray expectedParams = networkCopy.getModelParams(); SparkDl4jMultiLayer sparkNet = getBasicNetwork(); MultiLayerNetwork network = sparkNet.fit(data); - INDArray actualParams = network.params(); + INDArray actualParams = network.getModelParams(); assertEquals(expectedParams.size(1), actualParams.size(1)); } @@ -210,14 +210,14 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { MultiLayerNetwork netCopy = sparkNet.getNetwork().clone(); netCopy.fit(data); - IUpdater expectedUpdater = ((BaseLayer) netCopy.getLayerConfiguration()).getIUpdater(); - double expectedLR = ((Nesterovs)((BaseLayer) netCopy.getLayerConfiguration()).getIUpdater()).getLearningRate(); - double expectedMomentum = ((Nesterovs)((BaseLayer) netCopy.getLayerConfiguration()).getIUpdater()).getMomentum(); + IUpdater expectedUpdater = ((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater(); + double expectedLR = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater()).getLearningRate(); + double expectedMomentum = ((Nesterovs)((BaseLayerConfiguration) netCopy.getLayerConfiguration()).getIUpdater()).getMomentum(); - IUpdater actualUpdater = ((BaseLayer) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater(); + IUpdater actualUpdater = ((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater(); sparkNet.fit(sparkData); - double actualLR = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getLearningRate(); - double actualMomentum = ((Nesterovs)((BaseLayer) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getMomentum(); + double actualLR = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getLearningRate(); + double actualMomentum = ((Nesterovs)((BaseLayerConfiguration) sparkNet.getNetwork().getLayerConfiguration()).getIUpdater()).getMomentum(); assertEquals(expectedUpdater, actualUpdater); assertEquals(expectedLR, actualLR, 0.01); @@ -474,11 +474,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { paths.add(path); } - INDArray paramsBefore = sparkNet.getNetwork().params().dup(); + INDArray paramsBefore = sparkNet.getNetwork().getModelParams().dup(); JavaRDD pathRdd = sc.parallelize(paths); sparkNet.fitPaths(pathRdd); - INDArray paramsAfter = sparkNet.getNetwork().params().dup(); + INDArray paramsAfter = sparkNet.getNetwork().getModelParams().dup(); assertNotEquals(paramsBefore, paramsAfter); SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); @@ -545,11 +545,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { paths.add(path); } - INDArray paramsBefore = sparkNet.getNetwork().params().dup(); + INDArray paramsBefore = sparkNet.getNetwork().getModelParams().dup(); JavaRDD pathRdd = sc.parallelize(paths); sparkNet.fitPaths(pathRdd); - INDArray paramsAfter = sparkNet.getNetwork().params().dup(); + INDArray paramsAfter = sparkNet.getNetwork().getModelParams().dup(); assertNotEquals(paramsBefore, paramsAfter); Thread.sleep(200); @@ -635,11 +635,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { paths.add(path); } - INDArray paramsBefore = sparkNet.getNetwork().params().dup(); + INDArray paramsBefore = sparkNet.getNetwork().getModelParams().dup(); JavaRDD pathRdd = sc.parallelize(paths); sparkNet.fitPaths(pathRdd); - INDArray paramsAfter = sparkNet.getNetwork().params().dup(); + INDArray paramsAfter = sparkNet.getNetwork().getModelParams().dup(); assertNotEquals(paramsBefore, paramsAfter); SparkTrainingStats stats = sparkNet.getSparkTrainingStats(); @@ -657,11 +657,11 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { paths.add(path); } - paramsBefore = sparkNet.getNetwork().params().dup(); + paramsBefore = sparkNet.getNetwork().getModelParams().dup(); pathRdd = sc.parallelize(paths); sparkNet.fitPathsMultiDataSet(pathRdd); - paramsAfter = sparkNet.getNetwork().params().dup(); + paramsAfter = sparkNet.getNetwork().getModelParams().dup(); assertNotEquals(paramsBefore, paramsAfter); stats = sparkNet.getSparkTrainingStats(); @@ -731,9 +731,9 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { sparkNet3.fit(rdd); - INDArray p1 = sparkNet1.getNetwork().params(); - INDArray p2 = sparkNet2.getNetwork().params(); - INDArray p3 = sparkNet3.getNetwork().params(); + INDArray p1 = sparkNet1.getNetwork().getModelParams(); + INDArray p2 = sparkNet2.getNetwork().getModelParams(); + INDArray p3 = sparkNet3.getNetwork().getModelParams(); sparkNet1.getTrainingMaster().deleteTempFiles(sc); sparkNet2.getTrainingMaster().deleteTempFiles(sc); diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java index 5bb21442c..f26ae7e66 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/pw/SharedTrainingWrapper.java @@ -239,7 +239,7 @@ public class SharedTrainingWrapper { List listeners = worker.getListeners(); if(listeners != null){ - model.setListeners(listeners.toArray(new TrainingListener[]{})); + model.addTrainingListeners(listeners.toArray(new TrainingListener[]{})); StatsStorageRouter r = worker.getRouter(); if(r != null){ for(TrainingListener l : listeners){ @@ -319,7 +319,7 @@ public class SharedTrainingWrapper { consumer = UpdatesConsumer.builder() .numWorkers(numWorkers) .accumulator(accumulator) - .params(model.params()) + .params(model.getModelParams()) .build(); accumulator.setExternalSource(consumer.getUpdatesQueue()); @@ -382,7 +382,7 @@ public class SharedTrainingWrapper { // if we're going to extend iteratation for debugging purposes - let's do that here if (trainingConfiguration.getDebugLongerIterations() > 0) { log.warn("Adding SleepyListener: {} ms", trainingConfiguration.getDebugLongerIterations()); - model.addListeners(SleepyTrainingListener.builder() + model.addTrainingListeners(SleepyTrainingListener.builder() .timerIteration(trainingConfiguration.getDebugLongerIterations()).build()); } @@ -416,7 +416,7 @@ public class SharedTrainingWrapper { val mParams = modelParamsSupplier.get(); if (mParams != null) { log.info("Updating model params to the most recent ones..."); - originalModel.params().assign(mParams); + originalModel.getModelParams().assign(mParams); } // ok. attaching accumulator to model @@ -520,7 +520,7 @@ public class SharedTrainingWrapper { val taAveraged = mh.getAverageThresholdAlgorithm(); // FIXME: fill stats here - val result = SharedTrainingResult.builder().aggregationsCount(1).scoreSum(originalModel.score()) + val result = SharedTrainingResult.builder().aggregationsCount(1).scoreSum(originalModel.getScore()) .updaterStateArray(updaterState).listenerMetaData(new ArrayList<>()) .listenerStaticInfo(new ArrayList<>()).listenerUpdates(new ArrayList<>()) .minibatchesPerExecutor(Collections.singletonMap(SparkUtils.getSparkExecutorId(), iteratorDataSetCount.get().get())) diff --git a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java index bb291c0b8..5d9dd9d33 100644 --- a/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java +++ b/cavis-dnn/cavis-dnn-spark/cavis-dnn-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingMaster.java @@ -263,7 +263,7 @@ public class SharedTrainingMaster extends BaseTrainingMaster T[] doEvaluation(MultiDataSetIterator iterator, - T... evaluations) { + T... evaluations) { return null; } @@ -394,9 +397,6 @@ public class BarnesHutTsne implements IModel { return null; } - @Override - public void addListeners(TrainingListener... listener) {//no op - } public Map getParamTable() { return null; @@ -417,7 +417,8 @@ public class BarnesHutTsne implements IModel { } @Override - public void clear() {} + public void clear() { + } @Override public void applyConstraints(int iteration, int epoch) { @@ -440,6 +441,7 @@ public class BarnesHutTsne implements IModel { /** * Symmetrize the value matrix + * * @param rowP * @param colP * @param valP @@ -454,7 +456,8 @@ public class BarnesHutTsne implements IModel { workspaceConfigurationExternal, workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { + try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ + { for (int n = 0; n < N; n++) { int begin = rowP.getInt(n); int end = rowP.getInt(n + 1); @@ -487,7 +490,7 @@ public class BarnesHutTsne implements IModel { for (int n = 0; n < N; n++) { for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { boolean present = false; - for (int m = rowP.getInt(colP.getInt(i)); m < rowP.getInt(colP.getInt(i)+1); m++) { + for (int m = rowP.getInt(colP.getInt(i)); m < rowP.getInt(colP.getInt(i) + 1); m++) { if (colP.getInt(m) == n) { present = true; if (n <= colP.getInt(i)) { @@ -570,7 +573,7 @@ public class BarnesHutTsne implements IModel { * @param listeners */ - public void setListeners(Collection listeners) { + public void addTrainingListeners(Collection listeners) { } @@ -580,7 +583,7 @@ public class BarnesHutTsne implements IModel { * @param listeners */ @Override - public void setListeners(TrainingListener... listeners) { + public void addTrainingListeners(TrainingListener... listeners) { } @@ -615,7 +618,8 @@ public class BarnesHutTsne implements IModel { private INDArray staticData; - public Initializer() {} + public Initializer() { + } public Initializer(INDArray input) { this.staticData = input; @@ -654,7 +658,8 @@ public class BarnesHutTsne implements IModel { workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { + try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ + { x.divi(x.maxNumber()); @@ -697,6 +702,7 @@ public class BarnesHutTsne implements IModel { /** * An individual iteration + * * @param p the probabilities that certain points * are near each other * @param i the iteration (primarily for debugging purposes) @@ -705,7 +711,9 @@ public class BarnesHutTsne implements IModel { update(gradient().getGradientFor(Y_GRAD), Y_GRAD); } - static double sign_tsne(double x) { return (x == .0 ? .0 : (x < .0 ? -1.0 : 1.0)); } + static double sign_tsne(double x) { + return (x == .0 ? .0 : (x < .0 ? -1.0 : 1.0)); + } @Override @@ -717,7 +725,8 @@ public class BarnesHutTsne implements IModel { workspaceConfigurationExternal, workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { + try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ + { INDArray yGrads = gradient; if (gains == null) @@ -726,12 +735,11 @@ public class BarnesHutTsne implements IModel { //Nd4j.getExecutioner().exec(new BarnesHutGains(gains, gains, yGrads, yIncs)); // Copied from Reference for (int i = 0; i < yGrads.rows(); ++i) { - for (int j = 0; j < yGrads.columns(); ++j) { - if (sign_tsne(yGrads.getDouble(i,j)) == sign_tsne(yIncs.getDouble(i,j))) { - gains.putScalar(new int[]{i,j}, gains.getDouble(i,j)*0.8); - } - else { - gains.putScalar(new int[]{i,j}, gains.getDouble(i,j)+0.2); + for (int j = 0; j < yGrads.columns(); ++j) { + if (sign_tsne(yGrads.getDouble(i, j)) == sign_tsne(yIncs.getDouble(i, j))) { + gains.putScalar(new int[]{i, j}, gains.getDouble(i, j) * 0.8); + } else { + gains.putScalar(new int[]{i, j}, gains.getDouble(i, j) + 0.2); } } } @@ -759,8 +767,9 @@ public class BarnesHutTsne implements IModel { /** * Save the model as a file with a csv format, adding the label as the last column. + * * @param labels - * @param path the path to write + * @param path the path to write * @throws IOException */ public void saveAsFile(List labels, String path) throws IOException { @@ -805,6 +814,7 @@ public class BarnesHutTsne implements IModel { write.flush(); } } + /** * Plot tsne * @@ -823,7 +833,7 @@ public class BarnesHutTsne implements IModel { @Override - public double score() { + public double getScore() { /*MemoryWorkspace workspace = workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() @@ -832,7 +842,8 @@ public class BarnesHutTsne implements IModel { workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { + try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ + { // Get estimate of normalization term @@ -871,7 +882,7 @@ public class BarnesHutTsne implements IModel { } @Override - public INDArray params() { + public INDArray getModelParams() { return null; } @@ -912,7 +923,7 @@ public class BarnesHutTsne implements IModel { } @Override - public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr){ + public void fit(INDArray data, LayerWorkspaceMgr workspaceMgr) { fit(data); } @@ -937,7 +948,8 @@ public class BarnesHutTsne implements IModel { workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { + try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ + { if (yIncs == null) @@ -967,7 +979,7 @@ public class BarnesHutTsne implements IModel { @Override public Pair gradientAndScore() { - return new Pair<>(gradient(), score()); + return new Pair<>(gradient(), getScore()); } @Override @@ -1128,7 +1140,7 @@ public class BarnesHutTsne implements IModel { return this; } - public Builder workspaceMode(WorkspaceMode workspaceMode){ + public Builder workspaceMode(WorkspaceMode workspaceMode) { this.workspaceMode = workspaceMode; return this; } @@ -1143,7 +1155,7 @@ public class BarnesHutTsne implements IModel { @Override - public void close(){ + public void close() { //No-op } @@ -1153,7 +1165,34 @@ public class BarnesHutTsne implements IModel { * @return training listener */ @Override - public Collection getListeners() { + public Collection getTrainingListeners() { return null; } -} + + @Override + public ITraininableLayerConfiguration getTrainingConfig() { + throw new UnsupportedOperationException("Not supported"); + } + + @Override + public INDArray getParams() { + throw new RuntimeException("Not supported"); + + } + + /** + * DL4J layers typically produce the sum of the gradients during the backward pass for each layer, and if required + * (if minibatch=true) then divide by the minibatch size.
+ * However, there are some exceptions, such as the batch norm mean/variance estimate parameters: these "gradients" + * are actually not gradients, but are updates to be applied directly to the parameter vector. Put another way, + * most gradients should be divided by the minibatch to get the average; some "gradients" are actually final updates + * already, and should not be divided by the minibatch size. + * + * @param paramName Name of the parameter + * @return True if gradients should be divided by minibatch (most params); false otherwise (edge cases like batch norm mean/variance estimates) + */ + @Override + public boolean updaterDivideByMinibatch(String paramName) { + return false; + } +} \ No newline at end of file diff --git a/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/ManualTests.java b/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/ManualTests.java index 7c5de3bbb..415cd7ac2 100644 --- a/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/ManualTests.java +++ b/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/ManualTests.java @@ -30,7 +30,6 @@ import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration.NeuralNetConfigurationBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; @@ -151,7 +150,7 @@ public class ManualTests { log.info("Train model...."); - model.setListeners(new ScoreIterationListener(listenerFreq), new ConvolutionalIterationListener(listenerFreq)); + model.addTrainingListeners(new ScoreIterationListener(listenerFreq), new ConvolutionalIterationListener(listenerFreq)); while (lfw.hasNext()) { lfwNext = lfw.next(); @@ -279,7 +278,7 @@ public class ManualTests { */ log.info("Train model...."); - model.setListeners(new ConvolutionalIterationListener(1)); + model.addTrainingListeners(new ConvolutionalIterationListener(1)); //((NativeOpExecutioner) Nd4j.getExecutioner()).getLoop().setOmpNumThreads(8); @@ -339,7 +338,7 @@ public class ManualTests { model.init(); log.info("Train model...."); - model.setListeners(new ConvolutionalIterationListener(1)); + model.addTrainingListeners(new ConvolutionalIterationListener(1)); for (int i = 0; i < nEpochs; i++) { model.fit(mnistTrain); diff --git a/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java b/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java index e545ff53b..b53e55c9c 100644 --- a/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java +++ b/cavis-ui/cavis-ui-common/src/test/java/org/deeplearning4j/ui/weights/TestConvolutionalListener.java @@ -74,7 +74,7 @@ public class TestConvolutionalListener { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(new ConvolutionalIterationListener(1), new ScoreIterationListener(1)); + net.addTrainingListeners(new ConvolutionalIterationListener(1), new ScoreIterationListener(1)); for (int i = 0; i < 10; i++) { net.fit(mnistTrain.next()); @@ -82,7 +82,7 @@ public class TestConvolutionalListener { } ComputationGraph cg = net.toComputationGraph(); - cg.setListeners(new ConvolutionalIterationListener(1), new ScoreIterationListener(1)); + cg.addTrainingListeners(new ConvolutionalIterationListener(1), new ScoreIterationListener(1)); for (int i = 0; i < 10; i++) { cg.fit(mnistTrain.next()); Thread.sleep(1000); diff --git a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java index 3797b6550..70144bfd3 100644 --- a/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java +++ b/cavis-ui/cavis-ui-model/src/main/java/org/deeplearning4j/ui/model/stats/BaseStatsListener.java @@ -29,7 +29,6 @@ import org.deeplearning4j.core.storage.StatsStorageRouter; import org.deeplearning4j.core.storage.StorageMetaData; import org.deeplearning4j.core.storage.listener.RoutingIterationListener; import org.deeplearning4j.nn.api.Layer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.graph.ComputationGraph; @@ -419,7 +418,7 @@ public abstract class BaseStatsListener implements RoutingIterationListener { } //--- General --- - report.reportScore(model.score()); //Always report score + report.reportScore(model.getScore()); //Always report score if (updateConfig.collectLearningRates()) { Map lrs = new HashMap<>(); diff --git a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java index 9b1a4801e..24e5e1ed9 100644 --- a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java +++ b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestStatsListener.java @@ -64,9 +64,9 @@ public class TestStatsListener extends BaseDL4JTest { StatsStorage ss = new MapDBStatsStorage(); //in-memory if (useJ7) { - net.setListeners(new J7StatsListener(ss, 1)); + net.addTrainingListeners(new J7StatsListener(ss, 1)); } else { - net.setListeners(new StatsListener(ss, 1)); + net.addTrainingListeners(new StatsListener(ss, 1)); } diff --git a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java index d5b1a116b..1dc5cb1a6 100644 --- a/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java +++ b/cavis-ui/cavis-ui-model/src/test/java/org/deeplearning4j/ui/stats/TestTransferStatsCollection.java @@ -56,7 +56,7 @@ public class TestTransferStatsCollection extends BaseDL4JTest { new FineTuneConfiguration.Builder().updater(new Sgd(0.01)).build()) .setFeatureExtractor(0).build(); - net2.setListeners(new StatsListener(new InMemoryStatsStorage())); + net2.addTrainingListeners(new StatsListener(new InMemoryStatsStorage())); //Previosuly: failed on frozen layers net2.fit(new DataSet(Nd4j.rand(8, 10), Nd4j.rand(8, 10))); diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 2ca083a4d..c0df01142 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -1174,8 +1174,8 @@ public class TrainModule implements UIModule { layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerSize"), String.valueOf(ffl.getNOut())}); } - if (layer instanceof BaseLayer) { - BaseLayer bl = (BaseLayer) layer; + if (layer instanceof BaseLayerConfiguration) { + BaseLayerConfiguration bl = (BaseLayerConfiguration) layer; activationFn = bl.getActivationFn().toString(); long nParams = layer.initializer().numParams(bl.getLayer()); layerInfoRows.add(new String[]{i18N.getMessage("train.model.layerinfotable.layerNParams"), diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java index 8bae39055..09dbe9846 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestRemoteReceiver.java @@ -142,7 +142,7 @@ public class TestRemoteReceiver extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); try(RemoteUIStatsStorageRouter ssr = new RemoteUIStatsStorageRouter("http://localhost:9000")) { - net.setListeners(new StatsListener(ssr), new ScoreIterationListener(1)); + net.addTrainingListeners(new StatsListener(ssr), new ScoreIterationListener(1)); DataSetIterator iter = new IrisDataSetIterator(150, 150); diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java index 694a557bc..d51f74aba 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUI.java @@ -112,7 +112,7 @@ public class TestVertxUI extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(new StatsListener(ss), new ScoreIterationListener(1)); + net.addTrainingListeners(new StatsListener(ss), new ScoreIterationListener(1)); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -142,7 +142,7 @@ public class TestVertxUI extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(new StatsListener(ss, 1), new ScoreIterationListener(1)); + net.addTrainingListeners(new StatsListener(ss, 1), new ScoreIterationListener(1)); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -171,7 +171,7 @@ public class TestVertxUI extends BaseDL4JTest { ComputationGraph net = new ComputationGraph(conf); net.init(); - net.setListeners(new StatsListener(ss), new ScoreIterationListener(1)); + net.addTrainingListeners(new StatsListener(ss), new ScoreIterationListener(1)); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -195,7 +195,7 @@ public class TestVertxUI extends BaseDL4JTest { StatsStorage ss1 = new InMemoryStatsStorage(); - net.setListeners(new StatsListener(ss1, 1, "ss1")); + net.addTrainingListeners(new StatsListener(ss1, 1, "ss1")); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -204,7 +204,7 @@ public class TestVertxUI extends BaseDL4JTest { } StatsStorage ss2 = new InMemoryStatsStorage(); - net.setListeners(new StatsListener(ss2, 1, "ss2")); + net.addTrainingListeners(new StatsListener(ss2, 1, "ss2")); for (int i = 0; i < 4; i++) { net.fit(iter); diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java index bc1ae16a8..e17681c4c 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIManual.java @@ -108,7 +108,7 @@ public class TestVertxUIManual extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - net.setListeners(new StatsListener(ss), new ScoreIterationListener(1)); + net.addTrainingListeners(new StatsListener(ss), new ScoreIterationListener(1)); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -203,7 +203,7 @@ public class TestVertxUIManual extends BaseDL4JTest { StatsListener statsListener = new StatsListener(ss, 1); statsListener.setSessionID(sessionId); - net.setListeners(statsListener, new ScoreIterationListener(1)); + net.addTrainingListeners(statsListener, new ScoreIterationListener(1)); uIServer.attach(ss); DataSetIterator iter = new IrisDataSetIterator(150, 150); diff --git a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java index 7da17dafd..fb21f9561 100644 --- a/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java +++ b/cavis-ui/cavis-ui-vertx/src/test/java/org/deeplearning4j/ui/TestVertxUIMultiSession.java @@ -100,7 +100,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { StatsListener statsListener = new StatsListener(ss, 1); statsListener.setSessionID(sessionId); - net.setListeners(statsListener, new ScoreIterationListener(1)); + net.addTrainingListeners(statsListener, new ScoreIterationListener(1)); uIServer.attach(ss); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -164,7 +164,7 @@ public class TestVertxUIMultiSession extends BaseDL4JTest { StatsListener statsListener = new StatsListener(ss, 1); statsListener.setSessionID(sessionId); - net.setListeners(statsListener, new ScoreIterationListener(1)); + net.addTrainingListeners(statsListener, new ScoreIterationListener(1)); uIServer.attach(ss); DataSetIterator iter = new IrisDataSetIterator(150, 150); diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java index 0e6fdfb38..8d6b94d54 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestImageNet.java @@ -108,7 +108,7 @@ public class TestImageNet extends BaseDL4JTest { assertEquals("golden retriever", predictions.get(0).get(0).getLabel()); // clean up for current model - initializedModel.params().close(); + initializedModel.getModelParams().close(); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); System.gc(); @@ -134,7 +134,7 @@ public class TestImageNet extends BaseDL4JTest { } // clean up for current model - initializedModel.params().close(); + initializedModel.getModelParams().close(); Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); System.gc(); @@ -159,7 +159,7 @@ public class TestImageNet extends BaseDL4JTest { assertEquals("dog", classPrediction.getLabel()); } - initializedModel.params().close(); + initializedModel.getModelParams().close(); } } diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java index 2bf9e7ed1..27eb6e23d 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestInstantiation.java @@ -201,7 +201,7 @@ public class TestInstantiation extends BaseDL4JTest { // clean up for current model Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); - initializedModel.params().close(); + initializedModel.getModelParams().close(); for(INDArray arr : result){ arr.close(); } @@ -271,7 +271,7 @@ public class TestInstantiation extends BaseDL4JTest { Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); f.close(); l.close(); - initializedModel.params().close(); + initializedModel.getModelParams().close(); initializedModel.getFlattenedGradients().close(); System.gc(); } diff --git a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java index 7a26046cd..d759151e8 100644 --- a/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java +++ b/cavis-zoo/cavis-zoo-models/src/test/java/org/deeplearning4j/zoo/TestUtils.java @@ -46,7 +46,7 @@ public class TestUtils { MultiLayerNetwork restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); assertEquals(net.getNetConfiguration(), restored.getNetConfiguration()); - assertEquals(net.params(), restored.params()); + assertEquals(net.getModelParams(), restored.getModelParams()); return restored; } catch (IOException e){ @@ -66,7 +66,7 @@ public class TestUtils { ComputationGraph restored = ModelSerializer.restoreComputationGraph(bais, true); assertEquals(net.getComputationGraphConfiguration(), restored.getComputationGraphConfiguration()); - assertEquals(net.params(), restored.params()); + assertEquals(net.getModelParams(), restored.getModelParams()); return restored; } catch (IOException e){ From 1f2e82d3eff0b50ab40be359adfb81017e697bf6 Mon Sep 17 00:00:00 2001 From: brian Date: Fri, 14 Apr 2023 13:24:19 +0200 Subject: [PATCH 124/126] Playing with some new code 2 - clean build/test Signed-off-by: brian --- .../src/test/java/net/brutex/gan/App.java | 2 +- .../nn/conf/layers/LayerBuilderTest.java | 2 +- .../nn/conf/layers/LayerConfigTest.java | 8 +- .../layers/LayerConfigValidationTest.java | 2 +- .../samediff/testlayers/SameDiffConv.java | 2 +- .../samediff/testlayers/SameDiffDense.java | 2 +- .../TransferLearningCompGraphTest.java | 6 +- .../TransferLearningMLNTest.java | 14 +- .../nn/updater/TestUpdaters.java | 5 +- .../regressiontest/RegressionTest050.java | 12 +- .../regressiontest/RegressionTest060.java | 12 +- .../regressiontest/RegressionTest071.java | 12 +- .../regressiontest/RegressionTest080.java | 12 +- .../regressiontest/RegressionTest100a.java | 10 +- .../regressiontest/RegressionTest100b3.java | 10 +- .../regressiontest/RegressionTest100b4.java | 20 +- .../regressiontest/RegressionTest100b6.java | 20 +- .../KerasInitilizationTest.java | 2 +- .../advanced/activation/KerasPReLUTest.java | 2 +- .../KerasAtrousConvolution1DTest.java | 2 +- .../KerasAtrousConvolution2DTest.java | 2 +- .../convolution/KerasConvolution1DTest.java | 2 +- .../convolution/KerasConvolution2DTest.java | 2 +- .../convolution/KerasConvolution3DTest.java | 2 +- .../convolution/KerasDeconvolution2DTest.java | 2 +- .../KerasDepthwiseConvolution2DTest.java | 2 +- .../KerasSeparableConvolution2DTest.java | 2 +- .../keras/layers/core/KerasDenseTest.java | 2 +- .../keras/layers/recurrent/KerasLSTMTest.java | 3 +- .../layers/recurrent/KerasSimpleRnnTest.java | 2 +- .../dnn/api/INeuralNetworkConfiguration.java | 27 +- .../conf/ComputationGraphConfiguration.java | 4 +- .../NeuralNetBaseBuilderConfiguration.java | 817 ++++++++---------- .../nn/conf/NeuralNetConfiguration.java | 110 +-- .../conf/layers/BaseLayerConfiguration.java | 32 +- .../nn/conf/layers/ConvolutionLayer.java | 4 +- .../nn/conf/layers/DenseLayer.java | 7 +- .../conf/layers/EmbeddingSequenceLayer.java | 4 +- .../nn/conf/layers/LayerConfiguration.java | 39 +- .../nn/conf/layers/LocallyConnected1D.java | 2 +- .../nn/conf/layers/LocallyConnected2D.java | 2 +- .../nn/conf/layers/PReLULayer.java | 2 +- .../conf/layers/RecurrentAttentionLayer.java | 2 +- .../samediff/AbstractSameDiffLayer.java | 2 +- .../conf/layers/samediff/SameDiffVertex.java | 2 +- .../variational/VariationalAutoencoder.java | 1 + .../conf/serde/BaseNetConfigDeserializer.java | 4 +- ...utationGraphConfigurationDeserializer.java | 4 +- .../NeuralNetConfigurationDeserializer.java | 2 +- .../nn/layers/AbstractLayer.java | 5 + .../deeplearning4j/nn/layers/BaseLayer.java | 1 - .../nn/layers/ocnn/OCNNParamInitializer.java | 3 +- .../variational/VariationalAutoencoder.java | 41 +- .../params/Convolution3DParamInitializer.java | 3 +- .../params/ConvolutionParamInitializer.java | 2 +- .../Deconvolution3DParamInitializer.java | 3 +- .../params/DeconvolutionParamInitializer.java | 3 +- .../nn/params/DefaultParamInitializer.java | 6 +- .../DepthwiseConvolutionParamInitializer.java | 4 +- ...avesBidirectionalLSTMParamInitializer.java | 10 +- .../nn/params/GravesLSTMParamInitializer.java | 6 +- .../nn/params/LSTMParamInitializer.java | 5 +- .../nn/params/PReLUParamInitializer.java | 2 +- .../SeparableConvolutionParamInitializer.java | 6 +- .../nn/params/SimpleRnnParamInitializer.java | 6 +- ...ariationalAutoencoderParamInitializer.java | 3 +- .../FineTuneConfiguration.java | 6 +- .../nn/transferlearning/TransferLearning.java | 12 +- .../nn/updater/BaseMultiLayerUpdater.java | 7 +- .../java/net/brutex/ai/dnn/api/dnnTest.java | 6 +- .../ui/module/train/TrainModule.java | 2 +- .../deeplearning4j/zoo/model/Darknet19.java | 1 + .../deeplearning4j/zoo/model/ResNet50.java | 2 +- 73 files changed, 647 insertions(+), 743 deletions(-) diff --git a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java index aba07ef0d..0287d32a9 100644 --- a/brutex-extended-tests/src/test/java/net/brutex/gan/App.java +++ b/brutex-extended-tests/src/test/java/net/brutex/gan/App.java @@ -118,7 +118,7 @@ public class App { .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold(GRADIENT_THRESHOLD) //.weightInit(WeightInit.XAVIER) - .weightInitFn(new WeightInitXavier()) + .weightInit(WeightInit.XAVIER) .activation(Activation.IDENTITY) .layersFromArray(genLayers()) .inputType(InputType.convolutional(X_DIM, Y_DIM, CHANNELS)) diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java index 3ae5d8bd0..680681920 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerBuilderTest.java @@ -74,7 +74,7 @@ public class LayerBuilderTest extends BaseDL4JTest { checkSerialization(layer); assertEquals(act, layer.getActivationFn()); - assertEquals(weight.getWeightInitFunction(), layer.getWeightInitFn()); + assertEquals(weight.getWeightInitFunction(), layer.getWeightInit()); assertEquals(new Dropout(dropOut), layer.getIDropout()); assertEquals(updater, layer.getIUpdater()); assertEquals(gradNorm, layer.getGradientNormalization()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java index 28d17c150..7777475e6 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigTest.java @@ -99,8 +99,8 @@ public class LayerConfigTest extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(conf); net.init(); - assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn()); - assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn()); + assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInit()); + assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInit()); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0); @@ -117,8 +117,8 @@ public class LayerConfigTest extends BaseDL4JTest { net = new MultiLayerNetwork(conf); net.init(); - assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInitFn()); - assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInitFn()); + assertEquals(new WeightInitDistribution(defaultDistribution), ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getWeightInit()); + assertEquals(new WeightInitDistribution(overriddenDistribution), ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getWeightInit()); assertEquals(1, ((BaseLayerConfiguration) conf.getConf(0).getLayer()).getBiasInit(), 0.0); assertEquals(0, ((BaseLayerConfiguration) conf.getConf(1).getLayer()).getBiasInit(), 0.0); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java index dae839a06..b813b2b5f 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/conf/layers/LayerConfigValidationTest.java @@ -185,7 +185,7 @@ public class LayerConfigValidationTest extends BaseDL4JTest { layerConf1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); assertEquals(expectedAdamMeanDecay, ((Adam) layerConf1.getIUpdater()).getBeta1(), 1e-3); assertEquals(expectedAdamVarDecay, ((Adam) layerConf1.getIUpdater()).getBeta2(), 1e-3); - assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInitFn()); + assertEquals(new WeightInitDistribution(expectedDist), layerConf1.getWeightInit()); assertNull(TestUtils.getL1Reg(layerConf1.getRegularization())); assertNull(TestUtils.getL2Reg(layerConf1.getRegularization())); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java index 6fe2cf15e..f8a2f173b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffConv.java @@ -157,7 +157,7 @@ public class SameDiffConv extends SameDiffLayer { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { NeuralNetConfiguration clone = globalConfig.clone().build(); if (activation == null) { - activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(clone.getActivation()); } if (cm == null) { cm = clone.getConvolutionMode(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java index d0a176d63..e1799443d 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/layers/samediff/testlayers/SameDiffDense.java @@ -119,7 +119,7 @@ public class SameDiffDense extends SameDiffLayer { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { NeuralNetConfiguration clone = globalConfig.clone().build(); if(activation == null){ - activation = SameDiffLayerUtils.fromIActivation(clone.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(clone.getActivation()); } } diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java index a8bfd8d97..954e8ed18 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningCompGraphTest.java @@ -141,9 +141,9 @@ public class TransferLearningCompGraphTest extends BaseDL4JTest { BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getLayer("layer0").getLayerConfiguration()); BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getLayer("layer1").getLayerConfiguration()); BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getLayer("layer3").getLayerConfiguration()); - assertEquals(bl0.getWeightInitFn(), new WeightInitDistribution(new NormalDistribution(1, 1e-1))); - assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); - assertEquals(bl1.getWeightInitFn(), new WeightInitXavier()); + assertEquals(bl0.getWeightInit(), new WeightInitDistribution(new NormalDistribution(1, 1e-1))); + assertEquals(bl1.getWeightInit(), new WeightInitXavier()); + assertEquals(bl1.getWeightInit(), new WeightInitXavier()); ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In") .addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In") diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java index 88e8d5d01..7d10a3bc7 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/transferlearning/TransferLearningMLNTest.java @@ -163,14 +163,14 @@ public class TransferLearningMLNTest extends BaseDL4JTest { BaseLayerConfiguration bl0 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(0).getLayer()); BaseLayerConfiguration bl1 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(1).getLayer()); BaseLayerConfiguration bl3 = ((BaseLayerConfiguration) modelNow.getNetConfiguration().getConf(3).getLayer()); - assertEquals(bl0.getWeightInitFn().getClass(), WeightInitXavier.class); + assertEquals(bl0.getWeightInit().getClass(), WeightInitXavier.class); try { - assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInitFn()), + assertEquals(JsonMappers.getMapper().writeValueAsString(bl1.getWeightInit()), JsonMappers.getMapper().writeValueAsString(new WeightInitDistribution(new NormalDistribution(1, 1e-1)))); } catch (JsonProcessingException e) { throw new RuntimeException(e); } - assertEquals(bl3.getWeightInitFn(), new WeightInitXavier()); + assertEquals(bl3.getWeightInit(), new WeightInitXavier()); //modelNow should have the same architecture as modelExpectedArch assertArrayEquals(modelExpectedArch.getModelParams().shape(), modelNow.getModelParams().shape()); @@ -506,13 +506,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest { BaseLayerConfiguration l0 = (BaseLayerConfiguration) net.getLayer(0).getLayerConfiguration(); assertEquals(new Adam(1e-4), l0.getIUpdater()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(0.1, TestUtils.getL1(l0), 1e-6); BaseLayerConfiguration l1 = (BaseLayerConfiguration) net.getLayer(1).getLayerConfiguration(); assertEquals(new Adam(1e-4), l1.getIUpdater()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); - assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l1.getWeightInit()); assertEquals(0.2, TestUtils.getL2(l1), 1e-6); assertEquals(BackpropType.Standard, conf.getBackpropType()); @@ -521,13 +521,13 @@ public class TransferLearningMLNTest extends BaseDL4JTest { l0 = (BaseLayerConfiguration) net2.getLayer(0).getLayerConfiguration(); assertEquals(new Adam(2e-2), l0.getIUpdater()); assertEquals(Activation.TANH.getActivationFunction(), l0.getActivationFn()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(0.1, TestUtils.getL1(l0), 1e-6); l1 = (BaseLayerConfiguration) net2.getLayer(1).getLayerConfiguration(); assertEquals(new Adam(2e-2), l1.getIUpdater()); assertEquals(Activation.HARDSIGMOID.getActivationFunction(), l1.getActivationFn()); - assertEquals(new WeightInitRelu(), l1.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l1.getWeightInit()); assertEquals(0.2, TestUtils.getL2(l1), 1e-6); assertEquals(BackpropType.TruncatedBPTT, net2.getNetConfiguration().getBackpropType()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java index f92e34bf2..ce7d713dc 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/nn/updater/TestUpdaters.java @@ -37,6 +37,7 @@ import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.params.PretrainParamInitializer; import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater; +import org.deeplearning4j.nn.weights.WeightInit; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.nd4j.linalg.activations.Activation; @@ -940,7 +941,9 @@ public class TestUpdaters extends BaseDL4JTest { List blocks; NeuralNetConfiguration conf = - NeuralNetConfiguration.builder().updater(new Adam(0.5)).list() + NeuralNetConfiguration.builder() + .updater(new Adam(0.5)) + .weightInit(WeightInit.NORMAL) .layer(0, new VariationalAutoencoder.Builder().nIn(8).nOut(12) .encoderLayerSizes(10, 11).decoderLayerSizes(13, 14).build()) .build(); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java index 773ccbae8..a771e414b 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest050.java @@ -72,7 +72,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertEquals("relu", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); @@ -81,7 +81,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertTrue(l1.getLossFn() instanceof LossMCXENT); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater()); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); @@ -106,7 +106,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertTrue(l0.getActivationFn() instanceof ActivationLReLU); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); + assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(new Dropout(0.6), l0.getIDropout()); @@ -118,7 +118,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertTrue(l1.getLossFn() instanceof LossMSE); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); - assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); + assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(new Dropout(0.6), l1.getIDropout()); @@ -145,7 +145,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNOut()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); @@ -165,7 +165,7 @@ public class RegressionTest050 extends BaseDL4JTest { assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(5, l2.getNOut()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java index c75c11d11..8d6dae94a 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest060.java @@ -74,7 +74,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertEquals("relu", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); @@ -83,7 +83,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertTrue(l1.getLossFn() instanceof LossMCXENT); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new Nesterovs(0.15, 0.9), l1.getIUpdater()); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); @@ -108,7 +108,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertTrue(l0.getActivationFn() instanceof ActivationLReLU); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); + assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(new Dropout(0.6), l0.getIDropout()); @@ -122,7 +122,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertTrue(l1.getLossFn() instanceof LossMSE); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); - assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); + assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(0.15, ((RmsProp)l1.getIUpdater()).getLearningRate(), 1e-6); assertEquals(new Dropout(0.6), l1.getIDropout()); @@ -151,7 +151,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNOut()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); @@ -171,7 +171,7 @@ public class RegressionTest060 extends BaseDL4JTest { assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(5, l2.getNOut()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java index 63ea30e49..8589b7de2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest071.java @@ -75,7 +75,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertEquals("relu", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new Nesterovs(0.15, 0.9), l0.getIUpdater()); assertEquals(0.15, ((Nesterovs)l0.getIUpdater()).getLearningRate(), 1e-6); @@ -84,7 +84,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertTrue(l1.getLossFn() instanceof LossMCXENT); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); @@ -109,7 +109,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertTrue(l0.getActivationFn() instanceof ActivationLReLU); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); + assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(new Dropout(0.6), l0.getIDropout()); @@ -123,7 +123,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertTrue(l1.getLossFn() instanceof LossMSE); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); - assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); + assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l1.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertEquals(new Dropout(0.6), l1.getIDropout()); @@ -152,7 +152,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertEquals("tanh", l0.getActivationFn().toString()); assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNOut()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); assertArrayEquals(new int[] {2, 2}, l0.getKernelSize()); @@ -172,7 +172,7 @@ public class RegressionTest071 extends BaseDL4JTest { assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); //TODO assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(5, l2.getNOut()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertEquals(new RmsProp(0.15, 0.96, RmsProp.DEFAULT_RMSPROP_EPSILON), l0.getIUpdater()); assertEquals(0.15, ((RmsProp)l0.getIUpdater()).getLearningRate(), 1e-6); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java index 010ac9733..90cb2c126 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest080.java @@ -74,7 +74,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertTrue(l0.getActivationFn() instanceof ActivationReLU); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertTrue(l0.getIUpdater() instanceof Nesterovs); Nesterovs n = (Nesterovs) l0.getIUpdater(); assertEquals(0.9, n.getMomentum(), 1e-6); @@ -87,7 +87,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertTrue(l1.getLossFn() instanceof LossMCXENT); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertTrue(l1.getIUpdater() instanceof Nesterovs); assertEquals(0.9, ((Nesterovs)l1.getIUpdater()).getMomentum(), 1e-6); assertEquals(0.15, ((Nesterovs)l1.getIUpdater()).getLearningRate(), 1e-6); @@ -113,7 +113,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertTrue(l0.getActivationFn() instanceof ActivationLReLU); assertEquals(3, l0.getNIn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInitFn()); + assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l0.getWeightInit()); assertTrue(l0.getIUpdater() instanceof RmsProp); RmsProp r = (RmsProp) l0.getIUpdater(); assertEquals(0.96, r.getRmsDecay(), 1e-6); @@ -130,7 +130,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertTrue(l1.getLossFn() instanceof LossMSE); assertEquals(4, l1.getNIn()); assertEquals(5, l1.getNOut()); - assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInitFn()); + assertEquals(new WeightInitDistribution(new NormalDistribution(0.1, 1.2)), l1.getWeightInit()); assertTrue(l1.getIUpdater() instanceof RmsProp); r = (RmsProp) l1.getIUpdater(); assertEquals(0.96, r.getRmsDecay(), 1e-6); @@ -162,7 +162,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertTrue(l0.getActivationFn() instanceof ActivationTanH); assertEquals(3, l0.getNIn()); assertEquals(3, l0.getNOut()); - assertEquals(new WeightInitRelu(), l0.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l0.getWeightInit()); assertTrue(l0.getIUpdater() instanceof RmsProp); RmsProp r = (RmsProp) l0.getIUpdater(); assertEquals(0.96, r.getRmsDecay(), 1e-6); @@ -185,7 +185,7 @@ public class RegressionTest080 extends BaseDL4JTest { assertTrue(l2.getLossFn() instanceof LossNegativeLogLikelihood); assertEquals(26 * 26 * 3, l2.getNIn()); assertEquals(5, l2.getNOut()); - assertEquals(new WeightInitRelu(), l2.getWeightInitFn()); + assertEquals(new WeightInitRelu(), l2.getWeightInit()); assertTrue(l2.getIUpdater() instanceof RmsProp); r = (RmsProp) l2.getIUpdater(); assertEquals(0.96, r.getRmsDecay(), 1e-6); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java index 6b6558c48..6555c5eec 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100a.java @@ -89,21 +89,21 @@ public class RegressionTest100a extends BaseDL4JTest { GravesLSTM l0 = (GravesLSTM) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(200, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new RmsProp(0.1), l0.getIUpdater()); GravesLSTM l1 = (GravesLSTM) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(200, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new RmsProp(0.1), l1.getIUpdater()); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(77, l2.getNOut()); - assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l2.getWeightInit()); assertEquals(new WeightDecay(0.001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new RmsProp(0.1), l0.getIUpdater()); @@ -139,7 +139,7 @@ public class RegressionTest100a extends BaseDL4JTest { assertEquals(32, l0.getNOut()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new Adam(0.05), l0.getIUpdater()); @@ -175,7 +175,7 @@ public class RegressionTest100a extends BaseDL4JTest { assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); - assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); + assertEquals(new WeightInitXavier(), cl.getWeightInit()); assertArrayEquals(new int[]{1,1}, cl.getKernelSize()); assertArrayEquals(new int[]{1,1}, cl.getKernelSize()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java index 829fc8c2b..223b7be91 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b3.java @@ -124,21 +124,21 @@ public class RegressionTest100b3 extends BaseDL4JTest { LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(200, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(200, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l1)); assertEquals(new Adam(0.005), l1.getIUpdater()); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(77, l2.getNOut()); - assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l2.getWeightInit()); assertEquals(new WeightDecay(0.0001, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); @@ -174,7 +174,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { assertEquals(32, l0.getNOut()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new WeightDecay(1e-4, false), TestUtils.getWeightDecayReg(l0)); assertEquals(new Adam(1e-3), l0.getIUpdater()); @@ -210,7 +210,7 @@ public class RegressionTest100b3 extends BaseDL4JTest { assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); - assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); + assertEquals(new WeightInitXavier(), cl.getWeightInit()); assertArrayEquals(new int[]{1,1}, cl.getKernelSize()); assertArrayEquals(new int[]{1,1}, cl.getKernelSize()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java index b1247b3c1..6cdede6bd 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b4.java @@ -142,21 +142,21 @@ public class RegressionTest100b4 extends BaseDL4JTest { LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(200, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(200, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new Adam(0.005), l1.getIUpdater()); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(77, l2.getNOut()); - assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l2.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new Adam(0.005), l2.getIUpdater()); @@ -192,7 +192,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(32, l0.getNOut()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new Adam(1e-3), l0.getIUpdater()); @@ -229,7 +229,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); - assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); + assertEquals(new WeightInitXavier(), cl.getWeightInit()); assertArrayEquals(new int[]{1, 1}, cl.getKernelSize()); INDArray outExp; @@ -260,7 +260,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationReLU(), l0.getActivationFn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); assertArrayEquals(new int[]{3, 3}, l0.getKernelSize()); @@ -271,7 +271,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationReLU(), l1.getActivationFn()); assertEquals(8, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new Adam(0.005), l1.getIUpdater()); assertArrayEquals(new int[]{3, 3}, l1.getKernelSize()); @@ -297,7 +297,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration(); assertEquals(new ActivationReLU(), l5.getActivationFn()); assertEquals(16, l5.getNOut()); - assertEquals(new WeightInitXavier(), l5.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l5.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5)); assertEquals(new Adam(0.005), l5.getIUpdater()); assertArrayEquals(new int[]{3, 3}, l5.getKernelSize()); @@ -318,7 +318,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration(); assertEquals(4, l8.getNOut()); - assertEquals(new WeightInitXavier(), l8.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l8.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); assertEquals(new Adam(0.005), l8.getIUpdater()); assertArrayEquals(new int[]{4, 4}, l8.getKernelSize()); @@ -327,7 +327,7 @@ public class RegressionTest100b4 extends BaseDL4JTest { assertArrayEquals(new int[]{0, 0}, l8.getPadding()); CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration(); - assertEquals(new WeightInitXavier(), l9.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l9.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); assertEquals(new Adam(0.005), l9.getIUpdater()); assertEquals(new LossMAE(), l9.getLossFn()); diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index f00b9c437..c0ee3dca2 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -124,21 +124,21 @@ public class RegressionTest100b6 extends BaseDL4JTest { LSTM l0 = (LSTM) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationTanH(), l0.getActivationFn()); assertEquals(200, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); LSTM l1 = (LSTM) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationTanH(), l1.getActivationFn()); assertEquals(200, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new Adam(0.005), l1.getIUpdater()); RnnOutputLayer l2 = (RnnOutputLayer) net.getLayer(2).getLayerConfiguration(); assertEquals(new ActivationSoftmax(), l2.getActivationFn()); assertEquals(77, l2.getNOut()); - assertEquals(new WeightInitXavier(), l2.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l2.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l2)); assertEquals(new Adam(0.005), l2.getIUpdater()); @@ -174,7 +174,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(32, l0.getNOut()); assertArrayEquals(new int[]{256, 256}, l0.getEncoderLayerSizes()); assertArrayEquals(new int[]{256, 256}, l0.getDecoderLayerSizes()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new Adam(1e-3), l0.getIUpdater()); @@ -210,7 +210,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(nBoxes * (5 + nClasses), cl.getNOut()); assertEquals(new ActivationIdentity(), cl.getActivationFn()); assertEquals(ConvolutionMode.Same, cl.getConvolutionMode()); - assertEquals(new WeightInitXavier(), cl.getWeightInitFn()); + assertEquals(new WeightInitXavier(), cl.getWeightInit()); assertArrayEquals(new int[]{1, 1}, cl.getKernelSize()); INDArray outExp; @@ -240,7 +240,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { ConvolutionLayer l0 = (ConvolutionLayer) net.getLayer(0).getLayerConfiguration(); assertEquals(new ActivationReLU(), l0.getActivationFn()); assertEquals(4, l0.getNOut()); - assertEquals(new WeightInitXavier(), l0.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l0.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l0)); assertEquals(new Adam(0.005), l0.getIUpdater()); assertArrayEquals(new int[]{3, 3}, l0.getKernelSize()); @@ -251,7 +251,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { SeparableConvolution2D l1 = (SeparableConvolution2D) net.getLayer(1).getLayerConfiguration(); assertEquals(new ActivationReLU(), l1.getActivationFn()); assertEquals(8, l1.getNOut()); - assertEquals(new WeightInitXavier(), l1.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l1.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l1)); assertEquals(new Adam(0.005), l1.getIUpdater()); assertArrayEquals(new int[]{3, 3}, l1.getKernelSize()); @@ -277,7 +277,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { DepthwiseConvolution2D l5 = (DepthwiseConvolution2D) net.getLayer(5).getLayerConfiguration(); assertEquals(new ActivationReLU(), l5.getActivationFn()); assertEquals(16, l5.getNOut()); - assertEquals(new WeightInitXavier(), l5.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l5.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l5)); assertEquals(new Adam(0.005), l5.getIUpdater()); assertArrayEquals(new int[]{3, 3}, l5.getKernelSize()); @@ -298,7 +298,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { ConvolutionLayer l8 = (ConvolutionLayer) net.getLayer(8).getLayerConfiguration(); assertEquals(4, l8.getNOut()); - assertEquals(new WeightInitXavier(), l8.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l8.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l8)); assertEquals(new Adam(0.005), l8.getIUpdater()); assertArrayEquals(new int[]{4, 4}, l8.getKernelSize()); @@ -307,7 +307,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertArrayEquals(new int[]{0, 0}, l8.getPadding()); CnnLossLayer l9 = (CnnLossLayer) net.getLayer(9).getLayerConfiguration(); - assertEquals(new WeightInitXavier(), l9.getWeightInitFn()); + assertEquals(new WeightInitXavier(), l9.getWeightInit()); assertEquals(new L2Regularization(0.0001), TestUtils.getL2Reg(l9)); assertEquals(new Adam(0.005), l9.getIUpdater()); assertEquals(new LossMAE(), l9.getLossFn()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java index e97a1685e..eec8658cc 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/KerasInitilizationTest.java @@ -167,7 +167,7 @@ public class KerasInitilizationTest extends BaseDL4JTest { layerConfig.put(conf.getLAYER_FIELD_KERAS_VERSION(), kerasVersion); DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer(); - assertEquals(dl4jInitializer, layer.getWeightInitFn()); + assertEquals(dl4jInitializer, layer.getWeightInit()); } } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java index 202e06426..053eb1fab 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/advanced/activation/KerasPReLUTest.java @@ -79,7 +79,7 @@ public class KerasPReLUTest extends BaseDL4JTest { PReLULayer layer = kerasPReLU.getPReLULayer(); assertArrayEquals(layer.getInputShape(), new long[] {3, 5, 4}); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(layerName, layer.getLayerName()); } diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java index f5e25ea9f..10330113c 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution1DTest.java @@ -100,7 +100,7 @@ public class KerasAtrousConvolution1DTest extends BaseDL4JTest { Convolution1DLayer layer = new KerasAtrousConvolution1D(layerConfig).getAtrousConvolution1D(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java index f2eebb8f2..7f1d65b3b 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasAtrousConvolution2DTest.java @@ -114,7 +114,7 @@ public class KerasAtrousConvolution2DTest extends BaseDL4JTest { ConvolutionLayer layer = new KerasAtrousConvolution2D(layerConfig).getAtrousConvolution2D(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java index 994d3affe..b8629573f 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution1DTest.java @@ -122,7 +122,7 @@ public class KerasConvolution1DTest extends BaseDL4JTest { Convolution1DLayer layer = new KerasConvolution1D(layerConfig).getConvolution1DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java index b92ab0432..4ba12c10f 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution2DTest.java @@ -123,7 +123,7 @@ public class KerasConvolution2DTest extends BaseDL4JTest { ConvolutionLayer layer = new KerasConvolution2D(layerConfig).getConvolution2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java index c36b0351d..f52939947 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasConvolution3DTest.java @@ -119,7 +119,7 @@ public class KerasConvolution3DTest extends BaseDL4JTest { ConvolutionLayer layer = new KerasConvolution3D(layerConfig).getConvolution3DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java index c0db1c47b..9fecab86c 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDeconvolution2DTest.java @@ -123,7 +123,7 @@ public class KerasDeconvolution2DTest extends BaseDL4JTest { Deconvolution2D layer = new KerasDeconvolution2D(layerConfig).getDeconvolution2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java index 4dc4856c0..eef103f98 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasDepthwiseConvolution2DTest.java @@ -128,7 +128,7 @@ public class KerasDepthwiseConvolution2DTest extends BaseDL4JTest { DepthwiseConvolution2D layer = kerasLayer.getDepthwiseConvolution2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java index 54f50a478..9745ff5ed 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/convolution/KerasSeparableConvolution2DTest.java @@ -130,7 +130,7 @@ public class KerasSeparableConvolution2DTest extends BaseDL4JTest { SeparableConvolution2D layer = new KerasSeparableConvolution2D(layerConfig).getSeparableConvolution2DLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(DEPTH_MULTIPLIER, layer.getDepthMultiplier()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java index c9c70e5ff..637ce5915 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasDenseTest.java @@ -89,7 +89,7 @@ public class KerasDenseTest extends BaseDL4JTest { DenseLayer layer = new KerasDense(layerConfig, false).getDenseLayer(); assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java index 7ce6bf0b3..1bfc3a4ce 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasLSTMTest.java @@ -38,7 +38,6 @@ import org.deeplearning4j.nn.weights.WeightInitXavier; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -131,7 +130,7 @@ public class KerasLSTMTest extends BaseDL4JTest { } assertEquals(ACTIVATION_DL4J, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java index c8e8287fb..1b143a706 100644 --- a/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java +++ b/cavis-dnn/cavis-dnn-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/layers/recurrent/KerasSimpleRnnTest.java @@ -101,7 +101,7 @@ public class KerasSimpleRnnTest extends BaseDL4JTest { (SimpleRnn) ((LastTimeStep) new KerasSimpleRnn(layerConfig).getSimpleRnnLayer()).getUnderlying(); assertEquals(ACTIVATION, layer.getActivationFn().toString()); assertEquals(LAYER_NAME, layer.getLayerName()); - assertEquals(INIT_DL4J, layer.getWeightInitFn()); + assertEquals(INIT_DL4J, layer.getWeightInit()); assertEquals(L1_REGULARIZATION, KerasTestUtils.getL1(layer), 0.0); assertEquals(L2_REGULARIZATION, KerasTestUtils.getL2(layer), 0.0); assertEquals(new Dropout(DROPOUT_DL4J), layer.getIDropout()); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java index 02ae2d45f..3c679267c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/net/brutex/ai/dnn/api/INeuralNetworkConfiguration.java @@ -28,6 +28,7 @@ import org.deeplearning4j.nn.conf.NeuralNetConfiguration; public interface INeuralNetworkConfiguration extends Serializable, Cloneable { INeuralNetworkConfiguration clone(); + void init(); /** @@ -35,28 +36,4 @@ public interface INeuralNetworkConfiguration extends Serializable, Cloneable { * @return */ IModel getNet(); -} -/** - /** - * Provides a flat list of all embedded layer configurations, this - * can only be called after the layer is initialized or {@link #getLayerConfigurations()} is - * called. - * - * @return unstacked layer configurations - - List getLayerConfigurations(); - - - /** - * This uncollables any stacked layer configurations within building blocks like - * @link BuildingBlockLayer} - - void calculateInnerLayerConfigurations(); - - /** - * An implementation should provide a method to validate the network - * @return true if no errors found; false otherwise - - boolean isValid(); -} -**/ \ No newline at end of file + } \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java index e5e94ef3c..dac126dd7 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/ComputationGraphConfiguration.java @@ -259,7 +259,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { */ private static void handleLegacyWeightInitFromJson(String json, LayerConfiguration layer, ObjectMapper mapper, JsonNode vertices) { if (layer instanceof BaseLayerConfiguration - && ((BaseLayerConfiguration) layer).getWeightInitFn() == null) { + && ((BaseLayerConfiguration) layer).getWeightInit() == null) { String layerName = layer.getLayerName(); try { @@ -291,7 +291,7 @@ public class ComputationGraphConfiguration implements Serializable, Cloneable { if (weightInit != null) { final IWeightInit wi = WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist); - ((BaseLayerConfiguration) layer).setWeightInitFn(wi); + ((BaseLayerConfiguration) layer).setWeightInit(wi); } } catch (IOException e) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java index a11c21adc..4f9a9bb1f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetBaseBuilderConfiguration.java @@ -74,343 +74,274 @@ import org.nd4j.linalg.learning.regularization.WeightDecay; * and their hyperparameters. Hyperparameters are variables that determine how a neural network * learns. They include how many times to update the weights of the model, how to initialize those * weights, which activation function to attach to the nodes, which optimization algorithm to use, - * and how fast the model should learn. This is what one configuration would look like: - *

- * - * NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
- * .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)
- * .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
- * .updater(new Sgd(0.05)) //... other hyperparameters
- * .backprop(true)
- * .build();

- * - * With Deeplearning4j, you add a layer - * by calling layer on the NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of + * and how fast the model should learn. This is what one configuration would look like:
+ *
+ * NeuralNetConfiguration conf = NeuralNetConfiguration.builder()
+ * .weightInit(WeightInit.XAVIER) .activation(Activation.RELU)
+ * .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
+ * .updater(new Sgd(0.05)) //... other hyperparameters
+ * .backprop(true)
+ * .build();
+ *
+ * With Deeplearning4j, you add a layer by calling layer on the + * NeuralNetConfiguration.NeuralNetConfigurationBuilder(), specifying its place in the order of * layers (the zero-indexed layer below is the input layer), the number of input and output nodes, - * nIn and nOut, as well as the type: DenseLayer.

- * - * .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
- * .build())

- * - * Once you've configured your net, you train the - * model with model.fit. + * nIn and nOut, as well as the type: DenseLayer.
+ *
+ * .layer(0, new DenseLayer.Builder().nIn(784).nOut(250)
+ * .build())
+ *
+ * Once you've configured your net, you train the model with model.fit. */ - @Data @Slf4j @EqualsAndHashCode(exclude = {"iterationCount", "epochCount"}) @JsonIgnoreProperties(ignoreUnknown = true) -//The inner builder, that we can then extend ... -@SuperBuilder //TODO fix access +// The inner builder, that we can then extend ... +@SuperBuilder // TODO fix access public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetworkConfiguration { private static final int DEFAULT_TBPTT_LENGTH = 20; - /** * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm + * regularization, etc). These constraints are applied at each iteration, after the parameters + * have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * * @param constraints Constraints to apply to all weight parameters of all layers */ - @lombok.Builder.Default - protected final List contrainWeights = new ArrayList<>(); - - - + @lombok.Builder.Default protected final List contrainWeights = new ArrayList<>(); /** * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm + * regularization, etc). These constraints are applied at each iteration, after the parameters + * have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * * @param constraints Constraints to apply to all bias parameters of all layers */ - @lombok.Builder.Default - protected final List biasConstraints = new ArrayList<>(); + @lombok.Builder.Default protected final List biasConstraints = new ArrayList<>(); /** * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm + * regularization, etc). These constraints are applied at each iteration, after the parameters + * have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * * @param constraints Constraints to apply to all parameters of all layers */ @lombok.Builder.Default protected final List allParamContraints = new ArrayList<>(); /** - * This is a basic concept, a neural network is made of layers, but also can use - * another neural network as a building block. When the configuration is initialized, those - * building blocks will be flattened into a single list of layers. - * Internal ordered list of layers and inner neural networks. If the object is a NeuralNetConfiguration, - * each configuration must contain at least one layer. + * This is a basic concept, a neural network is made of layers, but also can use another neural + * network as a building block. When the configuration is initialized, those building blocks will + * be flattened into a single list of layers. Internal ordered list of layers and inner neural + * networks. If the object is a NeuralNetConfiguration, each configuration must contain at least + * one layer. */ @Getter @lombok.Builder.Default protected final List innerConfigurations = new ArrayList<>(); - @Getter - @Setter - @NonNull - @lombok.Builder.Default - @Deprecated + + @Getter @Setter @NonNull @lombok.Builder.Default @Deprecated protected WorkspaceMode trainingWorkspaceMode = WorkspaceMode.ENABLED; - @Getter - @Setter - @NonNull - @lombok.Builder.Default - @Deprecated + + @Getter @Setter @NonNull @lombok.Builder.Default @Deprecated protected WorkspaceMode inferenceWorkspaceMode = WorkspaceMode.ENABLED; /** * The type of backprop. Default setting is used for most networks (MLP, CNN etc), but optionally * truncated BPTT can be used for training recurrent neural networks. If using TruncatedBPTT make * sure you set both tBPTTForwardLength() and tBPTTBackwardLength() */ - @Getter - @Setter - @NonNull - @lombok.Builder.Default + @Getter @Setter @NonNull @lombok.Builder.Default protected BackpropType backpropType = BackpropType.Standard; - @Getter - @lombok.Builder.Default + + @Getter @lombok.Builder.Default protected Map inputPreProcessors = new HashMap<>(); /** * When doing truncated BPTT: how many steps of forward pass should we do before doing (truncated) - * backprop?
Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
Typically - * tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, but may be larger - * than it in some circumstances (but never smaller)
Ideally your training data time series - * length should be divisible by this This is the k1 parameter on pg23 of + * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
+ * Typically tBPTTForwardLength parameter is same as the tBPTTBackwardLength parameter, but may be + * larger than it in some circumstances (but never smaller)
+ * Ideally your training data time series length should be divisible by this This is the k1 + * parameter on pg23 of
http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf * * @param forwardLength Forward length > 0, >= backwardLength */ - @Getter - @Setter - @lombok.Builder.Default - protected int tbpttFwdLength = 20; + @Getter @Setter @lombok.Builder.Default protected int tbpttFwdLength = 20; /** - * When doing truncated BPTT: how many steps of backward should we do?
Only applicable when - * doing backpropType(BackpropType.TruncatedBPTT)
This is the k2 parameter on pg23 of + * Only applicable when doing backpropType(BackpropType.TruncatedBPTT)
+ * This is the k2 parameter on pg23 of
http://www.cs.utoronto.ca/~ilya/pubs/ilya_sutskever_phd_thesis.pdf * * @param backwardLength <= forwardLength */ - @Getter - @Setter - @lombok.Builder.Default - protected int tbpttBackLength = 20; - //Counter for the number of parameter updates so far - // This is important for learning rate schedules, for example, and is stored here to ensure it is persisted + @Getter @Setter @lombok.Builder.Default protected int tbpttBackLength = 20; + // Counter for the number of parameter updates so far + // This is important for learning rate schedules, for example, and is stored here to ensure it is + // persisted // for Spark and model serialization - @Getter - @Setter - @lombok.Builder.Default - protected int iterationCount = 0; - //Counter for the number of epochs completed so far. Used for per-epoch schedules - @Getter - @Setter - @lombok.Builder.Default - protected int epochCount = 0; - @lombok.Builder.Default - protected double dampingFactor = 100; - //gradient keys used for ensuring order when getting and setting the gradient - //@lombok.Builder.Default - //protected List variables = new ArrayList<>(); - @Getter - @Setter - @lombok.Builder.Default - private boolean miniBatch = false; - /** - * A seed for this network, will be random if not specified. - */ - @Getter - @Setter - @lombok.Builder.Default - private long seed = new Random().nextLong(); + @Getter @Setter @lombok.Builder.Default protected int iterationCount = 0; + // Counter for the number of epochs completed so far. Used for per-epoch schedules + @Getter @Setter @lombok.Builder.Default protected int epochCount = 0; + @lombok.Builder.Default protected double dampingFactor = 100; + // gradient keys used for ensuring order when getting and setting the gradient + // @lombok.Builder.Default + // protected List variables = new ArrayList<>(); + @Getter @Setter @lombok.Builder.Default private boolean miniBatch = false; + /** A seed for this network, will be random if not specified. */ + @Getter @Setter @lombok.Builder.Default private long seed = new Random().nextLong(); /** * The default {@link CacheMode} for this configuration. Will be set to "NONE" if not specified * otherwise. This method defines how/if preOutput cache is handled: NONE: cache disabled (default * value) HOST: Host memory will be used DEVICE: GPU memory will be used (on CPU backends effect * will be the same as for HOST) - *

- * Valid values are
CacheMode.NONE,
CacheMode.HOST or
CacheMode.DEVICE
+ * + *

Valid values are
+ * CacheMode.NONE,
+ * CacheMode.HOST or
+ * CacheMode.DEVICE
* * @param cacheMode */ - @NonNull - @Getter - @Setter - @lombok.Builder.Default - private CacheMode cacheMode = CacheMode.NONE; + @NonNull @Getter @Setter @lombok.Builder.Default private CacheMode cacheMode = CacheMode.NONE; /** * The name for this configuration. Defaults to "Anonymous INeuralNetworkConfiguration" if it is * not specified. */ - @lombok.Builder.Default - @Getter - private String name = "Anonymous INeuralNetworkConfiguration"; - /** - * The {@link InputType} of the data for this network configuration - */ - @Getter - @Setter - private InputType inputType; + @lombok.Builder.Default @Getter private String name = "Anonymous INeuralNetworkConfiguration"; + /** The {@link InputType} of the data for this network configuration */ + @Getter @Setter private InputType inputType; /** * Set the DataType for the network parameters and activations for all layers in the network. * Default: Float * * @param dataType Datatype to use for parameters and activations */ - @Getter - @Setter - @lombok.Builder.Default - @NonNull - private DataType dataType = DataType.FLOAT; + @Getter @Setter @lombok.Builder.Default @NonNull private DataType dataType = DataType.FLOAT; /** * Whether to override the nIn configuration forcibly upon construction. Default value is true. * * @return builder pattern */ - @Getter - @Setter - @lombok.Builder.Default - private boolean overrideNinUponBuild = true; + @Getter @Setter @lombok.Builder.Default private boolean overrideNinUponBuild = true; /** * Enabled by default. If enabled, the output layer configuration will be validated, to throw an - * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
If - * disabled (false) no output layer validation will be performed.
Disabling this validation is - * not recommended, as the configurations that fail validation usually will not be able to learn - * correctly. However, the option to disable this validation is provided for advanced users when - * creating non-standard architectures. + * exception on likely invalid outputs - such as softmax + nOut=1, or LossMCXENT + Tanh.
+ * If disabled (false) no output layer validation will be performed.
+ * Disabling this validation is not recommended, as the configurations that fail validation + * usually will not be able to learn correctly. However, the option to disable this validation is + * provided for advanced users when creating non-standard architectures. * * @param validate If true: validate output layer configuration. False: don't validate */ - @Getter - @Setter - @lombok.Builder.Default - private boolean validateOutputLayerConfig = true; + @Getter @Setter @lombok.Builder.Default private boolean validateOutputLayerConfig = true; /** * Enabled by default. If enabled, an exception will be throw when using the (invalid) combination * of truncated backpropagation through time (TBPTT) with either a GlobalPoolingLayer or - * LastTimeStepLayer.
It is possible to disable this validation to allow what is almost - * certainly an invalid configuration to be used, however this is not recommended. + * LastTimeStepLayer.
+ * It is possible to disable this validation to allow what is almost certainly an invalid + * configuration to be used, however this is not recommended. * * @param validate Whether TBPTT validation should be performed */ - @Getter - @Setter - @lombok.Builder.Default - private boolean validateTbpttConfig = true; + @Getter @Setter @lombok.Builder.Default private boolean validateTbpttConfig = true; /** * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or - * {@link org.nd4j.linalg.learning.config.Nesterovs}
Note: values set by this method will be - * applied to all applicable layers in the network, unless a different value is explicitly set on - * a given layer. In other words: values set via this method are used as the default value, and - * can be overridden on a per-layer basis. + * {@link org.nd4j.linalg.learning.config.Nesterovs}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * * @param updater Updater to use */ - @Getter - @Setter - private IUpdater updater; + @Getter @Setter private IUpdater updater; /** * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping - * etc. See {@link GradientNormalization} for details
Note: values set by this method will be - * applied to all applicable layers in the network, unless a different value is explicitly set on - * a given layer. In other words: values set via this method are used as the default value, and - * can be overridden on a per-layer basis. + * etc. See {@link GradientNormalization} for details
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * * @param gradientNormalization Type of normalization to use. Defaults to None. * @see GradientNormalization */ - @Getter - @Setter - @NonNull - @lombok.Builder.Default + @Getter @Setter @NonNull @lombok.Builder.Default private GradientNormalization gradientNormalization = GradientNormalization.None; /** * Threshold for gradient normalization, only used for GradientNormalization.ClipL2PerLayer, * GradientNormalization.ClipL2PerParamType, and - * GradientNormalization.ClipElementWiseAbsoluteValue
Not used otherwise.
L2 threshold for - * first two types of clipping, or absolute value threshold for last type of clipping.
Note: - * values set by this method will be applied to all applicable layers in the network, unless a - * different value is explicitly set on a given layer. In other words: values set via this method - * are used as the default value, and can be overridden on a per-layer basis. + * GradientNormalization.ClipElementWiseAbsoluteValue
+ * Not used otherwise.
+ * L2 threshold for first two types of clipping, or absolute value threshold for last type of + * clipping.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. */ - @Getter - @Setter - private double gradientNormalizationThreshold; + @Getter @Setter private double gradientNormalizationThreshold; /** - * Activation function / neuron non-linearity
Note: values set by this method will be applied - * to all applicable layers in the network, unless a different value is explicitly set on a given - * layer. In other words: values set via this method are used as the default value, and can be - * overridden on a per-layer basis. + * Activation function / neuron non-linearity
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. */ - @Getter - @Setter - private IActivation activation; - //whether to constrain the gradient to unit norm or not - @Getter - @Setter - private StepFunction stepFunction; - @Getter - @Setter - @lombok.Builder.Default - private OptimizationAlgorithm optimizationAlgo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; - @Getter - @Setter - @lombok.Builder.Default - private int maxNumLineSearchIterations = 5; + @Getter @Setter private IActivation activation; + // whether to constrain the gradient to unit norm or not + @Getter @Setter private StepFunction stepFunction; + + @Getter @Setter @lombok.Builder.Default + private OptimizationAlgorithm optimizationAlgo = + OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; + + @Getter @Setter @lombok.Builder.Default private int maxNumLineSearchIterations = 5; /** - * Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay}
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
+ * Set the regularization for the parameters (excluding biases) - for example {@link WeightDecay} + *
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis.
* - * @param regularization Regularization to apply for the network parameters/weights (excluding biases) + * @param regularization Regularization to apply for the network parameters/weights (excluding + * biases) */ - @Getter - @lombok.Builder.Default - private List regularization = new ArrayList<>(); + @Getter @lombok.Builder.Default private List regularization = new ArrayList<>(); /** * Set the regularization for the biases only - for example {@link WeightDecay}
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis.
* * @param regularizationBias Regularization to apply for the network biases only */ - @Getter - @lombok.Builder.Default + @Getter @lombok.Builder.Default private List regularizationBias = new ArrayList<>(); - @Getter - @Setter - @lombok.Builder.Default - private IUpdater iUpdater = new Sgd(); + + @Getter @Setter @lombok.Builder.Default private IUpdater iUpdater = new Sgd(); /** * Gradient updater configuration, for the biases only. If not set, biases will use the updater as * set by {@link #setIUpdater(IUpdater)}
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * * @param updater Updater to use for bias parameters */ - @Getter - @Setter - @lombok.Builder.Default - private IUpdater biasUpdater = null; - @Getter - @Setter - @lombok.Builder.Default + @Getter @Setter @lombok.Builder.Default private IUpdater biasUpdater = null; + + @Getter @Setter @lombok.Builder.Default private IActivation activationFn = new ActivationSigmoid(); /** * Weight initialization scheme to use, for initial weight values Note: values set by this method @@ -418,96 +349,83 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor * set on a given layer. In other words: values set via this method are used as the default value, * and can be overridden on a per-layer basis. */ - @Getter - @Setter - @lombok.Builder.Default - private IWeightInit weightInitFn = new WeightInitXavier(); + @Getter @Setter @lombok.Builder.Default private IWeightInit weightInit = new WeightInitXavier(); /** - * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. - * See {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. See + * {@link ConvolutionMode} for details. Defaults to ConvolutionMode.TRUNCATE
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. + * * @param convolutionMode Convolution mode to use */ - @Getter - @Setter - @lombok.Builder.Default + @Getter @Setter @lombok.Builder.Default private ConvolutionMode convolutionMode = ConvolutionMode.Truncate; /** - * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage of cuDNN. - * See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but "NO_WORKSPACE" uses less memory. - *
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Sets the cuDNN algo mode for convolutional layers, which impacts performance and memory usage + * of cuDNN. See {@link ConvolutionLayer.AlgoMode} for details. Defaults to "PREFER_FASTEST", but + * "NO_WORKSPACE" uses less memory.
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. + * * @param cudnnAlgoMode cuDNN algo mode to use */ - @Getter - @Setter - @lombok.Builder.Default + @Getter @Setter @lombok.Builder.Default private ConvolutionLayer.AlgoMode cudnnAlgoMode = ConvolutionLayer.AlgoMode.PREFER_FASTEST; - @Getter - @Setter - @lombok.Builder.Default - private boolean minimize = true; + + @Getter @Setter @lombok.Builder.Default private boolean minimize = true; /** * Set the dropout for all layers in this network
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. - * * Dropout probability. This is the probability of retaining each input activation value for a layer. - * * dropOut(x) will keep an input activation with probability x, and set to 0 with probability 1-x.
- * * dropOut(0.0) is a special value / special case - when set to 0.0., dropout is disabled (not applied). Note - * * that a dropout value of 1.0 is functionally equivalent to no dropout: i.e., 100% probability of retaining - * * each input activation.
- * *

- * * Note 1: Dropout is applied at training time only - and is automatically not applied at test time - * * (for evaluation, etc)
- * * Note 2: This sets the probability per-layer. Care should be taken when setting lower values for - * * complex networks (too much information may be lost with aggressive (very low) dropout values).
- * * Note 3: Frequently, dropout is not applied to (or, has higher retain probability for) input (first layer) - * * layers. Dropout is also often not applied to output layers. This needs to be handled MANUALLY by the user - * * - set .dropout(0) on those layers when using global dropout setting.
- * * Note 4: Implementation detail (most users can ignore): DL4J uses inverted dropout, as described here: - * * http://cs231n.github.io/neural-networks-2/ - * *

- * *
- * * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * * value, and can be overridden on a per-layer basis. - * * - * * @param inputRetainProbability Dropout probability (probability of retaining each input activation value for a layer) - * * @see #dropOut(IDropout) + * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * Dropout + * probability. This is the probability of retaining each input activation value for a + * layer. * dropOut(x) will keep an input activation with probability x, and set to 0 with + * probability 1-x.
+ * * dropOut(0.0) is a special value / special case - when set to 0.0., dropout is disabled (not + * applied). Note * that a dropout value of 1.0 is functionally equivalent to no dropout: i.e., + * 100% probability of retaining * each input activation.
+ * * * + *

* Note 1: Dropout is applied at training time only - and is automatically not applied at + * test time * (for evaluation, etc)
+ * * Note 2: This sets the probability per-layer. Care should be taken when setting lower values + * for * complex networks (too much information may be lost with aggressive (very low) dropout + * values).
+ * * Note 3: Frequently, dropout is not applied to (or, has higher retain probability for) input + * (first layer) * layers. Dropout is also often not applied to output layers. This needs to be + * handled MANUALLY by the user * - set .dropout(0) on those layers when using global dropout + * setting.
+ * * Note 4: Implementation detail (most users can ignore): DL4J uses inverted dropout, as + * described here: * http://cs231n.github.io/neural-networks-2/ + * * *
+ * * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different * value is explicitly set on a given layer. In other words: values set via + * this method are used as the default * value, and can be overridden on a per-layer basis. * + * * @param inputRetainProbability Dropout probability (probability of retaining each input + * activation value for a layer) * @see #dropOut(IDropout) * - * @param dropout Dropout, such as {@link Dropout}, {@link org.deeplearning4j.nn.conf.dropout.GaussianDropout}, - * {@link org.deeplearning4j.nn.conf.dropout.GaussianNoise} etc + * @param dropout Dropout, such as {@link Dropout}, {@link + * org.deeplearning4j.nn.conf.dropout.GaussianDropout}, {@link + * org.deeplearning4j.nn.conf.dropout.GaussianNoise} etc * @return */ - @Getter - @Setter - private IDropout idropOut; + @Getter @Setter private IDropout idropOut; /** * Set the weight noise (such as {@link org.deeplearning4j.nn.conf.weightnoise.DropConnect} and * {@link org.deeplearning4j.nn.conf.weightnoise.WeightNoise}) for the layers in this network.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * * @param weightNoise Weight noise instance to use */ - @Getter - @Setter - private IWeightNoise weightNoise; - @Getter - @Setter - @lombok.Builder.Default - private double biasInit = 0.0; - @Getter - @Setter - @lombok.Builder.Default - private double gainInit = 1.0; + @Getter @Setter private IWeightNoise weightNoise; + + @Getter @Setter @lombok.Builder.Default private double biasInit = 0.0; + @Getter @Setter @lombok.Builder.Default private double gainInit = 1.0; /** * Handle {@link WeightInit} and {@link Distribution} from legacy configs in Json format. Copied @@ -515,10 +433,10 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor * * @return True if all is well and layer iteration shall continue. False else-wise. */ - private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l, - ObjectMapper mapper, - JsonNode confs, int layerCount) { - if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInitFn() == null) { + private static boolean handleLegacyWeightInitFromJson( + String json, LayerConfiguration l, ObjectMapper mapper, JsonNode confs, int layerCount) { + if ((l instanceof BaseLayerConfiguration) + && ((BaseLayerConfiguration) l).getWeightInit() == null) { try { JsonNode jsonNode = mapper.readTree(json); if (confs == null) { @@ -528,7 +446,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor ArrayNode layerConfs = (ArrayNode) confs; JsonNode outputLayerNNCNode = layerConfs.get(layerCount); if (outputLayerNNCNode == null) { - return false; //Should never happen... + return false; // Should never happen... } JsonNode layerWrapperNode = outputLayerNNCNode.get("layer"); @@ -537,8 +455,8 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } JsonNode layerNode = layerWrapperNode.elements().next(); - JsonNode weightInit = layerNode.get( - "weightInit"); //Should only have 1 element: "dense", "output", etc + JsonNode weightInit = + layerNode.get("weightInit"); // Should only have 1 element: "dense", "output", etc JsonNode distribution = layerNode.get("dist"); Distribution dist = null; @@ -547,9 +465,9 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } if (weightInit != null) { - final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) - .getWeightInitFunction(dist); - ((BaseLayerConfiguration) l).setWeightInitFn(wi); + final IWeightInit wi = + WeightInit.valueOf(weightInit.asText()).getWeightInitFunction(dist); + ((BaseLayerConfiguration) l).setWeightInit(wi); } } @@ -560,7 +478,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } } return true; - } /** @@ -582,10 +499,9 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } public static NeuralNetBaseBuilderConfiguration fromYaml(String input) { - throw new RuntimeException("Needs fixing - not supported."); //TODO + throw new RuntimeException("Needs fixing - not supported."); // TODO } - /** * @return JSON representation of NN configuration */ @@ -606,8 +522,10 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor public String toJson() { ObjectMapper mapper = NeuralNetBaseBuilderConfiguration.mapper(); synchronized (mapper) { - //JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields occasionally - //when writeValueAsString is used by multiple threads. This results in invalid JSON. See issue #3243 + // JSON mappers are supposed to be thread safe: however, in practice they seem to miss fields + // occasionally + // when writeValueAsString is used by multiple threads. This results in invalid JSON. See + // issue #3243 try { return mapper.writeValueAsString(this); } catch (com.fasterxml.jackson.core.JsonProcessingException e) { @@ -616,18 +534,52 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } } - public abstract static class NeuralNetBaseBuilderConfigurationBuilder - > { + @Override + public NeuralNetBaseBuilderConfiguration clone() { + NeuralNetBaseBuilderConfiguration clone; + try { + clone = (NeuralNetBaseBuilderConfiguration) super.clone(); + } catch (CloneNotSupportedException ex) { + throw new RuntimeException(ex); + } + if (clone.stepFunction != null) { + clone.stepFunction = clone.stepFunction.clone(); + } + /** if (clone.variables != null) { clone.variables = new ArrayList<>(clone.variables); } */ + clone.getInnerConfigurations().addAll(innerConfigurations); - List innerConfigurations$value = new ArrayList<>(); //initialize with an empty list + if (clone.getInputPreProcessors() != null) { + Map map = new HashMap<>(); + for (Map.Entry entry : clone.getInputPreProcessors().entrySet()) { + map.put(entry.getKey(), entry.getValue().clone()); + } + clone.getInputPreProcessors().clear(); + clone.getInputPreProcessors().putAll(map); + } + + clone.setInferenceWorkspaceMode(this.inferenceWorkspaceMode); + clone.setTrainingWorkspaceMode(this.trainingWorkspaceMode); + clone.setCacheMode(this.cacheMode); + clone.setValidateOutputLayerConfig(this.validateOutputLayerConfig); + clone.setDataType(this.dataType); + + return clone; + } + + public abstract static class NeuralNetBaseBuilderConfigurationBuilder< + C extends NeuralNetBaseBuilderConfiguration, + B extends NeuralNetBaseBuilderConfiguration.NeuralNetBaseBuilderConfigurationBuilder> { + + List innerConfigurations$value = new ArrayList<>(); // initialize with an empty list /** * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm + * regularization, etc). These constraints are applied at each iteration, after the parameters + * have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis. * * @param constraints Constraints to apply to all weight parameters of all layers */ @@ -638,32 +590,35 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } /** - * For the (perhaps partially constructed) network configuration, return a list of activation sizes for each - * layer in the network.
- * Note: To use this method, the network input type must have been set using {@link #setInputType(InputType)} first + * For the (perhaps partially constructed) network configuration, return a list of activation + * sizes for each layer in the network.
+ * Note: To use this method, the network input type must have been set using {@link + * #setInputType(InputType)} first + * * @return A list of activation types for the network, indexed by layer number */ - public List getLayerActivationTypes(){ - Preconditions.checkState(inputType != null, "Can only calculate activation types if input type has" + - "been set. Use setInputType(InputType)"); - - - throw new RuntimeException("Error calculating layer activation types: error instantiating MultiLayerConfiguration"); + public List getLayerActivationTypes() { + Preconditions.checkState( + inputType != null, + "Can only calculate activation types if input type has" + + "been set. Use setInputType(InputType)"); + throw new RuntimeException( + "Error calculating layer activation types: error instantiating MultiLayerConfiguration"); } - /** * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm + * regularization, etc). These constraints are applied at each iteration, after the parameters + * have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis. * * @param constraints Constraints to apply to all parameters of all layers */ - public B constrainAllParameters(LayerConstraint... constraints){ + public B constrainAllParameters(LayerConstraint... constraints) { allParamContraints$value = Arrays.asList(constraints); allParamContraints$set = true; return (B) this; @@ -671,11 +626,12 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor /** * Set constraints to be applied to all layers. Default: no constraints.
- * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm regularization, - * etc). These constraints are applied at each iteration, after the parameters have been updated.
- * Note: values set by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used as the default - * value, and can be overridden on a per-layer basis. + * Constraints can be used to enforce certain conditions (non-negativity of parameters, max-norm + * regularization, etc). These constraints are applied at each iteration, after the parameters + * have been updated.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis. * * @param constraints Constraints to apply to all bias parameters of all layers */ @@ -692,14 +648,12 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor * @param processor what to use to preProcess the data. * @return builder pattern */ - public B inputPreProcessor(Integer layer, - InputPreProcessor processor) { + public B inputPreProcessor(Integer layer, InputPreProcessor processor) { inputPreProcessors$value.put(layer, processor); inputPreProcessors$set = true; return (B) this; } - /** * Set layer at index * @@ -725,14 +679,12 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor return (B) this; } - //TODO this is a dirty workaround + // TODO this is a dirty workaround public boolean isOverrideNinUponBuild() { return isOverrideNinUponBuild(); } - /** - * Specify additional layer configurations - */ + /** Specify additional layer configurations */ @Deprecated public B layersFromArray(@NonNull LayerConfiguration[] arrLayers) { innerConfigurations$value.addAll(List.of(arrLayers)); @@ -740,9 +692,7 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor return (B) this; } - /** - * Specify additional layer configurations - */ + /** Specify additional layer configurations */ @Deprecated public B layersFromList(@NonNull List listLayers) { innerConfigurations$value.addAll(listLayers); @@ -750,15 +700,14 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor return (B) this; } - /** - * L1 regularization coefficient for the weights (excluding biases).
Note: values set by - * this method will be applied to all applicable layers in the network, unless a different value - * is explicitly set on a given layer. In other words: values set via this method are used as - * the default value, and can be overridden on a per-layer basis. + * L1 regularization coefficient for the weights (excluding biases).
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis. */ public B l1(double l1) { - //Check if existing L1 exists; if so, replace it + // Check if existing L1 exists; if so, replace it NetworkUtils.removeInstances(regularization$value, L1Regularization.class); if (l1 > 0.0) { regularization$value.add(new L1Regularization(l1)); @@ -770,21 +719,23 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor /** * L2 regularization coefficient for the weights (excluding biases).
* Note: Generally, {@link WeightDecay} (set via {@link #weightDecay(double)} should be - * preferred to - * L2 regularization. See {@link WeightDecay} javadoc for further details.
Note: values set - * by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used - * as the default value, and can be overridden on a per-layer basis.
Note: L2 regularization - * and weight decay usually should not be used together; if any weight decay (or L2) has been - * added for the biases, these will be removed first. + * preferred to L2 regularization. See {@link WeightDecay} javadoc for further details.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis.
+ * Note: L2 regularization and weight decay usually should not be used together; if any weight + * decay (or L2) has been added for the biases, these will be removed first. * * @see #weightDecay(double, boolean) */ public B l2(double l2) { - //Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make sense to use both + // Check if existing L2 exists; if so, replace it. Also remove weight decay - it doesn't make + // sense to use both NetworkUtils.removeInstances(regularization$value, L2Regularization.class); if (l2 > 0.0) { - NetworkUtils.removeInstancesWithWarning(regularization$value, WeightDecay.class, + NetworkUtils.removeInstancesWithWarning( + regularization$value, + WeightDecay.class, "WeightDecay regularization removed: incompatible with added L2 regularization"); regularization$value.add(new L2Regularization(l2)); } @@ -793,10 +744,10 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } /** - * L1 regularization coefficient for the bias.
Note: values set by this method will be - * applied to all applicable layers in the network, unless a different value is explicitly set - * on a given layer. In other words: values set via this method are used as the default value, - * and can be overridden on a per-layer basis. + * L1 regularization coefficient for the bias.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis. */ public B l1Bias(double l1Bias) { NetworkUtils.removeInstances(regularizationBias$value, L1Regularization.class); @@ -809,21 +760,23 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor /** * L2 regularization coefficient for the bias.
- * Note: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double, boolean)} - * should be preferred to - * L2 regularization. See {@link WeightDecay} javadoc for further details.
Note: values set - * by this method will be applied to all applicable layers in the network, unless a different - * value is explicitly set on a given layer. In other words: values set via this method are used - * as the default value, and can be overridden on a per-layer basis.
Note: L2 regularization - * and weight decay usually should not be used together; if any weight decay (or L2) has been - * added for the biases, these will be removed first. + * Note: Generally, {@link WeightDecay} (set via {@link #weightDecayBias(double, + * boolean)} should be preferred to L2 regularization. See {@link WeightDecay} javadoc for + * further details.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis.
+ * Note: L2 regularization and weight decay usually should not be used together; if any weight + * decay (or L2) has been added for the biases, these will be removed first. * * @see #weightDecayBias(double, boolean) */ public B l2Bias(double l2Bias) { NetworkUtils.removeInstances(regularizationBias$value, L2Regularization.class); if (l2Bias > 0.0) { - NetworkUtils.removeInstancesWithWarning(regularizationBias$value, WeightDecay.class, + NetworkUtils.removeInstancesWithWarning( + regularizationBias$value, + WeightDecay.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization"); regularizationBias$value.add(new L2Regularization(l2Bias)); } @@ -831,12 +784,12 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } /** - * Add weight decay regularization for the network parameters (excluding biases).
This - * applies weight decay with multiplying the learning rate - see {@link WeightDecay} for - * more details.
Note: values set by this method will be applied to all applicable layers in - * the network, unless a different value is explicitly set on a given layer. In other words: - * values set via this method are used as the default value, and can be overridden on a - * per-layer basis.
+ * Add weight decay regularization for the network parameters (excluding biases).
+ * This applies weight decay with multiplying the learning rate - see {@link WeightDecay} + * for more details.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis.
* * @param coefficient Weight decay regularization coefficient * @see #weightDecay(double, boolean) @@ -846,22 +799,25 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } /** - * Add weight decay regularization for the network parameters (excluding biases). See - * {@link WeightDecay} for more details.
Note: values set by this method will be applied to - * all applicable layers in the network, unless a different value is explicitly set on a given - * layer. In other words: values set via this method are used as the default value, and can be - * overridden on a per-layer basis.
+ * Add weight decay regularization for the network parameters (excluding biases). See {@link + * WeightDecay} for more details.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis.
* * @param coefficient Weight decay regularization coefficient - * @param applyLR Whether the learning rate should be multiplied in when performing weight - * decay updates. See {@link WeightDecay} for more details. + * @param applyLR Whether the learning rate should be multiplied in when performing weight decay + * updates. See {@link WeightDecay} for more details. * @see #weightDecay(double, boolean) */ public B weightDecay(double coefficient, boolean applyLR) { - //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both + // Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't + // make sense to use both NetworkUtils.removeInstances(regularization$value, WeightDecay.class); if (coefficient > 0.0) { - NetworkUtils.removeInstancesWithWarning(regularization$value, L2Regularization.class, + NetworkUtils.removeInstancesWithWarning( + regularization$value, + L2Regularization.class, "L2 regularization removed: incompatible with added WeightDecay regularization"); regularization$value.add(new WeightDecay(coefficient, applyLR)); } @@ -871,10 +827,10 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor /** * Weight decay for the biases only - see {@link #weightDecay(double)} for more details. This - * applies weight decay with multiplying the learning rate.
Note: values set by this - * method will be applied to all applicable layers in the network, unless a different value is - * explicitly set on a given layer. In other words: values set via this method are used as the - * default value, and can be overridden on a per-layer basis.
+ * applies weight decay with multiplying the learning rate.
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis.
* * @param coefficient Weight decay regularization coefficient * @see #weightDecayBias(double, boolean) @@ -892,10 +848,13 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor * @param coefficient Weight decay regularization coefficient */ public B weightDecayBias(double coefficient, boolean applyLR) { - //Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't make sense to use both + // Check if existing weight decay if it exists; if so, replace it. Also remove L2 - it doesn't + // make sense to use both NetworkUtils.removeInstances(regularizationBias$value, WeightDecay.class); if (coefficient > 0) { - NetworkUtils.removeInstancesWithWarning(regularizationBias$value, L2Regularization.class, + NetworkUtils.removeInstancesWithWarning( + regularizationBias$value, + L2Regularization.class, "L2 bias regularization removed: incompatible with added WeightDecay regularization"); regularizationBias$value.add(new WeightDecay(coefficient, applyLR)); } @@ -904,25 +863,19 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor } /** - * Activation function / neuron non-linearity
Note: values set by this method will be - * applied to all applicable layers in the network, unless a different value is explicitly set - * on a given layer. In other words: values set via this method are used as the default value, - * and can be overridden on a per-layer basis. + * Activation function / neuron non-linearity
+ * Note: values set by this method will be applied to all applicable layers in the network, + * unless a different value is explicitly set on a given layer. In other words: values set via + * this method are used as the default value, and can be overridden on a per-layer basis. */ @Deprecated public B activation(@NonNull Activation activation) { return (B) activationFn(activation.getActivationFunction()); } - - - @Deprecated - public B weightInit(@NonNull WeightInit wi) { - return (B) weightInitFn(wi.getWeightInitFunction()); - } - /** * legacy code, does nothing + * * @return */ @Deprecated @@ -930,7 +883,6 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor return (B) this; } - /** * Set weight initialization scheme to random sampling via the specified distribution. * Equivalent to: {@code .weightInit(new WeightInitDistribution(distribution))} Note: values set @@ -941,11 +893,26 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor * @param distribution Distribution to use for weight initialization */ public B weightInit(@NonNull Distribution distribution) { - return (B) weightInitFn(new WeightInitDistribution(distribution)); + this.weightInit$value = new WeightInitDistribution(distribution); + this.weightInit$set = true; + return (B) this; + } + + public B weightInit(@NonNull WeightInit weightInit) { + this.weightInit$value = weightInit.getWeightInitFunction(); + this.weightInit$set = true; + return (B) this; + } + + public B weightInit(@NonNull IWeightInit iWeightInit) { + this.weightInit$value = iWeightInit; + this.weightInit$set = true; + return (B) this; } /** * Same as {@link #weightInit(Distribution)}. + * * @param distribution * @return */ @@ -959,61 +926,25 @@ public abstract class NeuralNetBaseBuilderConfiguration implements INeuralNetwor /** * Creates a new {@link Dropout} and sets the dropout in the builder for this configuration + * * @param dropout activationRetainProbability * @return builder */ - public B dropOut( double dropout) { - return (B) idropOut( new Dropout(dropout)); + public B dropOut(double dropout) { + return (B) idropOut(new Dropout(dropout)); } /** * Add multiple inner neural net configurations at once + * * @param confs list of configurations * @return builder */ @Deprecated public B confs(@NonNull List confs) { innerConfigurations$value.addAll(confs); - innerConfigurations$set=true; + innerConfigurations$set = true; return (B) this; } } - - @Override - public NeuralNetBaseBuilderConfiguration clone() { - NeuralNetBaseBuilderConfiguration clone; - try { - clone = (NeuralNetBaseBuilderConfiguration) super.clone(); - } catch(CloneNotSupportedException ex) { - throw new RuntimeException(ex); - } - if (clone.stepFunction != null) { - clone.stepFunction = clone.stepFunction.clone(); - } - /** - if (clone.variables != null) { - clone.variables = new ArrayList<>(clone.variables); - } - **/ - - clone.getInnerConfigurations().addAll(innerConfigurations); - - if (clone.getInputPreProcessors() != null) { - Map map = new HashMap<>(); - for (Map.Entry entry : clone.getInputPreProcessors().entrySet()) { - map.put(entry.getKey(), entry.getValue().clone()); - } - clone.getInputPreProcessors().clear(); - clone.getInputPreProcessors().putAll(map); - } - - clone.setInferenceWorkspaceMode(this.inferenceWorkspaceMode); - clone.setTrainingWorkspaceMode(this.trainingWorkspaceMode); - clone.setCacheMode(this.cacheMode); - clone.setValidateOutputLayerConfig(this.validateOutputLayerConfig); - clone.setDataType(this.dataType); - - return clone; - - } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java index ed5a406b4..fe946e022 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/NeuralNetConfiguration.java @@ -35,15 +35,11 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; -import lombok.Data; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.NonNull; -import lombok.Setter; + +import lombok.*; import lombok.experimental.SuperBuilder; import lombok.extern.jackson.Jacksonized; import lombok.extern.slf4j.Slf4j; -import lombok.val; import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.api.layers.LayerConstraint; @@ -67,9 +63,9 @@ import org.deeplearning4j.nn.conf.memory.LayerMemoryReport; import org.deeplearning4j.nn.conf.memory.MemoryReport; import org.deeplearning4j.nn.conf.memory.NetworkMemoryReport; import org.deeplearning4j.nn.conf.serde.JsonMappers; +import org.deeplearning4j.nn.conf.stepfunctions.DefaultStepFunction; import org.deeplearning4j.nn.conf.stepfunctions.StepFunction; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; -import org.deeplearning4j.nn.conf.weightnoise.WeightNoise; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.util.OutputLayerUtil; @@ -319,16 +315,14 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { private boolean validateTbpttConfig = true; /** * Gradient updater configuration. For example, {@link org.nd4j.linalg.learning.config.Adam} or - * {@link org.nd4j.linalg.learning.config.Nesterovs}
Note: values set by this method will be - * applied to all applicable layers in the network, unless a different value is explicitly set on - * a given layer. In other words: values set via this method are used as the default value, and - * can be overridden on a per-layer basis. + * {@link org.nd4j.linalg.learning.config.Nesterovs}
+ * Note: values set by this method will be applied to all applicable layers in the network, unless + * a different value is explicitly set on a given layer. In other words: values set via this + * method are used as the default value, and can be overridden on a per-layer basis. * * @param updater Updater to use */ - @Getter - @Setter - private IUpdater updater; + @Getter @Setter @Builder.Default private IUpdater updater = new Sgd(); /** * Gradient normalization strategy. Used to specify gradient renormalization, gradient clipping * etc. See {@link GradientNormalization} for details
Note: values set by this method will be @@ -357,19 +351,9 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { @Setter private double gradientNormalizationThreshold; - /** - * Activation function / neuron non-linearity
Note: values set by this method will be applied - * to all applicable layers in the network, unless a different value is explicitly set on a given - * layer. In other words: values set via this method are used as the default value, and can be - * overridden on a per-layer basis. - */ - @Getter - @Setter - private IActivation activation; - //whether to constrain the gradient to unit norm or not - @Getter - @Setter - private StepFunction stepFunction; + // whether to constrain the gradient to unit norm or not + @Getter @Setter @Builder.Default private StepFunction stepFunction = new DefaultStepFunction(); + @Getter @Setter @lombok.Builder.Default @@ -400,13 +384,10 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { @Getter @lombok.Builder.Default private List regularizationBias = new ArrayList<>(); - @Getter - @Setter - @lombok.Builder.Default - private IUpdater iUpdater = new Sgd(); + /** * Gradient updater configuration, for the biases only. If not set, biases will use the updater as - * set by {@link #setIUpdater(IUpdater)}
+ * set by {@link #setUpdater(IUpdater)}
* Note: values set by this method will be applied to all applicable layers in the network, unless a different * value is explicitly set on a given layer. In other words: values set via this method are used as the default * value, and can be overridden on a per-layer basis. @@ -420,7 +401,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { @Getter @Setter @lombok.Builder.Default - private IActivation activationFn = new ActivationSigmoid(); + private IActivation activation = new ActivationSigmoid(); /** * Sets the convolution mode for convolutional layers, which impacts padding and output sizes. @@ -698,7 +679,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { private static boolean handleLegacyWeightInitFromJson(String json, LayerConfiguration l, ObjectMapper mapper, JsonNode confs, int layerCount) { - if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInitFn() == null) { + if ((l instanceof BaseLayerConfiguration) && ((BaseLayerConfiguration) l).getWeightInit() == null) { try { JsonNode jsonNode = mapper.readTree(json); if (confs == null) { @@ -729,7 +710,7 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { if (weightInit != null) { final IWeightInit wi = WeightInit.valueOf(weightInit.asText()) .getWeightInitFunction(dist); - ((BaseLayerConfiguration) l).setWeightInitFn(wi); + ((BaseLayerConfiguration) l).setWeightInit(wi); } } @@ -851,8 +832,8 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { * that do not have an individual setting (nor a default) */ for(LayerConfiguration lconf : this.getFlattenedLayerConfigurations()) { - if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivationFn()); - if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getIUpdater() ); + if(lconf.getActivationFn() == null ) lconf.setActivationFn(this.getActivation()); + if(lconf.getIUpdater() == null ) lconf.setIUpdater( this.getUpdater() ); if(lconf.getIDropout() == null ) lconf.setIDropout( this.getIdropOut() ); if(lconf.getWeightNoise() == null ) lconf.setWeightNoise( this.getWeightNoise()); @@ -1108,29 +1089,27 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { */ public List getFlattenedLayerConfigurations(NeuralNetConfiguration conf) { List ret = new ArrayList<>(); //create the final return list - for( Object obj : conf.getInnerConfigurations().stream().skip(1) //don't include self - .collect(Collectors.toList())) { - //if Layer Config, include in list and inherit parameters from this conf - //else if neural net configuration, call self recursively to resolve layer configurations - if (obj instanceof LayerConfiguration) - ret.add((LayerConfiguration) obj); - else if (obj instanceof NeuralNetConfiguration) - ret.addAll(getFlattenedLayerConfigurations( - (NeuralNetConfiguration) obj)); - else { - log.error( - "The list of layers and neural network configurations does contain an object of {}. Element will be ignored.", - obj.getClass().getSimpleName()); - } - } - /** - LayerConfiguration lc = ((LayerConfiguration) lc).getType().getClazz().cast(obj); - switch(lc.getType()) { - case FC: { //fully connected layer - ((FeedForwardLayer) lc).setWeightInitFn(this.getWeightInitFn()); - } - if(lc instanceof FeedForwardLayer && ((FeedForwardLayer) lc).getWeightInitFn() == null) { - **/ + //When properly initialized, _this_ configuration is set first in the list, however we + //can find cases where this is not true, thus the first configuration is another net or layer configuration + //and should not be skipped. In essence, skip first configuration if that is "this". + int iSkip = 0; + if(conf.getInnerConfigurations().size()>0 && conf.getInnerConfigurations().get(0).equals(this)) { iSkip=1;} + conf.getInnerConfigurations().stream().skip(iSkip) + .forEach(obj -> { + //if Layer Config, include in list and inherit parameters from this conf + //else if neural net configuration, call self recursively to resolve layer configurations + if (obj instanceof LayerConfiguration) { + ((LayerConfiguration) obj).setNetConfiguration(conf); + ret.add((LayerConfiguration) obj); + } else if (obj instanceof NeuralNetConfiguration) + ret.addAll(getFlattenedLayerConfigurations( + (NeuralNetConfiguration) obj)); + else { + log.error( + "The list of layers and neural network configurations does contain an object of {}. Element will be ignored.", + obj.getClass().getSimpleName()); + } + }); return ret; } @@ -1143,17 +1122,6 @@ public class NeuralNetConfiguration extends NeuralNetBaseBuilderConfiguration { return getFlattenedLayerConfigurations(this); } - - /** - * Get the configuration of the first layer - * @return layer configuration - */ - /** - public LayerConfiguration getFirstLayer() { - return getFlattenedLayerConfigurations().get(0); - } -**/ - /** * Add a new layer to the first position * @param layer configuration diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java index 121f9b38f..5f99c8082 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java @@ -23,6 +23,7 @@ package org.deeplearning4j.nn.conf.layers; import lombok.*; import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; import org.deeplearning4j.nn.conf.GradientNormalization; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.Updater; import org.deeplearning4j.nn.conf.distribution.Distribution; import org.deeplearning4j.nn.conf.weightnoise.IWeightNoise; @@ -30,6 +31,7 @@ import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.weights.WeightInitDistribution; import org.deeplearning4j.util.NetworkUtils; +import org.jetbrains.annotations.NotNull; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.IActivation; import org.nd4j.linalg.api.buffer.DataType; @@ -52,7 +54,7 @@ import java.util.List; public abstract class BaseLayerConfiguration extends LayerConfiguration implements ITraininableLayerConfiguration, Serializable, Cloneable { @NonNull - protected IWeightInit weightInitFn; + protected IWeightInit weightInit; protected double biasInit = 0.0; protected double gainInit = 0.0; protected List regularization; @@ -68,7 +70,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen public BaseLayerConfiguration(Builder builder) { super(builder); this.layerName = builder.layerName; - this.weightInitFn = builder.weightInitFn; + this.weightInit = builder.weightInit; this.biasInit = builder.biasInit; this.gainInit = builder.gainInit; this.regularization = builder.regularization; @@ -89,7 +91,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen public void resetLayerDefaultConfig() { //clear the learning related params for all layers in the origConf and set to defaults this.setIUpdater(null); - this.setWeightInitFn(null); + this.setWeightInit(null); this.setBiasInit(Double.NaN); this.setGainInit(Double.NaN); this.regularization = null; @@ -103,9 +105,6 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen @Override public BaseLayerConfiguration clone() { BaseLayerConfiguration clone = (BaseLayerConfiguration) super.clone(); - if (clone.iDropout != null) { - clone.iDropout = clone.iDropout.clone(); - } if(regularization != null){ //Regularization fields are _usually_ thread safe and immutable, but let's clone to be sure clone.regularization = new ArrayList<>(regularization.size()); @@ -170,7 +169,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen * * @see IWeightInit */ - protected IWeightInit weightInitFn = null; + protected IWeightInit weightInit = null; /** * Bias initialization value, for layers with biases. Defaults to 0 @@ -255,7 +254,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen * @see IWeightInit */ public T weightInit(IWeightInit weightInit) { - this.setWeightInitFn(weightInit); + this.setWeightInit(weightInit); return (T) this; } @@ -270,7 +269,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen "Not supported!, Use weightInit(Distribution distribution) instead!"); } - this.setWeightInitFn(weightInit.getWeightInitFunction()); + this.setWeightInit(weightInit.getWeightInitFunction()); return (T) this; } @@ -508,4 +507,19 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen } } + /** + * Inherit setting from neural network for those settings, that are not already set or do have + * a layer(type) specific default. + * @param conf the neural net configration to inherit parameters from + */ + @Override + public void runInheritance(@NotNull NeuralNetConfiguration conf) { + super.runInheritance(conf); + if(this.biasUpdater == null ) this.biasUpdater = conf.getBiasUpdater(); + if(this.iUpdater == null ) this.iUpdater = conf.getUpdater(); + if(this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias(); + if(this.regularization == null ) this.regularization = conf.getRegularization(); + if(this.gradientNormalization == null) this.gradientNormalization = conf.getGradientNormalization(); + } + } \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java index 25ad6ba4b..9ef539ae9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/ConvolutionLayer.java @@ -172,6 +172,7 @@ public class ConvolutionLayer extends FeedForwardLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { setNetConfiguration(conf); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + lconf.runInheritance(); LayerValidation.assertNInNOutSet("ConvolutionLayer", getLayerName(), layerIndex, getNIn(), getNOut()); @@ -404,9 +405,10 @@ public class ConvolutionLayer extends FeedForwardLayer { /** * Set the convolution mode for the Convolution layer. See {@link ConvolutionMode} for more details + * Default is {@link ConvolutionMode}.Truncate. * */ - protected ConvolutionMode convolutionMode; + protected ConvolutionMode convolutionMode = ConvolutionMode.Truncate; /** * Kernel dilation. Default: {1, 1}, which is standard convolutions. Used for implementing dilated convolutions, diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java index bfd88a62d..b1dd9856a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/DenseLayer.java @@ -62,19 +62,18 @@ public class DenseLayer extends FeedForwardLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("DenseLayerConfiguration", getLayerName(), layerIndex, getNIn(), getNOut()); - LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + lconf.runInheritance(); org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer ret = new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(lconf, networkDataType); - if(getWeightInitFn() == null) setWeightInitFn(new WeightInitXavier()); + + if(getWeightInit() == null) setWeightInit(new WeightInitXavier()); ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); ret.setParamsViewArray(layerParamsView); Map paramTable = initializer().init(this, layerParamsView, initializeParams); ret.setParamTable(paramTable); - ret.setLayerConfiguration(lconf); - return ret; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java index 2ec7b654c..16aeb1acd 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/EmbeddingSequenceLayer.java @@ -217,14 +217,14 @@ public class EmbeddingSequenceLayer extends FeedForwardLayer { return this; } - @Override + public void setWeightInitFn(IWeightInit weightInit){ if(weightInit instanceof WeightInitEmbedding){ long[] shape = ((WeightInitEmbedding) weightInit).shape(); nIn(shape[0]); nOut(shape[1]); } - this.weightInitFn = weightInit; + this.weightInit = weightInit; } /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java index b0131b80d..394012c4f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LayerConfiguration.java @@ -66,28 +66,29 @@ import org.nd4j.linalg.learning.regularization.Regularization; @Slf4j public abstract class LayerConfiguration implements ILayerConfiguration, Serializable, Cloneable { // ITraininableLayerConfiguration - protected String layerName = "noname"; + protected String layerName; @Getter protected List variables = new ArrayList<>(); - public void addVariable(String s) {variables.add(s);} - - protected IDropout iDropout; protected List constraints; protected IWeightNoise weightNoise; + private IDropout iDropout; /** * The type of the layer, basically defines the base class and its properties */ @Getter @Setter @NonNull private LayerType type = LayerType.UNKNOWN; - @Getter @Setter private NeuralNetConfiguration netConfiguration; + @Getter @Setter + private IActivation activationFn; public LayerConfiguration(Builder builder) { this.layerName = builder.layerName; this.iDropout = builder.iDropout; } + public void addVariable(String s) {variables.add(s);} + public String toJson() { throw new RuntimeException("toJson is not implemented for LayerConfiguration"); } @@ -151,6 +152,7 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali public LayerConfiguration getLayer() { return this; } + @Override public LayerConfiguration clone() { try { @@ -218,7 +220,6 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali */ public abstract void setNIn(InputType inputType, boolean override); - /** * For the given type of input to this layer, what preprocessor (if any) is required?
* Returns null if no preprocessor is required, otherwise returns an appropriate {@link @@ -263,11 +264,11 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali "Not supported: all layers with parameters should override this method"); } - public IUpdater getIUpdater() { throw new UnsupportedOperationException( "Not supported: all layers with parameters should override this method"); } + public void setIUpdater(IUpdater iUpdater) { log.warn("Setting an IUpdater on {} with name {} has no effect.", getClass().getSimpleName(), getLayerName()); } @@ -285,15 +286,33 @@ public abstract class LayerConfiguration implements ILayerConfiguration, Seriali this.variables.clear(); } - @Getter @Setter - private IActivation activationFn; + /** + * Inherit setting from neural network for those settings, that are not already set or do have + * a layer(type) specific default. This implementation does not require the neural network configuration to be + * the same as the one returned from this layers {@link #getNetConfiguration()}. + * + * @param conf a neural net configration to inherit parameters from + * + */ + public void runInheritance(@NonNull NeuralNetConfiguration conf) { + if(this.activationFn == null ) this.activationFn = conf.getActivation(); + if(this.iDropout == null ) this.iDropout = conf.getIdropOut(); + if(this.weightNoise == null) this.weightNoise = conf.getWeightNoise(); + } + + /** Runs {@link #runInheritance(NeuralNetConfiguration)} using the layers configurations embedded neural net + * configuration (the one returned from {@link #getNetConfiguration()}. + */ + public void runInheritance() { + runInheritance(getNetConfiguration()); + } @SuppressWarnings("unchecked") @Getter @Setter public abstract static class Builder> { - protected String layerName = "noname"; + protected String layerName; protected List allParamConstraints; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 2a8afacb7..ea679c9d4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -215,7 +215,7 @@ public class LocallyConnected1D extends SameDiffLayer { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { NeuralNetConfiguration global_conf = globalConfig.build(); if (activation == null) { - activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(global_conf.getActivation()); } if (cm == null) { cm = global_conf.getConvolutionMode(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index a33445ce7..5dd5ec62e 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -232,7 +232,7 @@ public class LocallyConnected2D extends SameDiffLayer { public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { NeuralNetConfiguration gconf = globalConfig.build(); if (activation == null) { - activation = SameDiffLayerUtils.fromIActivation(gconf.getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(gconf.getActivation()); } if (cm == null) { cm = gconf.getConvolutionMode(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java index 50647d0f1..249339df9 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/PReLULayer.java @@ -117,7 +117,7 @@ public class PReLULayer extends BaseLayerConfiguration { public Builder(){ //Default to 0s, and don't inherit global default - this.weightInitFn = new WeightInitConstant(0); + this.weightInit = new WeightInitConstant(0); } /** diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java index 10924fd90..a1bbd9f83 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RecurrentAttentionLayer.java @@ -152,7 +152,7 @@ public class RecurrentAttentionLayer extends SameDiffLayer { @Override public void applyGlobalConfigToLayer(NeuralNetConfiguration.NeuralNetConfigurationBuilder globalConfig) { if (activation == null) { - activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivationFn()); + activation = SameDiffLayerUtils.fromIActivation(globalConfig.build().getActivation()); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java index 18c4601c8..0d05a9486 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/AbstractSameDiffLayer.java @@ -196,7 +196,7 @@ public abstract class AbstractSameDiffLayer extends LayerConfiguration { regularizationBias = bConf.getRegularizationBias(); } if (updater == null) { - updater = bConf.getIUpdater(); + updater = bConf.getUpdater(); } if (biasUpdater == null) { biasUpdater = bConf.getBiasUpdater(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java index e9bded983..accc675d0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/samediff/SameDiffVertex.java @@ -156,7 +156,7 @@ public abstract class SameDiffVertex extends GraphVertex implements ITraininable regularizationBias = b_conf.getRegularizationBias(); } if (updater == null) { - updater = b_conf.getIUpdater(); + updater = b_conf.getUpdater(); } if (biasUpdater == null) { biasUpdater = b_conf.getBiasUpdater(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java index 85f06a40b..a4cf67c79 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/variational/VariationalAutoencoder.java @@ -72,6 +72,7 @@ public class VariationalAutoencoder extends BasePretrainNetwork { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret = new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(lconf, networkDataType); + lconf.runInheritance(); ret.addTrainingListeners(trainingListeners); ret.setIndex(layerIndex); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java index 292b85c10..24a17c263 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/BaseNetConfigDeserializer.java @@ -98,7 +98,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im protected boolean requiresWeightInitFromLegacy(LayerConfiguration[] layers){ for(LayerConfiguration l : layers){ if(l instanceof BaseLayerConfiguration - && ((BaseLayerConfiguration)l).getWeightInitFn() == null){ + && ((BaseLayerConfiguration)l).getWeightInit() == null){ return true; } } @@ -254,7 +254,7 @@ public abstract class BaseNetConfigDeserializer extends StdDeserializer im d = NeuralNetConfiguration.mapper().readValue(dist, Distribution.class); } IWeightInit iwi = w.getWeightInitFunction(d); - baseLayerConfiguration.setWeightInitFn(iwi); + baseLayerConfiguration.setWeightInit(iwi); } catch (Throwable t){ log.warn("Failed to infer weight initialization from legacy JSON format",t); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java index 92399e037..9f93c43e0 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/ComputationGraphConfigurationDeserializer.java @@ -129,7 +129,7 @@ public class ComputationGraphConfigurationDeserializer } if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration - && ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null){ + && ((BaseLayerConfiguration)layers[layerIdx]).getWeightInit() == null){ handleWeightInitBackwardCompatibility((BaseLayerConfiguration)layers[layerIdx], (ObjectNode)next); } @@ -160,7 +160,7 @@ public class ComputationGraphConfigurationDeserializer layerIdx++; } else if("org.deeplearning4j.nn.conf.graph.LayerVertex".equals(cls)){ if(requiresLegacyWeightInitHandling && layers[layerIdx] instanceof BaseLayerConfiguration - && ((BaseLayerConfiguration)layers[layerIdx]).getWeightInitFn() == null) { + && ((BaseLayerConfiguration)layers[layerIdx]).getWeightInit() == null) { //Post JSON format change for subclasses, but before WeightInit was made a class confNode = (ObjectNode) next.get("layerConf"); next = confNode.get("layer"); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java index 633650b95..7863aca02 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/serde/NeuralNetConfigurationDeserializer.java @@ -141,7 +141,7 @@ public class NeuralNetConfigurationDeserializer extends BaseNetConfigDeserialize } if(requiresLegacyWeightInitHandling && layers[i] instanceof BaseLayerConfiguration - && ((BaseLayerConfiguration) layers[i]).getWeightInitFn() == null) { + && ((BaseLayerConfiguration) layers[i]).getWeightInit() == null) { handleWeightInitBackwardCompatibility((BaseLayerConfiguration) layers[i], on); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index e8501f312..9774b8c07 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -88,14 +88,19 @@ public abstract class AbstractLayer impl cacheMode = layerConfiguration.getNetConfiguration().getCacheMode(); } this.dataType = dataType; + if (layerConfiguration.getNetConfiguration() == null) { + throw new RuntimeException("You cannot create a layer from a layer configuration, that is not part of any neural network configuration."); + } this.net = layerConfiguration.getNetConfiguration().getNet(); } public void addTrainingListeners(TrainingListener... listeners) { + if(listeners != null) trainingListeners.addAll(List.of(listeners)); } public void addTrainingListeners(Collection listeners) { + if(listeners != null) trainingListeners.addAll(listeners); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java index 6363c77c5..01aede19f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java @@ -77,7 +77,6 @@ public abstract class BaseLayer * INDArray params; */ public BaseLayer(LayerConfiguration conf, DataType dataType) { - super(conf, dataType); } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.java index f2e639e21..ad931eb8f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/ocnn/OCNNParamInitializer.java @@ -21,7 +21,6 @@ package org.deeplearning4j.nn.layers.ocnn; import lombok.val; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.params.DefaultParamInitializer; import org.deeplearning4j.nn.weights.IWeightInit; @@ -154,7 +153,7 @@ public class OCNNParamInitializer extends DefaultParamInitializer { boolean initializeParameters) { org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer ocnnOutputLayer = ( org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer) configuration; - IWeightInit weightInit = ocnnOutputLayer.getWeightInitFn(); + IWeightInit weightInit = ocnnOutputLayer.getWeightInit(); if (initializeParameters) { INDArray ret = weightInit.init(weightParamView.size(0), //Fan in weightParamView.size(1), //Fan out diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java index 3c9d2706b..bf21a6dc8 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/variational/VariationalAutoencoder.java @@ -92,7 +92,7 @@ public class VariationalAutoencoder implements Layer { protected int epochCount; @Getter @Setter @NonNull private LayerConfiguration layerConfiguration; - private @Getter @Setter Collection trainingListeners; + private @Getter @Setter Collection trainingListeners = new HashSet<>(); public VariationalAutoencoder(@NonNull LayerConfiguration layerConfiguration, DataType dataType) { this.layerConfiguration = layerConfiguration; @@ -113,6 +113,27 @@ public class VariationalAutoencoder implements Layer { .getNumSamples(); } + /** + * Replace the TrainingListeners for this model + * + * @param listeners new listeners + */ + @Override + public void addTrainingListeners(TrainingListener... listeners) { + if(listeners != null) + trainingListeners.addAll(List.of(listeners)); + } + +/** +* + * @param listeners + */ + @Override + public void addTrainingListeners(Collection listeners) { + if(listeners != null) + trainingListeners.addAll(listeners); + } + /** * Get a reference to the network this layer is part of. * @@ -1214,24 +1235,6 @@ public class VariationalAutoencoder implements Layer { //No-op for individual layers } - /** - * Replace the TrainingListeners for this model - * - * @param listeners new listeners - */ - @Override - public void addTrainingListeners(TrainingListener... listeners) { - trainingListeners.addAll(List.of(listeners)); - } - -/** -* - * @param listeners - */ - @Override - public void addTrainingListeners(Collection listeners) { - trainingListeners.addAll(listeners); - } @AllArgsConstructor @Data diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java index 745e77a69..b11f9f3d2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Convolution3DParamInitializer.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.params; import lombok.val; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Convolution3D; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; @@ -131,7 +130,7 @@ public class Convolution3DParamInitializer extends ConvolutionParamInitializer { val weightsShape = new long[]{outputDepth, inputDepth, kernel[0], kernel[1], kernel[2]}; - return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', + return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView); } else { int[] kernel = layerConf.getKernelSize(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java index a8b3ce7aa..9b53e3713 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/ConvolutionParamInitializer.java @@ -180,7 +180,7 @@ public class ConvolutionParamInitializer extends AbstractParamInitializer { val weightsShape = new long[] {outputDepth, inputDepth, kernel[0], kernel[1]}; - return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView); + return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView); } else { int[] kernel = layerConf.getKernelSize(); return WeightInitUtil.reshapeWeights( diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Deconvolution3DParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Deconvolution3DParamInitializer.java index 7f8b8e9e6..6e2d2b128 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Deconvolution3DParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/Deconvolution3DParamInitializer.java @@ -22,7 +22,6 @@ package org.deeplearning4j.nn.params; import lombok.val; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.Deconvolution3D; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; @@ -130,7 +129,7 @@ public class Deconvolution3DParamInitializer extends ConvolutionParamInitializer //libnd4j: [kD, kH, kW, oC, iC] val weightsShape = new long[]{kernel[0], kernel[1], kernel[2], outputDepth, inputDepth}; - return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', weightView); + return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView); } else { int[] kernel = layerConf.getKernelSize(); return WeightInitUtil.reshapeWeights( diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java index 1c7ac91d9..463c24ae3 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DeconvolutionParamInitializer.java @@ -21,7 +21,6 @@ package org.deeplearning4j.nn.params; import lombok.val; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; import org.nd4j.linalg.api.ndarray.INDArray; @@ -61,7 +60,7 @@ public class DeconvolutionParamInitializer extends ConvolutionParamInitializer { val weightsShape = new long[] {inputDepth, outputDepth, kernel[0], kernel[1]}; - INDArray weights = layerConf.getWeightInitFn().init( + INDArray weights = layerConf.getWeightInit().init( fanIn, fanOut, weightsShape, 'c', weightView); return weights; diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java index c20562223..239fd20bf 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java @@ -196,13 +196,13 @@ public class DefaultParamInitializer extends AbstractParamInitializer { (org.deeplearning4j.nn.conf.layers.FeedForwardLayer) conf; if (initializeParameters) { - if( layerConf.getWeightInitFn() == null) { + if( layerConf.getWeightInit() == null) { // set a default and set warning - layerConf.setWeightInitFn(new WeightInitXavier()); + layerConf.setWeightInit(new WeightInitXavier()); log.warn("Weight Initializer function was not set on layer {} of class {}, it will default to {}", conf.getLayerName(), conf.getClass().getSimpleName(), WeightInitXavier.class.getSimpleName()); } - return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInitFn(), + return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), layerConf.getWeightInit(), weightParamView, true); } else { return createWeightMatrix(layerConf.getNIn(), layerConf.getNOut(), null, weightParamView, false); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java index 72f2ac6ba..d1bd00449 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DepthwiseConvolutionParamInitializer.java @@ -23,8 +23,6 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.api.AbstractParamInitializer; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DepthwiseConvolution2D; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.WeightInitUtil; @@ -193,7 +191,7 @@ public class DepthwiseConvolutionParamInitializer extends AbstractParamInitializ val weightsShape = new long[] {kernel[0], kernel[1], inputDepth, depthMultiplier}; - return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', + return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView); } else { int[] kernel = layerConf.getKernelSize(); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java index 5239a6c2c..e74d69a1a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesBidirectionalLSTMParamInitializer.java @@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.api.AbstractParamInitializer; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; @@ -159,14 +157,14 @@ public class GravesBidirectionalLSTMParamInitializer extends AbstractParamInitia val inputWShape = new long[]{nLast, 4 * nL}; val recurrentWShape = new long[]{nL, 4 * nL + 3}; - params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, + params.put(INPUT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwF)); - params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape, + params.put(RECURRENT_WEIGHT_KEY_FORWARDS, layerConf.getWeightInit().init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwF)); params.put(BIAS_KEY_FORWARDS, bF); - params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, + params.put(INPUT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, iwR)); - params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInitFn().init(fanIn, fanOut, recurrentWShape, + params.put(RECURRENT_WEIGHT_KEY_BACKWARDS, layerConf.getWeightInit().init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, rwR)); params.put(BIAS_KEY_BACKWARDS, bR); } else { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java index 5c59e5f7e..265027812 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/GravesLSTMParamInitializer.java @@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.api.AbstractParamInitializer; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInitUtil; @@ -124,10 +122,10 @@ public class GravesLSTMParamInitializer extends AbstractParamInitializer { if(layerConf.getWeightInitFnRecurrent() != null){ rwInit = layerConf.getWeightInitFnRecurrent(); } else { - rwInit = layerConf.getWeightInitFn(); + rwInit = layerConf.getWeightInit(); } - params.put(INPUT_WEIGHT_KEY,layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, + params.put(INPUT_WEIGHT_KEY,layerConf.getWeightInit().init(fanIn, fanOut, inputWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView)); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java index 04f12ea32..040822a8a 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/LSTMParamInitializer.java @@ -27,7 +27,6 @@ import java.util.List; import java.util.Map; import lombok.val; import org.deeplearning4j.nn.api.AbstractParamInitializer; -import org.deeplearning4j.nn.api.ParamInitializer; import org.deeplearning4j.nn.conf.layers.LSTM; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.weights.IWeightInit; @@ -132,10 +131,10 @@ public class LSTMParamInitializer extends AbstractParamInitializer { if(layerConf.getWeightInitFnRecurrent() != null){ rwInit = layerConf.getWeightInitFnRecurrent(); } else { - rwInit = layerConf.getWeightInitFn(); + rwInit = layerConf.getWeightInit(); } - params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInitFn().init(fanIn, fanOut, inputWShape, + params.put(INPUT_WEIGHT_KEY, layerConf.getWeightInit().init(fanIn, fanOut, inputWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, inputWeightView)); params.put(RECURRENT_WEIGHT_KEY, rwInit.init(fanIn, fanOut, recurrentWShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, recurrentWeightView)); biasView.put(new INDArrayIndex[] {NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(nL, 2 * nL)}, diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java index 32b05a04c..11d5638fe 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/PReLUParamInitializer.java @@ -133,7 +133,7 @@ public class PReLUParamInitializer extends AbstractParamInitializer { PReLULayer layerConf = (PReLULayer) conf; if (initializeParameters) { - return layerConf.getWeightInitFn().init(layerConf.getNIn(), layerConf.getNOut(), + return layerConf.getWeightInit().init(layerConf.getNIn(), layerConf.getNOut(), weightShape, IWeightInit.DEFAULT_WEIGHT_INIT_ORDER, weightParamView); } else { return WeightInitUtil.reshapeWeights(weightShape, weightParamView); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java index 9df032560..58547886f 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SeparableConvolutionParamInitializer.java @@ -23,8 +23,6 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.api.AbstractParamInitializer; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D; import org.deeplearning4j.nn.weights.WeightInitUtil; @@ -220,7 +218,7 @@ public class SeparableConvolutionParamInitializer extends AbstractParamInitializ val weightsShape = new long[] {depthMultiplier, inputDepth, kernel[0], kernel[1]}; - return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', + return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView); } else { int[] kernel = layerConf.getKernelSize(); @@ -249,7 +247,7 @@ public class SeparableConvolutionParamInitializer extends AbstractParamInitializ val weightsShape = new long[] {outputDepth, depthMultiplier * inputDepth, 1, 1}; - return layerConf.getWeightInitFn().init(fanIn, fanOut, weightsShape, 'c', + return layerConf.getWeightInit().init(fanIn, fanOut, weightsShape, 'c', weightView); } else { return WeightInitUtil.reshapeWeights( diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java index 603492afa..488c00396 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/SimpleRnnParamInitializer.java @@ -22,8 +22,6 @@ package org.deeplearning4j.nn.params; import lombok.val; import org.deeplearning4j.nn.api.AbstractParamInitializer; -import org.deeplearning4j.nn.api.ParamInitializer; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.recurrent.SimpleRnn; import org.deeplearning4j.nn.weights.IWeightInit; @@ -102,14 +100,14 @@ public class SimpleRnnParamInitializer extends AbstractParamInitializer { if (initializeParams) { m = getSubsets(paramsView, nIn, nOut, false, hasLayerNorm(c)); - INDArray w = c.getWeightInitFn().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY)); + INDArray w = c.getWeightInit().init(nIn, nOut, new long[]{nIn, nOut}, 'f', m.get(WEIGHT_KEY)); m.put(WEIGHT_KEY, w); IWeightInit rwInit; if (c.getWeightInitFnRecurrent() != null) { rwInit = c.getWeightInitFnRecurrent(); } else { - rwInit = c.getWeightInitFn(); + rwInit = c.getWeightInit(); } INDArray rw = rwInit.init(nOut, nOut, new long[]{nOut, nOut}, 'f', m.get(RECURRENT_WEIGHT_KEY)); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java index 9284843d5..362c35170 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/VariationalAutoencoderParamInitializer.java @@ -21,7 +21,6 @@ package org.deeplearning4j.nn.params; import lombok.val; -import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.LayerConfiguration; import org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder; import org.deeplearning4j.nn.weights.IWeightInit; @@ -200,7 +199,7 @@ public class VariationalAutoencoderParamInitializer extends DefaultParamInitiali int[] encoderLayerSizes = layer.getEncoderLayerSizes(); int[] decoderLayerSizes = layer.getDecoderLayerSizes(); - IWeightInit weightInit = layer.getWeightInitFn(); + IWeightInit weightInit = layer.getWeightInit(); int soFar = 0; for (int i = 0; i < encoderLayerSizes.length; i++) { diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java index 5d68bd890..b62e77e83 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/FineTuneConfiguration.java @@ -164,7 +164,7 @@ public class FineTuneConfiguration { bl.setActivationFn(activationFn); } if (weightInitFn != null) { - bl.setWeightInitFn(weightInitFn); + bl.setWeightInit(weightInitFn); } if (biasInit != null) { bl.setBiasInit(biasInit); @@ -264,10 +264,10 @@ public class FineTuneConfiguration { NeuralNetConfiguration.NeuralNetConfigurationBuilder confBuilder = NeuralNetConfiguration.builder(); if (activationFn != null) { - confBuilder.activationFn(activationFn); + confBuilder.activation(activationFn); } if (weightInitFn != null) { - confBuilder.weightInitFn(weightInitFn); + confBuilder.weightInit(weightInitFn); } if (biasInit != null) { confBuilder.biasInit(biasInit); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java index 663420f0a..708568d19 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/transferlearning/TransferLearning.java @@ -462,7 +462,7 @@ public class TransferLearning { Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nInReplace can only be applide on FeedForward layers;" + "got layer of type %s", layerImpl.getClass().getSimpleName()); FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; - layerImplF.setWeightInitFn(init); + layerImplF.setWeightInit(init); layerImplF.setNIn(nIn); long numParams = layerImpl.initializer().numParams(layerConf); INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); @@ -480,7 +480,7 @@ public class TransferLearning { Preconditions.checkArgument(layerImpl instanceof FeedForwardLayer, "nOutReplace can only be applide on FeedForward layers;" + "got layer of type %s", layerImpl.getClass().getSimpleName()); FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; - layerImplF.setWeightInitFn(scheme); + layerImplF.setWeightInit(scheme); layerImplF.setNOut(nOut); long numParams = layerImpl.initializer().numParams(layerConf); INDArray params = Nd4j.create(origModel.getNetConfiguration().getDataType(), 1, numParams); @@ -492,7 +492,7 @@ public class TransferLearning { layerImpl = layerConf; //modify in place if(layerImpl instanceof FeedForwardLayer) { layerImplF = (FeedForwardLayer) layerImpl; - layerImplF.setWeightInitFn(schemeNext); + layerImplF.setWeightInit(schemeNext); layerImplF.setNIn(nOut); numParams = layerImpl.initializer().numParams(layerConf); if (numParams > 0) { @@ -738,7 +738,7 @@ public class TransferLearning { layerImpl.resetLayerDefaultConfig(); FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; - layerImplF.setWeightInitFn(scheme); + layerImplF.setWeightInit(scheme); layerImplF.setNIn(nIn); if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex @@ -767,7 +767,7 @@ public class TransferLearning { LayerConfiguration layerImpl = layerConf.clone(); layerImpl.resetLayerDefaultConfig(); FeedForwardLayer layerImplF = (FeedForwardLayer) layerImpl; - layerImplF.setWeightInitFn(scheme); + layerImplF.setWeightInit(scheme); layerImplF.setNOut(nOut); if(editedVertices.contains(layerName) && editedConfigBuilder.getVertices().get(layerName) instanceof LayerVertex @@ -806,7 +806,7 @@ public class TransferLearning { continue; layerImpl = layerConf.clone(); layerImplF = (FeedForwardLayer) layerImpl; - layerImplF.setWeightInitFn(schemeNext); + layerImplF.setWeightInit(schemeNext); layerImplF.setNIn(nOut); nInFromNewConfig.put(fanoutVertexName, nOut); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java index cec9da44a..e7a74999c 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/updater/BaseMultiLayerUpdater.java @@ -207,10 +207,11 @@ public abstract class BaseMultiLayerUpdater implements Updater */ public void setStateViewArray(INDArray viewArray) { if(this.updaterStateViewArray == null){ - if(viewArray == null) + if(viewArray == null || viewArray.length()==0) return; //No op - for example, SGD and NoOp updater - i.e., no stored state else { - throw new IllegalStateException("Attempting to set updater state view array with null value"); + //this.updaterStateViewArray.set + // throw new IllegalStateException("Attempting to set updater state view array with null value"); } } if (this.updaterStateViewArray.length() != viewArray.length()) @@ -296,7 +297,7 @@ public abstract class BaseMultiLayerUpdater implements Updater //PRE apply (gradient clipping, etc): done on a per-layer basis for (Map.Entry entry : layerGradients.entrySet()) { String layerName = entry.getKey(); - ITrainableLayer layer = layersByName.get(layerName); + ITrainableLayer layer = layersByName.get(layerName); //Todo Layers may have the same name!? preApply(layer, layerGradients.get(layerName), iteration); } diff --git a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java index 0b7ce4627..7b5176670 100644 --- a/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java +++ b/cavis-dnn/cavis-dnn-nn/src/test/java/net/brutex/ai/dnn/api/dnnTest.java @@ -29,7 +29,6 @@ import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.datasets.iterator.FloatsDataSetIterator; import org.deeplearning4j.nn.conf.GradientNormalization; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; -import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.ActivationLayer; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; @@ -39,7 +38,6 @@ import org.deeplearning4j.optimize.listeners.ScoreToChartListener; import org.junit.jupiter.api.Test; import org.nd4j.common.primitives.Pair; import org.nd4j.linalg.activations.Activation; -import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationLReLU; import org.nd4j.linalg.activations.impl.ActivationSigmoid; import org.nd4j.linalg.learning.config.Adam; @@ -85,8 +83,8 @@ class dnnTest { .updater(Adam.builder().learningRate(0.0002).beta1(0.5).build()) .gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) .gradientNormalizationThreshold(100) - .weightInitFn(new WeightInitXavier()) - .activationFn(new ActivationSigmoid()) + .weightInit(new WeightInitXavier()) + .activation(new ActivationSigmoid()) // .inputType(InputType.convolutional(28, 28, 1)) .layer(new DenseLayer.Builder().nIn(6).nOut(20).build()) .layer(new ActivationLayer.Builder(new ActivationLReLU(0.2)).build()) diff --git a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index c0df01142..89b2ceef6 100644 --- a/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/cavis-ui/cavis-ui-vertx/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -1182,7 +1182,7 @@ public class TrainModule implements UIModule { String.valueOf(nParams)}); if (nParams > 0) { try { - String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInitFn()); + String str = JsonMappers.getMapper().writeValueAsString(bl.getWeightInit()); layerInfoRows.add(new String[]{ i18N.getMessage("train.model.layerinfotable.layerWeightInit"), str}); } catch (JsonProcessingException e) { diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Darknet19.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Darknet19.java index f2b07ec58..c94161a4b 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Darknet19.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/Darknet19.java @@ -29,6 +29,7 @@ import org.deeplearning4j.nn.conf.ComputationGraphConfiguration.GraphBuilder; import org.deeplearning4j.nn.conf.inputs.InputType; import org.deeplearning4j.nn.conf.layers.*; import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.weights.IWeightInit; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.zoo.ModelMetaData; import org.deeplearning4j.zoo.PretrainedType; diff --git a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/ResNet50.java b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/ResNet50.java index f530e0781..70abb5722 100644 --- a/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/ResNet50.java +++ b/cavis-zoo/cavis-zoo-models/src/main/java/org/deeplearning4j/zoo/model/ResNet50.java @@ -176,7 +176,7 @@ public class ResNet50 extends ZooModel { .activation(Activation.IDENTITY) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(updater) - .weightInitFn(weightInit) + .weightInit(weightInit) .l1(1e-7) .l2(5e-5) .miniBatch(true) From 9d4939ccfd8b75948caf9dd58750ef676d5a332d Mon Sep 17 00:00:00 2001 From: brian Date: Sat, 15 Apr 2023 12:50:26 +0200 Subject: [PATCH 125/126] Playing with some new code 2 - clean build/test Signed-off-by: brian --- .../java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java | 1 + .../deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java | 1 + .../org/deeplearning4j/nn/conf/layers/BatchNormalization.java | 3 ++- .../java/org/deeplearning4j/nn/conf/layers/OutputLayer.java | 1 + .../org/deeplearning4j/nn/params/DefaultParamInitializer.java | 2 +- 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java index 72615eca8..c0dbc4f56 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/AutoEncoder.java @@ -57,6 +57,7 @@ public class AutoEncoder extends BasePretrainNetwork { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { this.setNetConfiguration(conf); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + runInheritance(); org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder ret = new org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder(lconf, networkDataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java index 5f99c8082..b16ecb768 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BaseLayerConfiguration.java @@ -520,6 +520,7 @@ public abstract class BaseLayerConfiguration extends LayerConfiguration implemen if(this.regularizationBias == null) this.regularizationBias = conf.getRegularizationBias(); if(this.regularization == null ) this.regularization = conf.getRegularization(); if(this.gradientNormalization == null) this.gradientNormalization = conf.getGradientNormalization(); + if(this.weightInit == null) this.weightInit = conf.getWeightInit(); } } \ No newline at end of file diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java index ab0044448..5e266afb2 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/BatchNormalization.java @@ -92,8 +92,9 @@ public class BatchNormalization extends FeedForwardLayer { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { this.setNetConfiguration(conf); - LayerValidation.assertNOutSet("BatchNormalization", getLayerName(), layerIndex, getNOut()); + runInheritance(); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); org.deeplearning4j.nn.layers.normalization.BatchNormalization ret = new org.deeplearning4j.nn.layers.normalization.BatchNormalization(lconf, networkDataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java index f024caec2..2884ac424 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/OutputLayer.java @@ -54,6 +54,7 @@ public class OutputLayer extends BaseOutputLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("OutputLayer", getLayerName(), layerIndex, getNIn(), getNOut()); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + runInheritance(); org.deeplearning4j.nn.layers.OutputLayer ret = new org.deeplearning4j.nn.layers.OutputLayer(lconf, networkDataType); ret.addTrainingListeners(trainingListeners); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java index 239fd20bf..954ddc250 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java @@ -210,7 +210,7 @@ public class DefaultParamInitializer extends AbstractParamInitializer { } protected INDArray createWeightMatrix(long nIn, long nOut, - @NonNull IWeightInit weightInit, + IWeightInit weightInit, INDArray weightParamView, boolean initializeParameters) { val shape = new long[] {nIn, nOut}; From 82e65bdf59fc5bd83df94703c38ec4ef4e2d264b Mon Sep 17 00:00:00 2001 From: brian Date: Mon, 17 Apr 2023 09:41:12 +0200 Subject: [PATCH 126/126] Playing with some new code 2 - clean build/test Signed-off-by: brian --- .../nd4j/linalg/workspace/WorkspaceMgr.java | 10 +- .../eval/EvaluationToolsTests.java | 2 +- .../nn/conf/layers/Cnn3DLossLayer.java | 2 + .../nn/conf/layers/CnnLossLayer.java | 2 + .../nn/conf/layers/GravesLSTM.java | 4 + .../nn/conf/layers/RnnLossLayer.java | 2 + .../nn/conf/layers/SubsamplingLayer.java | 1 + .../nn/layers/AbstractLayer.java | 12 +- .../deeplearning4j/nn/layers/BaseLayer.java | 13 +- .../nn/params/DefaultParamInitializer.java | 6 +- .../cpu/nativecpu/ops/CpuOpContext.java | 3 +- .../nd4j/jita/workspace/CudaWorkspace.java | 768 ++++++++++-------- .../workspace/CudaWorkspaceDeallocator.java | 2 +- .../ops/executioner/CudaExecutioner.java | 2 +- .../ops/executioner/CudaOpContext.java | 3 +- 15 files changed, 473 insertions(+), 359 deletions(-) diff --git a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/WorkspaceMgr.java b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/WorkspaceMgr.java index 90b77f449..f096a1f6f 100644 --- a/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/WorkspaceMgr.java +++ b/cavis-dnn/cavis-dnn-api/src/main/java/org/nd4j/linalg/workspace/WorkspaceMgr.java @@ -67,7 +67,7 @@ public interface WorkspaceMgr> { /** * Set arrays to be scoped out (not in any workspace) for the specified array type. - * This means that create, dup, leverage etc methods will return result arrays that are not attached to any workspace + * This means that create, dup, leverage etc. methods will return result arrays that are not attached to any workspace * * @param arrayType Array type to set scoped out for */ @@ -120,7 +120,7 @@ public interface WorkspaceMgr> { boolean isWorkspaceOpen(T arrayType); /** - * Assert thath the workspace for the specified array type is open. + * Assert that the workspace for the specified array type is open. * For array types that are set to scoped out, this will be treated as a no-op * @param arrayType Array type to check * @param msg May be null. If non-null: include this in the exception @@ -129,7 +129,7 @@ public interface WorkspaceMgr> { void assertOpen(T arrayType, String msg) throws ND4JWorkspaceException; /** - * Assert thath the workspace for the specified array type is not open. + * Assert that the workspace for the specified array type is not open. * For array types that are set to scoped out, this will be treated as a no-op * @param arrayType Array type to check * @param msg May be null. If non-null: include this in the exception @@ -193,7 +193,7 @@ public interface WorkspaceMgr> { /** * Create an uninitialized array in the specified array type's workspace (or detached if none is specified). - * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int)} (int...)}, other than the array location + * Equivalent to {@link org.nd4j.linalg.factory.Nd4j#createUninitialized(int...)}, other than the array location * @param arrayType Array type * @param dataType Data type of the created array * @param shape Shape @@ -231,7 +231,7 @@ public interface WorkspaceMgr> { /** * Cast the specified array to the specified datatype.
- * If the array is already the correct type, the bahaviour depends on the 'dupIfCorrectType' argument.
+ * If the array is already the correct type, the behaviour depends on the 'dupIfCorrectType' argument.
* dupIfCorrectType = false && toCast.dataType() == dataType: return input array as-is (unless workspace is wrong)
* dupIfCorrectType = true && toCast.dataType() == dataType: duplicate the array into the specified workspace
* @param arrayType Array type diff --git a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java index aa9b2686f..7cf2431f9 100644 --- a/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java +++ b/cavis-dnn/cavis-dnn-core/src/test/java/org/deeplearning4j/eval/EvaluationToolsTests.java @@ -81,7 +81,7 @@ public class EvaluationToolsTests extends BaseDL4JTest { String str = EvaluationTools.rocChartToHtml(roc); - // System.out.println(str); + System.out.println(str); } } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java index 79782d956..8ae76bd41 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/Cnn3DLossLayer.java @@ -58,6 +58,8 @@ public class Cnn3DLossLayer extends FeedForwardLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { setNetConfiguration(conf); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + runInheritance(); + org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer ret = new org.deeplearning4j.nn.layers.convolution.Cnn3DLossLayer(lconf, networkDataType); ret.addTrainingListeners(trainingListeners); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java index b4f93482d..50e917dac 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/CnnLossLayer.java @@ -63,6 +63,8 @@ public class CnnLossLayer extends FeedForwardLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { setNetConfiguration(conf); LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + runInheritance(); + org.deeplearning4j.nn.layers.convolution.CnnLossLayer ret = new org.deeplearning4j.nn.layers.convolution.CnnLossLayer(lconf, networkDataType); ret.addTrainingListeners(trainingListeners); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java index 1cdd16dba..9c50ccba4 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/GravesLSTM.java @@ -77,7 +77,11 @@ public class GravesLSTM extends AbstractLSTM { public Layer instantiate(NeuralNetConfiguration conf, Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerValidation.assertNInNOutSet("GravesLSTM", getLayerName(), layerIndex, getNIn(), getNOut()); + LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + lconf.setNetConfiguration(conf); + runInheritance(); + org.deeplearning4j.nn.layers.recurrent.GravesLSTM ret = new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(lconf, networkDataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java index 4742b9e5b..e7db009ed 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/RnnLossLayer.java @@ -61,6 +61,8 @@ public class RnnLossLayer extends FeedForwardLayer { int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + lconf.setNetConfiguration(conf); + runInheritance(); org.deeplearning4j.nn.layers.recurrent.RnnLossLayer ret = new org.deeplearning4j.nn.layers.recurrent.RnnLossLayer(lconf, networkDataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java index bddd9fc30..55e766133 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/conf/layers/SubsamplingLayer.java @@ -135,6 +135,7 @@ public class SubsamplingLayer extends NoParamLayer { Collection trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) { LayerConfiguration lconf = conf.getFlattenedLayerConfigurations().get(layerIndex); + runInheritance(); org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer ret = new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(lconf, networkDataType); diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index 9774b8c07..d14f20d85 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -24,6 +24,8 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; + import lombok.*; import net.brutex.ai.dnn.api.IModel; import org.deeplearning4j.nn.api.ITraininableLayerConfiguration; @@ -328,13 +330,9 @@ public abstract class AbstractLayer impl @Override public void clearNoiseWeightParams() {} - public List variables() { - return variables; - } - - public List variables(boolean copy) { + public List getVariables(boolean copy) { if (copy) { - return variables(); + return new ArrayList<>(getVariables()); } return variables; } @@ -585,7 +583,7 @@ public abstract class AbstractLayer impl */ @Override public INDArray getParams() { - // throw new RuntimeException("Not implemented"); + //throw new RuntimeException("Not implemented"); return null; } diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java index 01aede19f..1a055c528 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/layers/BaseLayer.java @@ -184,6 +184,17 @@ public abstract class BaseLayer setParams(params, 'f'); } + /** + * * The AbstractLayer does not implement Params, ParamTable and GradientView. A RuntimeException + * * will be triggered when calling this. + * + * @return 1d parameter vector + */ + @Override + public INDArray getParams() { + return paramsFlattened; + } + /** */ @Override public void close() {} @@ -358,7 +369,7 @@ public abstract class BaseLayer protected void setParams(INDArray params, char order) { if (params == null) { - log.warn( + log.trace( "setParams(INDArray params, char order): params is null. Skipping setParams in Layer {}[{}] at index {}", getLayerConfiguration().getLayerName(), getClass().getSimpleName(), diff --git a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java index 954ddc250..a7f444c91 100644 --- a/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java +++ b/cavis-dnn/cavis-dnn-nn/src/main/java/org/deeplearning4j/nn/params/DefaultParamInitializer.java @@ -110,14 +110,14 @@ public class DefaultParamInitializer extends AbstractParamInitializer { INDArray weightView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(0, nWeightParams)); params.put(WEIGHT_KEY, createWeightMatrix(layerConf, weightView, initializeParams)); - layerConf.getNetConfiguration().addNetWideVariable(WEIGHT_KEY); + layerConf.addVariable(WEIGHT_KEY); long offset = nWeightParams; if(hasBias(layerConf)){ INDArray biasView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(offset, offset + nOut)); params.put(BIAS_KEY, createBias(layerConf, biasView, initializeParams)); - layerConf.getNetConfiguration().addNetWideVariable(BIAS_KEY); + layerConf.addVariable(BIAS_KEY); offset += nOut; } @@ -125,7 +125,7 @@ public class DefaultParamInitializer extends AbstractParamInitializer { INDArray gainView = paramsView.get(NDArrayIndex.interval(0,0,true), NDArrayIndex.interval(offset, offset + nOut)); params.put(GAIN_KEY, createGain(conf, gainView, initializeParams)); - conf.getNetConfiguration().addNetWideVariable(GAIN_KEY); + conf.addVariable(GAIN_KEY); } return params; diff --git a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index d6ddf49de..2c1a5860c 100644 --- a/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/cavis-native/cavis-native-cpu/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -50,7 +50,8 @@ public class CpuOpContext extends BaseOpContext implements OpContext, Deallocata @Override public void close() { - // no-op + nativeOps.ctxPurge(context); + context.deallocate(); } @Override diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java index 39dad7bd4..94c2ca71b 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspace.java @@ -20,6 +20,8 @@ package org.nd4j.jita.workspace; +import java.util.List; +import java.util.Queue; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; @@ -39,10 +41,6 @@ import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.nativeblas.NativeOpsHolder; -import java.util.List; -import java.util.Queue; - - /** * CUDA-aware MemoryWorkspace implementation * @@ -51,395 +49,489 @@ import java.util.Queue; @Slf4j public class CudaWorkspace extends Nd4jWorkspace { + public CudaWorkspace(@NonNull WorkspaceConfiguration configuration) { + super(configuration); + } - public CudaWorkspace(@NonNull WorkspaceConfiguration configuration) { - super(configuration); + public CudaWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String workspaceId) { + super(configuration, workspaceId); + } + + public CudaWorkspace( + @NonNull WorkspaceConfiguration configuration, + @NonNull String workspaceId, + Integer deviceId) { + super(configuration, workspaceId); + this.deviceId = deviceId; + } + + @Override + protected void init() { + if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) { + throw new ND4JIllegalStateException("CUDA do not support MMAP workspaces yet"); } - public CudaWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String workspaceId) { - super(configuration, workspaceId); + super.init(); + + if (currentSize.get() > 0) { + log.debug("Allocating {} bytes at DEVICE & HOST space...", currentSize.get()); + isInit.set(true); + + long bytes = currentSize.get(); + + log.debug( + "Allocating [{}] workspace on device_{}, {} bytes...", + id, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + bytes); + + if (isDebug.get()) { + Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread(); + } + + Pointer ptr = memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.HOST, false); + if (ptr == null) throw new ND4JIllegalStateException("Can't allocate memory for workspace"); + + workspace.setHostPointer(new PagedPointer(ptr)); + + if (workspaceConfiguration.getPolicyMirroring() != MirroringPolicy.HOST_ONLY) { + workspace.setDevicePointer( + new PagedPointer( + memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.DEVICE, false))); + AllocationsTracker.getInstance() + .markAllocated( + AllocationKind.GENERAL, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + bytes + SAFETY_OFFSET); + + MemoryTracker.getInstance() + .incrementWorkspaceAllocatedAmount( + Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes + SAFETY_OFFSET); + + // if base pointer isn't aligned to 16 bytes (128 bits) - adjust the offfset then + val addr = workspace.getDevicePointer().address(); + val div = addr % alignmentBase; + if (div != 0) { + deviceOffset.set(alignmentBase - div); + hostOffset.set(alignmentBase - div); + } + } + } + } + + @Override + public PagedPointer alloc(long requiredMemory, DataType type, boolean initialize) { + return this.alloc(requiredMemory, MemoryKind.DEVICE, type, initialize); + } + + @Override + public synchronized void destroyWorkspace(boolean extended) { + val size = currentSize.getAndSet(0); + reset(); + + if (extended) clearExternalAllocations(); + + clearPinnedAllocations(extended); + + if (workspace.getHostPointer() != null) + NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(workspace.getHostPointer()); + + if (workspace.getDevicePointer() != null) { + NativeOpsHolder.getInstance() + .getDeviceNativeOps() + .freeDevice(workspace.getDevicePointer(), 0); + AllocationsTracker.getInstance() + .markReleased( + AllocationKind.GENERAL, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + size + SAFETY_OFFSET); + + MemoryTracker.getInstance() + .decrementWorkspaceAmount( + Nd4j.getAffinityManager().getDeviceForCurrentThread(), size + SAFETY_OFFSET); } - public CudaWorkspace(@NonNull WorkspaceConfiguration configuration, @NonNull String workspaceId, Integer deviceId) { - super(configuration, workspaceId); - this.deviceId = deviceId; + workspace.setDevicePointer(null); + workspace.setHostPointer(null); + } + + @Override + public PagedPointer alloc( + long requiredMemory, MemoryKind kind, DataType type, boolean initialize) { + long numElements = requiredMemory / Nd4j.sizeOfDataType(type); + + // alignment + if (requiredMemory % alignmentBase != 0) + requiredMemory += alignmentBase - (requiredMemory % alignmentBase); + + if (!isUsed.get()) { + if (disabledCounter.incrementAndGet() % 10 == 0) + log.warn( + "Workspace was turned off, and wasn't enabled after {} allocations", + disabledCounter.get()); + + if (kind == MemoryKind.DEVICE) { + val pointer = + new PagedPointer( + memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements); + externalAllocations.add(new PointersPair(null, pointer)); + MemoryTracker.getInstance() + .incrementWorkspaceAllocatedAmount( + Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); + return pointer; + } else { + val pointer = + new PagedPointer( + memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements); + externalAllocations.add(new PointersPair(pointer, null)); + return pointer; + } } - @Override - protected void init() { - if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) { - throw new ND4JIllegalStateException("CUDA do not support MMAP workspaces yet"); + boolean trimmer = + (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED + && requiredMemory + cycleAllocations.get() > initialBlockSize.get() + && initialBlockSize.get() > 0 + && kind == MemoryKind.DEVICE) + || trimmedMode.get(); + + if (trimmer + && workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE + && !trimmedMode.get()) { + trimmedMode.set(true); + trimmedStep.set(stepsCount.get()); + } + + if (kind == MemoryKind.DEVICE) { + if (deviceOffset.get() + requiredMemory <= currentSize.get() + && !trimmer + && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { + cycleAllocations.addAndGet(requiredMemory); + long prevOffset = deviceOffset.getAndAdd(requiredMemory); + + if (workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) return null; + + val ptr = workspace.getDevicePointer().withOffset(prevOffset, numElements); + + log.debug( + "Workspace [{}] device_{}: alloc array of {} bytes, capacity of {} elements; prevOffset: {}; newOffset: {}; size: {}; address: {}", + id, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + requiredMemory, + numElements, + prevOffset, + deviceOffset.get(), + currentSize.get(), + ptr.address()); + + if (initialize) { + val context = AtomicAllocator.getInstance().getDeviceContext(); + + int ret = + NativeOpsHolder.getInstance() + .getDeviceNativeOps() + .memsetAsync(ptr, 0, requiredMemory, 0, context.getSpecialStream()); + if (ret == 0) + throw new ND4JIllegalStateException( + "memset failed device_" + Nd4j.getAffinityManager().getDeviceForCurrentThread()); + + context.syncSpecialStream(); } - super.init(); + return ptr; + } else { - if (currentSize.get() > 0) { - //log.info("Allocating {} bytes at DEVICE & HOST space...", currentSize.get()); - isInit.set(true); - - long bytes = currentSize.get(); - - if (isDebug.get()) - log.info("Allocating [{}] workspace on device_{}, {} bytes...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes); - - if (isDebug.get()) { - Nd4j.getWorkspaceManager().printAllocationStatisticsForCurrentThread(); - } - - Pointer ptr = memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.HOST, false); - if (ptr == null) - throw new ND4JIllegalStateException("Can't allocate memory for workspace"); - - workspace.setHostPointer(new PagedPointer(ptr)); - - if (workspaceConfiguration.getPolicyMirroring() != MirroringPolicy.HOST_ONLY) { - workspace.setDevicePointer(new PagedPointer(memoryManager.allocate((bytes + SAFETY_OFFSET), MemoryKind.DEVICE, false))); - AllocationsTracker.getInstance().markAllocated(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes + SAFETY_OFFSET); - - MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), bytes + SAFETY_OFFSET); - - // if base pointer isn't aligned to 16 bytes (128 bits) - adjust the offfset then - val addr = workspace.getDevicePointer().address(); - val div = addr % alignmentBase; - if (div != 0) { - deviceOffset.set(alignmentBase - div); - hostOffset.set(alignmentBase - div); - } - } - } - } - - @Override - public PagedPointer alloc(long requiredMemory, DataType type, boolean initialize) { - return this.alloc(requiredMemory, MemoryKind.DEVICE, type, initialize); - } - - - @Override - public synchronized void destroyWorkspace(boolean extended) { - val size = currentSize.getAndSet(0); - reset(); - - if (extended) - clearExternalAllocations(); - - clearPinnedAllocations(extended); - - if (workspace.getHostPointer() != null) - NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(workspace.getHostPointer()); - - if (workspace.getDevicePointer() != null) { - NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(workspace.getDevicePointer(), 0); - AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), size + SAFETY_OFFSET); - - MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), size + SAFETY_OFFSET); + // spill + if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED + && currentSize.get() > 0 + && !trimmer + && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { + // log.info("End of space reached. Current offset: {}; requiredMemory: {}", + // deviceOffset.get(), requiredMemory); + deviceOffset.set(0); + resetPlanned.set(true); + return alloc(requiredMemory, kind, type, initialize); } - workspace.setDevicePointer(null); - workspace.setHostPointer(null); + if (!trimmer) spilledAllocationsSize.addAndGet(requiredMemory); + else pinnedAllocationsSize.addAndGet(requiredMemory); - } + log.debug( + "Workspace [{}] device_{}: spilled DEVICE array of {} bytes, capacity of {} elements", + id, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + requiredMemory, + numElements); + val shape = + new AllocationShape( + requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type); - @Override - public PagedPointer alloc(long requiredMemory, MemoryKind kind, DataType type, boolean initialize) { - long numElements = requiredMemory / Nd4j.sizeOfDataType(type); + cycleAllocations.addAndGet(requiredMemory); - // alignment - if (requiredMemory % alignmentBase != 0) - requiredMemory += alignmentBase - (requiredMemory % alignmentBase); + if (workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) return null; - if (!isUsed.get()) { - if (disabledCounter.incrementAndGet() % 10 == 0) - log.warn("Worskpace was turned off, and wasn't enabled after {} allocations", disabledCounter.get()); + switch (workspaceConfiguration.getPolicySpill()) { + case REALLOCATE: + case EXTERNAL: + if (!trimmer) { + externalCount.incrementAndGet(); + // + // AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, + // null, AllocationStatus.DEVICE).getDevicePointer() + val pointer = + new PagedPointer( + memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), + numElements); + pointer.isLeaked(); - if (kind == MemoryKind.DEVICE) { - val pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements); - externalAllocations.add(new PointersPair(null, pointer)); - MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); - return pointer; + val pp = new PointersPair(null, pointer); + pp.setRequiredMemory(requiredMemory); + externalAllocations.add(pp); + + MemoryTracker.getInstance() + .incrementWorkspaceAllocatedAmount( + Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); + return pointer; } else { - val pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements); - externalAllocations.add(new PointersPair(pointer, null)); - return pointer; + pinnedCount.incrementAndGet(); + + val pointer = + new PagedPointer( + memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), + numElements); + pointer.isLeaked(); + + pinnedAllocations.add( + new PointersPair(stepsCount.get(), requiredMemory, null, pointer)); + MemoryTracker.getInstance() + .incrementWorkspaceAllocatedAmount( + Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); + return pointer; } + case FAIL: + default: + { + throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full"); + } + } + } + } else if (kind == MemoryKind.HOST) { + if (hostOffset.get() + requiredMemory <= currentSize.get() + && !trimmer + && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { + long prevOffset = hostOffset.getAndAdd(requiredMemory); + val ptr = workspace.getHostPointer().withOffset(prevOffset, numElements); + + // && workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY + if (initialize) Pointer.memset(ptr, 0, requiredMemory); + return ptr; + } else { + // log.info("Spilled HOST array of {} bytes, capacity of {} elements", requiredMemory, + // numElements); + if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED + && currentSize.get() > 0 + && !trimmer + && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { + // log.info("End of space reached. Current offset: {}; requiredMemory: {}", + // deviceOffset.get(), requiredMemory); + hostOffset.set(0); + // resetPlanned.set(true); + return alloc(requiredMemory, kind, type, initialize); } - boolean trimmer = (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && requiredMemory + cycleAllocations.get() > initialBlockSize.get() && initialBlockSize.get() > 0 && kind == MemoryKind.DEVICE) || trimmedMode.get(); + val shape = + new AllocationShape( + requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type); - if (trimmer && workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && !trimmedMode.get()) { - trimmedMode.set(true); - trimmedStep.set(stepsCount.get()); - } + switch (workspaceConfiguration.getPolicySpill()) { + case REALLOCATE: + case EXTERNAL: + if (!trimmer) { + // memoryManager.allocate(requiredMemory, MemoryKind.HOST, true) + // AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, + // null, AllocationStatus.DEVICE).getDevicePointer() + PagedPointer pointer = + new PagedPointer( + memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), + numElements); - if (kind == MemoryKind.DEVICE) { - if (deviceOffset.get() + requiredMemory <= currentSize.get() && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { - cycleAllocations.addAndGet(requiredMemory); - long prevOffset = deviceOffset.getAndAdd(requiredMemory); - - if (workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) - return null; - - val ptr = workspace.getDevicePointer().withOffset(prevOffset, numElements); - - if (isDebug.get()) - log.info("Workspace [{}] device_{}: alloc array of {} bytes, capacity of {} elements; prevOffset: {}; newOffset: {}; size: {}; address: {}", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements, prevOffset, deviceOffset.get(), currentSize.get(), ptr.address()); - - if (initialize) { - val context = AtomicAllocator.getInstance().getDeviceContext(); - - int ret = NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(ptr, 0, requiredMemory, 0, context.getSpecialStream()); - if (ret == 0) - throw new ND4JIllegalStateException("memset failed device_" + Nd4j.getAffinityManager().getDeviceForCurrentThread()); - - context.syncSpecialStream(); - } - - return ptr; + externalAllocations.add(new PointersPair(pointer, null)); + return pointer; } else { + // AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, + // null, AllocationStatus.DEVICE).getDevicePointer() + PagedPointer pointer = + new PagedPointer( + memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), + numElements); + pointer.isLeaked(); - // spill - if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0 && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { - //log.info("End of space reached. Current offset: {}; requiredMemory: {}", deviceOffset.get(), requiredMemory); - deviceOffset.set(0); - resetPlanned.set(true); - return alloc(requiredMemory, kind, type, initialize); - } - - if (!trimmer) - spilledAllocationsSize.addAndGet(requiredMemory); - else - pinnedAllocationsSize.addAndGet(requiredMemory); - - if (isDebug.get()) { - log.info("Workspace [{}] device_{}: spilled DEVICE array of {} bytes, capacity of {} elements", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory, numElements); - } - - val shape = new AllocationShape(requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type); - - cycleAllocations.addAndGet(requiredMemory); - - if (workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY) - return null; - - switch (workspaceConfiguration.getPolicySpill()) { - case REALLOCATE: - case EXTERNAL: - if (!trimmer) { - externalCount.incrementAndGet(); - // - //AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, null, AllocationStatus.DEVICE).getDevicePointer() - val pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements); - pointer.isLeaked(); - - val pp = new PointersPair(null, pointer); - pp.setRequiredMemory(requiredMemory); - externalAllocations.add(pp); - - MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); - return pointer; - } else { - pinnedCount.incrementAndGet(); - - val pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.DEVICE, initialize), numElements); - pointer.isLeaked(); - - pinnedAllocations.add(new PointersPair(stepsCount.get(), requiredMemory, null, pointer)); - MemoryTracker.getInstance().incrementWorkspaceAllocatedAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), requiredMemory); - return pointer; - } - case FAIL: - default: { - throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full"); - } - } + pinnedAllocations.add(new PointersPair(stepsCount.get(), 0L, pointer, null)); + return pointer; } - } else if (kind == MemoryKind.HOST) { - if (hostOffset.get() + requiredMemory <= currentSize.get() && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { - - long prevOffset = hostOffset.getAndAdd(requiredMemory); - - val ptr = workspace.getHostPointer().withOffset(prevOffset, numElements); - - // && workspaceConfiguration.getPolicyMirroring() == MirroringPolicy.HOST_ONLY - if (initialize) - Pointer.memset(ptr, 0, requiredMemory); - return ptr; - } else { - // log.info("Spilled HOST array of {} bytes, capacity of {} elements", requiredMemory, numElements); - if (workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && currentSize.get() > 0 && !trimmer && Nd4j.getWorkspaceManager().getDebugMode() != DebugMode.SPILL_EVERYTHING) { - //log.info("End of space reached. Current offset: {}; requiredMemory: {}", deviceOffset.get(), requiredMemory); - hostOffset.set(0); - //resetPlanned.set(true); - return alloc(requiredMemory, kind, type, initialize); - } - - val shape = new AllocationShape(requiredMemory / Nd4j.sizeOfDataType(type), Nd4j.sizeOfDataType(type), type); - - switch (workspaceConfiguration.getPolicySpill()) { - case REALLOCATE: - case EXTERNAL: - if (!trimmer) { - //memoryManager.allocate(requiredMemory, MemoryKind.HOST, true) - //AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, null, AllocationStatus.DEVICE).getDevicePointer() - PagedPointer pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements); - - externalAllocations.add(new PointersPair(pointer, null)); - return pointer; - } else { - //AtomicAllocator.getInstance().getMemoryHandler().getMemoryProvider().malloc(shape, null, AllocationStatus.DEVICE).getDevicePointer() - PagedPointer pointer = new PagedPointer(memoryManager.allocate(requiredMemory, MemoryKind.HOST, initialize), numElements); - pointer.isLeaked(); - - pinnedAllocations.add(new PointersPair(stepsCount.get(), 0L, pointer, null)); - return pointer; - } - case FAIL: - default: { - throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full"); - } - } - } - } else throw new ND4JIllegalStateException("Unknown MemoryKind was passed in: " + kind); - - //throw new ND4JIllegalStateException("Shouldn't ever reach this line"); - } - - @Override - protected void clearPinnedAllocations(boolean extended) { - if (isDebug.get()) - log.info("Workspace [{}] device_{} threadId {} cycle {}: clearing pinned allocations...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), cyclesCount.get()); - - while (!pinnedAllocations.isEmpty()) { - val pair = pinnedAllocations.peek(); - if (pair == null) - throw new RuntimeException(); - - long stepNumber = pair.getAllocationCycle(); - long stepCurrent = stepsCount.get(); - - if (isDebug.get()) - log.info("Allocation step: {}; Current step: {}", stepNumber, stepCurrent); - - if (stepNumber + 2 < stepCurrent || extended) { - pinnedAllocations.remove(); - - if (pair.getDevicePointer() != null) { - NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pair.getDevicePointer(), 0); - MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), pair.getRequiredMemory()); - pinnedCount.decrementAndGet(); - - if (isDebug.get()) - log.info("deleting external device allocation "); - } - - if (pair.getHostPointer() != null) { - NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pair.getHostPointer()); - - if (isDebug.get()) - log.info("deleting external host allocation "); - } - - val sizez = pair.getRequiredMemory() * -1; - pinnedAllocationsSize.addAndGet(sizez); - } else { - break; + case FAIL: + default: + { + throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full"); } } - } + } + } else throw new ND4JIllegalStateException("Unknown MemoryKind was passed in: " + kind); - @Override - protected void clearExternalAllocations() { - if (isDebug.get()) - log.info("Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), guid); + // throw new ND4JIllegalStateException("Shouldn't ever reach this line"); + } - Nd4j.getExecutioner().commit(); + @Override + protected void clearPinnedAllocations(boolean extended) { - try { - for (PointersPair pair : externalAllocations) { - if (pair.getHostPointer() != null) { - NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pair.getHostPointer()); + log.debug( + "Workspace [{}] device_{} threadId {} cycle {}: clearing pinned allocations...", + id, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + Thread.currentThread().getId(), + cyclesCount.get()); - if (isDebug.get()) - log.info("deleting external host allocation... "); - } + while (!pinnedAllocations.isEmpty()) { + val pair = pinnedAllocations.peek(); + if (pair == null) throw new RuntimeException(); - if (pair.getDevicePointer() != null) { - NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pair.getDevicePointer(), 0); + long stepNumber = pair.getAllocationCycle(); + long stepCurrent = stepsCount.get(); - if (isDebug.get()) - log.info("deleting external device allocation... "); + log.debug("Allocation step: {}; Current step: {}", stepNumber, stepCurrent); - val sizez = pair.getRequiredMemory(); - if (sizez != null) { - AllocationsTracker.getInstance().markReleased(AllocationKind.GENERAL, Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez); - MemoryTracker.getInstance().decrementWorkspaceAmount(Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez); - } - } - } - } catch (Exception e) { - log.error("RC: Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Thread.currentThread().getId(), guid); - throw new RuntimeException(e); + if (stepNumber + 2 < stepCurrent || extended) { + pinnedAllocations.remove(); + + if (pair.getDevicePointer() != null) { + NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pair.getDevicePointer(), 0); + MemoryTracker.getInstance() + .decrementWorkspaceAmount( + Nd4j.getAffinityManager().getDeviceForCurrentThread(), pair.getRequiredMemory()); + pinnedCount.decrementAndGet(); + + log.debug("deleting external device allocation "); } - spilledAllocationsSize.set(0); - externalCount.set(0); - externalAllocations.clear(); - } + if (pair.getHostPointer() != null) { + NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pair.getHostPointer()); - @Override - protected void resetWorkspace() { - if (currentSize.get() < 1) { + log.debug("deleting external host allocation "); } + val sizez = pair.getRequiredMemory() * -1; + pinnedAllocationsSize.addAndGet(sizez); + } else { + break; + } + } + } -/* - if (Nd4j.getExecutioner() instanceof GridExecutioner) - ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking(); + @Override + protected void clearExternalAllocations() { - CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); + log.debug( + "Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", + id, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + Thread.currentThread().getId(), + guid); - //log.info("workspace: {}, size: {}", workspace.getDevicePointer().address(), currentSize.get()); + Nd4j.getExecutioner().commit(); - NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(workspace.getDevicePointer(), 0, currentSize.get() + SAFETY_OFFSET, 0, context.getSpecialStream()); + try { + for (PointersPair pair : externalAllocations) { + if (pair.getHostPointer() != null) { + NativeOpsHolder.getInstance().getDeviceNativeOps().freeHost(pair.getHostPointer()); - Pointer.memset(workspace.getHostPointer(), 0, currentSize.get() + SAFETY_OFFSET); + log.debug("deleting external host allocation... "); + } - context.getSpecialStream().synchronize(); - */ + if (pair.getDevicePointer() != null) { + NativeOpsHolder.getInstance().getDeviceNativeOps().freeDevice(pair.getDevicePointer(), 0); + + log.debug("deleting external device allocation... "); + + val sizez = pair.getRequiredMemory(); + if (sizez != null) { + AllocationsTracker.getInstance() + .markReleased( + AllocationKind.GENERAL, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + sizez); + MemoryTracker.getInstance() + .decrementWorkspaceAmount( + Nd4j.getAffinityManager().getDeviceForCurrentThread(), sizez); + } + } + } + } catch (Exception e) { + log.error( + "RC: Workspace [{}] device_{} threadId {} guid [{}]: clearing external allocations...", + id, + Nd4j.getAffinityManager().getDeviceForCurrentThread(), + Thread.currentThread().getId(), + guid); + throw new RuntimeException(e); } - protected PointersPair workspace() { - return workspace; - } + spilledAllocationsSize.set(0); + externalCount.set(0); + externalAllocations.clear(); + } - protected Queue pinnedPointers() { - return pinnedAllocations; - } + @Override + protected void resetWorkspace() { + if (currentSize.get() < 1) {} - protected List externalPointers() { - return externalAllocations; - } + /* + if (Nd4j.getExecutioner() instanceof GridExecutioner) + ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking(); - @Override - public Deallocator deallocator() { - return new CudaWorkspaceDeallocator(this); - } + CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); - @Override - public String getUniqueId() { - return "Workspace_" + getId() + "_" + Nd4j.getDeallocatorService().nextValue(); - } + //log.info("workspace: {}, size: {}", workspace.getDevicePointer().address(), currentSize.get()); - @Override - public int targetDevice() { - return deviceId; - } + NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(workspace.getDevicePointer(), 0, currentSize.get() + SAFETY_OFFSET, 0, context.getSpecialStream()); - @Override - public long getPrimaryOffset() { - return getDeviceOffset(); - } + Pointer.memset(workspace.getHostPointer(), 0, currentSize.get() + SAFETY_OFFSET); + + context.getSpecialStream().synchronize(); + */ + } + + protected PointersPair workspace() { + return workspace; + } + + protected Queue pinnedPointers() { + return pinnedAllocations; + } + + protected List externalPointers() { + return externalAllocations; + } + + @Override + public Deallocator deallocator() { + return new CudaWorkspaceDeallocator(this); + } + + @Override + public String getUniqueId() { + return "Workspace_" + getId() + "_" + Nd4j.getDeallocatorService().nextValue(); + } + + @Override + public int targetDevice() { + return deviceId; + } + + @Override + public long getPrimaryOffset() { + return getDeviceOffset(); + } } diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspaceDeallocator.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspaceDeallocator.java index 806986fc7..41b936bf7 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspaceDeallocator.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/jita/workspace/CudaWorkspaceDeallocator.java @@ -48,7 +48,7 @@ public class CudaWorkspaceDeallocator implements Deallocator { @Override public void deallocate() { - log.trace("Deallocating CUDA workspace"); + log.debug("Deallocating CUDA workspace"); // purging workspace planes if (pointersPair != null) { diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 612dfdda8..5524bddbc 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1582,7 +1582,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { } if (nativeOps.lastErrorCode() != 0) - throw new RuntimeException(nativeOps.lastErrorMessage()); + throw new RuntimeException(nativeOps.lastErrorMessage() + " error code: " + nativeOps.lastErrorCode()); profilingConfigurableHookOut(op, oc, st); diff --git a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 3ba143e36..52af2eeb5 100644 --- a/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/cavis-native/cavis-native-jcublas/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -56,7 +56,8 @@ public class CudaOpContext extends BaseOpContext implements OpContext, Deallocat @Override public void close() { - // no-op + nativeOps.ctxPurge(context); + context.deallocate(); } @Override